Commit c8b3d029 authored by Rob Pike's avatar Rob Pike

gob: make robust when decoding a struct with non-struct data.

The decoder was crashing when handling an rpc that expected
a struct but was delivered something else.  This diagnoses the
problem.  The other direction (expecting non-struct but getting
one) was already handled.

R=rsc
CC=golang-dev
https://golang.org/cl/2246041
parent 42a61b92
...@@ -843,12 +843,17 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng ...@@ -843,12 +843,17 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
return dec.compileSingle(remoteId, rt) return dec.compileSingle(remoteId, rt)
} }
var wireStruct *structType var wireStruct *structType
// Builtin types can come from global pool; the rest must be defined by the decoder // Builtin types can come from global pool; the rest must be defined by the decoder.
// Also we know we're decoding a struct now, so the client must have sent one.
if t, ok := builtinIdToType[remoteId]; ok { if t, ok := builtinIdToType[remoteId]; ok {
wireStruct = t.(*structType) wireStruct, _ = t.(*structType)
} else { } else {
wireStruct = dec.wireType[remoteId].structT wireStruct = dec.wireType[remoteId].structT
} }
if wireStruct == nil {
return nil, os.ErrorString("gob: type mismatch in decoder: want struct type " +
rt.String() + "; got non-struct")
}
engine = new(decEngine) engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.field)) engine.instr = make([]decInstr, len(wireStruct.field))
// Loop over the fields of the wire type. // Loop over the fields of the wire type.
......
...@@ -327,3 +327,31 @@ func TestSingletons(t *testing.T) { ...@@ -327,3 +327,31 @@ func TestSingletons(t *testing.T) {
} }
} }
} }
func TestStructNonStruct(t *testing.T) {
type Struct struct {
a string
}
type NonStruct string
s := Struct{"hello"}
var sp Struct
if err := encAndDec(s, &sp); err != nil {
t.Error(err)
}
var ns NonStruct
if err := encAndDec(s, &ns); err == nil {
t.Error("should get error for struct/non-struct")
} else if strings.Index(err.String(), "type") < 0 {
t.Error("for struct/non-struct expected type error; got", err)
}
// Now try the other way
var nsp NonStruct
if err := encAndDec(ns, &nsp); err != nil {
t.Error(err)
}
if err := encAndDec(ns, &s); err == nil {
t.Error("should get error for non-struct/struct")
} else if strings.Index(err.String(), "type") < 0 {
t.Error("for non-struct/struct expected type error; got", err)
}
}
...@@ -52,10 +52,20 @@ func (t typeId) gobType() gobType { ...@@ -52,10 +52,20 @@ func (t typeId) gobType() gobType {
} }
// string returns the string representation of the type associated with the typeId. // string returns the string representation of the type associated with the typeId.
func (t typeId) string() string { return t.gobType().string() } func (t typeId) string() string {
if t.gobType() == nil {
return "<nil>"
}
return t.gobType().string()
}
// Name returns the name of the type associated with the typeId. // Name returns the name of the type associated with the typeId.
func (t typeId) Name() string { return t.gobType().Name() } func (t typeId) Name() string {
if t.gobType() == nil {
return "<nil>"
}
return t.gobType().Name()
}
// Common elements of all types. // Common elements of all types.
type commonType struct { type commonType struct {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment