Commit f0745651 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

bytes, strings: allow Reader.Seek past 1<<31

Fixes #7654

LGTM=rsc
R=rsc, dan.kortschak
CC=golang-codereviews
https://golang.org/cl/81530043
parent e150ca9c
...@@ -16,17 +16,17 @@ import ( ...@@ -16,17 +16,17 @@ import (
// Unlike a Buffer, a Reader is read-only and supports seeking. // Unlike a Buffer, a Reader is read-only and supports seeking.
type Reader struct { type Reader struct {
s []byte s []byte
i int // current reading index i int64 // current reading index
prevRune int // index of previous rune; or < 0 prevRune int // index of previous rune; or < 0
} }
// Len returns the number of bytes of the unread portion of the // Len returns the number of bytes of the unread portion of the
// slice. // slice.
func (r *Reader) Len() int { func (r *Reader) Len() int {
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0 return 0
} }
return len(r.s) - r.i return int(int64(len(r.s)) - r.i)
} }
func (r *Reader) Read(b []byte) (n int, err error) { func (r *Reader) Read(b []byte) (n int, err error) {
...@@ -34,11 +34,11 @@ func (r *Reader) Read(b []byte) (n int, err error) { ...@@ -34,11 +34,11 @@ func (r *Reader) Read(b []byte) (n int, err error) {
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0, io.EOF return 0, io.EOF
} }
n = copy(b, r.s[r.i:]) n = copy(b, r.s[r.i:])
r.i += n r.i += int64(n)
return return
} }
...@@ -49,7 +49,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { ...@@ -49,7 +49,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
if off >= int64(len(r.s)) { if off >= int64(len(r.s)) {
return 0, io.EOF return 0, io.EOF
} }
n = copy(b, r.s[int(off):]) n = copy(b, r.s[off:])
if n < len(b) { if n < len(b) {
err = io.EOF err = io.EOF
} }
...@@ -58,7 +58,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { ...@@ -58,7 +58,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
func (r *Reader) ReadByte() (b byte, err error) { func (r *Reader) ReadByte() (b byte, err error) {
r.prevRune = -1 r.prevRune = -1
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0, io.EOF return 0, io.EOF
} }
b = r.s[r.i] b = r.s[r.i]
...@@ -76,17 +76,17 @@ func (r *Reader) UnreadByte() error { ...@@ -76,17 +76,17 @@ func (r *Reader) UnreadByte() error {
} }
func (r *Reader) ReadRune() (ch rune, size int, err error) { func (r *Reader) ReadRune() (ch rune, size int, err error) {
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
r.prevRune = -1 r.prevRune = -1
return 0, 0, io.EOF return 0, 0, io.EOF
} }
r.prevRune = r.i r.prevRune = int(r.i)
if c := r.s[r.i]; c < utf8.RuneSelf { if c := r.s[r.i]; c < utf8.RuneSelf {
r.i++ r.i++
return rune(c), 1, nil return rune(c), 1, nil
} }
ch, size = utf8.DecodeRune(r.s[r.i:]) ch, size = utf8.DecodeRune(r.s[r.i:])
r.i += size r.i += int64(size)
return return
} }
...@@ -94,7 +94,7 @@ func (r *Reader) UnreadRune() error { ...@@ -94,7 +94,7 @@ func (r *Reader) UnreadRune() error {
if r.prevRune < 0 { if r.prevRune < 0 {
return errors.New("bytes.Reader: previous operation was not ReadRune") return errors.New("bytes.Reader: previous operation was not ReadRune")
} }
r.i = r.prevRune r.i = int64(r.prevRune)
r.prevRune = -1 r.prevRune = -1
return nil return nil
} }
...@@ -116,17 +116,14 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) { ...@@ -116,17 +116,14 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) {
if abs < 0 { if abs < 0 {
return 0, errors.New("bytes: negative position") return 0, errors.New("bytes: negative position")
} }
if abs >= 1<<31 { r.i = abs
return 0, errors.New("bytes: position out of range")
}
r.i = int(abs)
return abs, nil return abs, nil
} }
// WriteTo implements the io.WriterTo interface. // WriteTo implements the io.WriterTo interface.
func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
r.prevRune = -1 r.prevRune = -1
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0, nil return 0, nil
} }
b := r.s[r.i:] b := r.s[r.i:]
...@@ -134,7 +131,7 @@ func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { ...@@ -134,7 +131,7 @@ func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
if m > len(b) { if m > len(b) {
panic("bytes.Reader.WriteTo: invalid Write count") panic("bytes.Reader.WriteTo: invalid Write count")
} }
r.i += m r.i += int64(m)
n = int64(m) n = int64(m)
if m != len(b) && err == nil { if m != len(b) && err == nil {
err = io.ErrShortWrite err = io.ErrShortWrite
......
...@@ -27,8 +27,8 @@ func TestReader(t *testing.T) { ...@@ -27,8 +27,8 @@ func TestReader(t *testing.T) {
{seek: os.SEEK_SET, off: 1, n: 1, want: "1"}, {seek: os.SEEK_SET, off: 1, n: 1, want: "1"},
{seek: os.SEEK_CUR, off: 1, wantpos: 3, n: 2, want: "34"}, {seek: os.SEEK_CUR, off: 1, wantpos: 3, n: 2, want: "34"},
{seek: os.SEEK_SET, off: -1, seekerr: "bytes: negative position"}, {seek: os.SEEK_SET, off: -1, seekerr: "bytes: negative position"},
{seek: os.SEEK_SET, off: 1<<31 - 1}, {seek: os.SEEK_SET, off: 1 << 33, wantpos: 1 << 33},
{seek: os.SEEK_CUR, off: 1, seekerr: "bytes: position out of range"}, {seek: os.SEEK_CUR, off: 1, wantpos: 1<<33 + 1},
{seek: os.SEEK_SET, n: 5, want: "01234"}, {seek: os.SEEK_SET, n: 5, want: "01234"},
{seek: os.SEEK_CUR, n: 5, want: "56789"}, {seek: os.SEEK_CUR, n: 5, want: "56789"},
{seek: os.SEEK_END, off: -1, n: 1, wantpos: 9, want: "9"}, {seek: os.SEEK_END, off: -1, n: 1, wantpos: 9, want: "9"},
...@@ -60,6 +60,16 @@ func TestReader(t *testing.T) { ...@@ -60,6 +60,16 @@ func TestReader(t *testing.T) {
} }
} }
func TestReadAfterBigSeek(t *testing.T) {
r := NewReader([]byte("0123456789"))
if _, err := r.Seek(1<<31+5, os.SEEK_SET); err != nil {
t.Fatal(err)
}
if n, err := r.Read(make([]byte, 10)); n != 0 || err != io.EOF {
t.Errorf("Read = %d, %v; want 0, EOF", n, err)
}
}
func TestReaderAt(t *testing.T) { func TestReaderAt(t *testing.T) {
r := NewReader([]byte("0123456789")) r := NewReader([]byte("0123456789"))
tests := []struct { tests := []struct {
......
...@@ -15,17 +15,17 @@ import ( ...@@ -15,17 +15,17 @@ import (
// from a string. // from a string.
type Reader struct { type Reader struct {
s string s string
i int // current reading index i int64 // current reading index
prevRune int // index of previous rune; or < 0 prevRune int // index of previous rune; or < 0
} }
// Len returns the number of bytes of the unread portion of the // Len returns the number of bytes of the unread portion of the
// string. // string.
func (r *Reader) Len() int { func (r *Reader) Len() int {
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0 return 0
} }
return len(r.s) - r.i return int(int64(len(r.s)) - r.i)
} }
func (r *Reader) Read(b []byte) (n int, err error) { func (r *Reader) Read(b []byte) (n int, err error) {
...@@ -33,11 +33,11 @@ func (r *Reader) Read(b []byte) (n int, err error) { ...@@ -33,11 +33,11 @@ func (r *Reader) Read(b []byte) (n int, err error) {
if len(b) == 0 { if len(b) == 0 {
return 0, nil return 0, nil
} }
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0, io.EOF return 0, io.EOF
} }
n = copy(b, r.s[r.i:]) n = copy(b, r.s[r.i:])
r.i += n r.i += int64(n)
return return
} }
...@@ -48,7 +48,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { ...@@ -48,7 +48,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
if off >= int64(len(r.s)) { if off >= int64(len(r.s)) {
return 0, io.EOF return 0, io.EOF
} }
n = copy(b, r.s[int(off):]) n = copy(b, r.s[off:])
if n < len(b) { if n < len(b) {
err = io.EOF err = io.EOF
} }
...@@ -57,7 +57,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) { ...@@ -57,7 +57,7 @@ func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
func (r *Reader) ReadByte() (b byte, err error) { func (r *Reader) ReadByte() (b byte, err error) {
r.prevRune = -1 r.prevRune = -1
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0, io.EOF return 0, io.EOF
} }
b = r.s[r.i] b = r.s[r.i]
...@@ -75,17 +75,17 @@ func (r *Reader) UnreadByte() error { ...@@ -75,17 +75,17 @@ func (r *Reader) UnreadByte() error {
} }
func (r *Reader) ReadRune() (ch rune, size int, err error) { func (r *Reader) ReadRune() (ch rune, size int, err error) {
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
r.prevRune = -1 r.prevRune = -1
return 0, 0, io.EOF return 0, 0, io.EOF
} }
r.prevRune = r.i r.prevRune = int(r.i)
if c := r.s[r.i]; c < utf8.RuneSelf { if c := r.s[r.i]; c < utf8.RuneSelf {
r.i++ r.i++
return rune(c), 1, nil return rune(c), 1, nil
} }
ch, size = utf8.DecodeRuneInString(r.s[r.i:]) ch, size = utf8.DecodeRuneInString(r.s[r.i:])
r.i += size r.i += int64(size)
return return
} }
...@@ -93,7 +93,7 @@ func (r *Reader) UnreadRune() error { ...@@ -93,7 +93,7 @@ func (r *Reader) UnreadRune() error {
if r.prevRune < 0 { if r.prevRune < 0 {
return errors.New("strings.Reader: previous operation was not ReadRune") return errors.New("strings.Reader: previous operation was not ReadRune")
} }
r.i = r.prevRune r.i = int64(r.prevRune)
r.prevRune = -1 r.prevRune = -1
return nil return nil
} }
...@@ -115,17 +115,14 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) { ...@@ -115,17 +115,14 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) {
if abs < 0 { if abs < 0 {
return 0, errors.New("strings: negative position") return 0, errors.New("strings: negative position")
} }
if abs >= 1<<31 { r.i = abs
return 0, errors.New("strings: position out of range")
}
r.i = int(abs)
return abs, nil return abs, nil
} }
// WriteTo implements the io.WriterTo interface. // WriteTo implements the io.WriterTo interface.
func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
r.prevRune = -1 r.prevRune = -1
if r.i >= len(r.s) { if r.i >= int64(len(r.s)) {
return 0, nil return 0, nil
} }
s := r.s[r.i:] s := r.s[r.i:]
...@@ -133,7 +130,7 @@ func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { ...@@ -133,7 +130,7 @@ func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
if m > len(s) { if m > len(s) {
panic("strings.Reader.WriteTo: invalid WriteString count") panic("strings.Reader.WriteTo: invalid WriteString count")
} }
r.i += m r.i += int64(m)
n = int64(m) n = int64(m)
if m != len(s) && err == nil { if m != len(s) && err == nil {
err = io.ErrShortWrite err = io.ErrShortWrite
......
...@@ -27,8 +27,8 @@ func TestReader(t *testing.T) { ...@@ -27,8 +27,8 @@ func TestReader(t *testing.T) {
{seek: os.SEEK_SET, off: 1, n: 1, want: "1"}, {seek: os.SEEK_SET, off: 1, n: 1, want: "1"},
{seek: os.SEEK_CUR, off: 1, wantpos: 3, n: 2, want: "34"}, {seek: os.SEEK_CUR, off: 1, wantpos: 3, n: 2, want: "34"},
{seek: os.SEEK_SET, off: -1, seekerr: "strings: negative position"}, {seek: os.SEEK_SET, off: -1, seekerr: "strings: negative position"},
{seek: os.SEEK_SET, off: 1<<31 - 1}, {seek: os.SEEK_SET, off: 1 << 33, wantpos: 1 << 33},
{seek: os.SEEK_CUR, off: 1, seekerr: "strings: position out of range"}, {seek: os.SEEK_CUR, off: 1, wantpos: 1<<33 + 1},
{seek: os.SEEK_SET, n: 5, want: "01234"}, {seek: os.SEEK_SET, n: 5, want: "01234"},
{seek: os.SEEK_CUR, n: 5, want: "56789"}, {seek: os.SEEK_CUR, n: 5, want: "56789"},
{seek: os.SEEK_END, off: -1, n: 1, wantpos: 9, want: "9"}, {seek: os.SEEK_END, off: -1, n: 1, wantpos: 9, want: "9"},
...@@ -60,6 +60,16 @@ func TestReader(t *testing.T) { ...@@ -60,6 +60,16 @@ func TestReader(t *testing.T) {
} }
} }
func TestReadAfterBigSeek(t *testing.T) {
r := strings.NewReader("0123456789")
if _, err := r.Seek(1<<31+5, os.SEEK_SET); err != nil {
t.Fatal(err)
}
if n, err := r.Read(make([]byte, 10)); n != 0 || err != io.EOF {
t.Errorf("Read = %d, %v; want 0, EOF", n, err)
}
}
func TestReaderAt(t *testing.T) { func TestReaderAt(t *testing.T) {
r := strings.NewReader("0123456789") r := strings.NewReader("0123456789")
tests := []struct { tests := []struct {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment