Commit 2f2b6e55 authored by David Symonds's avatar David Symonds

encoding/base64: ignore new line characters during decode.

Fixes #2541.

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/5610045
parent 804f1882
...@@ -208,22 +208,30 @@ func (e CorruptInputError) Error() string { ...@@ -208,22 +208,30 @@ func (e CorruptInputError) Error() string {
// decode is like Decode but returns an additional 'end' value, which // decode is like Decode but returns an additional 'end' value, which
// indicates if end-of-message padding was encountered and thus any // indicates if end-of-message padding was encountered and thus any
// additional data is an error. decode also assumes len(src)%4==0, // additional data is an error.
// since it is meant for internal use.
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
for i := 0; i < len(src)/4 && !end; i++ { osrc := src
for len(src) > 0 && !end {
// Decode quantum using the base64 alphabet // Decode quantum using the base64 alphabet
var dbuf [4]byte var dbuf [4]byte
dlen := 4 dlen := 4
dbufloop: dbufloop:
for j := 0; j < 4; j++ { for j := 0; j < 4; {
in := src[i*4+j] if len(src) == 0 {
if in == '=' && j >= 2 && i == len(src)/4-1 { return n, false, CorruptInputError(len(osrc) - len(src) - j)
}
in := src[0]
src = src[1:]
if in == '\r' || in == '\n' {
// Ignore this character.
continue
}
if in == '=' && j >= 2 && len(src) < 4 {
// We've reached the end and there's // We've reached the end and there's
// padding // padding
if src[i*4+3] != '=' { if len(src) > 0 && src[0] != '=' {
return n, false, CorruptInputError(i*4 + 2) return n, false, CorruptInputError(len(osrc) - len(src) - 1)
} }
dlen = j dlen = j
end = true end = true
...@@ -231,22 +239,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -231,22 +239,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
} }
dbuf[j] = enc.decodeMap[in] dbuf[j] = enc.decodeMap[in]
if dbuf[j] == 0xFF { if dbuf[j] == 0xFF {
return n, false, CorruptInputError(i*4 + j) return n, false, CorruptInputError(len(osrc) - len(src) - 1)
} }
j++
} }
// Pack 4x 6-bit source blocks into 3 byte destination // Pack 4x 6-bit source blocks into 3 byte destination
// quantum // quantum
switch dlen { switch dlen {
case 4: case 4:
dst[i*3+2] = dbuf[2]<<6 | dbuf[3] dst[2] = dbuf[2]<<6 | dbuf[3]
fallthrough fallthrough
case 3: case 3:
dst[i*3+1] = dbuf[1]<<4 | dbuf[2]>>2 dst[1] = dbuf[1]<<4 | dbuf[2]>>2
fallthrough fallthrough
case 2: case 2:
dst[i*3+0] = dbuf[0]<<2 | dbuf[1]>>4 dst[0] = dbuf[0]<<2 | dbuf[1]>>4
} }
dst = dst[3:]
n += dlen - 1 n += dlen - 1
} }
...@@ -257,11 +267,8 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -257,11 +267,8 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// DecodedLen(len(src)) bytes to dst and returns the number of bytes // DecodedLen(len(src)) bytes to dst and returns the number of bytes
// written. If src contains invalid base64 data, it will return the // written. If src contains invalid base64 data, it will return the
// number of bytes successfully written and CorruptInputError. // number of bytes successfully written and CorruptInputError.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
if len(src)%4 != 0 {
return 0, CorruptInputError(len(src) / 4 * 4)
}
n, _, err = enc.decode(dst, src) n, _, err = enc.decode(dst, src)
return return
} }
......
...@@ -197,3 +197,29 @@ func TestBig(t *testing.T) { ...@@ -197,3 +197,29 @@ func TestBig(t *testing.T) {
t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i)
} }
} }
func TestNewLineCharacters(t *testing.T) {
// Each of these should decode to the string "sure", without errors.
const expected = "sure"
examples := []string{
"c3VyZQ==",
"c3VyZQ==\r",
"c3VyZQ==\n",
"c3VyZQ==\r\n",
"c3VyZ\r\nQ==",
"c3V\ryZ\nQ==",
"c3V\nyZ\rQ==",
"c3VyZ\nQ==",
"c3VyZQ\n==",
}
for _, e := range examples {
buf, err := StdEncoding.DecodeString(e)
if err != nil {
t.Errorf("Decode(%q) failed: %v", e, err)
continue
}
if s := string(buf); s != expected {
t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment