Commit 92c2c3bf authored by Kamil Kisiel's avatar Kamil Kisiel

Merge pull request #18 from larzconwell/handle-write-err

Handle write and read errors.
parents b555ed17 4cac89cc
...@@ -31,53 +31,48 @@ func (e *Encoder) encode(rv reflect.Value) error { ...@@ -31,53 +31,48 @@ func (e *Encoder) encode(rv reflect.Value) error {
switch rk := rv.Kind(); rk { switch rk := rv.Kind(); rk {
case reflect.Bool: case reflect.Bool:
e.encodeBool(rv.Bool()) return e.encodeBool(rv.Bool())
case reflect.Int, reflect.Int8, reflect.Int64, reflect.Int32, reflect.Int16: case reflect.Int, reflect.Int8, reflect.Int64, reflect.Int32, reflect.Int16:
e.encodeInt(reflect.Int, rv.Int()) return e.encodeInt(reflect.Int, rv.Int())
case reflect.Uint8, reflect.Uint64, reflect.Uint, reflect.Uint32, reflect.Uint16: case reflect.Uint8, reflect.Uint64, reflect.Uint, reflect.Uint32, reflect.Uint16:
e.encodeInt(reflect.Uint, int64(rv.Uint())) return e.encodeInt(reflect.Uint, int64(rv.Uint()))
case reflect.String: case reflect.String:
e.encodeString(rv.String()) return e.encodeString(rv.String())
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if rv.Type().Elem().Kind() == reflect.Uint8 { if rv.Type().Elem().Kind() == reflect.Uint8 {
e.encodeBytes(rv.Bytes()) return e.encodeBytes(rv.Bytes())
} else { } else {
e.encodeArray(rv) return e.encodeArray(rv)
} }
case reflect.Map: case reflect.Map:
e.encodeMap(rv) return e.encodeMap(rv)
case reflect.Struct: case reflect.Struct:
e.encodeStruct(rv) return e.encodeStruct(rv)
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
e.encodeFloat(float64(rv.Float())) return e.encodeFloat(float64(rv.Float()))
case reflect.Interface: case reflect.Interface:
// recurse until we get a concrete type // recurse until we get a concrete type
// could be optmized into a tail call // could be optmized into a tail call
var err error return e.encode(rv.Elem())
err = e.encode(rv.Elem())
if err != nil {
return err
}
case reflect.Ptr: case reflect.Ptr:
if rv.Elem().Kind() == reflect.Struct { if rv.Elem().Kind() == reflect.Struct {
switch rv.Elem().Interface().(type) { switch rv.Elem().Interface().(type) {
case None: case None:
e.encodeStruct(rv.Elem()) return e.encodeStruct(rv.Elem())
return nil
} }
} }
e.encode(rv.Elem()) return e.encode(rv.Elem())
case reflect.Invalid: case reflect.Invalid:
e.w.Write([]byte{opNone}) _, err := e.w.Write([]byte{opNone})
return err
default: default:
panic(fmt.Sprintf("no support for type '%s'", rk.String())) panic(fmt.Sprintf("no support for type '%s'", rk.String()))
} }
...@@ -85,114 +80,181 @@ func (e *Encoder) encode(rv reflect.Value) error { ...@@ -85,114 +80,181 @@ func (e *Encoder) encode(rv reflect.Value) error {
return nil return nil
} }
func (e *Encoder) encodeArray(arr reflect.Value) { func (e *Encoder) encodeArray(arr reflect.Value) error {
l := arr.Len() l := arr.Len()
e.w.Write([]byte{opEmptyList, opMark}) _, err := e.w.Write([]byte{opEmptyList, opMark})
if err != nil {
return err
}
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
v := arr.Index(i) v := arr.Index(i)
e.encode(v) err = e.encode(v)
if err != nil {
return err
}
} }
e.w.Write([]byte{opAppends})
_, err = e.w.Write([]byte{opAppends})
return err
} }
func (e *Encoder) encodeBool(b bool) { func (e *Encoder) encodeBool(b bool) error {
var err error
if b { if b {
e.w.Write([]byte(opTrue)) _, err = e.w.Write([]byte(opTrue))
} else { } else {
e.w.Write([]byte(opFalse)) _, err = e.w.Write([]byte(opFalse))
} }
return err
} }
func (e *Encoder) encodeBytes(byt []byte) { func (e *Encoder) encodeBytes(byt []byte) error {
l := len(byt) l := len(byt)
if l < 256 { if l < 256 {
e.w.Write([]byte{opShortBinstring, byte(l)}) _, err := e.w.Write([]byte{opShortBinstring, byte(l)})
if err != nil {
return err
}
} else { } else {
e.w.Write([]byte{opBinstring}) _, err := e.w.Write([]byte{opBinstring})
if err != nil {
return err
}
var b [4]byte var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(l)) binary.LittleEndian.PutUint32(b[:], uint32(l))
e.w.Write(b[:]) _, err = e.w.Write(b[:])
if err != nil {
return err
}
} }
e.w.Write(byt) _, err := e.w.Write(byt)
return err
} }
func (e *Encoder) encodeFloat(f float64) { func (e *Encoder) encodeFloat(f float64) error {
var u uint64 var u uint64
u = math.Float64bits(f) u = math.Float64bits(f)
e.w.Write([]byte{opBinfloat})
_, err := e.w.Write([]byte{opBinfloat})
if err != nil {
return err
}
var b [8]byte var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(u)) binary.BigEndian.PutUint64(b[:], uint64(u))
e.w.Write(b[:])
_, err = e.w.Write(b[:])
return err
} }
func (e *Encoder) encodeInt(k reflect.Kind, i int64) { func (e *Encoder) encodeInt(k reflect.Kind, i int64) error {
var err error
// FIXME: need support for 64-bit ints // FIXME: need support for 64-bit ints
switch { switch {
case i > 0 && i < math.MaxUint8: case i > 0 && i < math.MaxUint8:
e.w.Write([]byte{opBinint1, byte(i)}) _, err = e.w.Write([]byte{opBinint1, byte(i)})
case i > 0 && i < math.MaxUint16: case i > 0 && i < math.MaxUint16:
e.w.Write([]byte{opBinint2, byte(i), byte(i >> 8)}) _, err = e.w.Write([]byte{opBinint2, byte(i), byte(i >> 8)})
case i >= math.MinInt32 && i <= math.MaxInt32: case i >= math.MinInt32 && i <= math.MaxInt32:
e.w.Write([]byte{opBinint}) _, err = e.w.Write([]byte{opBinint})
if err != nil {
return err
}
var b [4]byte var b [4]byte
binary.LittleEndian.PutUint32(b[:], uint32(i)) binary.LittleEndian.PutUint32(b[:], uint32(i))
e.w.Write(b[:]) _, err = e.w.Write(b[:])
default: // int64, but as a string :/ default: // int64, but as a string :/
e.w.Write([]byte{opInt}) _, err = e.w.Write([]byte{opInt})
if err != nil {
return err
}
fmt.Fprintf(e.w, "%d\n", i) fmt.Fprintf(e.w, "%d\n", i)
} }
return err
} }
func (e *Encoder) encodeMap(m reflect.Value) { func (e *Encoder) encodeMap(m reflect.Value) error {
keys := m.MapKeys() keys := m.MapKeys()
l := len(keys) l := len(keys)
e.w.Write([]byte{opEmptyDict}) _, err := e.w.Write([]byte{opEmptyDict})
if err != nil {
return err
}
if l > 0 { if l > 0 {
e.w.Write([]byte{opMark}) _, err := e.w.Write([]byte{opMark})
if err != nil {
return err
}
for _, k := range keys { for _, k := range keys {
e.encode(k) err = e.encode(k)
if err != nil {
return err
}
v := m.MapIndex(k) v := m.MapIndex(k)
e.encode(v)
err = e.encode(v)
if err != nil {
return err
}
}
_, err = e.w.Write([]byte{opSetitems})
if err != nil {
return err
} }
e.w.Write([]byte{opSetitems})
} }
return nil
} }
func (e *Encoder) encodeString(s string) { func (e *Encoder) encodeString(s string) error {
e.encodeBytes([]byte(s)) return e.encodeBytes([]byte(s))
} }
func (e *Encoder) encodeStruct(st reflect.Value) { func (e *Encoder) encodeStruct(st reflect.Value) error {
typ := st.Type() typ := st.Type()
// first test if it's one of our internal python structs // first test if it's one of our internal python structs
if _, ok := st.Interface().(None); ok { if _, ok := st.Interface().(None); ok {
e.w.Write([]byte{opNone}) _, err := e.w.Write([]byte{opNone})
return return err
} }
structTags := getStructTags(st) structTags := getStructTags(st)
e.w.Write([]byte{opEmptyDict, opMark}) _, err := e.w.Write([]byte{opEmptyDict, opMark})
if err != nil {
return err
}
if structTags != nil { if structTags != nil {
for f, i := range structTags { for f, i := range structTags {
e.encodeString(f) err := e.encodeString(f)
e.encode(st.Field(i)) if err != nil {
return err
}
err = e.encode(st.Field(i))
if err != nil {
return err
}
} }
} else { } else {
l := typ.NumField() l := typ.NumField()
...@@ -201,12 +263,21 @@ func (e *Encoder) encodeStruct(st reflect.Value) { ...@@ -201,12 +263,21 @@ func (e *Encoder) encodeStruct(st reflect.Value) {
if fty.PkgPath != "" { if fty.PkgPath != "" {
continue // skip unexported names continue // skip unexported names
} }
e.encodeString(fty.Name)
e.encode(st.Field(i)) err := e.encodeString(fty.Name)
if err != nil {
return err
}
err = e.encode(st.Field(i))
if err != nil {
return err
}
} }
} }
e.w.Write([]byte{opSetitems}) _, err = e.w.Write([]byte{opSetitems})
return err
} }
func reflectValueOf(v interface{}) reflect.Value { func reflectValueOf(v interface{}) reflect.Value {
......
...@@ -218,8 +218,8 @@ func (d Decoder) Decode() (interface{}, error) { ...@@ -218,8 +218,8 @@ func (d Decoder) Decode() (interface{}, error) {
case opBinfloat: case opBinfloat:
err = d.binFloat() err = d.binFloat()
case opProto: case opProto:
v, _ := d.r.ReadByte() v, err := d.r.ReadByte()
if v != 2 { if err == nil && v != 2 {
err = ErrInvalidPickleVersion err = ErrInvalidPickleVersion
} }
...@@ -331,7 +331,10 @@ func (d *Decoder) loadBinInt() error { ...@@ -331,7 +331,10 @@ func (d *Decoder) loadBinInt() error {
// Push a 1-byte unsigned int // Push a 1-byte unsigned int
func (d *Decoder) loadBinInt1() error { func (d *Decoder) loadBinInt1() error {
b, _ := d.r.ReadByte() b, err := d.r.ReadByte()
if err != nil {
return err
}
d.push(int64(b)) d.push(int64(b))
return nil return nil
} }
...@@ -458,9 +461,13 @@ func (d *Decoder) loadBinString() error { ...@@ -458,9 +461,13 @@ func (d *Decoder) loadBinString() error {
} }
func (d *Decoder) loadShortBinString() error { func (d *Decoder) loadShortBinString() error {
b, _ := d.r.ReadByte() b, err := d.r.ReadByte()
if err != nil {
return err
}
s := make([]byte, b) s := make([]byte, b)
_, err := io.ReadFull(d.r, s) _, err = io.ReadFull(d.r, s)
if err != nil { if err != nil {
return err return err
} }
...@@ -595,7 +602,11 @@ func (d *Decoder) get() error { ...@@ -595,7 +602,11 @@ func (d *Decoder) get() error {
} }
func (d *Decoder) binGet() error { func (d *Decoder) binGet() error {
b, _ := d.r.ReadByte() b, err := d.r.ReadByte()
if err != nil {
return err
}
d.push(d.memo[strconv.Itoa(int(b))]) d.push(d.memo[strconv.Itoa(int(b))])
return nil return nil
} }
...@@ -669,7 +680,11 @@ func (d *Decoder) loadPut() error { ...@@ -669,7 +680,11 @@ func (d *Decoder) loadPut() error {
} }
func (d *Decoder) binPut() error { func (d *Decoder) binPut() error {
b, _ := d.r.ReadByte() b, err := d.r.ReadByte()
if err != nil {
return err
}
d.memo[strconv.Itoa(int(b))] = d.stack[len(d.stack)-1] d.memo[strconv.Itoa(int(b))] = d.stack[len(d.stack)-1]
return nil return nil
} }
......
...@@ -50,7 +50,7 @@ func TestDecode(t *testing.T) { ...@@ -50,7 +50,7 @@ func TestDecode(t *testing.T) {
{"unicode", "V\\u65e5\\u672c\\u8a9e\np0\n.", string("日本語")}, {"unicode", "V\\u65e5\\u672c\\u8a9e\np0\n.", string("日本語")},
{"empty dict", "(dp0\n.", make(map[interface{}]interface{})}, {"empty dict", "(dp0\n.", make(map[interface{}]interface{})},
{"dict with strings", "(dp0\nS'a'\np1\nS'1'\np2\nsS'b'\np3\nS'2'\np4\ns.", map[interface{}]interface{}{"a": "1", "b": "2"}}, {"dict with strings", "(dp0\nS'a'\np1\nS'1'\np2\nsS'b'\np3\nS'2'\np4\ns.", map[interface{}]interface{}{"a": "1", "b": "2"}},
{"GLOBAL and REDUCE opcodes", "cfoo\nbar\nS'bing'\n\x85R.", Call{Callable: Class{Module: "foo", Name: "bar"}, Args: []interface{} {"bing"}}}, {"GLOBAL and REDUCE opcodes", "cfoo\nbar\nS'bing'\n\x85R.", Call{Callable: Class{Module: "foo", Name: "bar"}, Args: []interface{}{"bing"}}},
{"LONG_BINPUT opcode", "(lr0000I17\na.", []interface{}{int64(17)}}, {"LONG_BINPUT opcode", "(lr0000I17\na.", []interface{}{int64(17)}},
{"graphite message1", string(graphitePickle1), []interface{}{map[interface{}]interface{}{"values": []interface{}{float64(473), float64(497), float64(540), float64(1497), float64(1808), float64(1890), float64(2013), float64(1821), float64(1847), float64(2176), float64(2156), float64(1250), float64(2055), float64(1570), None{}, None{}}, "start": int64(1383782400), "step": int64(86400), "end": int64(1385164800), "name": "ZZZZ.UUUUUUUU.CCCCCCCC.MMMMMMMM.XXXXXXXXX.TTT"}}}, {"graphite message1", string(graphitePickle1), []interface{}{map[interface{}]interface{}{"values": []interface{}{float64(473), float64(497), float64(540), float64(1497), float64(1808), float64(1890), float64(2013), float64(1821), float64(1847), float64(2176), float64(2156), float64(1250), float64(2055), float64(1570), None{}, None{}}, "start": int64(1383782400), "step": int64(86400), "end": int64(1385164800), "name": "ZZZZ.UUUUUUUU.CCCCCCCC.MMMMMMMM.XXXXXXXXX.TTT"}}},
{"graphite message2", string(graphitePickle2), []interface{}{map[interface{}]interface{}{"values": []interface{}{float64(473), float64(497), float64(540), float64(1497), float64(1808), float64(1890), float64(2013), float64(1821), float64(1847), float64(2176), float64(2156), float64(1250), float64(2055), float64(1570), None{}, None{}}, "start": int64(1383782400), "step": int64(86400), "end": int64(1385164800), "name": "user.login.area.machine.metric.minute"}}}, {"graphite message2", string(graphitePickle2), []interface{}{map[interface{}]interface{}{"values": []interface{}{float64(473), float64(497), float64(540), float64(1497), float64(1808), float64(1890), float64(2013), float64(1821), float64(1847), float64(2176), float64(2156), float64(1250), float64(2055), float64(1570), None{}, None{}}, "start": int64(1383782400), "step": int64(86400), "end": int64(1385164800), "name": "user.login.area.machine.metric.minute"}}},
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment