Commit 0442087f authored by Gustavo Niemeyer's avatar Gustavo Niemeyer

encoding/xml: bring API closer to other packages

Includes gofix module. The only case not covered should be
xml.Unmarshal, since it remains with a similar interface, and
would require introspecting the type of its first argument
better.

Fixes #2626.

R=golang-dev, rsc, gustavo
CC=golang-dev
https://golang.org/cl/5574053
parent 6d7e9382
......@@ -587,7 +587,7 @@ func commitPoll(key, pkg string) {
var logStruct struct {
Log []HgLog
}
err = xml.Unmarshal(strings.NewReader("<Top>"+data+"</Top>"), &logStruct)
err = xml.Unmarshal([]byte("<Top>"+data+"</Top>"), &logStruct)
if err != nil {
log.Printf("unmarshal hg log: %v", err)
return
......
......@@ -115,9 +115,9 @@ func loadCodewalk(filename string) (*Codewalk, error) {
}
defer f.Close()
cw := new(Codewalk)
p := xml.NewParser(f)
p.Entity = xml.HTMLEntity
err = p.Unmarshal(cw, nil)
d := xml.NewDecoder(f)
d.Entity = xml.HTMLEntity
err = d.Decode(cw)
if err != nil {
return nil, &os.PathError{"parsing", filename, err}
}
......
......@@ -42,6 +42,7 @@ GOFILES=\
timefileinfo.go\
typecheck.go\
url.go\
xmlapi.go\
include ../../Make.cmd
......
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"go/ast"
)
func init() {
register(xmlapiFix)
}
var xmlapiFix = fix{
"xmlapi",
"2012-01-23",
xmlapi,
`
Make encoding/xml's API look more like the rest of the encoding packages.
http://codereview.appspot.com/5574053
`,
}
var xmlapiTypeConfig = &TypeConfig{
Func: map[string]string{
"xml.NewParser": "xml.Parser",
},
}
func xmlapi(f *ast.File) bool {
if !imports(f, "encoding/xml") {
return false
}
typeof, _ := typecheck(xmlapiTypeConfig, f)
fixed := false
walk(f, func(n interface{}) {
s, ok := n.(*ast.SelectorExpr)
if ok && typeof[s.X] == "xml.Parser" && s.Sel.Name == "Unmarshal" {
s.Sel.Name = "DecodeElement"
fixed = true
return
}
if ok && isPkgDot(s, "xml", "Parser") {
s.Sel.Name = "Decoder"
fixed = true
return
}
call, ok := n.(*ast.CallExpr)
if !ok {
return
}
switch {
case len(call.Args) == 2 && isPkgDot(call.Fun, "xml", "Marshal"):
*call = xmlMarshal(call.Args)
fixed = true
// Can't fix without further diving into the type of call.Args[0].
//case len(call.Args) == 2 && isPkgDot(call.Fun, "xml", "Unmarshal"):
// *call = xmlUnmarshal(call.Args)
// fixed = true
case len(call.Args) == 1 && isPkgDot(call.Fun, "xml", "NewParser"):
sel := call.Fun.(*ast.SelectorExpr).Sel
sel.Name = "NewDecoder"
fixed = true
}
})
return fixed
}
func xmlMarshal(args []ast.Expr) ast.CallExpr {
return xmlCallChain("NewEncoder", "Encode", args)
}
func xmlUnmarshal(args []ast.Expr) ast.CallExpr {
return xmlCallChain("NewDecoder", "Decode", args)
}
func xmlCallChain(first, second string, args []ast.Expr) ast.CallExpr {
return ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("xml"),
Sel: ast.NewIdent(first),
},
Args: args[:1],
},
Sel: ast.NewIdent(second),
},
Args: args[1:2],
}
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
addTestCases(xmlapiTests, xmlapi)
}
var xmlapiTests = []testCase{
{
Name: "xmlapi.0",
In: `package main
import "encoding/xml"
func f() {
xml.Marshal(a, b)
xml.Unmarshal(a, b)
p1 := xml.NewParser(stream)
p1.Unmarshal(v, start)
var p2 xml.Parser
p2.Unmarshal(v, start)
}
`,
Out: `package main
import "encoding/xml"
func f() {
xml.NewEncoder(a).Encode(b)
xml.Unmarshal(a, b)
p1 := xml.NewDecoder(stream)
p1.DecodeElement(v, start)
var p2 xml.Decoder
p2.DecodeElement(v, start)
}
`,
},
}
......@@ -26,11 +26,7 @@ type Marshaler interface {
MarshalXML() ([]byte, error)
}
type printer struct {
*bufio.Writer
}
// Marshal writes an XML-formatted representation of v to w.
// Marshal returns the XML encoding of v.
//
// If v implements Marshaler, then Marshal calls its MarshalXML method.
// Otherwise, Marshal uses the following procedure to create the XML.
......@@ -76,7 +72,7 @@ type printer struct {
// Age int `xml:"person>age"`
// }
//
// xml.Marshal(w, &Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42})
// xml.Marshal(&Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42})
//
// would be marshalled as:
//
......@@ -91,13 +87,38 @@ type printer struct {
// </result>
//
// Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(w io.Writer, v interface{}) (err error) {
p := &printer{bufio.NewWriter(w)}
err = p.marshalValue(reflect.ValueOf(v), nil)
p.Flush()
func Marshal(v interface{}) ([]byte, error) {
var b bytes.Buffer
if err := NewEncoder(&b).Encode(v); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// An Encoder writes XML data to an output stream.
type Encoder struct {
printer
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{printer{bufio.NewWriter(w)}}
}
// Encode writes the XML encoding of v to the stream.
//
// See the documentation for Marshal for details about the conversion
// of Go values to XML.
func (enc *Encoder) Encode(v interface{}) error {
err := enc.marshalValue(reflect.ValueOf(v), nil)
enc.Flush()
return err
}
type printer struct {
*bufio.Writer
}
func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
if !val.IsValid() {
return nil
......
......@@ -5,7 +5,6 @@
package xml
import (
"bytes"
"reflect"
"strconv"
"strings"
......@@ -619,13 +618,12 @@ func TestMarshal(t *testing.T) {
if test.UnmarshalOnly {
continue
}
buf := bytes.NewBuffer(nil)
err := Marshal(buf, test.Value)
data, err := Marshal(test.Value)
if err != nil {
t.Errorf("#%d: Error: %s", idx, err)
continue
}
if got, want := buf.String(), test.ExpectXML; got != want {
if got, want := string(data), test.ExpectXML; got != want {
if strings.Contains(want, "\n") {
t.Errorf("#%d: marshal(%#v):\nHAVE:\n%s\nWANT:\n%s", idx, test.Value, got, want)
} else {
......@@ -666,8 +664,7 @@ var marshalErrorTests = []struct {
func TestMarshalErrors(t *testing.T) {
for idx, test := range marshalErrorTests {
buf := bytes.NewBuffer(nil)
err := Marshal(buf, test.Value)
_, err := Marshal(test.Value)
if err == nil || err.Error() != test.Err {
t.Errorf("#%d: marshal(%#v) = [error] %v, want %v", idx, test.Value, err, test.Err)
}
......@@ -691,8 +688,7 @@ func TestUnmarshal(t *testing.T) {
vt := reflect.TypeOf(test.Value)
dest := reflect.New(vt.Elem()).Interface()
buffer := bytes.NewBufferString(test.ExpectXML)
err := Unmarshal(buffer, dest)
err := Unmarshal([]byte(test.ExpectXML), dest)
switch fix := dest.(type) {
case *Feed:
......@@ -711,17 +707,14 @@ func TestUnmarshal(t *testing.T) {
}
func BenchmarkMarshal(b *testing.B) {
buf := bytes.NewBuffer(nil)
for i := 0; i < b.N; i++ {
Marshal(buf, atomValue)
buf.Truncate(0)
Marshal(atomValue)
}
}
func BenchmarkUnmarshal(b *testing.B) {
xml := []byte(atomXml)
for i := 0; i < b.N; i++ {
buffer := bytes.NewBuffer(xml)
Unmarshal(buffer, &Feed{})
Unmarshal(xml, &Feed{})
}
}
......@@ -7,7 +7,6 @@ package xml
import (
"bytes"
"errors"
"io"
"reflect"
"strconv"
"strings"
......@@ -20,10 +19,10 @@ import (
// See package json for a textual representation more suitable
// to data structures.
// Unmarshal parses an XML element from r and uses the
// reflect library to fill in an arbitrary struct, slice, or string
// pointed at by val. Well-formed data that does not fit
// into val is discarded.
// Unmarshal parses the XML-encoded data and stores the result in
// the value pointed to by v, which must be an arbitrary struct,
// slice, or string. Well-formed data that does not fit into v is
// discarded.
//
// For example, given these definitions:
//
......@@ -59,7 +58,7 @@ import (
// <address>123 Main Street</address>
// </result>
//
// via Unmarshal(r, &result) is equivalent to assigning
// via Unmarshal(data, &result) is equivalent to assigning
//
// r = Result{
// xml.Name{Local: "result"},
......@@ -157,18 +156,26 @@ import (
// Unmarshal maps an XML element to a pointer by setting the pointer
// to a freshly allocated value and then mapping the element to that value.
//
func Unmarshal(r io.Reader, val interface{}) error {
v := reflect.ValueOf(val)
if v.Kind() != reflect.Ptr {
func Unmarshal(data []byte, v interface{}) error {
return NewDecoder(bytes.NewBuffer(data)).Decode(v)
}
// Decode works like xml.Unmarshal, except it reads the decoder
// stream to find the start element.
func (d *Decoder) Decode(v interface{}) error {
return d.DecodeElement(v, nil)
}
// DecodeElement works like xml.Unmarshal except that it takes
// a pointer to the start XML element to decode into v.
// It is useful when a client reads some raw XML tokens itself
// but also wants to defer to Unmarshal for some elements.
func (d *Decoder) DecodeElement(v interface{}, start *StartElement) error {
val := reflect.ValueOf(v)
if val.Kind() != reflect.Ptr {
return errors.New("non-pointer passed to Unmarshal")
}
p := NewParser(r)
elem := v.Elem()
err := p.unmarshal(elem, nil)
if err != nil {
return err
}
return nil
return d.unmarshal(val.Elem(), start)
}
// An UnmarshalError represents an error in the unmarshalling process.
......@@ -176,22 +183,8 @@ type UnmarshalError string
func (e UnmarshalError) Error() string { return string(e) }
// The Parser's Unmarshal method is like xml.Unmarshal
// except that it can be passed a pointer to the initial start element,
// useful when a client reads some raw XML tokens itself
// but also defers to Unmarshal for some elements.
// Passing a nil start element indicates that Unmarshal should
// read the token stream to find the start element.
func (p *Parser) Unmarshal(val interface{}, start *StartElement) error {
v := reflect.ValueOf(val)
if v.Kind() != reflect.Ptr {
return errors.New("non-pointer passed to Unmarshal")
}
return p.unmarshal(v.Elem(), start)
}
// Unmarshal a single XML element into val.
func (p *Parser) unmarshal(val reflect.Value, start *StartElement) error {
func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
// Find start element if we need it.
if start == nil {
for {
......@@ -484,9 +477,9 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
// unmarshalPath walks down an XML structure looking for wanted
// paths, and calls unmarshal on them.
// The consumed result tells whether XML elements have been consumed
// from the Parser until start's matching end element, or if it's
// from the Decoder until start's matching end element, or if it's
// still untouched because start is uninteresting for sv's fields.
func (p *Parser) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) {
func (p *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) {
recurse := false
Loop:
for i := range tinfo.fields {
......@@ -550,7 +543,7 @@ Loop:
// Read tokens until we find the end element.
// Token is taking care of making sure the
// end element matches the start element we saw.
func (p *Parser) Skip() error {
func (p *Decoder) Skip() error {
for {
tok, err := p.Token()
if err != nil {
......
......@@ -6,7 +6,6 @@ package xml
import (
"reflect"
"strings"
"testing"
)
......@@ -14,7 +13,7 @@ import (
func TestUnmarshalFeed(t *testing.T) {
var f Feed
if err := Unmarshal(strings.NewReader(atomFeedString), &f); err != nil {
if err := Unmarshal([]byte(atomFeedString), &f); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if !reflect.DeepEqual(f, atomFeed) {
......@@ -281,7 +280,7 @@ var pathTests = []interface{}{
func TestUnmarshalPaths(t *testing.T) {
for _, pt := range pathTests {
v := reflect.New(reflect.TypeOf(pt).Elem()).Interface()
if err := Unmarshal(strings.NewReader(pathTestString), v); err != nil {
if err := Unmarshal([]byte(pathTestString), v); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if !reflect.DeepEqual(v, pt) {
......@@ -331,7 +330,7 @@ var badPathTests = []struct {
func TestUnmarshalBadPaths(t *testing.T) {
for _, tt := range badPathTests {
err := Unmarshal(strings.NewReader(pathTestString), tt.v)
err := Unmarshal([]byte(pathTestString), tt.v)
if !reflect.DeepEqual(err, tt.e) {
t.Fatalf("Unmarshal with %#v didn't fail properly:\nhave %#v,\nwant %#v", tt.v, err, tt.e)
}
......@@ -350,7 +349,7 @@ type TestThree struct {
func TestUnmarshalWithoutNameType(t *testing.T) {
var x TestThree
if err := Unmarshal(strings.NewReader(withoutNameTypeData), &x); err != nil {
if err := Unmarshal([]byte(withoutNameTypeData), &x); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if x.Attr != OK {
......
This diff is collapsed.
......@@ -5,7 +5,6 @@
package xml
import (
"bytes"
"io"
"reflect"
"strings"
......@@ -155,8 +154,8 @@ var xmlInput = []string{
}
func TestRawToken(t *testing.T) {
p := NewParser(strings.NewReader(testInput))
testRawToken(t, p, rawTokens)
d := NewDecoder(strings.NewReader(testInput))
testRawToken(t, d, rawTokens)
}
type downCaser struct {
......@@ -179,27 +178,27 @@ func (d *downCaser) Read(p []byte) (int, error) {
func TestRawTokenAltEncoding(t *testing.T) {
sawEncoding := ""
p := NewParser(strings.NewReader(testInputAltEncoding))
p.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
d := NewDecoder(strings.NewReader(testInputAltEncoding))
d.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
sawEncoding = charset
if charset != "x-testing-uppercase" {
t.Fatalf("unexpected charset %q", charset)
}
return &downCaser{t, input.(io.ByteReader)}, nil
}
testRawToken(t, p, rawTokensAltEncoding)
testRawToken(t, d, rawTokensAltEncoding)
}
func TestRawTokenAltEncodingNoConverter(t *testing.T) {
p := NewParser(strings.NewReader(testInputAltEncoding))
token, err := p.RawToken()
d := NewDecoder(strings.NewReader(testInputAltEncoding))
token, err := d.RawToken()
if token == nil {
t.Fatalf("expected a token on first RawToken call")
}
if err != nil {
t.Fatal(err)
}
token, err = p.RawToken()
token, err = d.RawToken()
if token != nil {
t.Errorf("expected a nil token; got %#v", token)
}
......@@ -213,9 +212,9 @@ func TestRawTokenAltEncodingNoConverter(t *testing.T) {
}
}
func testRawToken(t *testing.T, p *Parser, rawTokens []Token) {
func testRawToken(t *testing.T, d *Decoder, rawTokens []Token) {
for i, want := range rawTokens {
have, err := p.RawToken()
have, err := d.RawToken()
if err != nil {
t.Fatalf("token %d: unexpected error: %s", i, err)
}
......@@ -258,10 +257,10 @@ var nestedDirectivesTokens = []Token{
}
func TestNestedDirectives(t *testing.T) {
p := NewParser(strings.NewReader(nestedDirectivesInput))
d := NewDecoder(strings.NewReader(nestedDirectivesInput))
for i, want := range nestedDirectivesTokens {
have, err := p.Token()
have, err := d.Token()
if err != nil {
t.Fatalf("token %d: unexpected error: %s", i, err)
}
......@@ -272,10 +271,10 @@ func TestNestedDirectives(t *testing.T) {
}
func TestToken(t *testing.T) {
p := NewParser(strings.NewReader(testInput))
d := NewDecoder(strings.NewReader(testInput))
for i, want := range cookedTokens {
have, err := p.Token()
have, err := d.Token()
if err != nil {
t.Fatalf("token %d: unexpected error: %s", i, err)
}
......@@ -287,9 +286,9 @@ func TestToken(t *testing.T) {
func TestSyntax(t *testing.T) {
for i := range xmlInput {
p := NewParser(strings.NewReader(xmlInput[i]))
d := NewDecoder(strings.NewReader(xmlInput[i]))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
for _, err = d.Token(); err == nil; _, err = d.Token() {
}
if _, ok := err.(*SyntaxError); !ok {
t.Fatalf(`xmlInput "%s": expected SyntaxError not received`, xmlInput[i])
......@@ -368,8 +367,7 @@ const testScalarsInput = `<allscalars>
func TestAllScalars(t *testing.T) {
var a allScalars
buf := bytes.NewBufferString(testScalarsInput)
err := Unmarshal(buf, &a)
err := Unmarshal([]byte(testScalarsInput), &a)
if err != nil {
t.Fatal(err)
......@@ -386,8 +384,7 @@ type item struct {
func TestIssue569(t *testing.T) {
data := `<item><Field_a>abcd</Field_a></item>`
var i item
buf := bytes.NewBufferString(data)
err := Unmarshal(buf, &i)
err := Unmarshal([]byte(data), &i)
if err != nil || i.Field_a != "abcd" {
t.Fatal("Expecting abcd")
......@@ -396,9 +393,9 @@ func TestIssue569(t *testing.T) {
func TestUnquotedAttrs(t *testing.T) {
data := "<tag attr=azAZ09:-_\t>"
p := NewParser(strings.NewReader(data))
p.Strict = false
token, err := p.Token()
d := NewDecoder(strings.NewReader(data))
d.Strict = false
token, err := d.Token()
if _, ok := err.(*SyntaxError); ok {
t.Errorf("Unexpected error: %v", err)
}
......@@ -422,9 +419,9 @@ func TestValuelessAttrs(t *testing.T) {
{"<input checked />", "input", "checked"},
}
for _, test := range tests {
p := NewParser(strings.NewReader(test[0]))
p.Strict = false
token, err := p.Token()
d := NewDecoder(strings.NewReader(test[0]))
d.Strict = false
token, err := d.Token()
if _, ok := err.(*SyntaxError); ok {
t.Errorf("Unexpected error: %v", err)
}
......@@ -472,9 +469,9 @@ func TestCopyTokenStartElement(t *testing.T) {
func TestSyntaxErrorLineNum(t *testing.T) {
testInput := "<P>Foo<P>\n\n<P>Bar</>\n"
p := NewParser(strings.NewReader(testInput))
d := NewDecoder(strings.NewReader(testInput))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
for _, err = d.Token(); err == nil; _, err = d.Token() {
}
synerr, ok := err.(*SyntaxError)
if !ok {
......@@ -487,41 +484,41 @@ func TestSyntaxErrorLineNum(t *testing.T) {
func TestTrailingRawToken(t *testing.T) {
input := `<FOO></FOO> `
p := NewParser(strings.NewReader(input))
d := NewDecoder(strings.NewReader(input))
var err error
for _, err = p.RawToken(); err == nil; _, err = p.RawToken() {
for _, err = d.RawToken(); err == nil; _, err = d.RawToken() {
}
if err != io.EOF {
t.Fatalf("p.RawToken() = _, %v, want _, io.EOF", err)
t.Fatalf("d.RawToken() = _, %v, want _, io.EOF", err)
}
}
func TestTrailingToken(t *testing.T) {
input := `<FOO></FOO> `
p := NewParser(strings.NewReader(input))
d := NewDecoder(strings.NewReader(input))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
for _, err = d.Token(); err == nil; _, err = d.Token() {
}
if err != io.EOF {
t.Fatalf("p.Token() = _, %v, want _, io.EOF", err)
t.Fatalf("d.Token() = _, %v, want _, io.EOF", err)
}
}
func TestEntityInsideCDATA(t *testing.T) {
input := `<test><![CDATA[ &val=foo ]]></test>`
p := NewParser(strings.NewReader(input))
d := NewDecoder(strings.NewReader(input))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
for _, err = d.Token(); err == nil; _, err = d.Token() {
}
if err != io.EOF {
t.Fatalf("p.Token() = _, %v, want _, io.EOF", err)
t.Fatalf("d.Token() = _, %v, want _, io.EOF", err)
}
}
// The last three tests (respectively one for characters in attribute
// names and two for character entities) pass not because of code
// changed for issue 1259, but instead pass with the given messages
// from other parts of xml.Parser. I provide these to note the
// from other parts of xml.Decoder. I provide these to note the
// current behavior of situations where one might think that character
// range checking would detect the error, but it does not in fact.
......@@ -541,15 +538,15 @@ var characterTests = []struct {
func TestDisallowedCharacters(t *testing.T) {
for i, tt := range characterTests {
p := NewParser(strings.NewReader(tt.in))
d := NewDecoder(strings.NewReader(tt.in))
var err error
for err == nil {
_, err = p.Token()
_, err = d.Token()
}
synerr, ok := err.(*SyntaxError)
if !ok {
t.Fatalf("input %d p.Token() = _, %v, want _, *SyntaxError", i, err)
t.Fatalf("input %d d.Token() = _, %v, want _, *SyntaxError", i, err)
}
if synerr.Msg != tt.err {
t.Fatalf("input %d synerr.Msg wrong: want '%s', got '%s'", i, tt.err, synerr.Msg)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment