Commit 904e1136 authored by Russ Cox's avatar Russ Cox

encoding/xml: support generic encoding interfaces

Remove custom support for time.Time.
No new tests: the tests for the time.Time special case
now test the general case.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12751045
parent 3ec0427a
...@@ -7,12 +7,12 @@ package xml ...@@ -7,12 +7,12 @@ package xml
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"encoding"
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
) )
const ( const (
...@@ -319,6 +319,7 @@ func (p *printer) popPrefix() { ...@@ -319,6 +319,7 @@ func (p *printer) popPrefix() {
var ( var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem() marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
) )
// marshalValue writes one or more XML elements representing val. // marshalValue writes one or more XML elements representing val.
...@@ -348,14 +349,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat ...@@ -348,14 +349,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
} }
// Check for marshaler. // Check for marshaler.
if typ.Name() != "" && val.CanAddr() { if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr() pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerType) { if pv.CanInterface() && pv.Type().Implements(marshalerType) {
return p.marshalInterface(pv.Interface().(Marshaler), pv.Type(), finfo, startTemplate) return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
} }
} }
if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), typ, finfo, startTemplate) // Check for text marshaler.
if val.CanInterface() && typ.Implements(textMarshalerType) {
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
} }
// Slices and arrays iterate over the elements. They do not have an enclosing tag. // Slices and arrays iterate over the elements. They do not have an enclosing tag.
...@@ -416,6 +428,21 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat ...@@ -416,6 +428,21 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
continue continue
} }
if fv.Kind() == reflect.Interface && fv.IsNil() {
continue
}
if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) {
attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
continue
}
if fv.CanAddr() { if fv.CanAddr() {
pv := fv.Addr() pv := fv.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
...@@ -430,20 +457,27 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat ...@@ -430,20 +457,27 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
} }
} }
if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) { if fv.CanInterface() && fv.Type().Implements(textMarshalerType) {
if fv.Kind() == reflect.Interface && fv.IsNil() { text, err := fv.Interface().(encoding.TextMarshaler).MarshalText()
continue
}
attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil { if err != nil {
return err return err
} }
if attr.Name.Local != "" { start.Attr = append(start.Attr, Attr{name, string(text)})
start.Attr = append(start.Attr, attr)
}
continue continue
} }
if fv.CanAddr() {
pv := fv.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
continue
}
}
// Dereference or skip nil pointer, interface values. // Dereference or skip nil pointer, interface values.
switch fv.Kind() { switch fv.Kind() {
case reflect.Ptr, reflect.Interface: case reflect.Ptr, reflect.Interface:
...@@ -490,10 +524,10 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat ...@@ -490,10 +524,10 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
return p.cachedWriteError() return p.cachedWriteError()
} }
// marshalInterface marshals a Marshaler interface value. // defaultStart returns the default start element to use,
func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) error { // given the reflect type, field info, and start template.
func defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement {
var start StartElement var start StartElement
// Precedence for the XML element name is as above, // Precedence for the XML element name is as above,
// except that we do not look inside structs for the first field. // except that we do not look inside structs for the first field.
if startTemplate != nil { if startTemplate != nil {
...@@ -509,7 +543,11 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field ...@@ -509,7 +543,11 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field
// since it has the Marshaler methods. // since it has the Marshaler methods.
start.Name.Local = typ.Elem().Name() start.Name.Local = typ.Elem().Name()
} }
return start
}
// marshalInterface marshals a Marshaler interface value.
func (p *printer) marshalInterface(val Marshaler, start StartElement) error {
// Push a marker onto the tag stack so that MarshalXML // Push a marker onto the tag stack so that MarshalXML
// cannot close the XML tags that it did not open. // cannot close the XML tags that it did not open.
p.tags = append(p.tags, Name{}) p.tags = append(p.tags, Name{})
...@@ -528,6 +566,19 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field ...@@ -528,6 +566,19 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field
return nil return nil
} }
// marshalTextInterface marshals a TextMarshaler interface value.
func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error {
if err := p.writeStart(&start); err != nil {
return err
}
text, err := val.MarshalText()
if err != nil {
return err
}
EscapeText(p, text)
return p.writeEnd(start.Name)
}
// writeStart writes the given start element. // writeStart writes the given start element.
func (p *printer) writeStart(start *StartElement) error { func (p *printer) writeStart(start *StartElement) error {
if start.Name.Local == "" { if start.Name.Local == "" {
...@@ -591,13 +642,7 @@ func (p *printer) writeEnd(name Name) error { ...@@ -591,13 +642,7 @@ func (p *printer) writeEnd(name Name) error {
return nil return nil
} }
var timeType = reflect.TypeOf(time.Time{})
func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) { func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) {
// Normally we don't see structs, but this can happen for an attribute.
if val.Type() == timeType {
return val.Interface().(time.Time).Format(time.RFC3339Nano), nil, nil
}
switch val.Kind() { switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil, nil return strconv.FormatInt(val.Int(), 10), nil, nil
...@@ -629,10 +674,6 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, [] ...@@ -629,10 +674,6 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []
var ddBytes = []byte("--") var ddBytes = []byte("--")
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
if val.Type() == timeType {
_, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
return err
}
s := parentStack{p: p} s := parentStack{p: p}
for i := range tinfo.fields { for i := range tinfo.fields {
finfo := &tinfo.fields[i] finfo := &tinfo.fields[i]
...@@ -651,6 +692,25 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { ...@@ -651,6 +692,25 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
switch finfo.flags & fMode { switch finfo.flags & fMode {
case fCharData: case fCharData:
if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
Escape(p, data)
continue
}
if vf.CanAddr() {
pv := vf.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
Escape(p, data)
continue
}
}
var scratch [64]byte var scratch [64]byte
switch vf.Kind() { switch vf.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
...@@ -671,10 +731,6 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { ...@@ -671,10 +731,6 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
return err return err
} }
} }
case reflect.Struct:
if vf.Type() == timeType {
Escape(p, []byte(vf.Interface().(time.Time).Format(time.RFC3339Nano)))
}
} }
continue continue
......
...@@ -6,12 +6,12 @@ package xml ...@@ -6,12 +6,12 @@ package xml
import ( import (
"bytes" "bytes"
"encoding"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
) )
// BUG(rsc): Mapping between XML elements and data structures is inherently flawed: // BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
...@@ -178,8 +178,7 @@ func receiverType(val interface{}) string { ...@@ -178,8 +178,7 @@ func receiverType(val interface{}) string {
return "(" + t.String() + ")" return "(" + t.String() + ")"
} }
// unmarshalInterface unmarshals a single XML element into val, // unmarshalInterface unmarshals a single XML element into val.
// which is known to implement Unmarshaler.
// start is the opening tag of the element. // start is the opening tag of the element.
func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error { func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error {
// Record that decoder must stop at end tag corresponding to start. // Record that decoder must stop at end tag corresponding to start.
...@@ -200,6 +199,31 @@ func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error ...@@ -200,6 +199,31 @@ func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error
return nil return nil
} }
// unmarshalTextInterface unmarshals a single XML element into val.
// The chardata contained in the element (but not its children)
// is passed to the text unmarshaler.
func (p *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler, start *StartElement) error {
var buf []byte
depth := 1
for depth > 0 {
t, err := p.Token()
if err != nil {
return err
}
switch t := t.(type) {
case CharData:
if depth == 1 {
buf = append(buf, t...)
}
case StartElement:
depth++
case EndElement:
depth--
}
}
return val.UnmarshalText(buf)
}
// unmarshalAttr unmarshals a single XML attribute into val. // unmarshalAttr unmarshals a single XML attribute into val.
func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
if val.Kind() == reflect.Ptr { if val.Kind() == reflect.Ptr {
...@@ -221,7 +245,18 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { ...@@ -221,7 +245,18 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
} }
} }
// TODO: Check for and use encoding.TextUnmarshaler. // Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
}
copyValue(val, []byte(attr.Value)) copyValue(val, []byte(attr.Value))
return nil return nil
...@@ -230,6 +265,7 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { ...@@ -230,6 +265,7 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
var ( var (
unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem() unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
) )
// Unmarshal a single XML element into val. // Unmarshal a single XML element into val.
...@@ -268,7 +304,16 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { ...@@ -268,7 +304,16 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
} }
} }
// TODO: Check for and use encoding.TextUnmarshaler. if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
return p.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler), start)
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return p.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler), start)
}
}
var ( var (
data []byte data []byte
...@@ -332,10 +377,6 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { ...@@ -332,10 +377,6 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
v.Set(reflect.ValueOf(start.Name)) v.Set(reflect.ValueOf(start.Name))
break break
} }
if typ == timeType {
saveData = v
break
}
sv = v sv = v
tinfo, err = getTypeInfo(typ) tinfo, err = getTypeInfo(typ)
...@@ -464,6 +505,23 @@ Loop: ...@@ -464,6 +505,23 @@ Loop:
} }
} }
if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
if saveData.IsValid() && saveData.CanAddr() {
pv := saveData.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
}
if err := copyValue(saveData, data); err != nil { if err := copyValue(saveData, data); err != nil {
return err return err
} }
...@@ -486,6 +544,8 @@ Loop: ...@@ -486,6 +544,8 @@ Loop:
} }
func copyValue(dst reflect.Value, src []byte) (err error) { func copyValue(dst reflect.Value, src []byte) (err error) {
dst0 := dst
if dst.Kind() == reflect.Ptr { if dst.Kind() == reflect.Ptr {
if dst.IsNil() { if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem())) dst.Set(reflect.New(dst.Type().Elem()))
...@@ -496,9 +556,9 @@ func copyValue(dst reflect.Value, src []byte) (err error) { ...@@ -496,9 +556,9 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
// Save accumulated data. // Save accumulated data.
switch dst.Kind() { switch dst.Kind() {
case reflect.Invalid: case reflect.Invalid:
// Probably a commendst. // Probably a comment.
default: default:
return errors.New("cannot happen: unknown type " + dst.Type().String()) return errors.New("cannot unmarshal into " + dst0.Type().String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits()) itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits())
if err != nil { if err != nil {
...@@ -531,14 +591,6 @@ func copyValue(dst reflect.Value, src []byte) (err error) { ...@@ -531,14 +591,6 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
src = []byte{} src = []byte{}
} }
dst.SetBytes(src) dst.SetBytes(src)
case reflect.Struct:
if dst.Type() == timeType {
tv, err := time.Parse(time.RFC3339, string(src))
if err != nil {
return err
}
dst.Set(reflect.ValueOf(tv))
}
} }
return nil return nil
} }
......
...@@ -67,6 +67,11 @@ func (e StartElement) Copy() StartElement { ...@@ -67,6 +67,11 @@ func (e StartElement) Copy() StartElement {
return e return e
} }
// End returns the corresponding XML end element.
func (e StartElement) End() EndElement {
return EndElement{e.Name}
}
// An EndElement represents an XML end element. // An EndElement represents an XML end element.
type EndElement struct { type EndElement struct {
Name Name Name Name
......
...@@ -200,7 +200,7 @@ var pkgDeps = map[string][]string{ ...@@ -200,7 +200,7 @@ var pkgDeps = map[string][]string{
"encoding/hex": {"L4"}, "encoding/hex": {"L4"},
"encoding/json": {"L4", "encoding"}, "encoding/json": {"L4", "encoding"},
"encoding/pem": {"L4"}, "encoding/pem": {"L4"},
"encoding/xml": {"L4"}, "encoding/xml": {"L4", "encoding"},
"flag": {"L4", "OS"}, "flag": {"L4", "OS"},
"go/build": {"L4", "OS", "GOPARSER"}, "go/build": {"L4", "OS", "GOPARSER"},
"html": {"L4"}, "html": {"L4"},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment