Commit ea04619f authored by Kamil Kisiel's avatar Kamil Kisiel Committed by GitHub

Merge pull request #75 from navytux/y/pydict

Add PyDict mode
parents 773b7a86 3bf6c92d
......@@ -11,12 +11,6 @@ jobs:
strategy:
matrix:
go:
- "1.12.x"
- "1.13.x"
- "1.14.x"
- "1.15.x"
- "1.16.x"
- "1.17.x"
- "1.18.x"
- "1.19.x"
- "1.20.x"
......
This diff is collapsed.
This diff is collapsed.
......@@ -22,12 +22,29 @@
// long ↔ *big.Int
// float ↔ float64
// float ← floatX
// list ↔ []interface{}
// list ↔ []any
// tuple ↔ ogórek.Tuple
// dict ↔ map[interface{}]interface{}
//
//
// For strings there are two modes. In the first, default, mode both py2/py3
// For dicts there are two modes. In the first, default, mode Python dicts are
// decoded into standard Go map. This mode tries to use builtin Go type, but
// cannot mirror py behaviour fully because e.g. int(1), big.Int(1) and
// float64(1.0) are all treated as different keys by Go, while Python treats
// them as being equal. It also does not support decoding dicts with tuple
// used in keys:
//
// dict ↔ map[any]any PyDict=n mode, default
// ← ogórek.Dict
//
// With PyDict=y mode, however, Python dicts are decoded as ogórek.Dict which
// mirrors behaviour of Python dict with respect to keys equality, and with
// respect to which types are allowed to be used as keys.
//
// dict ↔ ogórek.Dict PyDict=y mode
// ← map[any]any
//
//
// For strings there are also two modes. In the first, default, mode both py2/py3
// str and py2 unicode are decoded into string with py2 str being considered
// as UTF-8 encoded. Correspondingly for protocol ≤ 2 Go string is encoded as
// UTF-8 encoded py2 str, and for protocol ≥ 3 as py3 str / py2 unicode.
......@@ -155,6 +172,11 @@
// Using the helpers fits into Python3 strings/bytes model but also allows to
// handle the data generated from under Python2.
//
// Similarly Dict considers ByteString to be equal to both string and Bytes
// with the same underlying content. This allows programs to access Dict via
// string/bytes keys following Python3 model, while still being able to handle
// dictionaries generated from under Python2.
//
//
// --------
//
......
......@@ -505,6 +505,41 @@ func (e *Encoder) encodeMap(m reflect.Value) error {
return e.emit(opDict)
}
func (e *Encoder) encodeDict(d Dict) error {
l := d.Len()
// protocol >= 1: ø dict -> EMPTY_DICT
if e.config.Protocol >= 1 && l == 0 {
return e.emit(opEmptyDict)
}
// MARK + ... + DICT
// TODO cycles + sort keys (see encodeMap for details)
err := e.emit(opMark)
if err != nil {
return err
}
d.Iter()(func(k, v any) bool {
err = e.encode(reflectValueOf(k))
if err != nil {
return false
}
err = e.encode(reflectValueOf(v))
if err != nil {
return false
}
return true
})
if err != nil {
return err
}
return e.emit(opDict)
}
func (e *Encoder) encodeCall(v *Call) error {
err := e.encodeClass(&v.Callable)
if err != nil {
......@@ -578,6 +613,8 @@ func (e *Encoder) encodeStruct(st reflect.Value) error {
return e.encodeRef(&v)
case big.Int:
return e.encodeLong(&v)
case Dict:
return e.encodeDict(v)
}
structTags := getStructTags(st)
......
......@@ -5,24 +5,27 @@ package ogórek
import (
"bytes"
"fmt"
"reflect"
)
func Fuzz(data []byte) int {
f1 := fuzz(data, false)
f2 := fuzz(data, true)
f := 0
f += fuzz(data, false, false)
f += fuzz(data, false, true)
f += fuzz(data, true, false)
f += fuzz(data, true, true)
f := f1+f2
if f > 1 {
f = 1
}
return f
}
func fuzz(data []byte, strictUnicode bool) int {
func fuzz(data []byte, pyDict, strictUnicode bool) int {
// obj = decode(data) - this tests things like stack overflow in Decoder
buf := bytes.NewBuffer(data)
dec := NewDecoderWithConfig(buf, &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
})
obj, err := dec.Decode()
......@@ -38,7 +41,7 @@ func fuzz(data []byte, strictUnicode bool) int {
// because obj - as we got it as decoding from input - is known not to
// contain arbitrary Go structs.
for proto := 0; proto <= highestProtocol; proto++ {
subj := fmt.Sprintf("strictUnicode %v: protocol %d", strictUnicode, proto)
subj := fmt.Sprintf("pyDict %v: strictUnicode %v: protocol %d", pyDict, strictUnicode, proto)
buf.Reset()
enc := NewEncoderWithConfig(buf, &EncoderConfig{
......@@ -67,6 +70,7 @@ func fuzz(data []byte, strictUnicode bool) int {
encoded := buf.String()
dec = NewDecoderWithConfig(bytes.NewBufferString(encoded), &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
})
obj2, err := dec.Decode()
......@@ -75,7 +79,7 @@ func fuzz(data []byte, strictUnicode bool) int {
panic(fmt.Sprintf("%s: decode back error: %s\npickle: %q", subj, err, encoded))
}
if !reflect.DeepEqual(obj, obj2) {
if !deepEqual(obj, obj2) {
panic(fmt.Sprintf("%s: decode·encode != identity:\nhave: %#v\nwant: %#v", subj, obj2, obj))
}
}
......
module github.com/kisielk/og-rek
go 1.12
go 1.18
require github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced
require golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect
github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced h1:HxlRMDx/VeRqzj3nvqX9k4tjeBcEIkoNHDJPsS389hs=
github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced/go.mod h1:p7lmI+ecoe1RTyD11SPXWsSQ3H+pJ4cp5y7vtKW4QdM=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
......@@ -189,6 +189,11 @@ type DecoderConfig struct {
// decoded into ByteString in this mode. See StrictUnicode mode
// documentation in top-level package overview for details.
StrictUnicode bool
// PyDict, when true, requests to decode Python dicts as ogórek.Dict
// instead of builtin map. See PyDict mode documentation in top-level
// package overview for details.
PyDict bool
}
// NewDecoder returns a new Decoder with the default configuration.
......@@ -1009,29 +1014,75 @@ func mapTryAssign(m map[interface{}]interface{}, key, value interface{}) (ok boo
return
}
// dictTryAssign is like mapTryAssign but for Dict.
func dictTryAssign(d Dict, key, value interface{}) (ok bool) {
defer func() {
if r := recover(); r != nil {
ok = false
}
}()
d.Set(key, value)
ok = true
return
}
func (d *Decoder) loadDict() error {
k, err := d.marker()
if err != nil {
return err
}
m := make(map[interface{}]interface{}, 0)
items := d.stack[k+1:]
if len(items) % 2 != 0 {
return fmt.Errorf("pickle: loadDict: odd # of elements")
}
var m interface{}
if d.config.PyDict {
m, err = d.loadDictDict(items)
} else {
m, err = d.loadDictMap(items)
}
if err != nil {
return err
}
d.stack = append(d.stack[:k], m)
return nil
}
func (d *Decoder) loadDictMap(items []interface{}) (map[interface{}]interface{}, error) {
m := make(map[interface{}]interface{}, len(items)/2)
for i := 0; i < len(items); i += 2 {
key := items[i]
if !mapTryAssign(m, key, items[i+1]) {
return fmt.Errorf("pickle: loadDict: invalid key type %T", key)
return nil, fmt.Errorf("pickle: loadDict: map: invalid key type %T", key)
}
}
d.stack = append(d.stack[:k], m)
return nil
return m, nil
}
func (d *Decoder) loadDictDict(items []interface{}) (Dict, error) {
m := NewDictWithSizeHint(len(items)/2)
for i := 0; i < len(items); i += 2 {
key := items[i]
if !dictTryAssign(m, key, items[i+1]) {
return Dict{}, fmt.Errorf("pickle: loadDict: Dict: invalid key type %T", key)
}
}
return m, nil
}
func (d *Decoder) loadEmptyDict() error {
m := make(map[interface{}]interface{}, 0)
var m interface{}
if d.config.PyDict {
m = NewDict()
} else {
m = make(map[interface{}]interface{}, 0)
}
d.push(m)
return nil
}
......@@ -1218,10 +1269,14 @@ func (d *Decoder) loadSetItem() error {
switch m := m.(type) {
case map[interface{}]interface{}:
if !mapTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: invalid key type %T", k)
return fmt.Errorf("pickle: loadSetItem: map: invalid key type %T", k)
}
case Dict:
if !dictTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: Dict: invalid key type %T", k)
}
default:
return fmt.Errorf("pickle: loadSetItem: expected a map, got %T", m)
return fmt.Errorf("pickle: loadSetItem: expected a map or Dict, got %T", m)
}
return nil
}
......@@ -1234,23 +1289,31 @@ func (d *Decoder) loadSetItems() error {
if k < 1 {
return errStackUnderflow
}
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
l := d.stack[k-1]
switch m := l.(type) {
case map[interface{}]interface{}:
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i]
if !mapTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: invalid key type %T", key)
return fmt.Errorf("pickle: loadSetItems: map: invalid key type %T", key)
}
}
d.stack = append(d.stack[:k-1], m)
case Dict:
for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i]
if !dictTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: Dict: invalid key type %T", key)
}
}
default:
return fmt.Errorf("pickle: loadSetItems: expected a map, got %T", m)
return fmt.Errorf("pickle: loadSetItems: expected a map or Dict, got %T", m)
}
d.stack = append(d.stack[:k-1], l)
return nil
}
......
This diff is collapsed.
......@@ -45,7 +45,7 @@ func TestAsInt64(t *testing.T) {
}
}
if !reflect.DeepEqual(out, tt.outOK) {
if !deepEqual(out, tt.outOK) {
t.Errorf("%T %#v -> %T %#v ; want %T %#v",
tt.in, tt.in, out, out, tt.outOK, tt.outOK)
}
......@@ -100,12 +100,12 @@ func TestAsBytesString(t *testing.T) {
serrOK = Estring(tt.in)
}
if !(bout == boutOK && reflect.DeepEqual(berr, berrOK)) {
if !(bout == boutOK && deepEqual(berr, berrOK)) {
t.Errorf("%#v: AsBytes:\nhave %#v %#v\nwant %#v %#v",
tt.in, bout, berr, boutOK, berrOK)
}
if !(sout == soutOK && reflect.DeepEqual(serr, serrOK)) {
if !(sout == soutOK && deepEqual(serr, serrOK)) {
t.Errorf("%#v: AsString:\nhave %#v %#v\nwant %#v %#v",
tt.in, sout, serr, soutOK, serrOK)
}
......
//go:build !go1.21
package ogórek
import (
"math/big"
)
func bigInt_Float64(b *big.Int) (float64, big.Accuracy) {
return new(big.Float).SetInt(b).Float64()
}
//go:build go1.21
package ogórek
import (
"math/big"
)
func bigInt_Float64(b *big.Int) (float64, big.Accuracy) {
return b.Float64()
}
//go:build !go1.19
package ogórek
import (
"hash/maphash"
)
func maphash_String(seed maphash.Seed, s string) uint64 {
var h maphash.Hash
h.SetSeed(seed)
h.WriteString(s)
return h.Sum64()
}
//go:build go1.19
package ogórek
import (
"hash/maphash"
)
func maphash_String(seed maphash.Seed, s string) uint64 {
return maphash.String(seed, s)
}
package ogórek
// Utilities that complement std reflect package.
import (
"reflect"
)
// deepEqual is like reflect.DeepEqual but also supports Dict.
//
// It is needed because reflect.DeepEqual considers two Dicts not-equal because
// each Dict is made with its own seed.
//
// XXX only top-level Dict is supported currently.
// For example comparing Dict inside list with the same won't work.
func deepEqual(a, b any) bool {
da, ok := a.(Dict)
if !ok {
return reflect.DeepEqual(a, b)
}
db, ok := b.(Dict)
if !ok {
return false // Dict != non-dict
}
if da.Len() != db.Len() {
return false
}
// XXX O(n^2) because we want to compare keys exactly and so cannot use
// db.Get(ka) because Dict.Get uses general equality that would match e.g. int == int64
eq := true
da.Iter()(func(ka, va any) bool {
keq := false
db.Iter()(func(kb, vb any) bool {
// NOTE don't use reflect.Equal(ka,kb) because it does not handle e.g. big.Int
if reflect.TypeOf(ka) == reflect.TypeOf(kb) && equal(ka,kb) {
if reflect.DeepEqual(va, vb) {
keq = true
}
return false
}
return true
})
if !keq {
eq = false
return false
}
return true
})
return eq
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment