Commit 98607d01 authored by Rob Pike's avatar Rob Pike

handle unsupported types safely.

R=rsc
DELTA=154  (71 added, 6 deleted, 77 changed)
OCL=32483
CL=32492
parent 312bd7a1
......@@ -564,7 +564,7 @@ func TestEndToEnd(t *testing.T) {
b := new(bytes.Buffer);
encode(b, t1);
var _t1 T1;
decode(b, getTypeInfo(reflect.Typeof(_t1)).id, &_t1);
decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1);
if !reflect.DeepEqual(t1, &_t1) {
t.Errorf("encode expected %v got %v", *t1, _t1);
}
......@@ -580,7 +580,7 @@ func TestOverflow(t *testing.T) {
}
var it inputT;
var err os.Error;
id := getTypeInfo(reflect.Typeof(it)).id;
id := getTypeInfoNoError(reflect.Typeof(it)).id;
b := new(bytes.Buffer);
// int8
......@@ -733,7 +733,7 @@ func TestNesting(t *testing.T) {
b := new(bytes.Buffer);
encode(b, rt);
var drt RT;
decode(b, getTypeInfo(reflect.Typeof(drt)).id, &drt);
decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt);
if drt.a != rt.a {
t.Errorf("nesting: encode expected %v got %v", *rt, drt);
}
......@@ -775,7 +775,7 @@ func TestAutoIndirection(t *testing.T) {
b := new(bytes.Buffer);
encode(b, t1);
var t0 T0;
t0Id := getTypeInfo(reflect.Typeof(t0)).id;
t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id;
decode(b, t0Id, &t0);
if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0);
......@@ -800,7 +800,7 @@ func TestAutoIndirection(t *testing.T) {
b.Reset();
encode(b, t0);
t1 = T1{};
t1Id := getTypeInfo(reflect.Typeof(t1)).id;
t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id;
decode(b, t1Id, &t1);
if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 {
t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d);
......@@ -810,7 +810,7 @@ func TestAutoIndirection(t *testing.T) {
b.Reset();
encode(b, t0);
t2 = T2{};
t2Id := getTypeInfo(reflect.Typeof(t2)).id;
t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id;
decode(b, t2Id, &t2);
if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d);
......@@ -848,7 +848,7 @@ func TestReorderedFields(t *testing.T) {
rt0.c = 3.14159;
b := new(bytes.Buffer);
encode(b, rt0);
rt0Id := getTypeInfo(reflect.Typeof(rt0)).id;
rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id;
var rt1 RT1;
// Wire type is RT0, local type is RT1.
decode(b, rt0Id, &rt1);
......@@ -886,7 +886,7 @@ func TestIgnoredFields(t *testing.T) {
b := new(bytes.Buffer);
encode(b, it0);
rt0Id := getTypeInfo(reflect.Typeof(it0)).id;
rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id;
var rt1 RT1;
// Wire type is IT0, local type is RT1.
err := decode(b, rt0Id, &rt1);
......@@ -897,3 +897,20 @@ func TestIgnoredFields(t *testing.T) {
t.Errorf("rt1->rt0: expected %v; got %v", it0, rt1);
}
}
type Bad0 struct {
inter interface{};
c float;
}
func TestInvalidField(t *testing.T) {
var bad0 Bad0;
bad0.inter = 17;
b := new(bytes.Buffer);
err := encode(b, &bad0);
if err == nil {
t.Error("expected error; got none")
} else if strings.Index(err.String(), "interface") < 0 {
t.Error("expected type error; got", err)
}
}
......@@ -335,11 +335,11 @@ var encOpMap = map[reflect.Type] encOp {
valueKind("x"): encString,
}
func getEncEngine(rt reflect.Type) *encEngine
func getEncEngine(rt reflect.Type) (*encEngine, os.Error)
// Return the encoding op for the base type under rt and
// the indirection count to reach it.
func encOpFor(rt reflect.Type) (encOp, int) {
func encOpFor(rt reflect.Type) (encOp, int, os.Error) {
typ, indir := indirect(rt);
op, ok := encOpMap[reflect.Typeof(typ)];
if !ok {
......@@ -352,7 +352,10 @@ func encOpFor(rt reflect.Type) (encOp, int) {
break;
}
// Slices have a header; we decode it to find the underlying array.
elemOp, indir := encOpFor(t.Elem());
elemOp, indir, err := encOpFor(t.Elem());
if err != nil {
return nil, 0, err
}
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
slice := (*reflect.SliceHeader)(p);
if slice.Len == 0 {
......@@ -363,15 +366,21 @@ func encOpFor(rt reflect.Type) (encOp, int) {
};
case *reflect.ArrayType:
// True arrays have size in the type.
elemOp, indir := encOpFor(t.Elem());
elemOp, indir, err := encOpFor(t.Elem());
if err != nil {
return nil, 0, err
}
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i);
state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir);
};
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
engine := getEncEngine(typ);
info := getTypeInfo(typ);
engine, err := getEncEngine(typ);
if err != nil {
return nil, 0, err
}
info := getTypeInfoNoError(typ);
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i);
// indirect through info to delay evaluation for recursive structs
......@@ -380,13 +389,13 @@ func encOpFor(rt reflect.Type) (encOp, int) {
}
}
if op == nil {
panicln("can't happen: encode type", rt.String());
return op, indir, os.ErrorString("gob enc: can't happen: encode type" + rt.String());
}
return op, indir
return op, indir, nil
}
// The local Type was compiled from the actual value, so we know it's compatible.
func compileEnc(rt reflect.Type) *encEngine {
func compileEnc(rt reflect.Type) (*encEngine, os.Error) {
srt, ok := rt.(*reflect.StructType);
if !ok {
panicln("can't happen: non-struct");
......@@ -395,23 +404,29 @@ func compileEnc(rt reflect.Type) *encEngine {
engine.instr = make([]encInstr, srt.NumField()+1); // +1 for terminator
for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ {
f := srt.Field(fieldnum);
op, indir := encOpFor(f.Type);
op, indir, err := encOpFor(f.Type);
if err != nil {
return nil, err
}
engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)};
}
engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0};
return engine;
return engine, nil;
}
// typeLock must be held (or we're in initialization and guaranteed single-threaded).
// The reflection type must have all its indirections processed out.
func getEncEngine(rt reflect.Type) *encEngine {
info := getTypeInfo(rt);
func getEncEngine(rt reflect.Type) (*encEngine, os.Error) {
info, err := getTypeInfo(rt);
if err != nil {
return nil, err
}
if info.encoder == nil {
// mark this engine as underway before compiling to handle recursive types.
info.encoder = new(encEngine);
info.encoder = compileEnc(rt);
info.encoder, err = compileEnc(rt);
}
return info.encoder;
return info.encoder, err;
}
func encode(b *bytes.Buffer, e interface{}) os.Error {
......@@ -425,7 +440,10 @@ func encode(b *bytes.Buffer, e interface{}) os.Error {
return os.ErrorString("gob: encode can't handle " + v.Type().String())
}
typeLock.Lock();
engine := getEncEngine(rt);
engine, err := getEncEngine(rt);
typeLock.Unlock();
if err != nil {
return err
}
return encodeStruct(engine, b, v.Addr());
}
......@@ -73,7 +73,7 @@
Maps are not supported yet, but they will be. Interfaces, functions, and channels
cannot be sent in a gob. Attempting to encode a value that contains one will
fail. (TODO(r): fix this - it panics now.)
fail.
The rest of this comment documents the encoding, details that are not important
for most users. Details are presented bottom-up.
......@@ -267,8 +267,12 @@ func (enc *Encoder) sendType(origt reflect.Type) {
// Need to send it.
typeLock.Lock();
info := getTypeInfo(rt);
info, err := getTypeInfo(rt);
typeLock.Unlock();
if err != nil {
enc.state.err = err;
return;
}
// Send the pair (-id, type)
// Id:
encodeInt(enc.state, -int64(info.id));
......@@ -285,6 +289,7 @@ func (enc *Encoder) sendType(origt reflect.Type) {
for i := 0; i < st.NumField(); i++ {
enc.sendType(st.Field(i).Type);
}
return
}
// Encode transmits the data item represented by the empty interface value,
......
......@@ -69,7 +69,7 @@ func TestBasicEncoder(t *testing.T) {
if err != nil {
t.Fatal("error decoding ET1 type:", err);
}
info := getTypeInfo(reflect.Typeof(ET1{}));
info := getTypeInfoNoError(reflect.Typeof(ET1{}));
trueWire1 := &wireType{s: info.id.gobType().(*structType)};
if !reflect.DeepEqual(wire1, trueWire1) {
t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1);
......@@ -90,7 +90,7 @@ func TestBasicEncoder(t *testing.T) {
if err != nil {
t.Fatal("error decoding ET2 type:", err);
}
info = getTypeInfo(reflect.Typeof(ET2{}));
info = getTypeInfoNoError(reflect.Typeof(ET2{}));
trueWire2 := &wireType{s: info.id.gobType().(*structType)};
if !reflect.DeepEqual(wire2, trueWire2) {
t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2);
......@@ -107,7 +107,7 @@ func TestBasicEncoder(t *testing.T) {
}
// 8) The value of et1
newEt1 := new(ET1);
et1Id := getTypeInfo(reflect.Typeof(*newEt1)).id;
et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id;
err = decode(b, et1Id, newEt1);
if err != nil {
t.Fatal("error decoding ET1 value:", err);
......
......@@ -202,7 +202,7 @@ func newStructType(name string) *structType {
}
// Construction
func newType(name string, rt reflect.Type) gobType
func getType(name string, rt reflect.Type) (gobType, os.Error)
// Step through the indirections on a type to discover the base type.
// Return the number of indirections.
......@@ -219,55 +219,63 @@ func indirect(t reflect.Type) (rt reflect.Type, count int) {
return;
}
func newTypeObject(name string, rt reflect.Type) gobType {
func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
switch t := rt.(type) {
// All basic types are easy: they are predefined.
case *reflect.BoolType:
return tBool.gobType()
return tBool.gobType(), nil
case *reflect.IntType:
return tInt.gobType()
return tInt.gobType(), nil
case *reflect.Int8Type:
return tInt.gobType()
return tInt.gobType(), nil
case *reflect.Int16Type:
return tInt.gobType()
return tInt.gobType(), nil
case *reflect.Int32Type:
return tInt.gobType()
return tInt.gobType(), nil
case *reflect.Int64Type:
return tInt.gobType()
return tInt.gobType(), nil
case *reflect.UintType:
return tUint.gobType()
return tUint.gobType(), nil
case *reflect.Uint8Type:
return tUint.gobType()
return tUint.gobType(), nil
case *reflect.Uint16Type:
return tUint.gobType()
return tUint.gobType(), nil
case *reflect.Uint32Type:
return tUint.gobType()
return tUint.gobType(), nil
case *reflect.Uint64Type:
return tUint.gobType()
return tUint.gobType(), nil
case *reflect.UintptrType:
return tUint.gobType()
return tUint.gobType(), nil
case *reflect.FloatType:
return tFloat.gobType()
return tFloat.gobType(), nil
case *reflect.Float32Type:
return tFloat.gobType()
return tFloat.gobType(), nil
case *reflect.Float64Type:
return tFloat.gobType()
return tFloat.gobType(), nil
case *reflect.StringType:
return tString.gobType()
return tString.gobType(), nil
case *reflect.ArrayType:
return newArrayType(name, newType("", t.Elem()), t.Len());
gt, err := getType("", t.Elem());
if err != nil {
return nil, err
}
return newArrayType(name, gt, t.Len()), nil;
case *reflect.SliceType:
// []byte == []uint8 is a special case
if _, ok := t.Elem().(*reflect.Uint8Type); ok {
return tBytes.gobType()
return tBytes.gobType(), nil
}
gt, err := getType(t.Elem().Name(), t.Elem());
if err != nil {
return nil, err
}
return newSliceType(name, newType(t.Elem().Name(), t.Elem()));
return newSliceType(name, gt), nil;
case *reflect.StructType:
// Install the struct type itself before the fields so recursive
......@@ -283,18 +291,24 @@ func newTypeObject(name string, rt reflect.Type) gobType {
if tname == "" {
tname = f.Type.String();
}
field[i] = &fieldType{ f.Name, newType(tname, f.Type).id() };
gt, err := getType(tname, f.Type);
if err != nil {
return nil, err
}
field[i] = &fieldType{ f.Name, gt.id() };
}
strType.field = field;
return strType;
return strType, nil;
default:
panicln("gob NewTypeObject can't handle type", rt.String()); // TODO(r): panic?
return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String());
}
return nil
return nil, nil
}
func newType(name string, rt reflect.Type) gobType {
// getType returns the Gob type describing the given reflect.Type.
// typeLock must be held.
func getType(name string, rt reflect.Type) (gobType, os.Error) {
// Flatten the data structure by collapsing out pointers
for {
pt, ok := rt.(*reflect.PtrType);
......@@ -305,19 +319,13 @@ func newType(name string, rt reflect.Type) gobType {
}
typ, present := types[rt];
if present {
return typ
return typ, nil
}
typ = newTypeObject(name, rt);
types[rt] = typ;
return typ
}
// getType returns the Gob type describing the given reflect.Type.
// typeLock must be held.
func getType(name string, rt reflect.Type) gobType {
// Set lock; all code running under here is synchronized.
t := newType(name, rt);
return t;
typ, err := newTypeObject(name, rt);
if err == nil {
types[rt] = typ
}
return typ, err
}
func checkId(want, got typeId) {
......@@ -371,7 +379,7 @@ var typeInfoMap = make(map[reflect.Type] *typeInfo) // protected by typeLock
// The reflection type must have all its indirections processed out.
// typeLock must be held.
func getTypeInfo(rt reflect.Type) *typeInfo {
func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
if pt, ok := rt.(*reflect.PtrType); ok {
panicln("pointer type in getTypeInfo:", rt.String())
}
......@@ -379,12 +387,25 @@ func getTypeInfo(rt reflect.Type) *typeInfo {
if !ok {
info = new(typeInfo);
name := rt.Name();
info.id = getType(name, rt).id();
gt, err := getType(name, rt);
if err != nil {
return nil, err
}
info.id = gt.id();
// assume it's a struct type
info.wire = &wireType{info.id.gobType().(*structType)};
typeInfoMap[rt] = info;
}
return info;
return info, nil;
}
// Called only when a panic is acceptable and unexpected.
func getTypeInfoNoError(rt reflect.Type) *typeInfo {
t, err := getTypeInfo(rt);
if err != nil {
panicln("getTypeInfo:", err.String());
}
return t
}
func init() {
......@@ -396,9 +417,9 @@ func init() {
// The string for tBytes is "bytes" not "[]byte" to signify its specialness.
tBytes = bootstrapType("bytes", make([]byte, 0), 5);
tString= bootstrapType("string", "", 6);
tWireType = getTypeInfo(reflect.Typeof(wireType{})).id;
tWireType = getTypeInfoNoError(reflect.Typeof(wireType{})).id;
checkId(7, tWireType);
checkId(8, getTypeInfo(reflect.Typeof(structType{})).id);
checkId(9, getTypeInfo(reflect.Typeof(commonType{})).id);
checkId(10, getTypeInfo(reflect.Typeof(fieldType{})).id);
checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id);
checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id);
checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id);
}
......@@ -27,7 +27,11 @@ var basicTypes = []typeT {
func getTypeUnlocked(name string, rt reflect.Type) gobType {
typeLock.Lock();
defer typeLock.Unlock();
return getType(name, rt);
t, err := getType(name, rt);
if err != nil {
panicln("getTypeUnlocked:", err.String())
}
return t;
}
// Sanity checks
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment