Commit f5f0e40e authored by Rémy Oudompheng's avatar Rémy Oudompheng

compress/flate: implement Reset method on Writer.

Fixes #6138.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12953048
parent 2fb8022e
...@@ -416,6 +416,50 @@ func (d *compressor) init(w io.Writer, level int) (err error) { ...@@ -416,6 +416,50 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
return nil return nil
} }
var zeroes [32]int
var bzeroes [256]byte
func (d *compressor) reset(w io.Writer) {
d.w.reset(w)
d.sync = false
d.err = nil
switch d.compressionLevel.chain {
case 0:
// level was NoCompression.
for i := range d.window {
d.window[i] = 0
}
d.windowEnd = 0
default:
d.chainHead = -1
for s := d.hashHead; len(s) > 0; {
n := copy(s, zeroes[:])
s = s[n:]
}
for s := d.hashPrev; len(s) > 0; s = s[len(zeroes):] {
copy(s, zeroes[:])
}
d.hashOffset = 1
d.index, d.windowEnd = 0, 0
for s := d.window; len(s) > 0; {
n := copy(s, bzeroes[:])
s = s[n:]
}
d.blockStart, d.byteAvailable = 0, false
d.tokens = d.tokens[:maxFlateBlockTokens+1]
for i := 0; i <= maxFlateBlockTokens; i++ {
d.tokens[i] = 0
}
d.tokens = d.tokens[:0]
d.length = minMatchLength - 1
d.offset = 0
d.hash = 0
d.maxInsertIndex = 0
}
}
func (d *compressor) close() error { func (d *compressor) close() error {
d.sync = true d.sync = true
d.step(d) d.step(d)
...@@ -439,7 +483,6 @@ func (d *compressor) close() error { ...@@ -439,7 +483,6 @@ func (d *compressor) close() error {
// If level is in the range [-1, 9] then the error returned will be nil. // If level is in the range [-1, 9] then the error returned will be nil.
// Otherwise the error returned will be non-nil. // Otherwise the error returned will be non-nil.
func NewWriter(w io.Writer, level int) (*Writer, error) { func NewWriter(w io.Writer, level int) (*Writer, error) {
const logWindowSize = logMaxOffsetSize
var dw Writer var dw Writer
if err := dw.d.init(w, level); err != nil { if err := dw.d.init(w, level); err != nil {
return nil, err return nil, err
...@@ -462,6 +505,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { ...@@ -462,6 +505,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
zw.Write(dict) zw.Write(dict)
zw.Flush() zw.Flush()
dw.enabled = true dw.enabled = true
zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method.
return zw, err return zw, err
} }
...@@ -480,7 +524,8 @@ func (w *dictWriter) Write(b []byte) (n int, err error) { ...@@ -480,7 +524,8 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
// A Writer takes data written to it and writes the compressed // A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter). // form of that data to an underlying writer (see NewWriter).
type Writer struct { type Writer struct {
d compressor d compressor
dict []byte
} }
// Write writes data to w, which will eventually write the // Write writes data to w, which will eventually write the
...@@ -506,3 +551,21 @@ func (w *Writer) Flush() error { ...@@ -506,3 +551,21 @@ func (w *Writer) Flush() error {
func (w *Writer) Close() error { func (w *Writer) Close() error {
return w.d.close() return w.d.close()
} }
// Reset discards the writer's state and makes it equivalent to
// the result of NewWriter or NewWriterDict called with w
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.w.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
dw.enabled = false
w.Write(w.dict)
w.Flush()
dw.enabled = true
} else {
// w was created with NewWriter
w.d.reset(dst)
}
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"reflect"
"sync" "sync"
"testing" "testing"
) )
...@@ -424,3 +425,66 @@ func TestRegression2508(t *testing.T) { ...@@ -424,3 +425,66 @@ func TestRegression2508(t *testing.T) {
} }
w.Close() w.Close()
} }
func TestWriterReset(t *testing.T) {
for level := 0; level <= 9; level++ {
if testing.Short() && level > 1 {
break
}
w, err := NewWriter(ioutil.Discard, level)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
buf := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(buf)
}
w.Reset(ioutil.Discard)
wref, err := NewWriter(ioutil.Discard, level)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
// DeepEqual doesn't compare functions.
w.d.fill, wref.d.fill = nil, nil
w.d.step, wref.d.step = nil, nil
if !reflect.DeepEqual(w, wref) {
t.Errorf("level %d Writer not reset after Reset", level)
}
}
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) })
dict := []byte("we are the world")
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) })
}
func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) {
buf := new(bytes.Buffer)
w, err := newWriter(buf)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
b := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out1 := buf.String()
buf2 := new(bytes.Buffer)
w.Reset(buf2)
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out2 := buf2.String()
if out1 != out2 {
t.Errorf("got %q, expected %q", out2, out1)
}
t.Logf("got %d bytes", len(out1))
}
...@@ -97,6 +97,31 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { ...@@ -97,6 +97,31 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
} }
} }
func (w *huffmanBitWriter) reset(writer io.Writer) {
w.w = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
w.bytes = [64]byte{}
for i := range w.codegen {
w.codegen[i] = 0
}
for _, s := range [...][]int32{w.literalFreq, w.offsetFreq, w.codegenFreq} {
for i := range s {
s[i] = 0
}
}
for _, enc := range [...]*huffmanEncoder{
w.literalEncoding,
w.offsetEncoding,
w.codegenEncoding} {
for i := range enc.code {
enc.code[i] = 0
}
for i := range enc.codeBits {
enc.codeBits[i] = 0
}
}
}
func (w *huffmanBitWriter) flushBits() { func (w *huffmanBitWriter) flushBits() {
if w.err != nil { if w.err != nil {
w.nbits = 0 w.nbits = 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment