Commit ae4aac00 authored by Hiroshi Ioka's avatar Hiroshi Ioka Committed by Brad Fitzpatrick

encoding/asn1: reduce allocations in Marshal

Current code uses trees of bytes.Buffer as data representation.
Each bytes.Buffer takes 4k bytes at least, so it's waste of memory.
The change introduces trees of lazy-encoder as
alternative one which reduce allocations.

name       old time/op    new time/op    delta
Marshal-4    64.7µs ± 2%    42.0µs ± 1%  -35.07%   (p=0.000 n=9+10)

name       old alloc/op   new alloc/op   delta
Marshal-4    35.1kB ± 0%     7.6kB ± 0%  -78.27%  (p=0.000 n=10+10)

name       old allocs/op  new allocs/op  delta
Marshal-4       503 ± 0%       293 ± 0%  -41.75%  (p=0.000 n=10+10)

Change-Id: I32b96c20b8df00414b282d69743d71a598a11336
Reviewed-on: https://go-review.googlesource.com/27030Reviewed-by: default avatarAdam Langley <agl@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Adam Langley <agl@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent ee3f3a60
...@@ -132,9 +132,9 @@ func TestParseBigInt(t *testing.T) { ...@@ -132,9 +132,9 @@ func TestParseBigInt(t *testing.T) {
if ret.String() != test.base10 { if ret.String() != test.base10 {
t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10) t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
} }
fw := newForkableWriter() e := makeBigInt(ret)
marshalBigInt(fw, ret) result := make([]byte, e.Len())
result := fw.Bytes() e.Encode(result)
if !bytes.Equal(result, test.in) { if !bytes.Equal(result, test.in) {
t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in) t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
} }
......
...@@ -5,77 +5,125 @@ ...@@ -5,77 +5,125 @@
package asn1 package asn1
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"math/big" "math/big"
"reflect" "reflect"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
// A forkableWriter is an in-memory buffer that can be var (
// 'forked' to create new forkableWriters that bracket the byte00Encoder encoder = byteEncoder(0x00)
// original. After byteFFEncoder encoder = byteEncoder(0xff)
// pre, post := w.fork() )
// the overall sequence of bytes represented is logically w+pre+post.
type forkableWriter struct { // encoder represents a ASN.1 element that is waiting to be marshaled.
*bytes.Buffer type encoder interface {
pre, post *forkableWriter // Len returns the number of bytes needed to marshal this element.
Len() int
// Encode encodes this element by writing Len() bytes to dst.
Encode(dst []byte)
}
type byteEncoder byte
func (c byteEncoder) Len() int {
return 1
}
func (c byteEncoder) Encode(dst []byte) {
dst[0] = byte(c)
} }
func newForkableWriter() *forkableWriter { type bytesEncoder []byte
return &forkableWriter{new(bytes.Buffer), nil, nil}
func (b bytesEncoder) Len() int {
return len(b)
} }
func (f *forkableWriter) fork() (pre, post *forkableWriter) { func (b bytesEncoder) Encode(dst []byte) {
if f.pre != nil || f.post != nil { if copy(dst, b) != len(b) {
panic("have already forked") panic("internal error")
} }
f.pre = newForkableWriter()
f.post = newForkableWriter()
return f.pre, f.post
} }
func (f *forkableWriter) Len() (l int) { type stringEncoder string
l += f.Buffer.Len()
if f.pre != nil { func (s stringEncoder) Len() int {
l += f.pre.Len() return len(s)
}
func (s stringEncoder) Encode(dst []byte) {
if copy(dst, s) != len(s) {
panic("internal error")
} }
if f.post != nil { }
l += f.post.Len()
type multiEncoder []encoder
func (m multiEncoder) Len() int {
var size int
for _, e := range m {
size += e.Len()
} }
return return size
} }
func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) { func (m multiEncoder) Encode(dst []byte) {
n, err = out.Write(f.Bytes()) var off int
if err != nil { for _, e := range m {
return e.Encode(dst[off:])
off += e.Len()
} }
}
var nn int type taggedEncoder struct {
// scratch contains temporary space for encoding the tag and length of
// an element in order to avoid extra allocations.
scratch [8]byte
tag encoder
body encoder
}
if f.pre != nil { func (t *taggedEncoder) Len() int {
nn, err = f.pre.writeTo(out) return t.tag.Len() + t.body.Len()
n += nn }
if err != nil {
return func (t *taggedEncoder) Encode(dst []byte) {
t.tag.Encode(dst)
t.body.Encode(dst[t.tag.Len():])
}
type int64Encoder int64
func (i int64Encoder) Len() int {
n := 1
for i > 127 {
n++
i >>= 8
} }
for i < -128 {
n++
i >>= 8
} }
if f.post != nil { return n
nn, err = f.post.writeTo(out) }
n += nn
func (i int64Encoder) Encode(dst []byte) {
n := i.Len()
for j := 0; j < n; j++ {
dst[j] = byte(i >> uint((n-1-j)*8))
} }
return
} }
func marshalBase128Int(out *forkableWriter, n int64) (err error) { func base128IntLength(n int64) int {
if n == 0 { if n == 0 {
err = out.WriteByte(0) return 1
return
} }
l := 0 l := 0
...@@ -83,54 +131,29 @@ func marshalBase128Int(out *forkableWriter, n int64) (err error) { ...@@ -83,54 +131,29 @@ func marshalBase128Int(out *forkableWriter, n int64) (err error) {
l++ l++
} }
return l
}
func appendBase128Int(dst []byte, n int64) []byte {
l := base128IntLength(n)
for i := l - 1; i >= 0; i-- { for i := l - 1; i >= 0; i-- {
o := byte(n >> uint(i*7)) o := byte(n >> uint(i*7))
o &= 0x7f o &= 0x7f
if i != 0 { if i != 0 {
o |= 0x80 o |= 0x80
} }
err = out.WriteByte(o)
if err != nil {
return
}
}
return nil
}
func marshalInt64(out *forkableWriter, i int64) (err error) {
n := int64Length(i)
for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8)))
if err != nil {
return
}
}
return nil
}
func int64Length(i int64) (numBytes int) {
numBytes = 1
for i > 127 {
numBytes++
i >>= 8
}
for i < -128 { dst = append(dst, o)
numBytes++
i >>= 8
} }
return return dst
} }
func marshalBigInt(out *forkableWriter, n *big.Int) (err error) { func makeBigInt(n *big.Int) encoder {
if n.Sign() < 0 { if n.Sign() < 0 {
// A negative number has to be converted to two's-complement // A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the // form. So we'll invert and subtract 1. If the
// most-significant-bit isn't set then we'll need to pad the // most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative. // beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n) nMinus1 := new(big.Int).Neg(n)
...@@ -140,41 +163,31 @@ func marshalBigInt(out *forkableWriter, n *big.Int) (err error) { ...@@ -140,41 +163,31 @@ func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
bytes[i] ^= 0xff bytes[i] ^= 0xff
} }
if len(bytes) == 0 || bytes[0]&0x80 == 0 { if len(bytes) == 0 || bytes[0]&0x80 == 0 {
err = out.WriteByte(0xff) return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)})
if err != nil {
return
}
} }
_, err = out.Write(bytes) return bytesEncoder(bytes)
} else if n.Sign() == 0 { } else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes. // Zero is written as a single 0 zero rather than no bytes.
err = out.WriteByte(0x00) return byte00Encoder
} else { } else {
bytes := n.Bytes() bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 { if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it // We'll have to pad this with 0x00 in order to stop it
// looking like a negative number. // looking like a negative number.
err = out.WriteByte(0) return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)})
if err != nil {
return
}
} }
_, err = out.Write(bytes) return bytesEncoder(bytes)
} }
return
} }
func marshalLength(out *forkableWriter, i int) (err error) { func appendLength(dst []byte, i int) []byte {
n := lengthLength(i) n := lengthLength(i)
for ; n > 0; n-- { for ; n > 0; n-- {
err = out.WriteByte(byte(i >> uint((n-1)*8))) dst = append(dst, byte(i>>uint((n-1)*8)))
if err != nil {
return
}
} }
return nil return dst
} }
func lengthLength(i int) (numBytes int) { func lengthLength(i int) (numBytes int) {
...@@ -186,123 +199,104 @@ func lengthLength(i int) (numBytes int) { ...@@ -186,123 +199,104 @@ func lengthLength(i int) (numBytes int) {
return return
} }
func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) { func appendTagAndLength(dst []byte, t tagAndLength) []byte {
b := uint8(t.class) << 6 b := uint8(t.class) << 6
if t.isCompound { if t.isCompound {
b |= 0x20 b |= 0x20
} }
if t.tag >= 31 { if t.tag >= 31 {
b |= 0x1f b |= 0x1f
err = out.WriteByte(b) dst = append(dst, b)
if err != nil { dst = appendBase128Int(dst, int64(t.tag))
return
}
err = marshalBase128Int(out, int64(t.tag))
if err != nil {
return
}
} else { } else {
b |= uint8(t.tag) b |= uint8(t.tag)
err = out.WriteByte(b) dst = append(dst, b)
if err != nil {
return
}
} }
if t.length >= 128 { if t.length >= 128 {
l := lengthLength(t.length) l := lengthLength(t.length)
err = out.WriteByte(0x80 | byte(l)) dst = append(dst, 0x80|byte(l))
if err != nil { dst = appendLength(dst, t.length)
return
}
err = marshalLength(out, t.length)
if err != nil {
return
}
} else { } else {
err = out.WriteByte(byte(t.length)) dst = append(dst, byte(t.length))
if err != nil {
return
}
} }
return nil return dst
} }
func marshalBitString(out *forkableWriter, b BitString) (err error) { type bitStringEncoder BitString
paddingBits := byte((8 - b.BitLength%8) % 8)
err = out.WriteByte(paddingBits) func (b bitStringEncoder) Len() int {
if err != nil { return len(b.Bytes) + 1
return
}
_, err = out.Write(b.Bytes)
return
} }
func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) { func (b bitStringEncoder) Encode(dst []byte) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { dst[0] = byte((8 - b.BitLength%8) % 8)
return StructuralError{"invalid object identifier"} if copy(dst[1:], b.Bytes) != len(b.Bytes) {
panic("internal error")
} }
}
err = marshalBase128Int(out, int64(oid[0]*40+oid[1])) type oidEncoder []int
if err != nil {
return func (oid oidEncoder) Len() int {
l := base128IntLength(int64(oid[0]*40 + oid[1]))
for i := 2; i < len(oid); i++ {
l += base128IntLength(int64(oid[i]))
} }
return l
}
func (oid oidEncoder) Encode(dst []byte) {
dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
for i := 2; i < len(oid); i++ { for i := 2; i < len(oid); i++ {
err = marshalBase128Int(out, int64(oid[i])) dst = appendBase128Int(dst, int64(oid[i]))
if err != nil {
return
} }
}
func makeObjectIdentifier(oid []int) (e encoder, err error) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
return nil, StructuralError{"invalid object identifier"}
} }
return return oidEncoder(oid), nil
} }
func marshalPrintableString(out *forkableWriter, s string) (err error) { func makePrintableString(s string) (e encoder, err error) {
b := []byte(s) for i := 0; i < len(s); i++ {
for _, c := range b { if !isPrintable(s[i]) {
if !isPrintable(c) { return nil, StructuralError{"PrintableString contains invalid character"}
return StructuralError{"PrintableString contains invalid character"}
} }
} }
_, err = out.Write(b) return stringEncoder(s), nil
return
} }
func marshalIA5String(out *forkableWriter, s string) (err error) { func makeIA5String(s string) (e encoder, err error) {
b := []byte(s) for i := 0; i < len(s); i++ {
for _, c := range b { if s[i] > 127 {
if c > 127 { return nil, StructuralError{"IA5String contains invalid character"}
return StructuralError{"IA5String contains invalid character"}
} }
} }
_, err = out.Write(b) return stringEncoder(s), nil
return
} }
func marshalUTF8String(out *forkableWriter, s string) (err error) { func makeUTF8String(s string) encoder {
_, err = out.Write([]byte(s)) return stringEncoder(s)
return
} }
func marshalTwoDigits(out *forkableWriter, v int) (err error) { func appendTwoDigits(dst []byte, v int) []byte {
err = out.WriteByte(byte('0' + (v/10)%10)) return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
if err != nil {
return
}
return out.WriteByte(byte('0' + v%10))
} }
func marshalFourDigits(out *forkableWriter, v int) (err error) { func appendFourDigits(dst []byte, v int) []byte {
var bytes [4]byte var bytes [4]byte
for i := range bytes { for i := range bytes {
bytes[3-i] = '0' + byte(v%10) bytes[3-i] = '0' + byte(v%10)
v /= 10 v /= 10
} }
_, err = out.Write(bytes[:]) return append(dst, bytes[:]...)
return
} }
func outsideUTCRange(t time.Time) bool { func outsideUTCRange(t time.Time) bool {
...@@ -310,80 +304,75 @@ func outsideUTCRange(t time.Time) bool { ...@@ -310,80 +304,75 @@ func outsideUTCRange(t time.Time) bool {
return year < 1950 || year >= 2050 return year < 1950 || year >= 2050
} }
func marshalUTCTime(out *forkableWriter, t time.Time) (err error) { func makeUTCTime(t time.Time) (e encoder, err error) {
dst := make([]byte, 0, 18)
dst, err = appendUTCTime(dst, t)
if err != nil {
return nil, err
}
return bytesEncoder(dst), nil
}
func makeGeneralizedTime(t time.Time) (e encoder, err error) {
dst := make([]byte, 0, 20)
dst, err = appendGeneralizedTime(dst, t)
if err != nil {
return nil, err
}
return bytesEncoder(dst), nil
}
func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year() year := t.Year()
switch { switch {
case 1950 <= year && year < 2000: case 1950 <= year && year < 2000:
err = marshalTwoDigits(out, year-1900) dst = appendTwoDigits(dst, year-1900)
case 2000 <= year && year < 2050: case 2000 <= year && year < 2050:
err = marshalTwoDigits(out, year-2000) dst = appendTwoDigits(dst, year-2000)
default: default:
return StructuralError{"cannot represent time as UTCTime"} return nil, StructuralError{"cannot represent time as UTCTime"}
}
if err != nil {
return
} }
return marshalTimeCommon(out, t) return appendTimeCommon(dst, t), nil
} }
func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) { func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year() year := t.Year()
if year < 0 || year > 9999 { if year < 0 || year > 9999 {
return StructuralError{"cannot represent time as GeneralizedTime"} return nil, StructuralError{"cannot represent time as GeneralizedTime"}
}
if err = marshalFourDigits(out, year); err != nil {
return
} }
return marshalTimeCommon(out, t) dst = appendFourDigits(dst, year)
return appendTimeCommon(dst, t), nil
} }
func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) { func appendTimeCommon(dst []byte, t time.Time) []byte {
_, month, day := t.Date() _, month, day := t.Date()
err = marshalTwoDigits(out, int(month)) dst = appendTwoDigits(dst, int(month))
if err != nil { dst = appendTwoDigits(dst, day)
return
}
err = marshalTwoDigits(out, day)
if err != nil {
return
}
hour, min, sec := t.Clock() hour, min, sec := t.Clock()
err = marshalTwoDigits(out, hour) dst = appendTwoDigits(dst, hour)
if err != nil { dst = appendTwoDigits(dst, min)
return dst = appendTwoDigits(dst, sec)
}
err = marshalTwoDigits(out, min)
if err != nil {
return
}
err = marshalTwoDigits(out, sec)
if err != nil {
return
}
_, offset := t.Zone() _, offset := t.Zone()
switch { switch {
case offset/60 == 0: case offset/60 == 0:
err = out.WriteByte('Z') return append(dst, 'Z')
return
case offset > 0: case offset > 0:
err = out.WriteByte('+') dst = append(dst, '+')
case offset < 0: case offset < 0:
err = out.WriteByte('-') dst = append(dst, '-')
}
if err != nil {
return
} }
offsetMinutes := offset / 60 offsetMinutes := offset / 60
...@@ -391,13 +380,10 @@ func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) { ...@@ -391,13 +380,10 @@ func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
offsetMinutes = -offsetMinutes offsetMinutes = -offsetMinutes
} }
err = marshalTwoDigits(out, offsetMinutes/60) dst = appendTwoDigits(dst, offsetMinutes/60)
if err != nil { dst = appendTwoDigits(dst, offsetMinutes%60)
return
}
err = marshalTwoDigits(out, offsetMinutes%60) return dst
return
} }
func stripTagAndLength(in []byte) []byte { func stripTagAndLength(in []byte) []byte {
...@@ -408,114 +394,124 @@ func stripTagAndLength(in []byte) []byte { ...@@ -408,114 +394,124 @@ func stripTagAndLength(in []byte) []byte {
return in[offset:] return in[offset:]
} }
func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) { func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
switch value.Type() { switch value.Type() {
case flagType: case flagType:
return nil return bytesEncoder(nil), nil
case timeType: case timeType:
t := value.Interface().(time.Time) t := value.Interface().(time.Time)
if params.timeType == TagGeneralizedTime || outsideUTCRange(t) { if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
return marshalGeneralizedTime(out, t) return makeGeneralizedTime(t)
} else {
return marshalUTCTime(out, t)
} }
return makeUTCTime(t)
case bitStringType: case bitStringType:
return marshalBitString(out, value.Interface().(BitString)) return bitStringEncoder(value.Interface().(BitString)), nil
case objectIdentifierType: case objectIdentifierType:
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
case bigIntType: case bigIntType:
return marshalBigInt(out, value.Interface().(*big.Int)) return makeBigInt(value.Interface().(*big.Int)), nil
} }
switch v := value; v.Kind() { switch v := value; v.Kind() {
case reflect.Bool: case reflect.Bool:
if v.Bool() { if v.Bool() {
return out.WriteByte(255) return byteFFEncoder, nil
} else {
return out.WriteByte(0)
} }
return byte00Encoder, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return marshalInt64(out, v.Int()) return int64Encoder(v.Int()), nil
case reflect.Struct: case reflect.Struct:
t := v.Type() t := v.Type()
startingField := 0 startingField := 0
n := t.NumField()
if n == 0 {
return bytesEncoder(nil), nil
}
// If the first element of the structure is a non-empty // If the first element of the structure is a non-empty
// RawContents, then we don't bother serializing the rest. // RawContents, then we don't bother serializing the rest.
if t.NumField() > 0 && t.Field(0).Type == rawContentsType { if t.Field(0).Type == rawContentsType {
s := v.Field(0) s := v.Field(0)
if s.Len() > 0 { if s.Len() > 0 {
bytes := make([]byte, s.Len()) bytes := s.Bytes()
for i := 0; i < s.Len(); i++ {
bytes[i] = uint8(s.Index(i).Uint())
}
/* The RawContents will contain the tag and /* The RawContents will contain the tag and
* length fields but we'll also be writing * length fields but we'll also be writing
* those ourselves, so we strip them out of * those ourselves, so we strip them out of
* bytes */ * bytes */
_, err = out.Write(stripTagAndLength(bytes)) return bytesEncoder(stripTagAndLength(bytes)), nil
return
} else {
startingField = 1
} }
startingField = 1
} }
for i := startingField; i < t.NumField(); i++ { switch n1 := n - startingField; n1 {
var pre *forkableWriter case 0:
pre, out = out.fork() return bytesEncoder(nil), nil
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1"))) case 1:
return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
default:
m := make([]encoder, n1)
for i := 0; i < n1; i++ {
m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
if err != nil { if err != nil {
return return nil, err
} }
} }
return
return multiEncoder(m), nil
}
case reflect.Slice: case reflect.Slice:
sliceType := v.Type() sliceType := v.Type()
if sliceType.Elem().Kind() == reflect.Uint8 { if sliceType.Elem().Kind() == reflect.Uint8 {
bytes := make([]byte, v.Len()) return bytesEncoder(v.Bytes()), nil
for i := 0; i < v.Len(); i++ {
bytes[i] = uint8(v.Index(i).Uint())
}
_, err = out.Write(bytes)
return
} }
var fp fieldParameters var fp fieldParameters
for i := 0; i < v.Len(); i++ {
var pre *forkableWriter switch l := v.Len(); l {
pre, out = out.fork() case 0:
err = marshalField(pre, v.Index(i), fp) return bytesEncoder(nil), nil
case 1:
return makeField(v.Index(0), fp)
default:
m := make([]encoder, l)
for i := 0; i < l; i++ {
m[i], err = makeField(v.Index(i), fp)
if err != nil { if err != nil {
return return nil, err
} }
} }
return
return multiEncoder(m), nil
}
case reflect.String: case reflect.String:
switch params.stringType { switch params.stringType {
case TagIA5String: case TagIA5String:
return marshalIA5String(out, v.String()) return makeIA5String(v.String())
case TagPrintableString: case TagPrintableString:
return marshalPrintableString(out, v.String()) return makePrintableString(v.String())
default: default:
return marshalUTF8String(out, v.String()) return makeUTF8String(v.String()), nil
} }
} }
return StructuralError{"unknown Go type"} return nil, StructuralError{"unknown Go type"}
} }
func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) { func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
if !v.IsValid() { if !v.IsValid() {
return fmt.Errorf("asn1: cannot marshal nil value") return nil, fmt.Errorf("asn1: cannot marshal nil value")
} }
// If the field is an interface{} then recurse into it. // If the field is an interface{} then recurse into it.
if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 { if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
return marshalField(out, v.Elem(), params) return makeField(v.Elem(), params)
} }
if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty { if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
return return bytesEncoder(nil), nil
} }
if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) { if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
...@@ -523,7 +519,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -523,7 +519,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
defaultValue.SetInt(*params.defaultValue) defaultValue.SetInt(*params.defaultValue)
if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) { if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
return return bytesEncoder(nil), nil
} }
} }
...@@ -532,37 +528,36 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -532,37 +528,36 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
// behaviour, but it's what Go has traditionally done. // behaviour, but it's what Go has traditionally done.
if params.optional && params.defaultValue == nil { if params.optional && params.defaultValue == nil {
if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) { if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return return bytesEncoder(nil), nil
} }
} }
if v.Type() == rawValueType { if v.Type() == rawValueType {
rv := v.Interface().(RawValue) rv := v.Interface().(RawValue)
if len(rv.FullBytes) != 0 { if len(rv.FullBytes) != 0 {
_, err = out.Write(rv.FullBytes) return bytesEncoder(rv.FullBytes), nil
} else {
err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
if err != nil {
return
} }
_, err = out.Write(rv.Bytes)
} t := new(taggedEncoder)
return
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
t.body = bytesEncoder(rv.Bytes)
return t, nil
} }
tag, isCompound, ok := getUniversalType(v.Type()) tag, isCompound, ok := getUniversalType(v.Type())
if !ok { if !ok {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
return
} }
class := ClassUniversal class := ClassUniversal
if params.timeType != 0 && tag != TagUTCTime { if params.timeType != 0 && tag != TagUTCTime {
return StructuralError{"explicit time type given to non-time member"} return nil, StructuralError{"explicit time type given to non-time member"}
} }
if params.stringType != 0 && tag != TagPrintableString { if params.stringType != 0 && tag != TagPrintableString {
return StructuralError{"explicit string type given to non-string member"} return nil, StructuralError{"explicit string type given to non-string member"}
} }
switch tag { switch tag {
...@@ -574,7 +569,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -574,7 +569,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
for _, r := range v.String() { for _, r := range v.String() {
if r >= utf8.RuneSelf || !isPrintable(byte(r)) { if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
if !utf8.ValidString(v.String()) { if !utf8.ValidString(v.String()) {
return errors.New("asn1: string not valid UTF-8") return nil, errors.New("asn1: string not valid UTF-8")
} }
tag = TagUTF8String tag = TagUTF8String
break break
...@@ -591,46 +586,46 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -591,46 +586,46 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
if params.set { if params.set {
if tag != TagSequence { if tag != TagSequence {
return StructuralError{"non sequence tagged as set"} return nil, StructuralError{"non sequence tagged as set"}
} }
tag = TagSet tag = TagSet
} }
tags, body := out.fork() t := new(taggedEncoder)
err = marshalBody(body, v, params) t.body, err = makeBody(v, params)
if err != nil { if err != nil {
return return nil, err
} }
bodyLen := body.Len() bodyLen := t.body.Len()
var explicitTag *forkableWriter
if params.explicit { if params.explicit {
explicitTag, tags = tags.fork() t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
}
if !params.explicit && params.tag != nil { tt := new(taggedEncoder)
// implicit tag.
tag = *params.tag
class = ClassContextSpecific
}
err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound}) tt.body = t
if err != nil {
return
}
if params.explicit { tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
err = marshalTagAndLength(explicitTag, tagAndLength{
class: ClassContextSpecific, class: ClassContextSpecific,
tag: *params.tag, tag: *params.tag,
length: bodyLen + tags.Len(), length: bodyLen + t.tag.Len(),
isCompound: true, isCompound: true,
}) }))
return tt, nil
} }
return err if params.tag != nil {
// implicit tag.
tag = *params.tag
class = ClassContextSpecific
}
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
return t, nil
} }
// Marshal returns the ASN.1 encoding of val. // Marshal returns the ASN.1 encoding of val.
...@@ -643,13 +638,11 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -643,13 +638,11 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
// printable: causes strings to be marshaled as ASN.1, PrintableString strings. // printable: causes strings to be marshaled as ASN.1, PrintableString strings.
// utf8: causes strings to be marshaled as ASN.1, UTF8 strings // utf8: causes strings to be marshaled as ASN.1, UTF8 strings
func Marshal(val interface{}) ([]byte, error) { func Marshal(val interface{}) ([]byte, error) {
var out bytes.Buffer e, err := makeField(reflect.ValueOf(val), fieldParameters{})
v := reflect.ValueOf(val)
f := newForkableWriter()
err := marshalField(f, v, fieldParameters{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = f.writeTo(&out) b := make([]byte, e.Len())
return out.Bytes(), err e.Encode(b)
return b, nil
} }
...@@ -173,3 +173,13 @@ func TestInvalidUTF8(t *testing.T) { ...@@ -173,3 +173,13 @@ func TestInvalidUTF8(t *testing.T) {
t.Errorf("invalid UTF8 string was accepted") t.Errorf("invalid UTF8 string was accepted")
} }
} }
func BenchmarkMarshal(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, test := range marshalTests {
Marshal(test.in)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment