Commit 1cb847a8 authored by Kamil Kisiel's avatar Kamil Kisiel Committed by GitHub

Merge pull request #33 from navytux/fix2

Make decode(encode(v)) to be identity + tuple
parents da5f0342 f98f54b1
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"math/big"
"reflect" "reflect"
) )
...@@ -52,6 +53,8 @@ func (e *Encoder) encode(rv reflect.Value) error { ...@@ -52,6 +53,8 @@ func (e *Encoder) encode(rv reflect.Value) error {
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if rv.Type().Elem().Kind() == reflect.Uint8 { if rv.Type().Elem().Kind() == reflect.Uint8 {
return e.encodeBytes(rv.Bytes()) return e.encodeBytes(rv.Bytes())
} else if _, ok := rv.Interface().(Tuple); ok {
return e.encodeTuple(rv.Interface().(Tuple))
} else { } else {
return e.encodeArray(rv) return e.encodeArray(rv)
} }
...@@ -91,6 +94,36 @@ func (e *Encoder) encode(rv reflect.Value) error { ...@@ -91,6 +94,36 @@ func (e *Encoder) encode(rv reflect.Value) error {
return nil return nil
} }
func (e *Encoder) encodeTuple(t Tuple) error {
l := len(t)
switch l {
case 0:
_, err := e.w.Write([]byte{opEmptyTuple})
return err
// TODO this are protocol 2 opcodes - check e.protocol before using them
//case 1:
//case 2:
//case 3:
}
_, err := e.w.Write([]byte{opMark})
if err != nil {
return err
}
for i := 0; i < l; i++ {
err = e.encode(reflectValueOf(t[i]))
if err != nil {
return err
}
}
_, err = e.w.Write([]byte{opTuple})
return err
}
func (e *Encoder) encodeArray(arr reflect.Value) error { func (e *Encoder) encodeArray(arr reflect.Value) error {
l := arr.Len() l := arr.Len()
...@@ -195,6 +228,12 @@ func (e *Encoder) encodeInt(k reflect.Kind, i int64) error { ...@@ -195,6 +228,12 @@ func (e *Encoder) encodeInt(k reflect.Kind, i int64) error {
return err return err
} }
func (e *Encoder) encodeLong(b *big.Int) error {
// TODO if e.protocol >= 2 use opLong1 & opLong4
_, err := fmt.Fprintf(e.w, "%c%dL\n", opLong, b)
return err
}
func (e *Encoder) encodeMap(m reflect.Value) error { func (e *Encoder) encodeMap(m reflect.Value) error {
keys := m.MapKeys() keys := m.MapKeys()
...@@ -238,14 +277,32 @@ func (e *Encoder) encodeString(s string) error { ...@@ -238,14 +277,32 @@ func (e *Encoder) encodeString(s string) error {
return e.encodeBytes([]byte(s)) return e.encodeBytes([]byte(s))
} }
func (e *Encoder) encodeCall(v *Call) error {
_, err := fmt.Fprintf(e.w, "%c%s\n%s\n", opGlobal, v.Callable.Module, v.Callable.Name)
if err != nil {
return err
}
err = e.encodeTuple(v.Args)
if err != nil {
return err
}
_, err = e.w.Write([]byte{opReduce})
return err
}
func (e *Encoder) encodeStruct(st reflect.Value) error { func (e *Encoder) encodeStruct(st reflect.Value) error {
typ := st.Type() typ := st.Type()
// first test if it's one of our internal python structs // first test if it's one of our internal python structs
if _, ok := st.Interface().(None); ok { switch v := st.Interface().(type) {
case None:
_, err := e.w.Write([]byte{opNone}) _, err := e.w.Write([]byte{opNone})
return err return err
case Call:
return e.encodeCall(&v)
case big.Int:
return e.encodeLong(&v)
} }
structTags := getStructTags(st) structTags := getStructTags(st)
......
...@@ -104,6 +104,9 @@ type mark struct{} ...@@ -104,6 +104,9 @@ type mark struct{}
// None is a representation of Python's None. // None is a representation of Python's None.
type None struct{} type None struct{}
// Tuple is a representation of Python's tuple.
type Tuple []interface{}
// Decoder is a decoder for pickle streams. // Decoder is a decoder for pickle streams.
type Decoder struct { type Decoder struct {
r *bufio.Reader r *bufio.Reader
...@@ -227,7 +230,7 @@ loop: ...@@ -227,7 +230,7 @@ loop:
case opTuple3: case opTuple3:
err = d.loadTuple3() err = d.loadTuple3()
case opEmptyTuple: case opEmptyTuple:
d.push([]interface{}{}) d.push(Tuple{})
case opSetitems: case opSetitems:
err = d.loadSetItems() err = d.loadSetItems()
case opBinfloat: case opBinfloat:
...@@ -403,11 +406,15 @@ func (d *Decoder) loadLong() error { ...@@ -403,11 +406,15 @@ func (d *Decoder) loadLong() error {
if err != nil { if err != nil {
return err return err
} }
if len(line) < 1 { l := len(line)
if l < 1 || line[l-1] != 'L' {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
v := new(big.Int) v := new(big.Int)
v.SetString(string(line[:len(line)-1]), 10) _, ok := v.SetString(string(line[:l-1]), 10)
if !ok {
return fmt.Errorf("pickle: loadLong: invalid string")
}
d.push(v) d.push(v)
return nil return nil
} }
...@@ -465,7 +472,7 @@ func (d *Decoder) loadBinPersid() error { ...@@ -465,7 +472,7 @@ func (d *Decoder) loadBinPersid() error {
type Call struct { type Call struct {
Callable Class Callable Class
Args []interface{} Args Tuple
} }
func (d *Decoder) reduce() error { func (d *Decoder) reduce() error {
...@@ -474,7 +481,7 @@ func (d *Decoder) reduce() error { ...@@ -474,7 +481,7 @@ func (d *Decoder) reduce() error {
} }
xargs := d.xpop() xargs := d.xpop()
xclass := d.xpop() xclass := d.xpop()
args, ok := xargs.([]interface{}) args, ok := xargs.(Tuple)
if !ok { if !ok {
return fmt.Errorf("pickle: reduce: invalid args: %T", xargs) return fmt.Errorf("pickle: reduce: invalid args: %T", xargs)
} }
...@@ -763,7 +770,7 @@ func (d *Decoder) loadTuple() error { ...@@ -763,7 +770,7 @@ func (d *Decoder) loadTuple() error {
return err return err
} }
v := append([]interface{}{}, d.stack[k+1:]...) v := append(Tuple{}, d.stack[k+1:]...)
d.stack = append(d.stack[:k], v) d.stack = append(d.stack[:k], v)
return nil return nil
} }
...@@ -773,7 +780,7 @@ func (d *Decoder) loadTuple1() error { ...@@ -773,7 +780,7 @@ func (d *Decoder) loadTuple1() error {
return errStackUnderflow return errStackUnderflow
} }
k := len(d.stack) - 1 k := len(d.stack) - 1
v := append([]interface{}{}, d.stack[k:]...) v := append(Tuple{}, d.stack[k:]...)
d.stack = append(d.stack[:k], v) d.stack = append(d.stack[:k], v)
return nil return nil
} }
...@@ -783,7 +790,7 @@ func (d *Decoder) loadTuple2() error { ...@@ -783,7 +790,7 @@ func (d *Decoder) loadTuple2() error {
return errStackUnderflow return errStackUnderflow
} }
k := len(d.stack) - 2 k := len(d.stack) - 2
v := append([]interface{}{}, d.stack[k:]...) v := append(Tuple{}, d.stack[k:]...)
d.stack = append(d.stack[:k], v) d.stack = append(d.stack[:k], v)
return nil return nil
} }
...@@ -793,7 +800,7 @@ func (d *Decoder) loadTuple3() error { ...@@ -793,7 +800,7 @@ func (d *Decoder) loadTuple3() error {
return errStackUnderflow return errStackUnderflow
} }
k := len(d.stack) - 3 k := len(d.stack) - 3
v := append([]interface{}{}, d.stack[k:]...) v := append(Tuple{}, d.stack[k:]...)
d.stack = append(d.stack[:k], v) d.stack = append(d.stack[:k], v)
return nil return nil
} }
......
...@@ -13,7 +13,10 @@ import ( ...@@ -13,7 +13,10 @@ import (
func bigInt(s string) *big.Int { func bigInt(s string) *big.Int {
i := new(big.Int) i := new(big.Int)
i.SetString(s, 10) _, ok := i.SetString(s, 10)
if !ok {
panic("bigInt")
}
return i return i
} }
...@@ -44,13 +47,13 @@ func TestDecode(t *testing.T) { ...@@ -44,13 +47,13 @@ func TestDecode(t *testing.T) {
{"float", "F1.23\n.", float64(1.23)}, {"float", "F1.23\n.", float64(1.23)},
{"long", "L12321231232131231231L\n.", bigInt("12321231232131231231")}, {"long", "L12321231232131231231L\n.", bigInt("12321231232131231231")},
{"None", "N.", None{}}, {"None", "N.", None{}},
{"empty tuple", "(t.", []interface{}{}}, {"empty tuple", "(t.", Tuple{}},
{"tuple of two ints", "(I1\nI2\ntp0\n.", []interface{}{int64(1), int64(2)}}, {"tuple of two ints", "(I1\nI2\ntp0\n.", Tuple{int64(1), int64(2)}},
{"nested tuples", "((I1\nI2\ntp0\n(I3\nI4\ntp1\ntp2\n.", {"nested tuples", "((I1\nI2\ntp0\n(I3\nI4\ntp1\ntp2\n.",
[]interface{}{[]interface{}{int64(1), int64(2)}, []interface{}{int64(3), int64(4)}}}, Tuple{Tuple{int64(1), int64(2)}, Tuple{int64(3), int64(4)}}},
{"tuple with top 1 items from stack", "I0\n\x85.", []interface{}{int64(0)}}, {"tuple with top 1 items from stack", "I0\n\x85.", Tuple{int64(0)}},
{"tuple with top 2 items from stack", "I0\nI1\n\x86.", []interface{}{int64(0), int64(1)}}, {"tuple with top 2 items from stack", "I0\nI1\n\x86.", Tuple{int64(0), int64(1)}},
{"tuple with top 3 items from stack", "I0\nI1\nI2\n\x87.", []interface{}{int64(0), int64(1), int64(2)}}, {"tuple with top 3 items from stack", "I0\nI1\nI2\n\x87.", Tuple{int64(0), int64(1), int64(2)}},
{"empty list", "(lp0\n.", []interface{}{}}, {"empty list", "(lp0\n.", []interface{}{}},
{"list of numbers", "(lp0\nI1\naI2\naI3\naI4\na.", []interface{}{int64(1), int64(2), int64(3), int64(4)}}, {"list of numbers", "(lp0\nI1\naI2\naI3\naI4\na.", []interface{}{int64(1), int64(2), int64(3), int64(4)}},
{"string", "S'abc'\np0\n.", string("abc")}, {"string", "S'abc'\np0\n.", string("abc")},
...@@ -68,6 +71,7 @@ func TestDecode(t *testing.T) { ...@@ -68,6 +71,7 @@ func TestDecode(t *testing.T) {
{"SHORTBINUNICODE opcode", "\x8c\t\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\x94.", "日本語"}, {"SHORTBINUNICODE opcode", "\x8c\t\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\x94.", "日本語"},
} }
for _, test := range tests { for _, test := range tests {
// decode(input) -> expected
buf := bytes.NewBufferString(test.input) buf := bytes.NewBufferString(test.input)
dec := NewDecoder(buf) dec := NewDecoder(buf)
v, err := dec.Decode() v, err := dec.Decode()
...@@ -76,12 +80,31 @@ func TestDecode(t *testing.T) { ...@@ -76,12 +80,31 @@ func TestDecode(t *testing.T) {
} }
if !reflect.DeepEqual(v, test.expected) { if !reflect.DeepEqual(v, test.expected) {
t.Errorf("%s: got\n%q\n expected\n%q", test.name, v, test.expected) t.Errorf("%s: decode:\nhave: %#v\nwant: %#v", test.name, v, test.expected)
} }
// decode more -> EOF
v, err = dec.Decode() v, err = dec.Decode()
if !(v == nil && err == io.EOF) { if !(v == nil && err == io.EOF) {
t.Errorf("decode: no EOF at end: v = %#v err = %#v", v, err) t.Errorf("%s: decode: no EOF at end: v = %#v err = %#v", test.name, v, err)
}
// expected (= decoded(input)) -> encode -> decode = identity
buf.Reset()
enc := NewEncoder(buf)
err = enc.Encode(test.expected)
if err != nil {
t.Errorf("%s: encode(expected): %v", test.name, err)
} else {
dec := NewDecoder(buf)
v, err := dec.Decode()
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(v, test.expected) {
t.Errorf("%s: expected -> decode -> encode != identity\nhave: %#v\nwant: %#v", test.name, v, test.expected)
}
} }
// for truncated input io.ErrUnexpectedEOF must be returned // for truncated input io.ErrUnexpectedEOF must be returned
...@@ -263,6 +286,10 @@ func TestDecodeError(t *testing.T) { ...@@ -263,6 +286,10 @@ func TestDecodeError(t *testing.T) {
"}g1\n.", "}g1\n.",
"}h\x01.", "}h\x01.",
"}j\x01\x02\x03\x04.", "}j\x01\x02\x03\x04.",
// invalid long format
"L123\n.",
"L12qL\n.",
} }
for _, tt := range testv { for _, tt := range testv {
buf := bytes.NewBufferString(tt) buf := bytes.NewBufferString(tt)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment