Commit 9d2de778 authored by Gustav Westling's avatar Gustav Westling Committed by Brad Fitzpatrick

encoding/base32: support custom and disabled padding when decoding

CL 38634 added support for custom (and disabled) padding characters
when encoding, but didn't update the decoding paths. This adds
decoding support.

Fixes #20854

Change-Id: I9fb1a0aaebb27f1204c9f726a780d5784eb71024
Reviewed-on: https://go-review.googlesource.com/47341
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent f3b5a2bc
...@@ -279,19 +279,28 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -279,19 +279,28 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
dlen := 8 dlen := 8
for j := 0; j < 8; { for j := 0; j < 8; {
if len(src) == 0 {
// We have reached the end and are missing padding
if len(src) == 0 && enc.padChar != NoPadding {
return n, false, CorruptInputError(olen - len(src) - j) return n, false, CorruptInputError(olen - len(src) - j)
} }
// We have reached the end and are not expecing any padding
if len(src) == 0 && enc.padChar == NoPadding {
dlen, end = j, true
break
}
in := src[0] in := src[0]
src = src[1:] src = src[1:]
if in == '=' && j >= 2 && len(src) < 8 { if in == byte(enc.padChar) && j >= 2 && len(src) < 8 {
// We've reached the end and there's padding // We've reached the end and there's padding
if len(src)+j < 8-1 { if len(src)+j < 8-1 {
// not enough padding // not enough padding
return n, false, CorruptInputError(olen) return n, false, CorruptInputError(olen)
} }
for k := 0; k < 8-1-j; k++ { for k := 0; k < 8-1-j; k++ {
if len(src) > k && src[k] != '=' { if len(src) > k && src[k] != byte(enc.padChar) {
// incorrect padding // incorrect padding
return n, false, CorruptInputError(olen - len(src) + k - 1) return n, false, CorruptInputError(olen - len(src) + k - 1)
} }
...@@ -484,4 +493,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader { ...@@ -484,4 +493,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
// DecodedLen returns the maximum length in bytes of the decoded data // DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base32-encoded data. // corresponding to n bytes of base32-encoded data.
func (enc *Encoding) DecodedLen(n int) int { return n / 8 * 5 } func (enc *Encoding) DecodedLen(n int) int {
if enc.padChar == NoPadding {
// +6 represents the missing padding
return (n + 6) / 8 * 5
}
return n / 8 * 5
}
...@@ -490,3 +490,91 @@ func TestWithoutPadding(t *testing.T) { ...@@ -490,3 +490,91 @@ func TestWithoutPadding(t *testing.T) {
} }
} }
} }
func TestDecodeWithPadding(t *testing.T) {
encodings := []*Encoding{
StdEncoding,
StdEncoding.WithPadding('-'),
StdEncoding.WithPadding(NoPadding),
}
for i, enc := range encodings {
for _, pair := range pairs {
input := pair.decoded
encoded := enc.EncodeToString([]byte(input))
decoded, err := enc.DecodeString(encoded)
if err != nil {
t.Errorf("DecodeString Error for encoding %d (%q): %v", i, input, err)
}
if input != string(decoded) {
t.Errorf("Unexpected result for encoding %d: got %q; want %q", i, decoded, input)
}
}
}
}
func TestDecodeWithWrongPadding(t *testing.T) {
encoded := StdEncoding.EncodeToString([]byte("foobar"))
_, err := StdEncoding.WithPadding('-').DecodeString(encoded)
if err == nil {
t.Error("expected error")
}
_, err = StdEncoding.WithPadding(NoPadding).DecodeString(encoded)
if err == nil {
t.Error("expected error")
}
}
func TestEncodedDecodedLen(t *testing.T) {
type test struct {
in int
wantEnc int
wantDec int
}
data := bytes.Repeat([]byte("x"), 100)
for _, test := range []struct {
name string
enc *Encoding
cases []test
}{
{"StdEncoding", StdEncoding, []test{
{0, 0, 0},
{1, 8, 5},
{5, 8, 5},
{6, 16, 10},
{10, 16, 10},
}},
{"NoPadding", StdEncoding.WithPadding(NoPadding), []test{
{0, 0, 0},
{1, 2, 5},
{2, 4, 5},
{5, 8, 5},
{6, 10, 10},
{7, 12, 10},
{10, 16, 10},
{11, 18, 15},
}},
} {
t.Run(test.name, func(t *testing.T) {
for _, tc := range test.cases {
encLen := test.enc.EncodedLen(tc.in)
decLen := test.enc.DecodedLen(encLen)
enc := test.enc.EncodeToString(data[:tc.in])
if len(enc) != encLen {
t.Fatalf("EncodedLen(%d) = %d but encoded to %q (%d)", tc.in, encLen, enc, len(enc))
}
if encLen != tc.wantEnc {
t.Fatalf("EncodedLen(%d) = %d; want %d", tc.in, encLen, tc.wantEnc)
}
if decLen != tc.wantDec {
t.Fatalf("DecodedLen(%d) = %d; want %d", encLen, decLen, tc.wantDec)
}
}
})
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment