Commit ffbd31e9 authored by Augusto Roman's avatar Augusto Roman Committed by Brad Fitzpatrick

encoding/json: allow non-string type keys for (un-)marshal

This CL allows JSON-encoding & -decoding maps whose keys are types that
implement encoding.TextMarshaler / TextUnmarshaler.

During encode, the map keys are marshaled upfront so that they can be
sorted.

Fixes #12146

Change-Id: I43809750a7ad82a3603662f095c7baf75fd172da
Reviewed-on: https://go-review.googlesource.com/20356
Run-TryBot: Caleb Spare <cespare@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent acefcb73
...@@ -61,10 +61,11 @@ import ( ...@@ -61,10 +61,11 @@ import (
// If the JSON array is smaller than the Go array, // If the JSON array is smaller than the Go array,
// the additional Go array elements are set to zero values. // the additional Go array elements are set to zero values.
// //
// To unmarshal a JSON object into a string-keyed map, Unmarshal first // To unmarshal a JSON object into a map, Unmarshal first establishes a map to
// establishes a map to use, If the map is nil, Unmarshal allocates a new map. // use, If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal
// Otherwise Unmarshal reuses the existing map, keeping existing entries. // reuses the existing map, keeping existing entries. Unmarshal then stores key-
// Unmarshal then stores key-value pairs from the JSON object into the map. // value pairs from the JSON object into the map. The map's key type must
// either be a string or implement encoding.TextUnmarshaler.
// //
// If a JSON value is not appropriate for a given target type, // If a JSON value is not appropriate for a given target type,
// or if a JSON number overflows the target type, Unmarshal // or if a JSON number overflows the target type, Unmarshal
...@@ -549,6 +550,7 @@ func (d *decodeState) array(v reflect.Value) { ...@@ -549,6 +550,7 @@ func (d *decodeState) array(v reflect.Value) {
} }
var nullLiteral = []byte("null") var nullLiteral = []byte("null")
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
// object consumes an object from d.data[d.off-1:], decoding into the value v. // object consumes an object from d.data[d.off-1:], decoding into the value v.
// the first byte ('{') of the object has been read already. // the first byte ('{') of the object has been read already.
...@@ -577,12 +579,15 @@ func (d *decodeState) object(v reflect.Value) { ...@@ -577,12 +579,15 @@ func (d *decodeState) object(v reflect.Value) {
return return
} }
// Check type of target: struct or map[string]T // Check type of target:
// struct or
// map[string]T or map[encoding.TextUnmarshaler]T
switch v.Kind() { switch v.Kind() {
case reflect.Map: case reflect.Map:
// map must have string kind // Map key must either have string kind or be an encoding.TextUnmarshaler.
t := v.Type() t := v.Type()
if t.Key().Kind() != reflect.String { if t.Key().Kind() != reflect.String &&
!reflect.PtrTo(t.Key()).Implements(textUnmarshalerType) {
d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)})
d.off-- d.off--
d.next() // skip over { } in input d.next() // skip over { } in input
...@@ -687,7 +692,18 @@ func (d *decodeState) object(v reflect.Value) { ...@@ -687,7 +692,18 @@ func (d *decodeState) object(v reflect.Value) {
// Write value back to map; // Write value back to map;
// if using struct, subv points into struct already. // if using struct, subv points into struct already.
if v.Kind() == reflect.Map { if v.Kind() == reflect.Map {
kv := reflect.ValueOf(key).Convert(v.Type().Key()) kt := v.Type().Key()
var kv reflect.Value
switch {
case kt.Kind() == reflect.String:
kv = reflect.ValueOf(key).Convert(v.Type().Key())
case reflect.PtrTo(kt).Implements(textUnmarshalerType):
kv = reflect.New(v.Type().Key())
d.literalStore(item, kv, true)
kv = kv.Elem()
default:
panic("json: Unexpected key type") // should never occur
}
v.SetMapIndex(kv, subv) v.SetMapIndex(kv, subv)
} }
......
...@@ -7,6 +7,7 @@ package json ...@@ -7,6 +7,7 @@ package json
import ( import (
"bytes" "bytes"
"encoding" "encoding"
"errors"
"fmt" "fmt"
"image" "image"
"net" "net"
...@@ -68,16 +69,20 @@ type ustruct struct { ...@@ -68,16 +69,20 @@ type ustruct struct {
} }
type unmarshalerText struct { type unmarshalerText struct {
T bool A, B string
} }
// needed for re-marshaling tests // needed for re-marshaling tests
func (u *unmarshalerText) MarshalText() ([]byte, error) { func (u unmarshalerText) MarshalText() ([]byte, error) {
return []byte(""), nil return []byte(u.A + ":" + u.B), nil
} }
func (u *unmarshalerText) UnmarshalText(b []byte) error { func (u *unmarshalerText) UnmarshalText(b []byte) error {
*u = unmarshalerText{true} // All we need to see that UnmarshalText is called. pos := bytes.Index(b, []byte(":"))
if pos == -1 {
return errors.New("missing separator")
}
u.A, u.B = string(b[:pos]), string(b[pos+1:])
return nil return nil
} }
...@@ -95,12 +100,16 @@ var ( ...@@ -95,12 +100,16 @@ var (
umslicep = new([]unmarshaler) umslicep = new([]unmarshaler)
umstruct = ustruct{unmarshaler{true}} umstruct = ustruct{unmarshaler{true}}
um0T, um1T unmarshalerText // target2 of unmarshaling um0T, um1T unmarshalerText // target2 of unmarshaling
umpT = &um1T umpType = &um1T
umtrueT = unmarshalerText{true} umtrueXY = unmarshalerText{"x", "y"}
umsliceT = []unmarshalerText{{true}} umsliceXY = []unmarshalerText{{"x", "y"}}
umslicepT = new([]unmarshalerText) umslicepType = new([]unmarshalerText)
umstructT = ustructText{unmarshalerText{true}} umstructType = new(ustructText)
umstructXY = ustructText{unmarshalerText{"x", "y"}}
ummapType = map[unmarshalerText]bool{}
ummapXY = map[unmarshalerText]bool{unmarshalerText{"x", "y"}: true}
) )
// Test data structures for anonymous fields. // Test data structures for anonymous fields.
...@@ -302,14 +311,19 @@ var unmarshalTests = []unmarshalTest{ ...@@ -302,14 +311,19 @@ var unmarshalTests = []unmarshalTest{
{in: `{"T":false}`, ptr: &ump, out: &umtrue}, {in: `{"T":false}`, ptr: &ump, out: &umtrue},
{in: `[{"T":false}]`, ptr: &umslice, out: umslice}, {in: `[{"T":false}]`, ptr: &umslice, out: umslice},
{in: `[{"T":false}]`, ptr: &umslicep, out: &umslice}, {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
{in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct}, {in: `{"M":{"T":"x:y"}}`, ptr: &umstruct, out: umstruct},
// UnmarshalText interface test // UnmarshalText interface test
{in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called {in: `"x:y"`, ptr: &um0T, out: umtrueXY},
{in: `"X"`, ptr: &umpT, out: &umtrueT}, {in: `"x:y"`, ptr: &umpType, out: &umtrueXY},
{in: `["X"]`, ptr: &umsliceT, out: umsliceT}, {in: `["x:y"]`, ptr: &umsliceXY, out: umsliceXY},
{in: `["X"]`, ptr: &umslicepT, out: &umsliceT}, {in: `["x:y"]`, ptr: &umslicepType, out: &umsliceXY},
{in: `{"M":"X"}`, ptr: &umstructT, out: umstructT}, {in: `{"M":"x:y"}`, ptr: umstructType, out: umstructXY},
// Map keys can be encoding.TextUnmarshalers
{in: `{"x:y":true}`, ptr: &ummapType, out: ummapXY},
// If multiple values for the same key exists, only the most recent value is used.
{in: `{"x:y":false,"x:y":true}`, ptr: &ummapType, out: ummapXY},
// Overwriting of data. // Overwriting of data.
// This is different from package xml, but it's what we've always done. // This is different from package xml, but it's what we've always done.
...@@ -426,11 +440,23 @@ var unmarshalTests = []unmarshalTest{ ...@@ -426,11 +440,23 @@ var unmarshalTests = []unmarshalTest{
out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld", out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld",
}, },
// issue 8305 // Used to be issue 8305, but time.Time implements encoding.TextUnmarshaler so this works now.
{ {
in: `{"2009-11-10T23:00:00Z": "hello world"}`, in: `{"2009-11-10T23:00:00Z": "hello world"}`,
ptr: &map[time.Time]string{}, ptr: &map[time.Time]string{},
err: &UnmarshalTypeError{"object", reflect.TypeOf(map[time.Time]string{}), 1}, out: map[time.Time]string{time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC): "hello world"},
},
// issue 8305
{
in: `{"2009-11-10T23:00:00Z": "hello world"}`,
ptr: &map[Point]string{},
err: &UnmarshalTypeError{"object", reflect.TypeOf(map[Point]string{}), 1},
},
{
in: `{"asdf": "hello world"}`,
ptr: &map[unmarshaler]string{},
err: &UnmarshalTypeError{"object", reflect.TypeOf(map[unmarshaler]string{}), 1},
}, },
} }
......
...@@ -116,8 +116,8 @@ import ( ...@@ -116,8 +116,8 @@ import (
// an anonymous struct field in both current and earlier versions, give the field // an anonymous struct field in both current and earlier versions, give the field
// a JSON tag of "-". // a JSON tag of "-".
// //
// Map values encode as JSON objects. // Map values encode as JSON objects. The map's key type must either be a string
// The map's key type must be string; the map keys are used as JSON object // or implement encoding.TextMarshaler. The map keys are used as JSON object
// keys, subject to the UTF-8 coercion described for string values above. // keys, subject to the UTF-8 coercion described for string values above.
// //
// Pointer values encode as the value pointed to. // Pointer values encode as the value pointed to.
...@@ -611,21 +611,31 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) { ...@@ -611,21 +611,31 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
return return
} }
e.WriteByte('{') e.WriteByte('{')
var sv stringValues = v.MapKeys()
sort.Sort(sv) // Extract and sort the keys.
for i, k := range sv { keys := v.MapKeys()
sv := make([]reflectWithString, len(keys))
for i, v := range keys {
sv[i].v = v
if err := sv[i].resolve(); err != nil {
e.error(&MarshalerError{v.Type(), err})
}
}
sort.Sort(byString(sv))
for i, kv := range sv {
if i > 0 { if i > 0 {
e.WriteByte(',') e.WriteByte(',')
} }
e.string(k.String()) e.string(kv.s)
e.WriteByte(':') e.WriteByte(':')
me.elemEnc(e, v.MapIndex(k), false) me.elemEnc(e, v.MapIndex(kv.v), false)
} }
e.WriteByte('}') e.WriteByte('}')
} }
func newMapEncoder(t reflect.Type) encoderFunc { func newMapEncoder(t reflect.Type) encoderFunc {
if t.Key().Kind() != reflect.String { if t.Key().Kind() != reflect.String && !t.Key().Implements(textMarshalerType) {
return unsupportedTypeEncoder return unsupportedTypeEncoder
} }
me := &mapEncoder{typeEncoder(t.Elem())} me := &mapEncoder{typeEncoder(t.Elem())}
...@@ -775,14 +785,29 @@ func typeByIndex(t reflect.Type, index []int) reflect.Type { ...@@ -775,14 +785,29 @@ func typeByIndex(t reflect.Type, index []int) reflect.Type {
return t return t
} }
// stringValues is a slice of reflect.Value holding *reflect.StringValue. type reflectWithString struct {
v reflect.Value
s string
}
func (w *reflectWithString) resolve() error {
if w.v.Kind() == reflect.String {
w.s = w.v.String()
return nil
}
buf, err := w.v.Interface().(encoding.TextMarshaler).MarshalText()
w.s = string(buf)
return err
}
// byString is a slice of reflectWithString where the reflect.Value is either
// a string or an encoding.TextMarshaler.
// It implements the methods to sort by string. // It implements the methods to sort by string.
type stringValues []reflect.Value type byString []reflectWithString
func (sv stringValues) Len() int { return len(sv) } func (sv byString) Len() int { return len(sv) }
func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } func (sv byString) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } func (sv byString) Less(i, j int) bool { return sv[i].s < sv[j].s }
func (sv stringValues) get(i int) string { return sv[i].String() }
// NOTE: keep in sync with stringBytes below. // NOTE: keep in sync with stringBytes below.
func (e *encodeState) string(s string) int { func (e *encodeState) string(s string) int {
......
...@@ -536,3 +536,19 @@ func TestEncodeString(t *testing.T) { ...@@ -536,3 +536,19 @@ func TestEncodeString(t *testing.T) {
} }
} }
} }
func TestTextMarshalerMapKeysAreSorted(t *testing.T) {
b, err := Marshal(map[unmarshalerText]int{
{"x", "y"}: 1,
{"y", "x"}: 2,
{"a", "z"}: 3,
{"z", "a"}: 4,
})
if err != nil {
t.Fatalf("Failed to Marshal text.Marshaler: %v", err)
}
const want = `{"a:z":3,"x:y":1,"y:x":2,"z:a":4}`
if string(b) != want {
t.Errorf("Marshal map with text.Marshaler keys: got %#q, want %#q", b, want)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment