Commit 4cac89cc authored by Larz Conwell's avatar Larz Conwell

Handle write and read errors.

parent b555ed17
......@@ -31,53 +31,48 @@ func (e *Encoder) encode(rv reflect.Value) error {
switch rk := rv.Kind(); rk {
case reflect.Bool:
e.encodeBool(rv.Bool())
return e.encodeBool(rv.Bool())
case reflect.Int, reflect.Int8, reflect.Int64, reflect.Int32, reflect.Int16:
e.encodeInt(reflect.Int, rv.Int())
return e.encodeInt(reflect.Int, rv.Int())
case reflect.Uint8, reflect.Uint64, reflect.Uint, reflect.Uint32, reflect.Uint16:
e.encodeInt(reflect.Uint, int64(rv.Uint()))
return e.encodeInt(reflect.Uint, int64(rv.Uint()))
case reflect.String:
e.encodeString(rv.String())
return e.encodeString(rv.String())
case reflect.Array, reflect.Slice:
if rv.Type().Elem().Kind() == reflect.Uint8 {
e.encodeBytes(rv.Bytes())
return e.encodeBytes(rv.Bytes())
} else {
e.encodeArray(rv)
return e.encodeArray(rv)
}
case reflect.Map:
e.encodeMap(rv)
return e.encodeMap(rv)
case reflect.Struct:
e.encodeStruct(rv)
return e.encodeStruct(rv)
case reflect.Float32, reflect.Float64:
e.encodeFloat(float64(rv.Float()))
return e.encodeFloat(float64(rv.Float()))
case reflect.Interface:
// recurse until we get a concrete type
// could be optmized into a tail call
var err error
err = e.encode(rv.Elem())
if err != nil {
return err
}
return e.encode(rv.Elem())
case reflect.Ptr:
if rv.Elem().Kind() == reflect.Struct {
switch rv.Elem().Interface().(type) {
case None:
e.encodeStruct(rv.Elem())
return nil
return e.encodeStruct(rv.Elem())
}
}
e.encode(rv.Elem())
return e.encode(rv.Elem())
case reflect.Invalid:
e.w.Write([]byte{opNone})
_, err := e.w.Write([]byte{opNone})
return err
default:
panic(fmt.Sprintf("no support for type '%s'", rk.String()))
}
......@@ -85,114 +80,181 @@ func (e *Encoder) encode(rv reflect.Value) error {
return nil
}
func (e *Encoder) encodeArray(arr reflect.Value) {
func (e *Encoder) encodeArray(arr reflect.Value) error {
l := arr.Len()
e.w.Write([]byte{opEmptyList, opMark})
_, err := e.w.Write([]byte{opEmptyList, opMark})
if err != nil {
return err
}
for i := 0; i < l; i++ {
v := arr.Index(i)
e.encode(v)
err = e.encode(v)
if err != nil {
return err
}
}
e.w.Write([]byte{opAppends})
_, err = e.w.Write([]byte{opAppends})
return err
}
func (e *Encoder) encodeBool(b bool) {
func (e *Encoder) encodeBool(b bool) error {
var err error
if b {
e.w.Write([]byte(opTrue))
_, err = e.w.Write([]byte(opTrue))
} else {
e.w.Write([]byte(opFalse))
_, err = e.w.Write([]byte(opFalse))
}
return err
}
func (e *Encoder) encodeBytes(byt []byte) {
func (e *Encoder) encodeBytes(byt []byte) error {
l := len(byt)
if l < 256 {
e.w.Write([]byte{opShortBinstring, byte(l)})
_, err := e.w.Write([]byte{opShortBinstring, byte(l)})
if err != nil {
return err
}
} else {
e.w.Write([]byte{opBinstring})
_, err := e.w.Write([]byte{opBinstring})
if err != nil {
return err
}
var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(l))
e.w.Write(b[:])
_, err = e.w.Write(b[:])
if err != nil {
return err
}
}
e.w.Write(byt)
_, err := e.w.Write(byt)
return err
}
func (e *Encoder) encodeFloat(f float64) {
func (e *Encoder) encodeFloat(f float64) error {
var u uint64
u = math.Float64bits(f)
e.w.Write([]byte{opBinfloat})
_, err := e.w.Write([]byte{opBinfloat})
if err != nil {
return err
}
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(u))
e.w.Write(b[:])
_, err = e.w.Write(b[:])
return err
}
func (e *Encoder) encodeInt(k reflect.Kind, i int64) {
func (e *Encoder) encodeInt(k reflect.Kind, i int64) error {
var err error
// FIXME: need support for 64-bit ints
switch {
case i > 0 && i < math.MaxUint8:
e.w.Write([]byte{opBinint1, byte(i)})
_, err = e.w.Write([]byte{opBinint1, byte(i)})
case i > 0 && i < math.MaxUint16:
e.w.Write([]byte{opBinint2, byte(i), byte(i >> 8)})
_, err = e.w.Write([]byte{opBinint2, byte(i), byte(i >> 8)})
case i >= math.MinInt32 && i <= math.MaxInt32:
e.w.Write([]byte{opBinint})
_, err = e.w.Write([]byte{opBinint})
if err != nil {
return err
}
var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(i))
e.w.Write(b[:])
_, err = e.w.Write(b[:])
default: // int64, but as a string :/
e.w.Write([]byte{opInt})
_, err = e.w.Write([]byte{opInt})
if err != nil {
return err
}
fmt.Fprintf(e.w, "%d\n", i)
}
return err
}
func (e *Encoder) encodeMap(m reflect.Value) {
func (e *Encoder) encodeMap(m reflect.Value) error {
keys := m.MapKeys()
l := len(keys)
e.w.Write([]byte{opEmptyDict})
_, err := e.w.Write([]byte{opEmptyDict})
if err != nil {
return err
}
if l > 0 {
e.w.Write([]byte{opMark})
_, err := e.w.Write([]byte{opMark})
if err != nil {
return err
}
for _, k := range keys {
e.encode(k)
err = e.encode(k)
if err != nil {
return err
}
v := m.MapIndex(k)
e.encode(v)
err = e.encode(v)
if err != nil {
return err
}
}
_, err = e.w.Write([]byte{opSetitems})
if err != nil {
return err
}
e.w.Write([]byte{opSetitems})
}
return nil
}
func (e *Encoder) encodeString(s string) {
e.encodeBytes([]byte(s))
func (e *Encoder) encodeString(s string) error {
return e.encodeBytes([]byte(s))
}
func (e *Encoder) encodeStruct(st reflect.Value) {
func (e *Encoder) encodeStruct(st reflect.Value) error {
typ := st.Type()
// first test if it's one of our internal python structs
if _, ok := st.Interface().(None); ok {
e.w.Write([]byte{opNone})
return
_, err := e.w.Write([]byte{opNone})
return err
}
structTags := getStructTags(st)
e.w.Write([]byte{opEmptyDict, opMark})
_, err := e.w.Write([]byte{opEmptyDict, opMark})
if err != nil {
return err
}
if structTags != nil {
for f, i := range structTags {
e.encodeString(f)
e.encode(st.Field(i))
err := e.encodeString(f)
if err != nil {
return err
}
err = e.encode(st.Field(i))
if err != nil {
return err
}
}
} else {
l := typ.NumField()
......@@ -201,12 +263,21 @@ func (e *Encoder) encodeStruct(st reflect.Value) {
if fty.PkgPath != "" {
continue // skip unexported names
}
e.encodeString(fty.Name)
e.encode(st.Field(i))
err := e.encodeString(fty.Name)
if err != nil {
return err
}
err = e.encode(st.Field(i))
if err != nil {
return err
}
}
}
e.w.Write([]byte{opSetitems})
_, err = e.w.Write([]byte{opSetitems})
return err
}
func reflectValueOf(v interface{}) reflect.Value {
......
......@@ -218,8 +218,8 @@ func (d Decoder) Decode() (interface{}, error) {
case opBinfloat:
err = d.binFloat()
case opProto:
v, _ := d.r.ReadByte()
if v != 2 {
v, err := d.r.ReadByte()
if err == nil && v != 2 {
err = ErrInvalidPickleVersion
}
......@@ -331,7 +331,10 @@ func (d *Decoder) loadBinInt() error {
// Push a 1-byte unsigned int
func (d *Decoder) loadBinInt1() error {
b, _ := d.r.ReadByte()
b, err := d.r.ReadByte()
if err != nil {
return err
}
d.push(int64(b))
return nil
}
......@@ -458,9 +461,13 @@ func (d *Decoder) loadBinString() error {
}
func (d *Decoder) loadShortBinString() error {
b, _ := d.r.ReadByte()
b, err := d.r.ReadByte()
if err != nil {
return err
}
s := make([]byte, b)
_, err := io.ReadFull(d.r, s)
_, err = io.ReadFull(d.r, s)
if err != nil {
return err
}
......@@ -595,7 +602,11 @@ func (d *Decoder) get() error {
}
func (d *Decoder) binGet() error {
b, _ := d.r.ReadByte()
b, err := d.r.ReadByte()
if err != nil {
return err
}
d.push(d.memo[strconv.Itoa(int(b))])
return nil
}
......@@ -669,7 +680,11 @@ func (d *Decoder) loadPut() error {
}
func (d *Decoder) binPut() error {
b, _ := d.r.ReadByte()
b, err := d.r.ReadByte()
if err != nil {
return err
}
d.memo[strconv.Itoa(int(b))] = d.stack[len(d.stack)-1]
return nil
}
......
......@@ -50,7 +50,7 @@ func TestDecode(t *testing.T) {
{"unicode", "V\\u65e5\\u672c\\u8a9e\np0\n.", string("日本語")},
{"empty dict", "(dp0\n.", make(map[interface{}]interface{})},
{"dict with strings", "(dp0\nS'a'\np1\nS'1'\np2\nsS'b'\np3\nS'2'\np4\ns.", map[interface{}]interface{}{"a": "1", "b": "2"}},
{"GLOBAL and REDUCE opcodes", "cfoo\nbar\nS'bing'\n\x85R.", Call{Callable: Class{Module: "foo", Name: "bar"}, Args: []interface{} {"bing"}}},
{"GLOBAL and REDUCE opcodes", "cfoo\nbar\nS'bing'\n\x85R.", Call{Callable: Class{Module: "foo", Name: "bar"}, Args: []interface{}{"bing"}}},
{"LONG_BINPUT opcode", "(lr0000I17\na.", []interface{}{int64(17)}},
{"graphite message1", string(graphitePickle1), []interface{}{map[interface{}]interface{}{"values": []interface{}{float64(473), float64(497), float64(540), float64(1497), float64(1808), float64(1890), float64(2013), float64(1821), float64(1847), float64(2176), float64(2156), float64(1250), float64(2055), float64(1570), None{}, None{}}, "start": int64(1383782400), "step": int64(86400), "end": int64(1385164800), "name": "ZZZZ.UUUUUUUU.CCCCCCCC.MMMMMMMM.XXXXXXXXX.TTT"}}},
{"graphite message2", string(graphitePickle2), []interface{}{map[interface{}]interface{}{"values": []interface{}{float64(473), float64(497), float64(540), float64(1497), float64(1808), float64(1890), float64(2013), float64(1821), float64(1847), float64(2176), float64(2156), float64(1250), float64(2055), float64(1570), None{}, None{}}, "start": int64(1383782400), "step": int64(86400), "end": int64(1385164800), "name": "user.login.area.machine.metric.minute"}}},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment