Commit 8cd04da7 authored by Joe Tsai's avatar Joe Tsai Committed by Joe Tsai

compress/flate: make huffmanBitWriter errors persistent

For persistent error handling, the methods of huffmanBitWriter have to be
consistent about how they check errors. It must either consistently
check error *before* every operation OR immediately *after* every
operation. Since most of the current logic uses the previous approach,
we apply the same style of error checking to writeBits and all calls
to Write such that they only operate if w.err is already nil going
into them.

The error handling approach is brittle and easily broken by future commits to
the code. In the near future, we should switch the logic to use panic at the
lowest levels and a recover at the edge of the public API to ensure
that errors are always persistent.

Fixes #16749

Change-Id: Ie1d83e4ed8842f6911a31e23311cd3cbf38abe8c
Reviewed-on: https://go-review.googlesource.com/27200Reviewed-by: default avatarMatthew Dempsky <mdempsky@google.com>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent c10f8700
...@@ -724,7 +724,7 @@ func (w *Writer) Close() error { ...@@ -724,7 +724,7 @@ func (w *Writer) Close() error {
// the result of NewWriter or NewWriterDict called with dst // the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary. // and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) { func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.w.(*dictWriter); ok { if dw, ok := w.d.w.writer.(*dictWriter); ok {
// w was created with NewWriterDict // w was created with NewWriterDict
dw.w = dst dw.w = dst
w.d.reset(dw) w.d.reset(dw)
......
...@@ -6,6 +6,7 @@ package flate ...@@ -6,6 +6,7 @@ package flate
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"internal/testenv" "internal/testenv"
"io" "io"
...@@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) { ...@@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) {
} }
} }
} }
var errIO = errors.New("IO error")
// failWriter fails with errIO exactly at the nth call to Write.
type failWriter struct{ n int }
func (w *failWriter) Write(b []byte) (int, error) {
w.n--
if w.n == -1 {
return 0, errIO
}
return len(b), nil
}
func TestWriterPersistentError(t *testing.T) {
d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
d = d[:10000] // Keep this test short
zw, err := NewWriter(nil, DefaultCompression)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
// Sweep over the threshold at which an error is returned.
// The variable i makes it such that the ith call to failWriter.Write will
// return errIO. Since failWriter errors are not persistent, we must ensure
// that flate.Writer errors are persistent.
for i := 0; i < 1000; i++ {
fw := &failWriter{i}
zw.Reset(fw)
_, werr := zw.Write(d)
cerr := zw.Close()
if werr != errIO && werr != nil {
t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
}
if cerr != errIO && fw.n < 0 {
t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
}
if fw.n >= 0 {
// At this point, the failure threshold was sufficiently high enough
// that we wrote the whole stream without any errors.
return
}
}
}
...@@ -77,7 +77,11 @@ var offsetBase = []uint32{ ...@@ -77,7 +77,11 @@ var offsetBase = []uint32{
var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
type huffmanBitWriter struct { type huffmanBitWriter struct {
w io.Writer // writer is the underlying writer.
// Do not use it directly; use the write method, which ensures
// that Write errors are sticky.
writer io.Writer
// Data waiting to be written is bytes[0:nbytes] // Data waiting to be written is bytes[0:nbytes]
// and then the low nbits of bits. // and then the low nbits of bits.
bits uint64 bits uint64
...@@ -96,7 +100,7 @@ type huffmanBitWriter struct { ...@@ -96,7 +100,7 @@ type huffmanBitWriter struct {
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{ return &huffmanBitWriter{
w: w, writer: w,
literalFreq: make([]int32, maxNumLit), literalFreq: make([]int32, maxNumLit),
offsetFreq: make([]int32, offsetCodeCount), offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxNumLit+offsetCodeCount+1), codegen: make([]uint8, maxNumLit+offsetCodeCount+1),
...@@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { ...@@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
} }
func (w *huffmanBitWriter) reset(writer io.Writer) { func (w *huffmanBitWriter) reset(writer io.Writer) {
w.w = writer w.writer = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
w.bytes = [bufferSize]byte{} w.bytes = [bufferSize]byte{}
} }
...@@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() { ...@@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() {
n++ n++
} }
w.bits = 0 w.bits = 0
_, w.err = w.w.Write(w.bytes[:n]) w.write(w.bytes[:n])
w.nbytes = 0 w.nbytes = 0
} }
func (w *huffmanBitWriter) write(b []byte) {
if w.err != nil {
return
}
_, w.err = w.writer.Write(b)
}
func (w *huffmanBitWriter) writeBits(b int32, nb uint) { func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
if w.err != nil {
return
}
w.bits |= uint64(b) << w.nbits w.bits |= uint64(b) << w.nbits
w.nbits += nb w.nbits += nb
if w.nbits >= 48 { if w.nbits >= 48 {
...@@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) { ...@@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
bytes[5] = byte(bits >> 40) bytes[5] = byte(bits >> 40)
n += 6 n += 6
if n >= bufferFlushSize { if n >= bufferFlushSize {
_, w.err = w.w.Write(w.bytes[:n]) w.write(w.bytes[:n])
n = 0 n = 0
} }
w.nbytes = n w.nbytes = n
...@@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) { ...@@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
n++ n++
} }
if n != 0 { if n != 0 {
_, w.err = w.w.Write(w.bytes[:n]) w.write(w.bytes[:n])
if w.err != nil {
return
}
} }
w.nbytes = 0 w.nbytes = 0
_, w.err = w.w.Write(bytes) w.write(bytes)
} }
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying // RFC 1951 3.2.7 specifies a special run-length encoding for specifying
...@@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) { ...@@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) {
bytes[5] = byte(bits >> 40) bytes[5] = byte(bits >> 40)
n += 6 n += 6
if n >= bufferFlushSize { if n >= bufferFlushSize {
_, w.err = w.w.Write(w.bytes[:n]) w.write(w.bytes[:n])
n = 0 n = 0
} }
w.nbytes = n w.nbytes = n
...@@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets ...@@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets
// writeTokens writes a slice of tokens to the output. // writeTokens writes a slice of tokens to the output.
// codes for literal and offset encoding must be supplied. // codes for literal and offset encoding must be supplied.
func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) { func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
if w.err != nil {
return
}
for _, t := range tokens { for _, t := range tokens {
if t < matchType { if t < matchType {
w.writeCode(leCodes[t.literal()]) w.writeCode(leCodes[t.literal()])
...@@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) { ...@@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
if n < bufferFlushSize { if n < bufferFlushSize {
continue continue
} }
_, w.err = w.w.Write(w.bytes[:n]) w.write(w.bytes[:n])
if w.err != nil { if w.err != nil {
return return // Return early in the event of write failures
} }
n = 0 n = 0
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment