// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tar

import (
	"bytes"
	"fmt"
	"io"
	"path"
	"sort"
	"strconv"
	"strings"
	"time"
)

// A Writer provides sequential writing of a tar archive in POSIX.1 format.
// A tar archive consists of a sequence of files.
// Call WriteHeader to begin a new file, and then call Write to supply that file's data,
// writing at most hdr.Size bytes in total.
type Writer struct {
	w    io.Writer
	pad  int64      // Amount of padding to write after current file entry
	curr fileWriter // Writer for current file entry
	hdr  Header     // Shallow copy of Header that is safe for mutations
	blk  block      // Buffer to use as temporary local storage

	// err is a persistent error.
	// It is only the responsibility of every exported method of Writer to
	// ensure that this error is sticky.
	err error
}

// NewWriter creates a new Writer writing to w.
func NewWriter(w io.Writer) *Writer {
	return &Writer{w: w, curr: &regFileWriter{w, 0}}
}

type fileWriter interface {
	io.Writer
	fileState

	FillZeros(n int64) (int64, error)
}

// Flush finishes writing the current file's block padding.
// The current file must be fully written before Flush can be called.
//
// Deprecated: This is unnecessary as the next call to WriteHeader or Close
// will implicitly flush out the file's padding.
func (tw *Writer) Flush() error {
	if tw.err != nil {
		return tw.err
	}
	if nb := tw.curr.Remaining(); nb > 0 {
		return fmt.Errorf("archive/tar: missed writing %d bytes", nb)
	}
	if _, tw.err = tw.w.Write(zeroBlock[:tw.pad]); tw.err != nil {
		return tw.err
	}
	tw.pad = 0
	return nil
}

// WriteHeader writes hdr and prepares to accept the file's contents.
// WriteHeader calls Flush if it is not the first header.
// Calling after a Close will return ErrWriteAfterClose.
func (tw *Writer) WriteHeader(hdr *Header) error {
	if err := tw.Flush(); err != nil {
		return err
	}

	tw.hdr = *hdr // Shallow copy of Header
	allowedFormats, paxHdrs, err := tw.hdr.allowedFormats()
	switch {
	case allowedFormats.has(FormatUSTAR):
		tw.err = tw.writeUSTARHeader(&tw.hdr)
		return tw.err
	case allowedFormats.has(FormatPAX):
		tw.err = tw.writePAXHeader(&tw.hdr, paxHdrs)
		return tw.err
	case allowedFormats.has(FormatGNU):
		tw.err = tw.writeGNUHeader(&tw.hdr)
		return tw.err
	default:
		return err // Non-fatal error
	}
}

func (tw *Writer) writeUSTARHeader(hdr *Header) error {
	// Check if we can use USTAR prefix/suffix splitting.
	var namePrefix string
	if prefix, suffix, ok := splitUSTARPath(hdr.Name); ok {
		namePrefix, hdr.Name = prefix, suffix
	}

	// Pack the main header.
	var f formatter
	blk := tw.templateV7Plus(hdr, f.formatString, f.formatOctal)
	f.formatString(blk.USTAR().Prefix(), namePrefix)
	blk.SetFormat(FormatUSTAR)
	if f.err != nil {
		return f.err // Should never happen since header is validated
	}
	return tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag)
}

func (tw *Writer) writePAXHeader(hdr *Header, paxHdrs map[string]string) error {
	realName, realSize := hdr.Name, hdr.Size

	// Handle sparse files.
	var spd sparseDatas
	var spb []byte
	if len(hdr.SparseHoles) > 0 {
		sph := append([]SparseEntry{}, hdr.SparseHoles...) // Copy sparse map
		sph = alignSparseEntries(sph, hdr.Size)
		spd = invertSparseEntries(sph, hdr.Size)

		// Format the sparse map.
		hdr.Size = 0 // Replace with encoded size
		spb = append(strconv.AppendInt(spb, int64(len(spd)), 10), '\n')
		for _, s := range spd {
			hdr.Size += s.Length
			spb = append(strconv.AppendInt(spb, s.Offset, 10), '\n')
			spb = append(strconv.AppendInt(spb, s.Length, 10), '\n')
		}
		pad := blockPadding(int64(len(spb)))
		spb = append(spb, zeroBlock[:pad]...)
		hdr.Size += int64(len(spb)) // Accounts for encoded sparse map

		// Add and modify appropriate PAX records.
		dir, file := path.Split(realName)
		hdr.Name = path.Join(dir, "GNUSparseFile.0", file)
		paxHdrs[paxGNUSparseMajor] = "1"
		paxHdrs[paxGNUSparseMinor] = "0"
		paxHdrs[paxGNUSparseName] = realName
		paxHdrs[paxGNUSparseRealSize] = strconv.FormatInt(realSize, 10)
		paxHdrs[paxSize] = strconv.FormatInt(hdr.Size, 10)
		delete(paxHdrs, paxPath) // Recorded by paxGNUSparseName
	}

	// Write PAX records to the output.
	if len(paxHdrs) > 0 {
		// Sort keys for deterministic ordering.
		var keys []string
		for k := range paxHdrs {
			keys = append(keys, k)
		}
		sort.Strings(keys)

		// Write each record to a buffer.
		var buf bytes.Buffer
		for _, k := range keys {
			rec, err := formatPAXRecord(k, paxHdrs[k])
			if err != nil {
				return err
			}
			buf.WriteString(rec)
		}

		// Write the extended header file.
		dir, file := path.Split(realName)
		name := path.Join(dir, "PaxHeaders.0", file)
		data := buf.String()
		if err := tw.writeRawFile(name, data, TypeXHeader, FormatPAX); err != nil {
			return err
		}
	}

	// Pack the main header.
	var f formatter // Ignore errors since they are expected
	fmtStr := func(b []byte, s string) { f.formatString(b, toASCII(s)) }
	blk := tw.templateV7Plus(hdr, fmtStr, f.formatOctal)
	blk.SetFormat(FormatPAX)
	if err := tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag); err != nil {
		return err
	}

	// Write the sparse map and setup the sparse writer if necessary.
	if len(spd) > 0 {
		// Use tw.curr since the sparse map is accounted for in hdr.Size.
		if _, err := tw.curr.Write(spb); err != nil {
			return err
		}
		tw.curr = &sparseFileWriter{tw.curr, spd, 0}
	}
	return nil
}

func (tw *Writer) writeGNUHeader(hdr *Header) error {
	// Use long-link files if Name or Linkname exceeds the field size.
	const longName = "././@LongLink"
	if len(hdr.Name) > nameSize {
		data := hdr.Name + "\x00"
		if err := tw.writeRawFile(longName, data, TypeGNULongName, FormatGNU); err != nil {
			return err
		}
	}
	if len(hdr.Linkname) > nameSize {
		data := hdr.Linkname + "\x00"
		if err := tw.writeRawFile(longName, data, TypeGNULongLink, FormatGNU); err != nil {
			return err
		}
	}

	// Pack the main header.
	var f formatter // Ignore errors since they are expected
	var spd sparseDatas
	var spb []byte
	blk := tw.templateV7Plus(hdr, f.formatString, f.formatNumeric)
	if !hdr.AccessTime.IsZero() {
		f.formatNumeric(blk.GNU().AccessTime(), hdr.AccessTime.Unix())
	}
	if !hdr.ChangeTime.IsZero() {
		f.formatNumeric(blk.GNU().ChangeTime(), hdr.ChangeTime.Unix())
	}
	if hdr.Typeflag == TypeGNUSparse {
		sph := append([]SparseEntry{}, hdr.SparseHoles...) // Copy sparse map
		sph = alignSparseEntries(sph, hdr.Size)
		spd = invertSparseEntries(sph, hdr.Size)

		// Format the sparse map.
		formatSPD := func(sp sparseDatas, sa sparseArray) sparseDatas {
			for i := 0; len(sp) > 0 && i < sa.MaxEntries(); i++ {
				f.formatNumeric(sa.Entry(i).Offset(), sp[0].Offset)
				f.formatNumeric(sa.Entry(i).Length(), sp[0].Length)
				sp = sp[1:]
			}
			if len(sp) > 0 {
				sa.IsExtended()[0] = 1
			}
			return sp
		}
		sp2 := formatSPD(spd, blk.GNU().Sparse())
		for len(sp2) > 0 {
			var spHdr block
			sp2 = formatSPD(sp2, spHdr.Sparse())
			spb = append(spb, spHdr[:]...)
		}

		// Update size fields in the header block.
		realSize := hdr.Size
		hdr.Size = 0 // Encoded size; does not account for encoded sparse map
		for _, s := range spd {
			hdr.Size += s.Length
		}
		copy(blk.V7().Size(), zeroBlock[:]) // Reset field
		f.formatNumeric(blk.V7().Size(), hdr.Size)
		f.formatNumeric(blk.GNU().RealSize(), realSize)
	}
	blk.SetFormat(FormatGNU)
	if err := tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag); err != nil {
		return err
	}

	// Write the extended sparse map and setup the sparse writer if necessary.
	if len(spd) > 0 {
		// Use tw.w since the sparse map is not accounted for in hdr.Size.
		if _, err := tw.w.Write(spb); err != nil {
			return err
		}
		tw.curr = &sparseFileWriter{tw.curr, spd, 0}
	}
	return nil
}

type (
	stringFormatter func([]byte, string)
	numberFormatter func([]byte, int64)
)

// templateV7Plus fills out the V7 fields of a block using values from hdr.
// It also fills out fields (uname, gname, devmajor, devminor) that are
// shared in the USTAR, PAX, and GNU formats using the provided formatters.
//
// The block returned is only valid until the next call to
// templateV7Plus or writeRawFile.
func (tw *Writer) templateV7Plus(hdr *Header, fmtStr stringFormatter, fmtNum numberFormatter) *block {
	tw.blk.Reset()

	modTime := hdr.ModTime
	if modTime.IsZero() {
		modTime = time.Unix(0, 0)
	}

	v7 := tw.blk.V7()
	v7.TypeFlag()[0] = hdr.Typeflag
	fmtStr(v7.Name(), hdr.Name)
	fmtStr(v7.LinkName(), hdr.Linkname)
	fmtNum(v7.Mode(), hdr.Mode)
	fmtNum(v7.UID(), int64(hdr.Uid))
	fmtNum(v7.GID(), int64(hdr.Gid))
	fmtNum(v7.Size(), hdr.Size)
	fmtNum(v7.ModTime(), modTime.Unix())

	ustar := tw.blk.USTAR()
	fmtStr(ustar.UserName(), hdr.Uname)
	fmtStr(ustar.GroupName(), hdr.Gname)
	fmtNum(ustar.DevMajor(), hdr.Devmajor)
	fmtNum(ustar.DevMinor(), hdr.Devminor)

	return &tw.blk
}

// writeRawFile writes a minimal file with the given name and flag type.
// It uses format to encode the header format and will write data as the body.
// It uses default values for all of the other fields (as BSD and GNU tar does).
func (tw *Writer) writeRawFile(name, data string, flag byte, format Format) error {
	tw.blk.Reset()

	// Best effort for the filename.
	name = toASCII(name)
	if len(name) > nameSize {
		name = name[:nameSize]
	}

	var f formatter
	v7 := tw.blk.V7()
	v7.TypeFlag()[0] = flag
	f.formatString(v7.Name(), name)
	f.formatOctal(v7.Mode(), 0)
	f.formatOctal(v7.UID(), 0)
	f.formatOctal(v7.GID(), 0)
	f.formatOctal(v7.Size(), int64(len(data))) // Must be < 8GiB
	f.formatOctal(v7.ModTime(), 0)
	tw.blk.SetFormat(format)
	if f.err != nil {
		return f.err // Only occurs if size condition is violated
	}

	// Write the header and data.
	if err := tw.writeRawHeader(&tw.blk, int64(len(data)), flag); err != nil {
		return err
	}
	_, err := io.WriteString(tw, data)
	return err
}

// writeRawHeader writes the value of blk, regardless of its value.
// It sets up the Writer such that it can accept a file of the given size.
// If the flag is a special header-only flag, then the size is treated as zero.
func (tw *Writer) writeRawHeader(blk *block, size int64, flag byte) error {
	if err := tw.Flush(); err != nil {
		return err
	}
	if _, err := tw.w.Write(blk[:]); err != nil {
		return err
	}
	if isHeaderOnlyType(flag) {
		size = 0
	}
	tw.curr = &regFileWriter{tw.w, size}
	tw.pad = blockPadding(size)
	return nil
}

// splitUSTARPath splits a path according to USTAR prefix and suffix rules.
// If the path is not splittable, then it will return ("", "", false).
func splitUSTARPath(name string) (prefix, suffix string, ok bool) {
	length := len(name)
	if length <= nameSize || !isASCII(name) {
		return "", "", false
	} else if length > prefixSize+1 {
		length = prefixSize + 1
	} else if name[length-1] == '/' {
		length--
	}

	i := strings.LastIndex(name[:length], "/")
	nlen := len(name) - i - 1 // nlen is length of suffix
	plen := i                 // plen is length of prefix
	if i <= 0 || nlen > nameSize || nlen == 0 || plen > prefixSize {
		return "", "", false
	}
	return name[:i], name[i+1:], true
}

// Write writes to the current entry in the tar archive.
// Write returns the error ErrWriteTooLong if more than
// Header.Size bytes are written after WriteHeader.
//
// If the current file is sparse, then the regions marked as a sparse hole
// must be written as NUL-bytes.
//
// Calling Write on special types like TypeLink, TypeSymLink, TypeChar,
// TypeBlock, TypeDir, and TypeFifo returns (0, ErrWriteTooLong) regardless
// of what the Header.Size claims.
func (tw *Writer) Write(b []byte) (int, error) {
	if tw.err != nil {
		return 0, tw.err
	}
	n, err := tw.curr.Write(b)
	if err != nil && err != ErrWriteTooLong {
		tw.err = err
	}
	return n, err
}

// TODO(dsnet): Export the Writer.FillZeros method to assist in quickly zeroing
// out sections of a file. This is especially useful for efficiently
// skipping over large holes in a sparse file.

// fillZeros writes n bytes of zeros to the current file,
// returning the number of bytes written.
// If fewer than n bytes are discarded, it returns an non-nil error,
// which may be ErrWriteTooLong if the current file is complete.
func (tw *Writer) fillZeros(n int64) (int64, error) {
	if tw.err != nil {
		return 0, tw.err
	}
	n, err := tw.curr.FillZeros(n)
	if err != nil && err != ErrWriteTooLong {
		tw.err = err
	}
	return n, err
}

// Close closes the tar archive, flushing any unwritten
// data to the underlying writer.
func (tw *Writer) Close() error {
	if tw.err == ErrWriteAfterClose {
		return nil
	}
	if tw.err != nil {
		return tw.err
	}

	// Trailer: two zero blocks.
	err := tw.Flush()
	for i := 0; i < 2 && err == nil; i++ {
		_, err = tw.w.Write(zeroBlock[:])
	}

	// Ensure all future actions are invalid.
	tw.err = ErrWriteAfterClose
	return err // Report IO errors
}

// regFileWriter is a fileWriter for writing data to a regular file entry.
type regFileWriter struct {
	w  io.Writer // Underlying Writer
	nb int64     // Number of remaining bytes to write
}

func (fw *regFileWriter) Write(b []byte) (int, error) {
	overwrite := int64(len(b)) > fw.nb
	if overwrite {
		b = b[:fw.nb]
	}
	n, err := fw.w.Write(b)
	fw.nb -= int64(n)
	switch {
	case err != nil:
		return n, err
	case overwrite:
		return n, ErrWriteTooLong
	default:
		return n, nil
	}
}

func (fw *regFileWriter) FillZeros(n int64) (int64, error) {
	return io.CopyN(fw, zeroReader{}, n)
}

func (fw regFileWriter) Remaining() int64 {
	return fw.nb
}

// sparseFileWriter is a fileWriter for writing data to a sparse file entry.
type sparseFileWriter struct {
	fw  fileWriter  // Underlying fileWriter
	sp  sparseDatas // Normalized list of data fragments
	pos int64       // Current position in sparse file
}

func (sw *sparseFileWriter) Write(b []byte) (n int, err error) {
	overwrite := int64(len(b)) > sw.Remaining()
	if overwrite {
		b = b[:sw.Remaining()]
	}

	b0 := b
	endPos := sw.pos + int64(len(b))
	for endPos > sw.pos && err == nil {
		var nf int // Bytes written in fragment
		dataStart, dataEnd := sw.sp[0].Offset, sw.sp[0].endOffset()
		if sw.pos < dataStart { // In a hole fragment
			bf := b[:min(int64(len(b)), dataStart-sw.pos)]
			nf, err = zeroWriter{}.Write(bf)
		} else { // In a data fragment
			bf := b[:min(int64(len(b)), dataEnd-sw.pos)]
			nf, err = sw.fw.Write(bf)
		}
		b = b[nf:]
		sw.pos += int64(nf)
		if sw.pos >= dataEnd && len(sw.sp) > 1 {
			sw.sp = sw.sp[1:] // Ensure last fragment always remains
		}
	}

	n = len(b0) - len(b)
	switch {
	case err == ErrWriteTooLong:
		return n, errMissData // Not possible; implies bug in validation logic
	case err != nil:
		return n, err
	case sw.Remaining() == 0 && sw.fw.Remaining() > 0:
		return n, errUnrefData // Not possible; implies bug in validation logic
	case overwrite:
		return n, ErrWriteTooLong
	default:
		return n, nil
	}
}

func (sw *sparseFileWriter) FillZeros(n int64) (int64, error) {
	overwrite := n > sw.Remaining()
	if overwrite {
		n = sw.Remaining()
	}

	var realFill int64 // Number of real data bytes to fill
	endPos := sw.pos + n
	for endPos > sw.pos {
		var nf int64 // Size of fragment
		dataStart, dataEnd := sw.sp[0].Offset, sw.sp[0].endOffset()
		if sw.pos < dataStart { // In a hole fragment
			nf = min(endPos-sw.pos, dataStart-sw.pos)
		} else { // In a data fragment
			nf = min(endPos-sw.pos, dataEnd-sw.pos)
			realFill += nf
		}
		sw.pos += nf
		if sw.pos >= dataEnd && len(sw.sp) > 1 {
			sw.sp = sw.sp[1:] // Ensure last fragment always remains
		}
	}

	_, err := sw.fw.FillZeros(realFill)
	switch {
	case err == ErrWriteTooLong:
		return n, errMissData // Not possible; implies bug in validation logic
	case err != nil:
		return n, err
	case sw.Remaining() == 0 && sw.fw.Remaining() > 0:
		return n, errUnrefData // Not possible; implies bug in validation logic
	case overwrite:
		return n, ErrWriteTooLong
	default:
		return n, nil
	}
}

func (sw sparseFileWriter) Remaining() int64 {
	return sw.sp[len(sw.sp)-1].endOffset() - sw.pos
}

// zeroWriter may only be written with NULs, otherwise it returns errWriteHole.
type zeroWriter struct{}

func (zeroWriter) Write(b []byte) (int, error) {
	for i, c := range b {
		if c != 0 {
			return i, errWriteHole
		}
	}
	return len(b), nil
}