Commit f5f0e40e authored by Rémy Oudompheng's avatar Rémy Oudompheng

compress/flate: implement Reset method on Writer.

Fixes #6138.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12953048
parent 2fb8022e
......@@ -416,6 +416,50 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
return nil
}
var zeroes [32]int
var bzeroes [256]byte
func (d *compressor) reset(w io.Writer) {
d.w.reset(w)
d.sync = false
d.err = nil
switch d.compressionLevel.chain {
case 0:
// level was NoCompression.
for i := range d.window {
d.window[i] = 0
}
d.windowEnd = 0
default:
d.chainHead = -1
for s := d.hashHead; len(s) > 0; {
n := copy(s, zeroes[:])
s = s[n:]
}
for s := d.hashPrev; len(s) > 0; s = s[len(zeroes):] {
copy(s, zeroes[:])
}
d.hashOffset = 1
d.index, d.windowEnd = 0, 0
for s := d.window; len(s) > 0; {
n := copy(s, bzeroes[:])
s = s[n:]
}
d.blockStart, d.byteAvailable = 0, false
d.tokens = d.tokens[:maxFlateBlockTokens+1]
for i := 0; i <= maxFlateBlockTokens; i++ {
d.tokens[i] = 0
}
d.tokens = d.tokens[:0]
d.length = minMatchLength - 1
d.offset = 0
d.hash = 0
d.maxInsertIndex = 0
}
}
func (d *compressor) close() error {
d.sync = true
d.step(d)
......@@ -439,7 +483,6 @@ func (d *compressor) close() error {
// If level is in the range [-1, 9] then the error returned will be nil.
// Otherwise the error returned will be non-nil.
func NewWriter(w io.Writer, level int) (*Writer, error) {
const logWindowSize = logMaxOffsetSize
var dw Writer
if err := dw.d.init(w, level); err != nil {
return nil, err
......@@ -462,6 +505,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
zw.Write(dict)
zw.Flush()
dw.enabled = true
zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method.
return zw, err
}
......@@ -481,6 +525,7 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
d compressor
dict []byte
}
// Write writes data to w, which will eventually write the
......@@ -506,3 +551,21 @@ func (w *Writer) Flush() error {
func (w *Writer) Close() error {
return w.d.close()
}
// Reset discards the writer's state and makes it equivalent to
// the result of NewWriter or NewWriterDict called with w
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.w.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
dw.enabled = false
w.Write(w.dict)
w.Flush()
dw.enabled = true
} else {
// w was created with NewWriter
w.d.reset(dst)
}
}
......@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"io/ioutil"
"reflect"
"sync"
"testing"
)
......@@ -424,3 +425,66 @@ func TestRegression2508(t *testing.T) {
}
w.Close()
}
func TestWriterReset(t *testing.T) {
for level := 0; level <= 9; level++ {
if testing.Short() && level > 1 {
break
}
w, err := NewWriter(ioutil.Discard, level)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
buf := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(buf)
}
w.Reset(ioutil.Discard)
wref, err := NewWriter(ioutil.Discard, level)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
// DeepEqual doesn't compare functions.
w.d.fill, wref.d.fill = nil, nil
w.d.step, wref.d.step = nil, nil
if !reflect.DeepEqual(w, wref) {
t.Errorf("level %d Writer not reset after Reset", level)
}
}
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) })
dict := []byte("we are the world")
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) })
}
func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) {
buf := new(bytes.Buffer)
w, err := newWriter(buf)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
b := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out1 := buf.String()
buf2 := new(bytes.Buffer)
w.Reset(buf2)
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out2 := buf2.String()
if out1 != out2 {
t.Errorf("got %q, expected %q", out2, out1)
}
t.Logf("got %d bytes", len(out1))
}
......@@ -97,6 +97,31 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
}
}
func (w *huffmanBitWriter) reset(writer io.Writer) {
w.w = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
w.bytes = [64]byte{}
for i := range w.codegen {
w.codegen[i] = 0
}
for _, s := range [...][]int32{w.literalFreq, w.offsetFreq, w.codegenFreq} {
for i := range s {
s[i] = 0
}
}
for _, enc := range [...]*huffmanEncoder{
w.literalEncoding,
w.offsetEncoding,
w.codegenEncoding} {
for i := range enc.code {
enc.code[i] = 0
}
for i := range enc.codeBits {
enc.codeBits[i] = 0
}
}
}
func (w *huffmanBitWriter) flushBits() {
if w.err != nil {
w.nbits = 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment