Commit 1cb847a8 authored by Kamil Kisiel's avatar Kamil Kisiel Committed by GitHub

Merge pull request #33 from navytux/fix2

Make decode(encode(v)) to be identity + tuple
parents da5f0342 f98f54b1
......@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"math"
"math/big"
"reflect"
)
......@@ -52,6 +53,8 @@ func (e *Encoder) encode(rv reflect.Value) error {
case reflect.Array, reflect.Slice:
if rv.Type().Elem().Kind() == reflect.Uint8 {
return e.encodeBytes(rv.Bytes())
} else if _, ok := rv.Interface().(Tuple); ok {
return e.encodeTuple(rv.Interface().(Tuple))
} else {
return e.encodeArray(rv)
}
......@@ -91,6 +94,36 @@ func (e *Encoder) encode(rv reflect.Value) error {
return nil
}
func (e *Encoder) encodeTuple(t Tuple) error {
l := len(t)
switch l {
case 0:
_, err := e.w.Write([]byte{opEmptyTuple})
return err
// TODO this are protocol 2 opcodes - check e.protocol before using them
//case 1:
//case 2:
//case 3:
}
_, err := e.w.Write([]byte{opMark})
if err != nil {
return err
}
for i := 0; i < l; i++ {
err = e.encode(reflectValueOf(t[i]))
if err != nil {
return err
}
}
_, err = e.w.Write([]byte{opTuple})
return err
}
func (e *Encoder) encodeArray(arr reflect.Value) error {
l := arr.Len()
......@@ -195,6 +228,12 @@ func (e *Encoder) encodeInt(k reflect.Kind, i int64) error {
return err
}
func (e *Encoder) encodeLong(b *big.Int) error {
// TODO if e.protocol >= 2 use opLong1 & opLong4
_, err := fmt.Fprintf(e.w, "%c%dL\n", opLong, b)
return err
}
func (e *Encoder) encodeMap(m reflect.Value) error {
keys := m.MapKeys()
......@@ -238,14 +277,32 @@ func (e *Encoder) encodeString(s string) error {
return e.encodeBytes([]byte(s))
}
func (e *Encoder) encodeCall(v *Call) error {
_, err := fmt.Fprintf(e.w, "%c%s\n%s\n", opGlobal, v.Callable.Module, v.Callable.Name)
if err != nil {
return err
}
err = e.encodeTuple(v.Args)
if err != nil {
return err
}
_, err = e.w.Write([]byte{opReduce})
return err
}
func (e *Encoder) encodeStruct(st reflect.Value) error {
typ := st.Type()
// first test if it's one of our internal python structs
if _, ok := st.Interface().(None); ok {
switch v := st.Interface().(type) {
case None:
_, err := e.w.Write([]byte{opNone})
return err
case Call:
return e.encodeCall(&v)
case big.Int:
return e.encodeLong(&v)
}
structTags := getStructTags(st)
......
......@@ -104,6 +104,9 @@ type mark struct{}
// None is a representation of Python's None.
type None struct{}
// Tuple is a representation of Python's tuple.
type Tuple []interface{}
// Decoder is a decoder for pickle streams.
type Decoder struct {
r *bufio.Reader
......@@ -227,7 +230,7 @@ loop:
case opTuple3:
err = d.loadTuple3()
case opEmptyTuple:
d.push([]interface{}{})
d.push(Tuple{})
case opSetitems:
err = d.loadSetItems()
case opBinfloat:
......@@ -403,11 +406,15 @@ func (d *Decoder) loadLong() error {
if err != nil {
return err
}
if len(line) < 1 {
l := len(line)
if l < 1 || line[l-1] != 'L' {
return io.ErrUnexpectedEOF
}
v := new(big.Int)
v.SetString(string(line[:len(line)-1]), 10)
_, ok := v.SetString(string(line[:l-1]), 10)
if !ok {
return fmt.Errorf("pickle: loadLong: invalid string")
}
d.push(v)
return nil
}
......@@ -465,7 +472,7 @@ func (d *Decoder) loadBinPersid() error {
type Call struct {
Callable Class
Args []interface{}
Args Tuple
}
func (d *Decoder) reduce() error {
......@@ -474,7 +481,7 @@ func (d *Decoder) reduce() error {
}
xargs := d.xpop()
xclass := d.xpop()
args, ok := xargs.([]interface{})
args, ok := xargs.(Tuple)
if !ok {
return fmt.Errorf("pickle: reduce: invalid args: %T", xargs)
}
......@@ -763,7 +770,7 @@ func (d *Decoder) loadTuple() error {
return err
}
v := append([]interface{}{}, d.stack[k+1:]...)
v := append(Tuple{}, d.stack[k+1:]...)
d.stack = append(d.stack[:k], v)
return nil
}
......@@ -773,7 +780,7 @@ func (d *Decoder) loadTuple1() error {
return errStackUnderflow
}
k := len(d.stack) - 1
v := append([]interface{}{}, d.stack[k:]...)
v := append(Tuple{}, d.stack[k:]...)
d.stack = append(d.stack[:k], v)
return nil
}
......@@ -783,7 +790,7 @@ func (d *Decoder) loadTuple2() error {
return errStackUnderflow
}
k := len(d.stack) - 2
v := append([]interface{}{}, d.stack[k:]...)
v := append(Tuple{}, d.stack[k:]...)
d.stack = append(d.stack[:k], v)
return nil
}
......@@ -793,7 +800,7 @@ func (d *Decoder) loadTuple3() error {
return errStackUnderflow
}
k := len(d.stack) - 3
v := append([]interface{}{}, d.stack[k:]...)
v := append(Tuple{}, d.stack[k:]...)
d.stack = append(d.stack[:k], v)
return nil
}
......
......@@ -13,7 +13,10 @@ import (
func bigInt(s string) *big.Int {
i := new(big.Int)
i.SetString(s, 10)
_, ok := i.SetString(s, 10)
if !ok {
panic("bigInt")
}
return i
}
......@@ -44,13 +47,13 @@ func TestDecode(t *testing.T) {
{"float", "F1.23\n.", float64(1.23)},
{"long", "L12321231232131231231L\n.", bigInt("12321231232131231231")},
{"None", "N.", None{}},
{"empty tuple", "(t.", []interface{}{}},
{"tuple of two ints", "(I1\nI2\ntp0\n.", []interface{}{int64(1), int64(2)}},
{"empty tuple", "(t.", Tuple{}},
{"tuple of two ints", "(I1\nI2\ntp0\n.", Tuple{int64(1), int64(2)}},
{"nested tuples", "((I1\nI2\ntp0\n(I3\nI4\ntp1\ntp2\n.",
[]interface{}{[]interface{}{int64(1), int64(2)}, []interface{}{int64(3), int64(4)}}},
{"tuple with top 1 items from stack", "I0\n\x85.", []interface{}{int64(0)}},
{"tuple with top 2 items from stack", "I0\nI1\n\x86.", []interface{}{int64(0), int64(1)}},
{"tuple with top 3 items from stack", "I0\nI1\nI2\n\x87.", []interface{}{int64(0), int64(1), int64(2)}},
Tuple{Tuple{int64(1), int64(2)}, Tuple{int64(3), int64(4)}}},
{"tuple with top 1 items from stack", "I0\n\x85.", Tuple{int64(0)}},
{"tuple with top 2 items from stack", "I0\nI1\n\x86.", Tuple{int64(0), int64(1)}},
{"tuple with top 3 items from stack", "I0\nI1\nI2\n\x87.", Tuple{int64(0), int64(1), int64(2)}},
{"empty list", "(lp0\n.", []interface{}{}},
{"list of numbers", "(lp0\nI1\naI2\naI3\naI4\na.", []interface{}{int64(1), int64(2), int64(3), int64(4)}},
{"string", "S'abc'\np0\n.", string("abc")},
......@@ -68,6 +71,7 @@ func TestDecode(t *testing.T) {
{"SHORTBINUNICODE opcode", "\x8c\t\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e\x94.", "日本語"},
}
for _, test := range tests {
// decode(input) -> expected
buf := bytes.NewBufferString(test.input)
dec := NewDecoder(buf)
v, err := dec.Decode()
......@@ -76,12 +80,31 @@ func TestDecode(t *testing.T) {
}
if !reflect.DeepEqual(v, test.expected) {
t.Errorf("%s: got\n%q\n expected\n%q", test.name, v, test.expected)
t.Errorf("%s: decode:\nhave: %#v\nwant: %#v", test.name, v, test.expected)
}
// decode more -> EOF
v, err = dec.Decode()
if !(v == nil && err == io.EOF) {
t.Errorf("decode: no EOF at end: v = %#v err = %#v", v, err)
t.Errorf("%s: decode: no EOF at end: v = %#v err = %#v", test.name, v, err)
}
// expected (= decoded(input)) -> encode -> decode = identity
buf.Reset()
enc := NewEncoder(buf)
err = enc.Encode(test.expected)
if err != nil {
t.Errorf("%s: encode(expected): %v", test.name, err)
} else {
dec := NewDecoder(buf)
v, err := dec.Decode()
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(v, test.expected) {
t.Errorf("%s: expected -> decode -> encode != identity\nhave: %#v\nwant: %#v", test.name, v, test.expected)
}
}
// for truncated input io.ErrUnexpectedEOF must be returned
......@@ -263,6 +286,10 @@ func TestDecodeError(t *testing.T) {
"}g1\n.",
"}h\x01.",
"}j\x01\x02\x03\x04.",
// invalid long format
"L123\n.",
"L12qL\n.",
}
for _, tt := range testv {
buf := bytes.NewBufferString(tt)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment