Commit a509026f authored by Rui Ueyama's avatar Rui Ueyama Committed by Brad Fitzpatrick

strings, bytes: fix Reader.UnreadRune

UnreadRune should return an error if previous operation is not
ReadRune.

Fixes #7579.

LGTM=bradfitz
R=golang-codereviews, bradfitz
CC=golang-codereviews
https://golang.org/cl/77580046
parent 8a908efd
...@@ -30,6 +30,7 @@ func (r *Reader) Len() int { ...@@ -30,6 +30,7 @@ func (r *Reader) Len() int {
} }
func (r *Reader) Read(b []byte) (n int, err error) { func (r *Reader) Read(b []byte) (n int, err error) {
r.prevRune = -1
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
...@@ -38,11 +39,11 @@ func (r *Reader) Read(b []byte) (n int, err error) { ...@@ -38,11 +39,11 @@ func (r *Reader) Read(b []byte) (n int, err error) {
} }
n = copy(b, r.s[r.i:]) n = copy(b, r.s[r.i:])
r.i += n r.i += n
r.prevRune = -1
return return
} }
func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
r.prevRune = -1
if off < 0 { if off < 0 {
return 0, errors.New("bytes: invalid offset") return 0, errors.New("bytes: invalid offset")
} }
...@@ -57,12 +58,12 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { ...@@ -57,12 +58,12 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
} }
func (r *Reader) ReadByte() (b byte, err error) { func (r *Reader) ReadByte() (b byte, err error) {
r.prevRune = -1
if r.i >= len(r.s) { if r.i >= len(r.s) {
return 0, io.EOF return 0, io.EOF
} }
b = r.s[r.i] b = r.s[r.i]
r.i++ r.i++
r.prevRune = -1
return return
} }
...@@ -77,6 +78,7 @@ func (r *Reader) UnreadByte() error { ...@@ -77,6 +78,7 @@ func (r *Reader) UnreadByte() error {
func (r *Reader) ReadRune() (ch rune, size int, err error) { func (r *Reader) ReadRune() (ch rune, size int, err error) {
if r.i >= len(r.s) { if r.i >= len(r.s) {
r.prevRune = -1
return 0, 0, io.EOF return 0, 0, io.EOF
} }
r.prevRune = r.i r.prevRune = r.i
...@@ -100,6 +102,7 @@ func (r *Reader) UnreadRune() error { ...@@ -100,6 +102,7 @@ func (r *Reader) UnreadRune() error {
// Seek implements the io.Seeker interface. // Seek implements the io.Seeker interface.
func (r *Reader) Seek(offset int64, whence int) (int64, error) { func (r *Reader) Seek(offset int64, whence int) (int64, error) {
r.prevRune = -1
var abs int64 var abs int64
switch whence { switch whence {
case 0: case 0:
......
...@@ -133,6 +133,33 @@ func TestReaderLen(t *testing.T) { ...@@ -133,6 +133,33 @@ func TestReaderLen(t *testing.T) {
} }
} }
var UnreadRuneErrorTests = []struct {
name string
f func(*Reader)
}{
{"Read", func(r *Reader) { r.Read([]byte{}) }},
{"ReadAt", func(r *Reader) { r.ReadAt([]byte{}, 0) }},
{"ReadByte", func(r *Reader) { r.ReadByte() }},
{"UnreadRune", func(r *Reader) { r.UnreadRune() }},
{"Seek", func(r *Reader) { r.Seek(0, 1) }},
{"WriteTo", func(r *Reader) { r.WriteTo(&Buffer{}) }},
}
func TestUnreadRuneError(t *testing.T) {
for _, tt := range UnreadRuneErrorTests {
reader := NewReader([]byte("0123456789"))
if _, _, err := reader.ReadRune(); err != nil {
// should not happen
t.Fatal(err)
}
tt.f(reader)
err := reader.UnreadRune()
if err == nil {
t.Errorf("Unreading after %s: expected error", tt.name)
}
}
}
func TestReaderDoubleUnreadRune(t *testing.T) { func TestReaderDoubleUnreadRune(t *testing.T) {
buf := NewBuffer([]byte("groucho")) buf := NewBuffer([]byte("groucho"))
if _, _, err := buf.ReadRune(); err != nil { if _, _, err := buf.ReadRune(); err != nil {
......
...@@ -29,6 +29,7 @@ func (r *Reader) Len() int { ...@@ -29,6 +29,7 @@ func (r *Reader) Len() int {
} }
func (r *Reader) Read(b []byte) (n int, err error) { func (r *Reader) Read(b []byte) (n int, err error) {
r.prevRune = -1
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
...@@ -37,11 +38,11 @@ func (r *Reader) Read(b []byte) (n int, err error) { ...@@ -37,11 +38,11 @@ func (r *Reader) Read(b []byte) (n int, err error) {
} }
n = copy(b, r.s[r.i:]) n = copy(b, r.s[r.i:])
r.i += n r.i += n
r.prevRune = -1
return return
} }
func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
r.prevRune = -1
if off < 0 { if off < 0 {
return 0, errors.New("strings: invalid offset") return 0, errors.New("strings: invalid offset")
} }
...@@ -56,12 +57,12 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { ...@@ -56,12 +57,12 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
} }
func (r *Reader) ReadByte() (b byte, err error) { func (r *Reader) ReadByte() (b byte, err error) {
r.prevRune = -1
if r.i >= len(r.s) { if r.i >= len(r.s) {
return 0, io.EOF return 0, io.EOF
} }
b = r.s[r.i] b = r.s[r.i]
r.i++ r.i++
r.prevRune = -1
return return
} }
...@@ -76,6 +77,7 @@ func (r *Reader) UnreadByte() error { ...@@ -76,6 +77,7 @@ func (r *Reader) UnreadByte() error {
func (r *Reader) ReadRune() (ch rune, size int, err error) { func (r *Reader) ReadRune() (ch rune, size int, err error) {
if r.i >= len(r.s) { if r.i >= len(r.s) {
r.prevRune = -1
return 0, 0, io.EOF return 0, 0, io.EOF
} }
r.prevRune = r.i r.prevRune = r.i
...@@ -99,6 +101,7 @@ func (r *Reader) UnreadRune() error { ...@@ -99,6 +101,7 @@ func (r *Reader) UnreadRune() error {
// Seek implements the io.Seeker interface. // Seek implements the io.Seeker interface.
func (r *Reader) Seek(offset int64, whence int) (int64, error) { func (r *Reader) Seek(offset int64, whence int) (int64, error) {
r.prevRune = -1
var abs int64 var abs int64
switch whence { switch whence {
case 0: case 0:
......
...@@ -858,6 +858,33 @@ func TestReadRune(t *testing.T) { ...@@ -858,6 +858,33 @@ func TestReadRune(t *testing.T) {
} }
} }
var UnreadRuneErrorTests = []struct {
name string
f func(*Reader)
}{
{"Read", func(r *Reader) { r.Read([]byte{}) }},
{"ReadAt", func(r *Reader) { r.ReadAt([]byte{}, 0) }},
{"ReadByte", func(r *Reader) { r.ReadByte() }},
{"UnreadRune", func(r *Reader) { r.UnreadRune() }},
{"Seek", func(r *Reader) { r.Seek(0, 1) }},
{"WriteTo", func(r *Reader) { r.WriteTo(&bytes.Buffer{}) }},
}
func TestUnreadRuneError(t *testing.T) {
for _, tt := range UnreadRuneErrorTests {
reader := NewReader("0123456789")
if _, _, err := reader.ReadRune(); err != nil {
// should not happen
t.Fatal(err)
}
tt.f(reader)
err := reader.UnreadRune()
if err == nil {
t.Errorf("Unreading after %s: expected error", tt.name)
}
}
}
var ReplaceTests = []struct { var ReplaceTests = []struct {
in string in string
old, new string old, new string
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment