Commit 3bf6c92d authored by Kirill Smelkov's avatar Kirill Smelkov

Add PyDict mode

Similarly to StrictUnicode mode (see b28613c2) add new opt-in mode that
requests to decode Python dicts as ogórek.Dict instead of builtin map.
As explained in recent patch "Add custom Dict that mirrors Python dict
behaviour" this is needed to fix decoding issues that can be there due
to different behaviour of Python dict and builtin Go map:

    ---- 8< ----
    Ogórek currently represents unpickled dict via map[any]any, which is
    logical, but also exhibits issues because builtin Go map behaviour is
    different from Python's dict behaviour. For example:

    - Python's dict allows tuples to be used in keys, while Go map
      does not (https://github.com/kisielk/og-rek/issues/50),

    - Python's dict allows both long and int to be used interchangeable as
      keys, while Go map does not handle *big.Int as key with the same
      semantic (https://github.com/kisielk/og-rek/issues/55)

    - Python's dict allows to use numbers interchangeable in keys - all int
      and float, but on Go side int(1) and float64(1.0) are considered by
      builtin map as different keys.

    - In Python world bytestring (str from py2) is considered to be related
      to both unicode (str on py3) and bytes, but builtin map considers all
      string, Bytes and ByteString as different keys.

    - etc...

    All in all there are many differences in behaviour in builtin Python
    dict and Go map that result in generally different semantics when
    decoding pickled data. Those differences can be fixed only if we add
    custom dict implementation that mirrors what Python does.

    -> Do that: add custom Dict that implements key -> value mapping with
       mirroring Python behaviour.

    For now we are only adding the Dict class itself and its tests.
    Later we will use this new Dict to handle decoding dictionaries from the pickles.
    ---- 8< ----

In this patch we add new Decoder option to activate PyDict mode
decoding, teach encoder to also support encoding of Dict and adjust
tests.

The behaviour of new system is explained by the following doc.go
excerpt:

    For dicts there are two modes. In the first, default, mode Python dicts are
    decoded into standard Go map. This mode tries to use builtin Go type, but
    cannot mirror py behaviour fully because e.g. int(1), big.Int(1) and
    float64(1.0) are all treated as different keys by Go, while Python treats
    them as being equal. It also does not support decoding dicts with tuple
    used in keys:

         dict      map[any]any                       PyDict=n mode, default
                 ←  ogórek.Dict

    With PyDict=y mode, however, Python dicts are decoded as ogórek.Dict which
    mirrors behaviour of Python dict with respect to keys equality, and with
    respect to which types are allowed to be used as keys.

         dict      ogórek.Dict                       PyDict=y mode
                 ←  map[any]any
parent 8be3fcab
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
"github.com/aristanetworks/gomap" "github.com/aristanetworks/gomap"
) )
// Dict represents dict from Python. // Dict represents dict from Python in PyDict mode.
// //
// It mirrors Python with respect to which types are allowed to be used as // It mirrors Python with respect to which types are allowed to be used as
// keys, and with respect to keys equality. For example Tuple is allowed to be // keys, and with respect to keys equality. For example Tuple is allowed to be
...@@ -27,6 +27,8 @@ import ( ...@@ -27,6 +27,8 @@ import (
// underlying content ByteString, because it represents str type from Python2, // underlying content ByteString, because it represents str type from Python2,
// is treated equal to both Bytes and string. // is treated equal to both Bytes and string.
// //
// See PyDict mode documentation in top-level package overview for details.
//
// Note: similarly to builtin map Dict is pointer-like type: its zero-value // Note: similarly to builtin map Dict is pointer-like type: its zero-value
// represents nil dictionary that is empty and invalid to use Set on. // represents nil dictionary that is empty and invalid to use Set on.
type Dict struct { type Dict struct {
......
...@@ -22,12 +22,29 @@ ...@@ -22,12 +22,29 @@
// long ↔ *big.Int // long ↔ *big.Int
// float ↔ float64 // float ↔ float64
// float ← floatX // float ← floatX
// list ↔ []interface{} // list ↔ []any
// tuple ↔ ogórek.Tuple // tuple ↔ ogórek.Tuple
// dict ↔ map[interface{}]interface{}
// //
// //
// For strings there are two modes. In the first, default, mode both py2/py3 // For dicts there are two modes. In the first, default, mode Python dicts are
// decoded into standard Go map. This mode tries to use builtin Go type, but
// cannot mirror py behaviour fully because e.g. int(1), big.Int(1) and
// float64(1.0) are all treated as different keys by Go, while Python treats
// them as being equal. It also does not support decoding dicts with tuple
// used in keys:
//
// dict ↔ map[any]any PyDict=n mode, default
// ← ogórek.Dict
//
// With PyDict=y mode, however, Python dicts are decoded as ogórek.Dict which
// mirrors behaviour of Python dict with respect to keys equality, and with
// respect to which types are allowed to be used as keys.
//
// dict ↔ ogórek.Dict PyDict=y mode
// ← map[any]any
//
//
// For strings there are also two modes. In the first, default, mode both py2/py3
// str and py2 unicode are decoded into string with py2 str being considered // str and py2 unicode are decoded into string with py2 str being considered
// as UTF-8 encoded. Correspondingly for protocol ≤ 2 Go string is encoded as // as UTF-8 encoded. Correspondingly for protocol ≤ 2 Go string is encoded as
// UTF-8 encoded py2 str, and for protocol ≥ 3 as py3 str / py2 unicode. // UTF-8 encoded py2 str, and for protocol ≥ 3 as py3 str / py2 unicode.
...@@ -155,6 +172,11 @@ ...@@ -155,6 +172,11 @@
// Using the helpers fits into Python3 strings/bytes model but also allows to // Using the helpers fits into Python3 strings/bytes model but also allows to
// handle the data generated from under Python2. // handle the data generated from under Python2.
// //
// Similarly Dict considers ByteString to be equal to both string and Bytes
// with the same underlying content. This allows programs to access Dict via
// string/bytes keys following Python3 model, while still being able to handle
// dictionaries generated from under Python2.
//
// //
// -------- // --------
// //
......
...@@ -505,6 +505,41 @@ func (e *Encoder) encodeMap(m reflect.Value) error { ...@@ -505,6 +505,41 @@ func (e *Encoder) encodeMap(m reflect.Value) error {
return e.emit(opDict) return e.emit(opDict)
} }
func (e *Encoder) encodeDict(d Dict) error {
l := d.Len()
// protocol >= 1: ø dict -> EMPTY_DICT
if e.config.Protocol >= 1 && l == 0 {
return e.emit(opEmptyDict)
}
// MARK + ... + DICT
// TODO cycles + sort keys (see encodeMap for details)
err := e.emit(opMark)
if err != nil {
return err
}
d.Iter()(func(k, v any) bool {
err = e.encode(reflectValueOf(k))
if err != nil {
return false
}
err = e.encode(reflectValueOf(v))
if err != nil {
return false
}
return true
})
if err != nil {
return err
}
return e.emit(opDict)
}
func (e *Encoder) encodeCall(v *Call) error { func (e *Encoder) encodeCall(v *Call) error {
err := e.encodeClass(&v.Callable) err := e.encodeClass(&v.Callable)
if err != nil { if err != nil {
...@@ -578,6 +613,8 @@ func (e *Encoder) encodeStruct(st reflect.Value) error { ...@@ -578,6 +613,8 @@ func (e *Encoder) encodeStruct(st reflect.Value) error {
return e.encodeRef(&v) return e.encodeRef(&v)
case big.Int: case big.Int:
return e.encodeLong(&v) return e.encodeLong(&v)
case Dict:
return e.encodeDict(v)
} }
structTags := getStructTags(st) structTags := getStructTags(st)
......
...@@ -8,20 +8,24 @@ import ( ...@@ -8,20 +8,24 @@ import (
) )
func Fuzz(data []byte) int { func Fuzz(data []byte) int {
f1 := fuzz(data, false) f := 0
f2 := fuzz(data, true)
f += fuzz(data, false, false)
f += fuzz(data, false, true)
f += fuzz(data, true, false)
f += fuzz(data, true, true)
f := f1+f2
if f > 1 { if f > 1 {
f = 1 f = 1
} }
return f return f
} }
func fuzz(data []byte, strictUnicode bool) int { func fuzz(data []byte, pyDict, strictUnicode bool) int {
// obj = decode(data) - this tests things like stack overflow in Decoder // obj = decode(data) - this tests things like stack overflow in Decoder
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
dec := NewDecoderWithConfig(buf, &DecoderConfig{ dec := NewDecoderWithConfig(buf, &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode, StrictUnicode: strictUnicode,
}) })
obj, err := dec.Decode() obj, err := dec.Decode()
...@@ -37,7 +41,7 @@ func fuzz(data []byte, strictUnicode bool) int { ...@@ -37,7 +41,7 @@ func fuzz(data []byte, strictUnicode bool) int {
// because obj - as we got it as decoding from input - is known not to // because obj - as we got it as decoding from input - is known not to
// contain arbitrary Go structs. // contain arbitrary Go structs.
for proto := 0; proto <= highestProtocol; proto++ { for proto := 0; proto <= highestProtocol; proto++ {
subj := fmt.Sprintf("strictUnicode %v: protocol %d", strictUnicode, proto) subj := fmt.Sprintf("pyDict %v: strictUnicode %v: protocol %d", pyDict, strictUnicode, proto)
buf.Reset() buf.Reset()
enc := NewEncoderWithConfig(buf, &EncoderConfig{ enc := NewEncoderWithConfig(buf, &EncoderConfig{
...@@ -66,6 +70,7 @@ func fuzz(data []byte, strictUnicode bool) int { ...@@ -66,6 +70,7 @@ func fuzz(data []byte, strictUnicode bool) int {
encoded := buf.String() encoded := buf.String()
dec = NewDecoderWithConfig(bytes.NewBufferString(encoded), &DecoderConfig{ dec = NewDecoderWithConfig(bytes.NewBufferString(encoded), &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode, StrictUnicode: strictUnicode,
}) })
obj2, err := dec.Decode() obj2, err := dec.Decode()
......
...@@ -189,6 +189,11 @@ type DecoderConfig struct { ...@@ -189,6 +189,11 @@ type DecoderConfig struct {
// decoded into ByteString in this mode. See StrictUnicode mode // decoded into ByteString in this mode. See StrictUnicode mode
// documentation in top-level package overview for details. // documentation in top-level package overview for details.
StrictUnicode bool StrictUnicode bool
// PyDict, when true, requests to decode Python dicts as ogórek.Dict
// instead of builtin map. See PyDict mode documentation in top-level
// package overview for details.
PyDict bool
} }
// NewDecoder returns a new Decoder with the default configuration. // NewDecoder returns a new Decoder with the default configuration.
...@@ -1009,29 +1014,75 @@ func mapTryAssign(m map[interface{}]interface{}, key, value interface{}) (ok boo ...@@ -1009,29 +1014,75 @@ func mapTryAssign(m map[interface{}]interface{}, key, value interface{}) (ok boo
return return
} }
// dictTryAssign is like mapTryAssign but for Dict.
func dictTryAssign(d Dict, key, value interface{}) (ok bool) {
defer func() {
if r := recover(); r != nil {
ok = false
}
}()
d.Set(key, value)
ok = true
return
}
func (d *Decoder) loadDict() error { func (d *Decoder) loadDict() error {
k, err := d.marker() k, err := d.marker()
if err != nil { if err != nil {
return err return err
} }
m := make(map[interface{}]interface{}, 0)
items := d.stack[k+1:] items := d.stack[k+1:]
if len(items) % 2 != 0 { if len(items) % 2 != 0 {
return fmt.Errorf("pickle: loadDict: odd # of elements") return fmt.Errorf("pickle: loadDict: odd # of elements")
} }
var m interface{}
if d.config.PyDict {
m, err = d.loadDictDict(items)
} else {
m, err = d.loadDictMap(items)
}
if err != nil {
return err
}
d.stack = append(d.stack[:k], m)
return nil
}
func (d *Decoder) loadDictMap(items []interface{}) (map[interface{}]interface{}, error) {
m := make(map[interface{}]interface{}, len(items)/2)
for i := 0; i < len(items); i += 2 { for i := 0; i < len(items); i += 2 {
key := items[i] key := items[i]
if !mapTryAssign(m, key, items[i+1]) { if !mapTryAssign(m, key, items[i+1]) {
return fmt.Errorf("pickle: loadDict: invalid key type %T", key) return nil, fmt.Errorf("pickle: loadDict: map: invalid key type %T", key)
} }
} }
d.stack = append(d.stack[:k], m) return m, nil
return nil
} }
func (d *Decoder) loadDictDict(items []interface{}) (Dict, error) {
m := NewDictWithSizeHint(len(items)/2)
for i := 0; i < len(items); i += 2 {
key := items[i]
if !dictTryAssign(m, key, items[i+1]) {
return Dict{}, fmt.Errorf("pickle: loadDict: Dict: invalid key type %T", key)
}
}
return m, nil
}
func (d *Decoder) loadEmptyDict() error { func (d *Decoder) loadEmptyDict() error {
m := make(map[interface{}]interface{}, 0) var m interface{}
if d.config.PyDict {
m = NewDict()
} else {
m = make(map[interface{}]interface{}, 0)
}
d.push(m) d.push(m)
return nil return nil
} }
...@@ -1218,10 +1269,14 @@ func (d *Decoder) loadSetItem() error { ...@@ -1218,10 +1269,14 @@ func (d *Decoder) loadSetItem() error {
switch m := m.(type) { switch m := m.(type) {
case map[interface{}]interface{}: case map[interface{}]interface{}:
if !mapTryAssign(m, k, v) { if !mapTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: invalid key type %T", k) return fmt.Errorf("pickle: loadSetItem: map: invalid key type %T", k)
}
case Dict:
if !dictTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: Dict: invalid key type %T", k)
} }
default: default:
return fmt.Errorf("pickle: loadSetItem: expected a map, got %T", m) return fmt.Errorf("pickle: loadSetItem: expected a map or Dict, got %T", m)
} }
return nil return nil
} }
...@@ -1234,23 +1289,31 @@ func (d *Decoder) loadSetItems() error { ...@@ -1234,23 +1289,31 @@ func (d *Decoder) loadSetItems() error {
if k < 1 { if k < 1 {
return errStackUnderflow return errStackUnderflow
} }
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
l := d.stack[k-1] l := d.stack[k-1]
switch m := l.(type) { switch m := l.(type) {
case map[interface{}]interface{}: case map[interface{}]interface{}:
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
for i := k + 1; i < len(d.stack); i += 2 { for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i] key := d.stack[i]
if !mapTryAssign(m, key, d.stack[i+1]) { if !mapTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: invalid key type %T", key) return fmt.Errorf("pickle: loadSetItems: map: invalid key type %T", key)
} }
} }
d.stack = append(d.stack[:k-1], m) case Dict:
for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i]
if !dictTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: Dict: invalid key type %T", key)
}
}
default: default:
return fmt.Errorf("pickle: loadSetItems: expected a map, got %T", m) return fmt.Errorf("pickle: loadSetItems: expected a map or Dict, got %T", m)
} }
d.stack = append(d.stack[:k-1], l)
return nil return nil
} }
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment