Commit 904e1136 authored by Russ Cox's avatar Russ Cox

encoding/xml: support generic encoding interfaces

Remove custom support for time.Time.
No new tests: the tests for the time.Time special case
now test the general case.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/12751045
parent 3ec0427a
......@@ -7,12 +7,12 @@ package xml
import (
"bufio"
"bytes"
"encoding"
"fmt"
"io"
"reflect"
"strconv"
"strings"
"time"
)
const (
......@@ -319,6 +319,7 @@ func (p *printer) popPrefix() {
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)
// marshalValue writes one or more XML elements representing val.
......@@ -348,14 +349,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
}
// Check for marshaler.
if typ.Name() != "" && val.CanAddr() {
if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
return p.marshalInterface(pv.Interface().(Marshaler), pv.Type(), finfo, startTemplate)
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
}
if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), typ, finfo, startTemplate)
// Check for text marshaler.
if val.CanInterface() && typ.Implements(textMarshalerType) {
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
}
// Slices and arrays iterate over the elements. They do not have an enclosing tag.
......@@ -416,6 +428,21 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
continue
}
if fv.Kind() == reflect.Interface && fv.IsNil() {
continue
}
if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) {
attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
continue
}
if fv.CanAddr() {
pv := fv.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
......@@ -430,20 +457,27 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
}
}
if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) {
if fv.Kind() == reflect.Interface && fv.IsNil() {
continue
}
attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if fv.CanInterface() && fv.Type().Implements(textMarshalerType) {
text, err := fv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
start.Attr = append(start.Attr, Attr{name, string(text)})
continue
}
if fv.CanAddr() {
pv := fv.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
continue
}
}
// Dereference or skip nil pointer, interface values.
switch fv.Kind() {
case reflect.Ptr, reflect.Interface:
......@@ -490,10 +524,10 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
return p.cachedWriteError()
}
// marshalInterface marshals a Marshaler interface value.
func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) error {
// defaultStart returns the default start element to use,
// given the reflect type, field info, and start template.
func defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement {
var start StartElement
// Precedence for the XML element name is as above,
// except that we do not look inside structs for the first field.
if startTemplate != nil {
......@@ -509,7 +543,11 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field
// since it has the Marshaler methods.
start.Name.Local = typ.Elem().Name()
}
return start
}
// marshalInterface marshals a Marshaler interface value.
func (p *printer) marshalInterface(val Marshaler, start StartElement) error {
// Push a marker onto the tag stack so that MarshalXML
// cannot close the XML tags that it did not open.
p.tags = append(p.tags, Name{})
......@@ -528,6 +566,19 @@ func (p *printer) marshalInterface(val Marshaler, typ reflect.Type, finfo *field
return nil
}
// marshalTextInterface marshals a TextMarshaler interface value.
func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error {
if err := p.writeStart(&start); err != nil {
return err
}
text, err := val.MarshalText()
if err != nil {
return err
}
EscapeText(p, text)
return p.writeEnd(start.Name)
}
// writeStart writes the given start element.
func (p *printer) writeStart(start *StartElement) error {
if start.Name.Local == "" {
......@@ -591,13 +642,7 @@ func (p *printer) writeEnd(name Name) error {
return nil
}
var timeType = reflect.TypeOf(time.Time{})
func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) {
// Normally we don't see structs, but this can happen for an attribute.
if val.Type() == timeType {
return val.Interface().(time.Time).Format(time.RFC3339Nano), nil, nil
}
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil, nil
......@@ -629,10 +674,6 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []
var ddBytes = []byte("--")
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
if val.Type() == timeType {
_, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
return err
}
s := parentStack{p: p}
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
......@@ -651,6 +692,25 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
switch finfo.flags & fMode {
case fCharData:
if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
Escape(p, data)
continue
}
if vf.CanAddr() {
pv := vf.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
Escape(p, data)
continue
}
}
var scratch [64]byte
switch vf.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
......@@ -671,10 +731,6 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
return err
}
}
case reflect.Struct:
if vf.Type() == timeType {
Escape(p, []byte(vf.Interface().(time.Time).Format(time.RFC3339Nano)))
}
}
continue
......
......@@ -6,12 +6,12 @@ package xml
import (
"bytes"
"encoding"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
// BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
......@@ -178,8 +178,7 @@ func receiverType(val interface{}) string {
return "(" + t.String() + ")"
}
// unmarshalInterface unmarshals a single XML element into val,
// which is known to implement Unmarshaler.
// unmarshalInterface unmarshals a single XML element into val.
// start is the opening tag of the element.
func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error {
// Record that decoder must stop at end tag corresponding to start.
......@@ -200,6 +199,31 @@ func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error
return nil
}
// unmarshalTextInterface unmarshals a single XML element into val.
// The chardata contained in the element (but not its children)
// is passed to the text unmarshaler.
func (p *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler, start *StartElement) error {
var buf []byte
depth := 1
for depth > 0 {
t, err := p.Token()
if err != nil {
return err
}
switch t := t.(type) {
case CharData:
if depth == 1 {
buf = append(buf, t...)
}
case StartElement:
depth++
case EndElement:
depth--
}
}
return val.UnmarshalText(buf)
}
// unmarshalAttr unmarshals a single XML attribute into val.
func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
if val.Kind() == reflect.Ptr {
......@@ -221,7 +245,18 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
}
}
// TODO: Check for and use encoding.TextUnmarshaler.
// Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
}
copyValue(val, []byte(attr.Value))
return nil
......@@ -230,6 +265,7 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
var (
unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
)
// Unmarshal a single XML element into val.
......@@ -268,7 +304,16 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
}
}
// TODO: Check for and use encoding.TextUnmarshaler.
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
return p.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler), start)
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return p.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler), start)
}
}
var (
data []byte
......@@ -332,10 +377,6 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
v.Set(reflect.ValueOf(start.Name))
break
}
if typ == timeType {
saveData = v
break
}
sv = v
tinfo, err = getTypeInfo(typ)
......@@ -464,6 +505,23 @@ Loop:
}
}
if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
if saveData.IsValid() && saveData.CanAddr() {
pv := saveData.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
}
if err := copyValue(saveData, data); err != nil {
return err
}
......@@ -486,6 +544,8 @@ Loop:
}
func copyValue(dst reflect.Value, src []byte) (err error) {
dst0 := dst
if dst.Kind() == reflect.Ptr {
if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem()))
......@@ -496,9 +556,9 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
// Save accumulated data.
switch dst.Kind() {
case reflect.Invalid:
// Probably a commendst.
// Probably a comment.
default:
return errors.New("cannot happen: unknown type " + dst.Type().String())
return errors.New("cannot unmarshal into " + dst0.Type().String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits())
if err != nil {
......@@ -531,14 +591,6 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
src = []byte{}
}
dst.SetBytes(src)
case reflect.Struct:
if dst.Type() == timeType {
tv, err := time.Parse(time.RFC3339, string(src))
if err != nil {
return err
}
dst.Set(reflect.ValueOf(tv))
}
}
return nil
}
......
......@@ -67,6 +67,11 @@ func (e StartElement) Copy() StartElement {
return e
}
// End returns the corresponding XML end element.
func (e StartElement) End() EndElement {
return EndElement{e.Name}
}
// An EndElement represents an XML end element.
type EndElement struct {
Name Name
......
......@@ -200,7 +200,7 @@ var pkgDeps = map[string][]string{
"encoding/hex": {"L4"},
"encoding/json": {"L4", "encoding"},
"encoding/pem": {"L4"},
"encoding/xml": {"L4"},
"encoding/xml": {"L4", "encoding"},
"flag": {"L4", "OS"},
"go/build": {"L4", "OS", "GOPARSER"},
"html": {"L4"},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment