Commit a26ab29a authored by Rob Pike's avatar Rob Pike

gob: allow transmission of things other than structs at the top level.

also fix a bug handling nil maps: before, would needlessly send empty map

R=rsc
CC=golang-dev
https://golang.org/cl/1739043
parent e0334ab7
...@@ -1039,7 +1039,7 @@ func TestInvalidField(t *testing.T) { ...@@ -1039,7 +1039,7 @@ func TestInvalidField(t *testing.T) {
type Indirect struct { type Indirect struct {
a ***[3]int a ***[3]int
s ***[]int s ***[]int
m ***map[string]int m ****map[string]int
} }
type Direct struct { type Direct struct {
...@@ -1059,10 +1059,11 @@ func TestIndirectSliceMapArray(t *testing.T) { ...@@ -1059,10 +1059,11 @@ func TestIndirectSliceMapArray(t *testing.T) {
*i.s = new(*[]int) *i.s = new(*[]int)
**i.s = new([]int) **i.s = new([]int)
***i.s = []int{4, 5, 6} ***i.s = []int{4, 5, 6}
i.m = new(**map[string]int) i.m = new(***map[string]int)
*i.m = new(*map[string]int) *i.m = new(**map[string]int)
**i.m = new(map[string]int) **i.m = new(*map[string]int)
***i.m = map[string]int{"one": 1, "two": 2, "three": 3} ***i.m = new(map[string]int)
****i.m = map[string]int{"one": 1, "two": 2, "three": 3}
b := new(bytes.Buffer) b := new(bytes.Buffer)
NewEncoder(b).Encode(i) NewEncoder(b).Encode(i)
dec := NewDecoder(b) dec := NewDecoder(b)
...@@ -1093,12 +1094,12 @@ func TestIndirectSliceMapArray(t *testing.T) { ...@@ -1093,12 +1094,12 @@ func TestIndirectSliceMapArray(t *testing.T) {
t.Error("error: ", err) t.Error("error: ", err)
} }
if len(***i.a) != 3 || (***i.a)[0] != 11 || (***i.a)[1] != 22 || (***i.a)[2] != 33 { if len(***i.a) != 3 || (***i.a)[0] != 11 || (***i.a)[1] != 22 || (***i.a)[2] != 33 {
t.Errorf("indirect to direct: ***i.a is %v not %v", ***i.a, d.a) t.Errorf("direct to indirect: ***i.a is %v not %v", ***i.a, d.a)
} }
if len(***i.s) != 3 || (***i.s)[0] != 44 || (***i.s)[1] != 55 || (***i.s)[2] != 66 { if len(***i.s) != 3 || (***i.s)[0] != 44 || (***i.s)[1] != 55 || (***i.s)[2] != 66 {
t.Errorf("indirect to direct: ***i.s is %v not %v", ***i.s, ***i.s) t.Errorf("direct to indirect: ***i.s is %v not %v", ***i.s, ***i.s)
} }
if len(***i.m) != 3 || (***i.m)["four"] != 4 || (***i.m)["five"] != 5 || (***i.m)["six"] != 6 { if len(****i.m) != 3 || (****i.m)["four"] != 4 || (****i.m)["five"] != 5 || (****i.m)["six"] != 6 {
t.Errorf("indirect to direct: ***i.m is %v not %v", ***i.m, d.m) t.Errorf("direct to indirect: ****i.m is %v not %v", ****i.m, d.m)
} }
} }
...@@ -13,15 +13,13 @@ import ( ...@@ -13,15 +13,13 @@ import (
"math" "math"
"os" "os"
"reflect" "reflect"
"runtime"
"unsafe" "unsafe"
) )
var ( var (
errBadUint = os.ErrorString("gob: encoded unsigned integer out of range") errBadUint = os.ErrorString("gob: encoded unsigned integer out of range")
errBadType = os.ErrorString("gob: unknown type id or corrupted data") errBadType = os.ErrorString("gob: unknown type id or corrupted data")
errRange = os.ErrorString("gob: internal error: field numbers out of bounds") errRange = os.ErrorString("gob: internal error: field numbers out of bounds")
errNotStruct = os.ErrorString("gob: TODO: can only handle structs")
) )
// The global execution state of an instance of the decoder. // The global execution state of an instance of the decoder.
...@@ -389,18 +387,44 @@ type decEngine struct { ...@@ -389,18 +387,44 @@ type decEngine struct {
numInstr int // the number of active instructions numInstr int // the number of active instructions
} }
func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error { // allocate makes sure storage is available for an object of underlying type rtyp
if indir > 0 { // that is indir levels of indirection through p.
up := unsafe.Pointer(p) func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
if indir > 1 { if indir == 0 {
up = decIndirect(up, indir) return p
} }
if *(*unsafe.Pointer)(up) == nil { up := unsafe.Pointer(p)
// Allocate object. if indir > 1 {
*(*unsafe.Pointer)(up) = unsafe.New((*runtime.StructType)(unsafe.Pointer(rtyp))) up = decIndirect(up, indir)
}
p = *(*uintptr)(up)
} }
if *(*unsafe.Pointer)(up) == nil {
// Allocate object.
*(*unsafe.Pointer)(up) = unsafe.New(rtyp)
}
return *(*uintptr)(up)
}
func decodeSingle(engine *decEngine, rtyp reflect.Type, b *bytes.Buffer, p uintptr, indir int) os.Error {
p = allocate(rtyp, p, indir)
state := newDecodeState(b)
state.fieldnum = singletonField
basep := p
delta := int(decodeUint(state))
if delta != 0 {
state.err = os.ErrorString("gob decode: corrupted data: non-zero delta for singleton")
return state.err
}
instr := &engine.instr[singletonField]
ptr := unsafe.Pointer(basep) // offset will be zero
if instr.indir > 1 {
ptr = decIndirect(ptr, instr.indir)
}
instr.op(instr, state, ptr)
return state.err
}
func decodeStruct(engine *decEngine, rtyp *reflect.StructType, b *bytes.Buffer, p uintptr, indir int) os.Error {
p = allocate(rtyp, p, indir)
state := newDecodeState(b) state := newDecodeState(b)
state.fieldnum = -1 state.fieldnum = -1
basep := p basep := p
...@@ -468,12 +492,7 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint ...@@ -468,12 +492,7 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint
func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error { func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error {
if indir > 0 { if indir > 0 {
up := unsafe.Pointer(p) p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect
if *(*unsafe.Pointer)(up) == nil {
// Allocate object.
*(*unsafe.Pointer)(up) = unsafe.New(atyp)
}
p = *(*uintptr)(up)
} }
if n := decodeUint(state); n != uint64(length) { if n := decodeUint(state); n != uint64(length) {
return os.ErrorString("gob: length mismatch in decodeArray") return os.ErrorString("gob: length mismatch in decodeArray")
...@@ -493,12 +512,7 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o ...@@ -493,12 +512,7 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o
func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error { func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error {
if indir > 0 { if indir > 0 {
up := unsafe.Pointer(p) p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect
if *(*unsafe.Pointer)(up) == nil {
// Allocate object.
*(*unsafe.Pointer)(up) = unsafe.New(mtyp)
}
p = *(*uintptr)(up)
} }
up := unsafe.Pointer(p) up := unsafe.Pointer(p)
if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime if *(*unsafe.Pointer)(up) == nil { // maps are represented as a pointer in the runtime
...@@ -806,18 +820,34 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool { ...@@ -806,18 +820,34 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId) bool {
return true return true
} }
func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do
if !dec.compatibleType(rt, remoteId) {
return nil, os.ErrorString("gob: wrong type received for local value " + name)
}
op, indir, err := dec.decOpFor(remoteId, rt, name)
if err != nil {
return nil, err
}
ovfl := os.ErrorString(`value for "` + name + `" out of range`)
engine.instr[singletonField] = decInstr{op, singletonField, indir, 0, ovfl}
engine.numInstr = 1
return
}
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
srt, ok1 := rt.(*reflect.StructType) srt, ok := rt.(*reflect.StructType)
if !ok {
return dec.compileSingle(remoteId, rt)
}
var wireStruct *structType var wireStruct *structType
// Builtin types can come from global pool; the rest must be defined by the decoder // Builtin types can come from global pool; the rest must be defined by the decoder
if t, ok := builtinIdToType[remoteId]; ok { if t, ok := builtinIdToType[remoteId]; ok {
wireStruct = t.(*structType) wireStruct = t.(*structType)
} else { } else {
w, ok2 := dec.wireType[remoteId] wireStruct = dec.wireType[remoteId].structT
if !ok1 || !ok2 {
return nil, errNotStruct
}
wireStruct = w.structT
} }
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.field)) engine.instr = make([]decInstr, len(wireStruct.field))
...@@ -891,20 +921,19 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er ...@@ -891,20 +921,19 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error { func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
// Dereference down to the underlying struct type. // Dereference down to the underlying struct type.
rt, indir := indirect(reflect.Typeof(e)) rt, indir := indirect(reflect.Typeof(e))
st, ok := rt.(*reflect.StructType)
if !ok {
return os.ErrorString("gob: decode can't handle " + rt.String())
}
enginePtr, err := dec.getDecEnginePtr(wireId, rt) enginePtr, err := dec.getDecEnginePtr(wireId, rt)
if err != nil { if err != nil {
return err return err
} }
engine := *enginePtr engine := *enginePtr
if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 { if st, ok := rt.(*reflect.StructType); ok {
name := rt.Name() if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].structT.field) > 0 {
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) name := rt.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
}
return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir)
} }
return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir) return decodeSingle(engine, rt, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir)
} }
func init() { func init() {
......
...@@ -108,8 +108,9 @@ func (dec *Decoder) Decode(e interface{}) os.Error { ...@@ -108,8 +108,9 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
} }
// No, it's a value. // No, it's a value.
// Make sure the type has been defined already. // Make sure the type has been defined already or is a builtin type (for
if dec.wireType[id] == nil { // top-level singleton values).
if dec.wireType[id] == nil && builtinIdToType[id] == nil {
dec.state.err = errBadType dec.state.err = errBadType
break break
} }
......
This diff is collapsed.
...@@ -68,7 +68,7 @@ func (enc *Encoder) send() { ...@@ -68,7 +68,7 @@ func (enc *Encoder) send() {
} }
} }
func (enc *Encoder) sendType(origt reflect.Type) { func (enc *Encoder) sendType(origt reflect.Type) (sent bool) {
// Drill down to the base type. // Drill down to the base type.
rt, _ := indirect(origt) rt, _ := indirect(origt)
...@@ -147,11 +147,6 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -147,11 +147,6 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
enc.state.err = nil enc.state.err = nil
rt, _ := indirect(reflect.Typeof(e)) rt, _ := indirect(reflect.Typeof(e))
// Must be a struct
if _, ok := rt.(*reflect.StructType); !ok {
enc.badType(rt)
return enc.state.err
}
// Sanity check only: encoder should never come in with data present. // Sanity check only: encoder should never come in with data present.
if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 { if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 {
...@@ -163,10 +158,23 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -163,10 +158,23 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// First, have we already sent this type? // First, have we already sent this type?
if _, alreadySent := enc.sent[rt]; !alreadySent { if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it. // No, so send it.
enc.sendType(rt) sent := enc.sendType(rt)
if enc.state.err != nil { if enc.state.err != nil {
return enc.state.err return enc.state.err
} }
// If the type info has still not been transmitted, it means we have
// a singleton basic type (int, []byte etc.) at top level. We don't
// need to send the type info but we do need to update enc.sent.
if !sent {
typeLock.Lock()
info, err := getTypeInfo(rt)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return err
}
enc.sent[rt] = info.id
}
} }
// Identify the type of this top-level value. // Identify the type of this top-level value.
......
...@@ -131,17 +131,10 @@ func TestBadData(t *testing.T) { ...@@ -131,17 +131,10 @@ func TestBadData(t *testing.T) {
corruptDataCheck("\x03now is the time for all good men", errBadType, t) corruptDataCheck("\x03now is the time for all good men", errBadType, t)
} }
// Types not supported by the Encoder (only structs work at the top level). // Types not supported by the Encoder.
// Basic types work implicitly.
var unsupportedValues = []interface{}{ var unsupportedValues = []interface{}{
3,
"hi",
7.2,
[]int{1, 2, 3},
[3]int{1, 2, 3},
make(chan int), make(chan int),
func(a int) bool { return true }, func(a int) bool { return true },
make(map[string]int),
new(interface{}), new(interface{}),
} }
...@@ -275,3 +268,59 @@ func TestDefaultsInArray(t *testing.T) { ...@@ -275,3 +268,59 @@ func TestDefaultsInArray(t *testing.T) {
t.Error(err) t.Error(err)
} }
} }
var testInt int
var testFloat32 float32
var testString string
var testSlice []string
var testMap map[string]int
type SingleTest struct {
in interface{}
out interface{}
err string
}
var singleTests = []SingleTest{
SingleTest{17, &testInt, ""},
SingleTest{float32(17.5), &testFloat32, ""},
SingleTest{"bike shed", &testString, ""},
SingleTest{[]string{"bike", "shed", "paint", "color"}, &testSlice, ""},
SingleTest{map[string]int{"seven": 7, "twelve": 12}, &testMap, ""},
// Decode errors
SingleTest{172, &testFloat32, "wrong type"},
}
func TestSingletons(t *testing.T) {
b := new(bytes.Buffer)
enc := NewEncoder(b)
dec := NewDecoder(b)
for _, test := range singleTests {
b.Reset()
err := enc.Encode(test.in)
if err != nil {
t.Errorf("error encoding %v: %s", test.in, err)
continue
}
err = dec.Decode(test.out)
switch {
case err != nil && test.err == "":
t.Errorf("error decoding %v: %s", test.in, err)
continue
case err == nil && test.err != "":
t.Errorf("expected error decoding %v: %s", test.in, test.err)
continue
case err != nil && test.err != "":
if strings.Index(err.String(), test.err) < 0 {
t.Errorf("wrong error decoding %v: wanted %s, got %v", test.in, test.err, err)
}
continue
}
// Get rid of the pointer in the rhs
val := reflect.NewValue(test.out).(*reflect.PtrValue).Elem().Interface()
if !reflect.DeepEqual(test.in, val) {
t.Errorf("decoding int: expected %v got %v", test.in, val)
}
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment