Commit 642a1cc7 authored by Nigel Tao's avatar Nigel Tao

compress/lzw: fix hi code overflow.

Change-Id: I2d3c3c715d857305944cd96c45554a16cb7967e9
Reviewed-on: https://go-review.googlesource.com/42032Reviewed-by: default avatarDavid Symonds <dsymonds@golang.org>
parent 4fcceca1
...@@ -57,8 +57,14 @@ type decoder struct { ...@@ -57,8 +57,14 @@ type decoder struct {
// The next two codes mean clear and EOF. // The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2, // Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen. // with the upper bound incrementing on each code seen.
// overflow is the code at which hi overflows the code width. //
// overflow is the code at which hi overflows the code width. It always
// equals 1 << width.
//
// last is the most recently seen code, or decoderInvalidCode. // last is the most recently seen code, or decoderInvalidCode.
//
// An invariant is that
// (hi < overflow) || (hi == overflow && last == decoderInvalidCode)
clear, eof, hi, overflow, last uint16 clear, eof, hi, overflow, last uint16
// Each code c in [lo, hi] expands to two or more bytes. For c != hi: // Each code c in [lo, hi] expands to two or more bytes. For c != hi:
...@@ -196,6 +202,10 @@ loop: ...@@ -196,6 +202,10 @@ loop:
if d.hi >= d.overflow { if d.hi >= d.overflow {
if d.width == maxWidth { if d.width == maxWidth {
d.last = decoderInvalidCode d.last = decoderInvalidCode
// Undo the d.hi++ a few lines above, so that (1) we maintain
// the invariant that d.hi <= d.overflow, and (2) d.hi does not
// eventually overflow a uint16.
d.hi--
} else { } else {
d.width++ d.width++
d.overflow <<= 1 d.overflow <<= 1
......
...@@ -120,6 +120,32 @@ func TestReader(t *testing.T) { ...@@ -120,6 +120,32 @@ func TestReader(t *testing.T) {
} }
} }
type devZero struct{}
func (devZero) Read(p []byte) (int, error) {
for i := range p {
p[i] = 0
}
return len(p), nil
}
func TestHiCodeDoesNotOverflow(t *testing.T) {
r := NewReader(devZero{}, LSB, 8)
d := r.(*decoder)
buf := make([]byte, 1024)
oldHi := uint16(0)
for i := 0; i < 100; i++ {
if _, err := io.ReadFull(r, buf); err != nil {
t.Fatalf("i=%d: %v", i, err)
}
// The hi code should never decrease.
if d.hi < oldHi {
t.Fatalf("i=%d: hi=%d decreased from previous value %d", i, d.hi, oldHi)
}
oldHi = d.hi
}
}
func BenchmarkDecoder(b *testing.B) { func BenchmarkDecoder(b *testing.B) {
buf, err := ioutil.ReadFile("../testdata/e.txt") buf, err := ioutil.ReadFile("../testdata/e.txt")
if err != nil { if err != nil {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment