Commit 8bd9242f authored by Robert Griesemer's avatar Robert Griesemer

bufio: fix UnreadByte

Also:
- fix error messages in tests
- make tests more symmetric

Fixes #7607.

LGTM=r
R=r
CC=golang-codereviews
https://golang.org/cl/86180043
parent a8787cd8
...@@ -177,7 +177,7 @@ func (b *Reader) Read(p []byte) (n int, err error) { ...@@ -177,7 +177,7 @@ func (b *Reader) Read(p []byte) (n int, err error) {
// If no byte is available, returns an error. // If no byte is available, returns an error.
func (b *Reader) ReadByte() (c byte, err error) { func (b *Reader) ReadByte() (c byte, err error) {
b.lastRuneSize = -1 b.lastRuneSize = -1
for b.w == b.r { for b.r == b.w {
if b.err != nil { if b.err != nil {
return 0, b.readErr() return 0, b.readErr()
} }
...@@ -191,19 +191,19 @@ func (b *Reader) ReadByte() (c byte, err error) { ...@@ -191,19 +191,19 @@ func (b *Reader) ReadByte() (c byte, err error) {
// UnreadByte unreads the last byte. Only the most recently read byte can be unread. // UnreadByte unreads the last byte. Only the most recently read byte can be unread.
func (b *Reader) UnreadByte() error { func (b *Reader) UnreadByte() error {
b.lastRuneSize = -1 if b.lastByte < 0 || b.r == 0 && b.w > 0 {
if b.r == b.w && b.lastByte >= 0 {
b.w = 1
b.r = 0
b.buf[0] = byte(b.lastByte)
b.lastByte = -1
return nil
}
if b.r <= 0 {
return ErrInvalidUnreadByte return ErrInvalidUnreadByte
} }
// b.r > 0 || b.w == 0
if b.r > 0 {
b.r-- b.r--
} else {
// b.r == 0 && b.w == 0
b.w = 1
}
b.buf[b.r] = byte(b.lastByte)
b.lastByte = -1 b.lastByte = -1
b.lastRuneSize = -1
return nil return nil
} }
...@@ -233,7 +233,7 @@ func (b *Reader) ReadRune() (r rune, size int, err error) { ...@@ -233,7 +233,7 @@ func (b *Reader) ReadRune() (r rune, size int, err error) {
// regard it is stricter than UnreadByte, which will unread the last byte // regard it is stricter than UnreadByte, which will unread the last byte
// from any read operation.) // from any read operation.)
func (b *Reader) UnreadRune() error { func (b *Reader) UnreadRune() error {
if b.lastRuneSize < 0 || b.r == 0 { if b.lastRuneSize < 0 || b.r < b.lastRuneSize {
return ErrInvalidUnreadRune return ErrInvalidUnreadRune
} }
b.r -= b.lastRuneSize b.r -= b.lastRuneSize
......
...@@ -228,66 +228,94 @@ func TestReadRune(t *testing.T) { ...@@ -228,66 +228,94 @@ func TestReadRune(t *testing.T) {
} }
func TestUnreadRune(t *testing.T) { func TestUnreadRune(t *testing.T) {
got := ""
segments := []string{"Hello, world:", "日本語"} segments := []string{"Hello, world:", "日本語"}
data := strings.Join(segments, "")
r := NewReader(&StringReader{data: segments}) r := NewReader(&StringReader{data: segments})
got := ""
want := strings.Join(segments, "")
// Normal execution. // Normal execution.
for { for {
r1, _, err := r.ReadRune() r1, _, err := r.ReadRune()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
t.Error("unexpected EOF") t.Error("unexpected error on ReadRune:", err)
} }
break break
} }
got += string(r1) got += string(r1)
// Put it back and read it again // Put it back and read it again.
if err = r.UnreadRune(); err != nil { if err = r.UnreadRune(); err != nil {
t.Error("unexpected error on UnreadRune:", err) t.Fatal("unexpected error on UnreadRune:", err)
} }
r2, _, err := r.ReadRune() r2, _, err := r.ReadRune()
if err != nil { if err != nil {
t.Error("unexpected error reading after unreading:", err) t.Fatal("unexpected error reading after unreading:", err)
} }
if r1 != r2 { if r1 != r2 {
t.Errorf("incorrect rune after unread: got %c wanted %c", r1, r2) t.Fatalf("incorrect rune after unread: got %c, want %c", r1, r2)
} }
} }
if got != data { if got != want {
t.Errorf("want=%q got=%q", data, got) t.Errorf("got %q, want %q", got, want)
} }
} }
func TestUnreadByte(t *testing.T) { func TestUnreadByte(t *testing.T) {
want := "Hello, world"
got := ""
segments := []string{"Hello, ", "world"} segments := []string{"Hello, ", "world"}
r := NewReader(&StringReader{data: segments}) r := NewReader(&StringReader{data: segments})
got := ""
want := strings.Join(segments, "")
// Normal execution. // Normal execution.
for { for {
b1, err := r.ReadByte() b1, err := r.ReadByte()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
t.Fatal("unexpected EOF") t.Error("unexpected error on ReadByte:", err)
} }
break break
} }
got += string(b1) got += string(b1)
// Put it back and read it again // Put it back and read it again.
if err = r.UnreadByte(); err != nil { if err = r.UnreadByte(); err != nil {
t.Fatalf("unexpected error on UnreadByte: %v", err) t.Fatal("unexpected error on UnreadByte:", err)
} }
b2, err := r.ReadByte() b2, err := r.ReadByte()
if err != nil { if err != nil {
t.Fatalf("unexpected error reading after unreading: %v", err) t.Fatal("unexpected error reading after unreading:", err)
} }
if b1 != b2 { if b1 != b2 {
t.Fatalf("incorrect byte after unread: got %c wanted %c", b1, b2) t.Fatalf("incorrect byte after unread: got %q, want %q", b1, b2)
} }
} }
if got != want { if got != want {
t.Errorf("got=%q want=%q", got, want) t.Errorf("got %q, want %q", got, want)
}
}
func TestUnreadByteMultiple(t *testing.T) {
segments := []string{"Hello, ", "world"}
data := strings.Join(segments, "")
for n := 0; n <= len(data); n++ {
r := NewReader(&StringReader{data: segments})
// Read n bytes.
for i := 0; i < n; i++ {
b, err := r.ReadByte()
if err != nil {
t.Fatalf("n = %d: unexpected error on ReadByte: %v", n, err)
}
if b != data[i] {
t.Fatalf("n = %d: incorrect byte returned from ReadByte: got %q, want %q", n, b, data[i])
}
}
// Unread one byte if there is one.
if n > 0 {
if err := r.UnreadByte(); err != nil {
t.Errorf("n = %d: unexpected error on UnreadByte: %v", n, err)
}
}
// Test that we cannot unread any further.
if err := r.UnreadByte(); err == nil {
t.Errorf("n = %d: expected error on UnreadByte", n)
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment