Commit ea04619f authored by Kamil Kisiel's avatar Kamil Kisiel Committed by GitHub

Merge pull request #75 from navytux/y/pydict

Add PyDict mode
parents 773b7a86 3bf6c92d
......@@ -11,12 +11,6 @@ jobs:
strategy:
matrix:
go:
- "1.12.x"
- "1.13.x"
- "1.14.x"
- "1.15.x"
- "1.16.x"
- "1.17.x"
- "1.18.x"
- "1.19.x"
- "1.20.x"
......
package ogórek
// Python-like Dict that handles keys by Python-like equality on access.
//
// For example Dict.Get() will access the same element for all keys int(1), float64(1.0) and big.Int(1).
import (
"encoding/binary"
"fmt"
"hash/maphash"
"math"
"math/big"
"reflect"
"sort"
"github.com/aristanetworks/gomap"
)
// Dict represents dict from Python in PyDict mode.
//
// It mirrors Python with respect to which types are allowed to be used as
// keys, and with respect to keys equality. For example Tuple is allowed to be
// used as key, and all int(1), float64(1.0) and big.Int(1) are considered to be
// equal.
//
// For strings, similarly to Python3, Bytes and string are considered to be not
// equal, even if their underlying content is the same. However with same
// underlying content ByteString, because it represents str type from Python2,
// is treated equal to both Bytes and string.
//
// See PyDict mode documentation in top-level package overview for details.
//
// Note: similarly to builtin map Dict is pointer-like type: its zero-value
// represents nil dictionary that is empty and invalid to use Set on.
type Dict struct {
m *gomap.Map[any, any]
}
// NewDict returns new empty dictionary.
func NewDict() Dict {
return NewDictWithSizeHint(0)
}
// NewDictWithSizeHint returns new empty dictionary with preallocated space for size items.
func NewDictWithSizeHint(size int) Dict {
return Dict{m: gomap.NewHint[any, any](size, equal, hash)}
}
// NewDictWithData returns new dictionary with preset data.
//
// kv should be key₁, value₁, key₂, value₂, ...
func NewDictWithData(kv ...any) Dict {
l := len(kv)
if l % 2 != 0 {
panic("odd number of arguments")
}
l /= 2
d := NewDictWithSizeHint(l)
for i := 0; i < l; i++ {
k := kv[2*i]
v := kv[2*i+1]
d.Set(k, v)
}
return d
}
// Get returns value associated with equal key.
//
// An entry with key equal to the query is looked up and corresponding value
// is returned.
//
// nil is returned if no matching key is present in the dictionary.
//
// Get panics if key's type is not allowed to be used as Dict key.
func (d Dict) Get(key any) any {
value, _ := d.Get_(key)
return value
}
// Get_ is comma-ok version of Get.
func (d Dict) Get_(key any) (value any, ok bool) {
return d.m.Get(key)
}
// Set sets key to be associated with value.
//
// Any previous keys, equal to the new key, are removed from the dictionary
// before the assignment.
//
// Set panics if key's type is not allowed to be used as Dict key.
func (d Dict) Set(key, value any) {
// ByteString and container(with ByteString) are non-transitive equal types
// so Set(ByteString) should first remove Bytes and string,
// and Set(Tuple{ByteString) should first remove Tuple{Bytes} and Tuple{string}
d.Del(key)
d.m.Set(key, value)
}
// Del removes equal keys from the dictionary.
//
// All entries with key equal to the query are looked up and removed.
//
// Del panics if key's type is not allowed to be used as Dict key.
func (d Dict) Del(key any) {
// see comment in Set about ByteString and container(with ByteString)
for {
d.m.Delete(key)
_, have := d.Get_(key)
if !have {
break
}
}
}
// Len returns the number of items in the dictionary.
func (d Dict) Len() int {
return d.m.Len()
}
// Iter returns iterator over all elements in the dictionary.
//
// The order to visit entries is arbitrary.
func (d Dict) Iter() /* iter.Seq2 */ func(yield func(any, any) bool) {
it := d.m.Iter()
return func(yield func(any, any) bool) {
for it.Next() {
cont := yield(it.Key(), it.Elem())
if !cont {
break
}
}
}
}
// String returns human-readable representation of the dictionary.
func (d Dict) String() string {
return d.sprintf("%v")
}
// GoString returns detailed human-readable representation of the dictionary.
func (d Dict) GoString() string {
return fmt.Sprintf("%T%s", d, d.sprintf("%#v"))
}
// sprintf serves String and GoString.
func (d Dict) sprintf(format string) string {
type KV struct { k,v string }
vkv := make([]KV, 0, d.Len())
d.Iter()(func(k, v any) bool {
vkv = append(vkv, KV{
k: fmt.Sprintf(format, k),
v: fmt.Sprintf(format, v),
})
return true
})
sort.Slice(vkv, func(i, j int) bool {
return vkv[i].k < vkv[j].k
})
s := "{"
for i, kv := range vkv {
if i > 0 {
s += ", "
}
s += kv.k + ": " + kv.v
}
s += "}"
return s
}
// ---- equal ----
// kind represents to which category a type belongs.
//
// It primarily classifies bool, numbers, slices, structs and maps, and puts
// everything else into "other" category.
type kind uint
const (
kBool = iota
kInt // int + intX
kUint // uint + uintX
kFloat // floatX
kComplex // complexX
kBigInt // *big.Int
kSlice // slice + array
kMap // map
kStruct // struct
kPointer // pointer
kOther // everything else
)
// kindOf returns kind of x.
func kindOf(x any) kind {
r := reflect.ValueOf(x)
switch r.Kind() {
case reflect.Bool:
return kBool
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
return kInt
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
return kUint
case reflect.Float64, reflect.Float32:
return kFloat
case reflect.Complex128, reflect.Complex64:
return kComplex
case reflect.Slice, reflect.Array:
return kSlice
case reflect.Map:
return kMap
case reflect.Struct:
return kStruct
}
switch x.(type) {
case *big.Int:
return kBigInt
}
switch r.Kind() {
case reflect.Pointer:
return kPointer
}
return kOther
}
// equal implements equality matching what Python would return for a == b.
//
// Equality properties:
//
// 1) equality is extension of Go ==
//
// (a == b) ⇒ equal(a,b)
//
// 2) self equal:
//
// equal(a,a) = y
//
// 3) equality is symmetrical:
//
// equal(a,b) = equal(b,a)
//
// 4) equality is mostly transitive:
//
// EqTransitive = set of all x:
// ∀ a,b,c ∈ EqTransitive:
// equal(a,b) ^ equal(b,c) ⇒ equal(a,c)
//
// EqTransitive = all \ {ByteString + containers with ByteString}
func equal(xa, xb any) bool {
// strings/bytes
switch a := xa.(type) {
case string:
switch b := xb.(type) {
case string: return a == b
case ByteString: return a == string(b)
case Bytes: return false
default: return false
}
case ByteString:
switch b := xb.(type) {
case string: return a == ByteString(b)
case ByteString: return a == b
case Bytes: return a == ByteString(b)
default: return false
}
case Bytes:
switch b := xb.(type) {
case string: return false
case ByteString: return a == Bytes(b)
case Bytes: return a == b
default: return false
}
}
// everything else
a := reflect.ValueOf(xa)
b := reflect.ValueOf(xb)
ak := kindOf(xa)
bk := kindOf(xb)
// since equality is symmetric, we can implement only half of comparison matrix
if ak > bk {
a, b = b, a
ak, bk = bk, ak
xa, xb = xb, xa
}
// ak ≤ bk
handled := true
switch ak {
default:
handled = false
// numbers
case kBool:
// bool compares to numbers as 1 or 0
//
// In [1]: 1.0 == True
// Out[1]: True
//
// In [2]: 0.0 == False
// Out[2]: True
//
// In [3]: d = {1: 'abc'}
//
// In [4]: d[True]
// Out[4]: 'abc'
abint := bint(a.Bool())
switch bk {
case kBool: return eq_Int_Int (abint, bint(b.Bool()))
case kInt: return eq_Int_Int (abint, b.Int())
case kUint: return eq_Int_Uint (abint, b.Uint())
case kFloat: return eq_Int_Float (abint, b.Float())
case kComplex: return eq_Int_Complex (abint, b.Complex())
case kBigInt: return eq_Int_BigInt (abint, xb.(*big.Int))
}
case kInt:
aint := a.Int()
switch bk {
// kBool
case kInt: return eq_Int_Int (aint, b.Int())
case kUint: return eq_Int_Uint (aint, b.Uint())
case kFloat: return eq_Int_Float (aint, b.Float())
case kComplex: return eq_Int_Complex (aint, b.Complex())
case kBigInt: return eq_Int_BigInt (aint, xb.(*big.Int))
}
case kUint:
auint := a.Uint()
switch bk {
// kBool
// kInt
case kUint: return eq_Uint_Uint (auint, b.Uint())
case kFloat: return eq_Uint_Float (auint, b.Float())
case kComplex: return eq_Uint_Complex (auint, b.Complex())
case kBigInt: return eq_Uint_BigInt (auint, xb.(*big.Int))
}
case kFloat:
afloat := a.Float()
switch bk {
// kBool
// kInt
// kUint
case kFloat: return eq_Float_Float (afloat, b.Float())
case kComplex: return eq_Float_Complex (afloat, b.Complex())
case kBigInt: return eq_Float_BigInt (afloat, xb.(*big.Int))
}
case kComplex:
acomplex := a.Complex()
switch bk {
// kBool
// kInt
// kUint
// kFloat
case kComplex: return eq_Complex_Complex (acomplex, b.Complex())
case kBigInt: return eq_Complex_BigInt (acomplex, xb.(*big.Int))
}
case kBigInt:
switch bk {
// kBool
// kInt
// kUint
// kFloat
// kComplex
case kBigInt: return eq_BigInt_BigInt (xa.(*big.Int), xb.(*big.Int))
}
// slices
case kSlice:
switch bk {
case kSlice: return eq_Slice_Slice (a, b)
}
// builtin map
case kMap:
switch bk {
case kMap: return eq_Map_Map (a, b)
}
switch b := xb.(type) {
case Dict: return eq_Map_Dict (a, b)
}
}
if handled {
return false
}
// our types that need special handling
switch a := xa.(type) {
case Dict:
switch b := xb.(type) {
case Dict: return eq_Dict_Dict(a, b)
default: return false
}
}
// structs (also covers None, Class, Call etc...)
switch ak {
case kStruct:
switch bk {
case kStruct: return eq_Struct_Struct (a, b)
default: return false
}
}
return (xa == xb) // fallback to builtin equality
}
// equality matrix. nontrivial elements
func eq_Int_Uint(a int64, b uint64) bool {
if a >= 0 {
return uint64(a) == b
}
return false
}
func eq_Int_BigInt(a int64, b *big.Int) bool {
if b.IsInt64() {
return a == b.Int64()
}
return false
}
func eq_Uint_BigInt(a uint64, b *big.Int) bool {
if b.IsUint64() {
return a == b.Uint64()
}
return false
}
func eq_Float_BigInt(a float64, b *big.Int) bool {
bf, accuracy := bigInt_Float64(b)
if accuracy == big.Exact {
return a == bf
}
return false
}
func eq_Complex_BigInt(a complex128, b *big.Int) bool {
if imag(a) == 0 {
return eq_Float_BigInt(real(a), b)
}
return false
}
func eq_BigInt_BigInt(a, b *big.Int) bool {
return (a.Cmp(b) == 0)
}
func eq_Slice_Slice(a, b reflect.Value) bool {
al := a.Len()
bl := b.Len()
if al != bl {
return false
}
for i := 0; i < al; i++ {
if !equal(a.Index(i).Interface(), b.Index(i).Interface()) {
return false
}
}
return true
}
func eq_Struct_Struct(a, b reflect.Value) bool {
if a.Type() != b.Type() {
return false
}
typ := a.Type()
l := typ.NumField()
for i := 0; i < l; i++ {
af := a.Field(i)
bf := b.Field(i)
// .Interface() is not allowed if the field is private.
// Work around the protection via unsafe. We may need to switch
// to struct copy if it is not addressable because Addr() is
// used in the workaround. https://stackoverflow.com/a/43918797/9456786
ftyp := typ.Field(i)
if !ftyp.IsExported() {
if !af.CanAddr() {
// switch a to addressable copy
a_ := reflect.New(typ).Elem()
a_.Set(a)
a = a_
af = a.Field(i)
}
if !bf.CanAddr() {
// switch b to addressable copy
b_ := reflect.New(typ).Elem()
b_.Set(b)
b = b_
bf = b.Field(i)
}
af = reflect.NewAt(ftyp.Type, af.Addr().UnsafePointer()).Elem()
bf = reflect.NewAt(ftyp.Type, bf.Addr().UnsafePointer()).Elem()
}
if !equal(af.Interface(), bf.Interface()) {
return false
}
}
return true
}
func eq_Dict_Dict(a Dict, b Dict) bool {
// dicts D₁ and D₂ are considered equal if the following is true:
//
// - len(D₁) = len(D₂)
// - ∀ k ∈ D₁ equal(D₁[k], D₂[k]) = y
// - ∀ k ∈ D₂ equal(D₁[k], D₂[k]) = y
//
// this definition is reasonable and fast to implement without additional memory.
// Also if D₁ and D₂ have keys only from equal-transitive subset of all
// keys (i.e. anything without ByteString), it becomes equivalent to the
// following definition:
//
// - (k₁i, v₁i) is set of all key/values from D₁
// - (k₂j, v₂j) is set of all key/values from D₂
// - equal(D₁,D₂):
//
// ∃ 1-1 mapping in between i<->j: equal(k₁i, k₂j) ^ equal(v₁i, v₂j)
if a.Len() != b.Len() {
return false
}
eq := true
a.Iter()(func(k,va any) bool {
vb, ok := b.Get_(k)
if !ok || !equal(va, vb) {
eq = false
return false
}
return true
})
if !eq {
return false
}
b.Iter()(func(k,vb any) bool {
va, ok := a.Get_(k)
if !ok || !equal(va, vb) {
eq = false
return false
}
return true
})
return eq
}
// equal(Map, Dict) and equal(Map, Map) follow semantic of equal(Dict, Dict)
func eq_Map_Dict(a reflect.Value, b Dict) bool {
if a.Len() != b.Len() {
return false
}
aKeyType := a.Type().Key()
ai := a.MapRange()
for ai.Next() {
k := ai.Key().Interface()
va := ai.Value().Interface()
vb, ok := b.Get_(k)
if !ok || !equal(va, vb) {
return false
}
}
eq := true
b.Iter()(func(k,vb any) bool {
xk := reflect.ValueOf(k)
if !xk.Type().AssignableTo(aKeyType) {
eq = false
return false
}
xva := a.MapIndex(xk)
if !(xva.IsValid() && equal(xva.Interface(), vb)) {
eq = false
return false
}
return true
})
return eq
}
func eq_Map_Map(a reflect.Value, b reflect.Value) bool {
if a.Len() != b.Len() {
return false
}
aKeyType := a.Type().Key()
bKeyType := b.Type().Key()
ai := a.MapRange()
for ai.Next() {
k := ai.Key().Interface() // NOTE xk != ai.Key() because that might have type any
xk := reflect.ValueOf(k) // while xk has type of particular contained value
va := ai.Value().Interface()
if !xk.Type().AssignableTo(bKeyType) {
return false
}
xvb := b.MapIndex(xk)
if !(xvb.IsValid() && equal(va, xvb.Interface())) {
return false
}
}
bi := b.MapRange()
for bi.Next() {
k := bi.Key().Interface() // see ^^^
xk := reflect.ValueOf(k)
vb := bi.Value().Interface()
if !xk.Type().AssignableTo(aKeyType) {
return false
}
xva := a.MapIndex(xk)
if !(xva.IsValid() && equal(xva.Interface(), vb)) {
return false
}
}
return true
}
// equality matrix. trivial elements
func eq_Int_Int (a int64, b int64) bool { return a == b }
func eq_Int_Float (a int64, b float64) bool { return float64(a) == b }
func eq_Int_Complex (a int64, b complex128) bool { return complex(float64(a), 0) == b }
func eq_Uint_Uint (a uint64, b uint64) bool { return a == b }
func eq_Uint_Float (a uint64, b float64) bool { return float64(a) == b }
func eq_Uint_Complex (a uint64, b complex128) bool { return complex(float64(a), 0) == b }
func eq_Float_Float (a float64, b float64) bool { return a == b }
func eq_Float_Complex (a float64, b complex128) bool { return complex(a, 0) == b }
func eq_Complex_Complex (a complex128, b complex128) bool { return a == b }
// ---- hash ----
// hash returns hash of x consistent with equality implemented by equal.
//
// equal(a,b) ⇒ hash(a) = hash(b)
//
// hash panics with "unhashable type: ..." if x is not allowed to be used as Dict key.
func hash(seed maphash.Seed, x any) uint64 {
// strings/bytes use standard hash of string
switch v := x.(type) {
case string: return maphash_String(seed, v)
case ByteString: return maphash_String(seed, string(v))
case Bytes: return maphash_String(seed, string(v))
}
// for everything else we implement custom hashing ourselves to match equal
var h maphash.Hash
h.SetSeed(seed)
hash_Uint := func(u uint64) {
var b [8]byte
binary.BigEndian.PutUint64(b[:], u)
h.Write(b[:])
}
hash_Int := func(i int64) {
hash_Uint(uint64(i))
}
hash_Float := func(f float64) {
// if float is in int range and is integer number - hash it as integer
i := int64(f)
f_ := float64(i)
if f_ == f {
hash_Int(i)
// else use raw float64 bytes representation for hashing
} else {
hash_Uint(math.Float64bits(f))
}
}
// numbers
r := reflect.ValueOf(x)
k := kindOf(x)
handled := true
switch k {
default:
handled = false
case kBool: hash_Int(bint(r.Bool()))
case kInt: hash_Int(r.Int())
case kUint: hash_Uint(r.Uint())
case kFloat: hash_Float(r.Float())
case kComplex:
c := r.Complex()
hash_Float(real(c))
if imag(c) != 0 {
hash_Float(imag(c))
}
case kBigInt:
b := x.(*big.Int)
switch {
case b.IsInt64(): hash_Int(b.Int64())
case b.IsUint64(): hash_Uint(b.Uint64())
default:
f, accuracy := bigInt_Float64(b)
if accuracy == big.Exact {
hash_Float(f)
} else {
h.WriteString("bigInt")
h.Write(b.Bytes())
}
}
// kSlice - skip
// kStruct - skip
case kPointer: hash_Uint(uint64(r.Elem().UnsafeAddr()))
}
if handled {
return h.Sum64()
}
// tuple
switch v := x.(type) {
case Tuple:
h.WriteString("tuple")
for _, item := range v {
hash_Uint(hash(seed, item))
}
return h.Sum64()
}
// structs (also covers None, Class, Call etc)
switch k {
case kStruct:
// our types that are handled specially by equal
switch x.(type) {
case Dict:
goto unhashable
}
typ := r.Type()
h.WriteString(typ.Name())
l := typ.NumField()
for i := 0; i < l; i++ {
f := r.Field(i)
// .Interface() is not allowed if the field is private.
// Work it around via unsafe. See eq_Struct_Struct for details.
ftyp := typ.Field(i)
if !ftyp.IsExported() {
if !f.CanAddr() {
// switch r to addressable copy
r_ := reflect.New(typ).Elem()
r_.Set(r)
r = r_
f = r.Field(i)
}
f = reflect.NewAt(ftyp.Type, f.Addr().UnsafePointer()).Elem()
}
hash_Uint(hash(seed, f.Interface()))
}
return h.Sum64()
}
unhashable:
panic(fmt.Sprintf("unhashable type: %T", x))
}
// ---- misc ----
// bint returns int corresponding to bool.
//
// true -> 1
// false -> 0
func bint(x bool) int64 {
if x {
return 1
}
return 0
}
package ogórek
import (
"fmt"
"hash/maphash"
"reflect"
"strings"
"testing"
)
// tStructWithPrivate is used by tests to verify handing of struct with private fields.
type tStructWithPrivate struct {
x, y any
}
// TestEqual verifies equal and hash.
func TestEqual(t *testing.T) {
// tEqualSet represents tested set of values:
// ∀ a ∈ tEqualSet:
// ∀ b ∈ tEqualSet ⇒ equal(a,b) = y
// ∀ c ∉ tEqualSet ⇒ equal(a,c) = n
//
// Intersection in between different tEqualSets is mostly empty: such
// intersections can contain elements only from all \ EqTransitive, i.e. only ByteString.
type tAllEqual []any
// E is shortcut to create tEqualSet
E := func(v ...any) tAllEqual { return tAllEqual(v) }
// D and M are shortcuts to create Dict and map[any]any
D := NewDictWithData
type M = map[any]any
// i1 and i1_ are two integer variables equal to 1 but with different address
// obj and obj_ are similar equal structures located at different memory regions
i1 := 1; i1_ := 1
obj := &Class{"a","b"}; obj_ := &Class{"a","b"}
// testv is vector of all test-cases
testv := []tAllEqual {
// numbers
E(int(0),
int64(0), int32(0), int16(0), int8(0),
uint64(0), uint32(0), uint16(0), uint8(0),
bigInt("0"),
false,
float32 (0), float64 (0),
complex64(0), complex128(0)),
E(int(1),
int64 (1), int32(1), int16(1), int8(1),
uint64(1), uint32(1), uint16(1), uint8(1),
bigInt("1"),
true,
float32 (1), float64 (1),
complex64(1), complex128(1)),
E(int(-1),
int64(-1), int32(-1), int16(-1), int8(-1),
// NOTE no uintX because they ≥ 0 only
bigInt("-1"),
// NOTE no bool because it ∈ {0,1}
float32 (-1), float64 (-1),
complex64(-1), complex128(-1)),
// intX/uintX different range
E(int(0xff),
int64(0xff), int32(0xff), int16(0xff), // int8(overflow),
uint64(0xff), uint32(0xff), uint16(0xff), // uint8(overflow),
bigInt("255"),
bigInt("255"), // two different *big.Int instances
float32 (0xff), float64 (0xff),
complex64(0xff), complex128(0xff)),
E(int(-0x80),
int64(-0x80), int32(-0x80), int16(-0x80), int8(-0x80),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-128"),
float32 (-0x80), float64 (-0x80),
complex64(-0x80), complex128(-0x80)),
E(int(0xffff),
int64(0xffff), int32(0xffff), // int16(overflow), int8(overflow),
uint64(0xffff), uint32(0xffff), uint16(0xffff), // uint8(overflow),
bigInt("65535"),
float32 (0xffff), float64 (0xffff),
complex64(0xffff), complex128(0xffff)),
E(int(-0x8000),
int64(-0x8000), int32(-0x8000), int16(-0x8000), // int8(overflow),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-32768"),
float32 (-0x8000), float64 (-0x8000),
complex64(-0x8000), complex128(-0x8000)),
E(int(0xffffffff),
int64(0xffffffff), // int32(overflow), int16(overflow), int8(overflow),
uint64(0xffffffff), uint32(0xffffffff), // uint16(overflow), uint8(overflow),
bigInt("4294967295"),
/* float32 (precision loss), */ float64 (0xffffffff),
/* complex64(precision loss), */ complex128(0xffffffff)),
E(int(-0x80000000),
int64(-0x80000000), int32(-0x80000000), // int16(overflow), int8(overflow),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-2147483648"),
float32 (-0x80000000), float64 (-0x80000000),
complex64(-0x80000000), complex128(-0x80000000)),
E(// int(overflow),
// int64(overflow), int32(overflow), int16(overflow), int8(overflow),
uint64(0xffffffffffffffff), // uint32(overflow), uint16(overflow), uint8(overflow),
bigInt("18446744073709551615")),
// float32 (precision loss), float64 (precision loss),
// complex64(precision loss), complex128(precision loss)),
E(int(-0x8000000000000000),
int64(-0x8000000000000000), // int32(overflow), int16(overflow), int8(overflow),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-9223372036854775808"),
float32 (-0x8000000000000000), float64 (-0x8000000000000000),
complex64(-0x8000000000000000), complex128(-0x8000000000000000)),
E(bigInt("1"+strings.Repeat("0",22)), float64(1e22), complex128(complex(1e22,0))),
E(complex64(complex(0,1)), complex128(complex(0,1))),
E(float64(1.25), float32(1.25), complex64(complex(1.25,0)), complex128(complex(1.25,0))),
// strings/bytes
E("", ByteString("")), E(ByteString(""), Bytes("")),
E("a", ByteString("a")), E(ByteString("a"), Bytes("a")),
E("мир", ByteString("мир")), E(ByteString("мир"), Bytes("мир")),
// none / empty tuple|list
E(None{}),
E(Tuple{}, []any{}),
// sequences
E([]int{}, []float32{}, []any{}, Tuple{}, [0]float64{}),
E([]int{1,2}, []float32{1,2}, []any{1,2}, Tuple{1,2}, [2]float64{1,2}),
E([]any{1,"a"}, Tuple{1,"a"}, [2]any{1,"a"}, Tuple{1,ByteString("a")}),
// Dict, map
E(D(),
M{}, map[int]bool{}),
E(D(1,bigInt("2")),
M{1:2.0}, map[int]int{1:2}),
E(D(1,"a"),
M{1:"a"}, map[int]string{1:"a"}),
E(D("a",1),
M{"a":1}),
E(D("a",1, None{},2),
M{"a":1, None{}:2}),
E(D("a",1, Bytes("a"),1),
M{"a":1, Bytes("a"):1}),
E(D("a",1, Bytes("a"),2),
M{"a":1, Bytes("a"):2}),
E(D("a",1), D(ByteString("a"),1)), E(D(ByteString("a"),1), D(Bytes("a"),1)),
E(D("a",1, Bytes("a"),1, ByteString("b"),2),
D(ByteString("a"),1, "b",2, Bytes("b"),2)),
// structs
E(Class{"mod","cls"}, Class{"mod","cls"}),
E(Call{Class{"mod","cls"}, Tuple{"a","b",3}},
Call{Class{"mod","cls"}, Tuple{ByteString("a"),"b",bigInt("3")}}),
E(Ref{1}, Ref{bigInt("1")}, Ref{1.0}),
E(tStructWithPrivate{"a",1}, tStructWithPrivate{ByteString("a"),bigInt("1")}),
E(tStructWithPrivate{"b",2}, tStructWithPrivate{"b",2.0}),
// pointers, as in builtin ==, are compared only by address
E(&i1), E(&i1_), E(&obj), E(&obj_),
// nil
E(nil),
}
// automatically test equality on Tuples/list from ^^^ data
testvAddSequences := func() {
l := len(testv)
for i := 0; i < l; i++ {
Ex := testv[i]
Ey := testv[(i+1)%l]
x0 := Ex[0]; x1 := Ex[1%len(Ex)]
y0 := Ey[0]; y1 := Ey[1%len(Ey)]
t1 := Tuple{x0,y0}; l1 := []any{x0,y0}
t2 := Tuple{x1,y1}; l2 := []any{x1,y1}
testv = append(testv, E(t1, t2, l1, l2))
}
}
testvAddSequences()
// and sequences of sequences
testvAddSequences()
// thash is used to invoke hash.
// if x is not hashable ok=false is returned instead of panic.
tseed := maphash.MakeSeed()
thash := func(x any) (h uint64, ok bool) {
defer func() {
r := recover()
if r != nil {
s, sok := r.(string)
if sok && strings.HasPrefix(s, "unhashable type: ") {
ok = false
h = 0
} else {
panic(r)
}
}
}()
return hash(tseed, x), true
}
// tequal is used to invoke equal.
// it automatically checks Go-extension, self-equal, symmetry and hash invariants:
//
// a==b ⇒ equal(a,b)
// equal(a,a) = y
// equal(a,b) = equal(b,a)
// equal(a,b) ⇒ hash(a) = hash(b)
tequal := func(a, b any) bool {
aa := equal(a, a)
bb := equal(b, b)
if !aa {
t.Errorf("not self-equal %T %#v", a,a)
}
if !bb {
t.Errorf("not self-equal %T %#v", b,b)
}
eq := equal(a, b)
qe := equal(b, a)
if eq != qe {
t.Errorf("equal not symmetric: %T %#v %T %#v; a == b: %v b == a: %v",
a,a, b,b, eq, qe)
}
ah, ahOk := thash(a)
bh, bhOk := thash(b)
if eq && ahOk && bhOk && !(ah == bh) {
t.Errorf("hash different of equal %T %#v hash:%x %T %#v hash:%x",
a,a,ah, b,b,bh)
}
goeq := false
func() {
// a == b can trigger "comparing uncomparable type ..."
// even if reflect reports both types as comparable
// (see mapTryAssign for details)
defer func() {
recover()
}()
goeq = (a == b)
}()
if goeq && !eq {
t.Errorf("equal is not extension of == %T %#v %T %#v",
a,a, b,b)
}
return eq
}
// EHas returns whether x ∈ E.
EHas := func(E tAllEqual, x any) bool {
for _, a := range E {
if tequal(a, x) {
return true
}
}
return false;
}
// do the tests
for i, E1 := range testv {
// ∀ a,b ∈ tEqualSet ⇒ equal(a,b) = y
for _, a := range E1 {
for _, b := range E1 {
if !tequal(a,b) {
t.Errorf("not equal %T %#v %T %#v", a,a, b,b)
}
}
}
// ∀ a ∈ tEqualSet
// ∀ c ∉ tEqualSet ⇒ equal(a,c) = n
for j, E2 := range testv {
if j == i {
continue
}
for _, a := range E1 {
for _, c := range E2 {
if EHas(E1, c) {
continue
}
if tequal(a,c) {
t.Errorf("equal %T %#v %T %#v", a,a, c,c)
}
}
}
}
}
}
// TestDict verifies Dict.
func TestDict(t *testing.T) {
d := NewDict()
// assertData asserts that d has data exactly as specified by provided key,value pairs.
assertData := func(kvok ...any) {
t.Helper()
if len(kvok) % 2 != 0 {
panic("kvok % 2 != 0")
}
lok := len(kvok)/2
kvokGet := func(k any) (any, bool) {
t.Helper()
for i := 0; i < lok; i++ {
kok := kvok[2*i]
vok := kvok[2*i+1]
if reflect.TypeOf(k) == reflect.TypeOf(kok) &&
equal(k, kok) {
return vok, true
}
}
return nil, false
}
bad := false
badf := func(format string, argv ...any) {
t.Helper()
bad = true
t.Errorf(format, argv...)
}
l := d.Len()
if l != lok {
badf("len: have: %d want: %d", l, lok)
}
d.Iter()(func(k,v any) bool {
t.Helper()
vok, ok := kvokGet(k)
if !ok {
badf("unexpected key %#v", k)
}
if v != vok {
badf("key %T %#v -> value %#T %#v ; want %T %#v", k,k, v,v, vok,vok)
}
return true
})
if bad {
t.Fatalf("\nd: %#v\nkvok: %#v", d, kvok)
}
}
// assertGet asserts that d.Get(k) results in exactly vok or any element from vokExtra.
assertGet := func(k any, vok any, vokExtra ...any) {
t.Helper()
v := d.Get(k)
if v == vok {
return
}
for _, eok := range vokExtra {
if v == eok {
return
}
}
emsg := fmt.Sprintf("get %#v: have: %#v want: %#v", k, v, vok)
for _, eok := range vokExtra {
emsg += fmt.Sprintf(" ∪ %#v", eok)
}
emsg += fmt.Sprintf("\nd: %#v", d)
t.Fatal(emsg)
}
// numbers
assertData()
d.Set(1, "x")
assertData(1,"x")
assertGet(1, "x")
assertGet(1.0, "x")
assertGet(bigInt("1"), "x")
assertGet(complex(1,0), "x")
d.Del(7)
assertData(1,"x")
assertGet(1, "x")
assertGet(1.0, "x")
assertGet(bigInt("1"), "x")
assertGet(complex(1,0), "x")
d.Set(2.5, "y")
assertData(1,"x", 2.5,"y")
assertGet(1, "x")
assertGet(1.0, "x")
assertGet(bigInt("1"), "x")
assertGet(complex(1,0), "x")
assertGet(2, nil)
assertGet(2.5, "y")
assertGet(bigInt("2"), nil)
assertGet(complex(2.5,0), "y")
d.Del(1)
assertData(2.5,"y")
assertGet(1, nil)
assertGet(1.0, nil)
assertGet(bigInt("1"), nil)
assertGet(complex(1,0), nil)
assertGet(2, nil)
assertGet(2.5, "y")
assertGet(bigInt("2"), nil)
assertGet(complex(2.5,0), "y")
d.Del(2.5)
assertData()
assertGet(1, nil)
assertGet(1.0, nil)
assertGet(bigInt("1"), nil)
assertGet(complex(1,0), nil)
assertGet(2, nil)
assertGet(2.5, nil)
assertGet(bigInt("2"), nil)
assertGet(complex(2.5,0), nil)
// strings/bytes
assertData()
assertGet("abc", nil)
d.Set("abc", "a")
assertData("abc","a")
assertGet("abc", "a")
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), "a")
d.Set(Bytes("abc"), "b")
assertData("abc","a", Bytes("abc"),"b")
assertGet("abc", "a")
assertGet(Bytes("abc"), "b")
assertGet(ByteString("abc"), "a", "b")
d.Set(ByteString("abc"), "c")
assertData(ByteString("abc"),"c")
assertGet("abc", "c")
assertGet(Bytes("abc"), "c")
assertGet(ByteString("abc"), "c")
d.Del("abc")
assertData()
assertGet("abc", nil)
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), nil)
d.Set("abc", "a")
assertData("abc","a")
assertGet("abc", "a")
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), "a")
d.Set(Bytes("abc"), "b")
assertData("abc","a", Bytes("abc"),"b")
assertGet("abc", "a")
assertGet(Bytes("abc"), "b")
assertGet(ByteString("abc"), "a", "b")
d.Del(ByteString("abc"))
assertData()
assertGet("abc", nil)
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), nil)
// None, tuple
assertData()
d.Set(None{}, "n")
assertData(None{},"n")
assertGet(None{}, "n")
assertGet(Tuple{}, nil)
d.Set(Tuple{}, "t")
assertData(None{},"n", Tuple{},"t")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
d.Set(Tuple{1,2,"a"}, "t12a")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, nil)
assertGet(Tuple{1,2,ByteString("a")}, "t12a")
d.Set(Tuple{1,2,Bytes("a")}, "t12b")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a", Tuple{1,2,Bytes("a")},"t12b")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, "t12b")
assertGet(Tuple{1,2,ByteString("a")}, "t12a", "t12b")
d.Set(Tuple{1,2,ByteString("a")}, "t12c")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,ByteString("a")},"t12c")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12c")
assertGet(Tuple{1,2,Bytes("a")}, "t12c")
assertGet(Tuple{1,2,ByteString("a")}, "t12c")
d.Set(Tuple{1,2,"a"}, "t12a")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, nil)
assertGet(Tuple{1,2,ByteString("a")}, "t12a")
d.Set(Tuple{1,2,Bytes("a")}, "t12b")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a", Tuple{1,2,Bytes("a")},"t12b")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, "t12b")
assertGet(Tuple{1,2,ByteString("a")}, "t12a", "t12b")
d.Del(Tuple{1,2,ByteString("a")})
assertData(None{},"n", Tuple{},"t")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, nil)
assertGet(Tuple{1,2,Bytes("a")}, nil)
assertGet(Tuple{1,2,ByteString("a")}, nil)
// structs
d = NewDict()
d.Set(Class{"a","b"}, 1)
d.Set(Class{"c","d"}, 2)
d.Set(Ref{"a"}, 3)
d.Set(tStructWithPrivate{"x","y"}, 4)
assertData(Class{"a","b"},1, Class{"c","d"},2, Ref{"a"},3, tStructWithPrivate{"x","y"},4)
assertGet(Class{"a","b"}, 1)
assertGet(Class{"c","d"}, 2)
assertGet(Class{"x","y"}, nil)
assertGet(Ref{"a"}, 3)
assertGet(Ref{"x"}, nil)
assertGet(tStructWithPrivate{"x","y"}, 4)
assertGet(tStructWithPrivate{"p","q"}, nil)
// pointers
i := 1
j := 1
k := 1
x := Class{"a","b"}
y := Class{"a","b"}
z := Class{"a","b"}
d = NewDict()
d.Set(&i, 1)
d.Set(&j, 2)
d.Set(&x, 3)
d.Set(&y, 4)
assertData(&i,1, &j,2, &x,3, &y,4)
assertGet(&i, 1)
assertGet(&j, 2)
assertGet(&k, nil)
assertGet(&x, 3)
assertGet(&y, 4)
assertGet(&z, nil)
// NewDictWithSizeHint
d = NewDictWithSizeHint(100)
assertData()
assertGet(1, nil)
assertGet(2, nil)
assertGet("a", nil)
assertGet("b", nil)
// NewDictWithData
d = NewDictWithData("a",1, 2,"b")
assertData("a",1, 2,"b")
assertGet(1, nil)
assertGet(2, "b")
assertGet("a", 1)
assertGet("b", nil)
// unhashable types
vbad := []any{
[]any{},
[]any{1,2,3},
[]int{},
[]int{1,2,3},
NewDict(),
map[any]any{},
map[int]bool{},
Ref{[]any{}},
tStructWithPrivate{1,[]any{}},
tStructWithPrivate{[]any{},1},
tStructWithPrivate{[]any{},[]any{}},
}
assertPanics := func(subj any, errPrefix string, f func()) {
t.Helper()
defer func() {
t.Helper()
r := recover()
if r == nil {
t.Errorf("%#v: no panic", subj)
return
}
s, ok := r.(string)
if ok && strings.HasPrefix(s, errPrefix) {
// ok
} else {
panic(r)
}
}()
f()
}
for _, k := range vbad {
assertUnhashable := func(f func()) {
t.Helper()
assertPanics(k, "unhashable type: ", f)
}
assertUnhashable(func() { d.Get(k) })
assertUnhashable(func() { d.Set(k, 1) })
assertUnhashable(func() { d.Del(k) })
assertUnhashable(func() { NewDictWithData(k,1) })
}
// = ~nil
d = Dict{}
assertData()
assertGet(1, nil)
assertGet(2, nil)
assertGet("a", nil)
assertGet("b", nil)
d.Del(1)
assertData()
assertGet(1, nil)
assertGet(2, nil)
assertGet("a", nil)
assertGet("b", nil)
assertPanics("nil.Set", "Set called on nil map", func() { d.Set(1, "x") })
}
// benchmarks for map and Dict compare them from performance point of view.
func BenchmarkMapGet(b *testing.B) {
m := map[any]any{}
for i := 0; i < 100; i++ {
m[i] = i
}
m["abc"] = 777
b.Run("string", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = m["abc"]
}
})
b.Run("int", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = m[77]
}
})
}
func BenchmarkDictGet(b *testing.B) {
d := NewDict()
for i := 0; i < 100; i++ {
d.Set(i, i)
}
d.Set("abc", 777)
b.Run("string", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = d.Get("abc")
}
})
b.Run("int", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = d.Get(77)
}
})
}
......@@ -22,12 +22,29 @@
// long ↔ *big.Int
// float ↔ float64
// float ← floatX
// list ↔ []interface{}
// list ↔ []any
// tuple ↔ ogórek.Tuple
// dict ↔ map[interface{}]interface{}
//
//
// For strings there are two modes. In the first, default, mode both py2/py3
// For dicts there are two modes. In the first, default, mode Python dicts are
// decoded into standard Go map. This mode tries to use builtin Go type, but
// cannot mirror py behaviour fully because e.g. int(1), big.Int(1) and
// float64(1.0) are all treated as different keys by Go, while Python treats
// them as being equal. It also does not support decoding dicts with tuple
// used in keys:
//
// dict ↔ map[any]any PyDict=n mode, default
// ← ogórek.Dict
//
// With PyDict=y mode, however, Python dicts are decoded as ogórek.Dict which
// mirrors behaviour of Python dict with respect to keys equality, and with
// respect to which types are allowed to be used as keys.
//
// dict ↔ ogórek.Dict PyDict=y mode
// ← map[any]any
//
//
// For strings there are also two modes. In the first, default, mode both py2/py3
// str and py2 unicode are decoded into string with py2 str being considered
// as UTF-8 encoded. Correspondingly for protocol ≤ 2 Go string is encoded as
// UTF-8 encoded py2 str, and for protocol ≥ 3 as py3 str / py2 unicode.
......@@ -155,6 +172,11 @@
// Using the helpers fits into Python3 strings/bytes model but also allows to
// handle the data generated from under Python2.
//
// Similarly Dict considers ByteString to be equal to both string and Bytes
// with the same underlying content. This allows programs to access Dict via
// string/bytes keys following Python3 model, while still being able to handle
// dictionaries generated from under Python2.
//
//
// --------
//
......
......@@ -505,6 +505,41 @@ func (e *Encoder) encodeMap(m reflect.Value) error {
return e.emit(opDict)
}
func (e *Encoder) encodeDict(d Dict) error {
l := d.Len()
// protocol >= 1: ø dict -> EMPTY_DICT
if e.config.Protocol >= 1 && l == 0 {
return e.emit(opEmptyDict)
}
// MARK + ... + DICT
// TODO cycles + sort keys (see encodeMap for details)
err := e.emit(opMark)
if err != nil {
return err
}
d.Iter()(func(k, v any) bool {
err = e.encode(reflectValueOf(k))
if err != nil {
return false
}
err = e.encode(reflectValueOf(v))
if err != nil {
return false
}
return true
})
if err != nil {
return err
}
return e.emit(opDict)
}
func (e *Encoder) encodeCall(v *Call) error {
err := e.encodeClass(&v.Callable)
if err != nil {
......@@ -578,6 +613,8 @@ func (e *Encoder) encodeStruct(st reflect.Value) error {
return e.encodeRef(&v)
case big.Int:
return e.encodeLong(&v)
case Dict:
return e.encodeDict(v)
}
structTags := getStructTags(st)
......
......@@ -5,24 +5,27 @@ package ogórek
import (
"bytes"
"fmt"
"reflect"
)
func Fuzz(data []byte) int {
f1 := fuzz(data, false)
f2 := fuzz(data, true)
f := 0
f += fuzz(data, false, false)
f += fuzz(data, false, true)
f += fuzz(data, true, false)
f += fuzz(data, true, true)
f := f1+f2
if f > 1 {
f = 1
}
return f
}
func fuzz(data []byte, strictUnicode bool) int {
func fuzz(data []byte, pyDict, strictUnicode bool) int {
// obj = decode(data) - this tests things like stack overflow in Decoder
buf := bytes.NewBuffer(data)
dec := NewDecoderWithConfig(buf, &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
})
obj, err := dec.Decode()
......@@ -38,7 +41,7 @@ func fuzz(data []byte, strictUnicode bool) int {
// because obj - as we got it as decoding from input - is known not to
// contain arbitrary Go structs.
for proto := 0; proto <= highestProtocol; proto++ {
subj := fmt.Sprintf("strictUnicode %v: protocol %d", strictUnicode, proto)
subj := fmt.Sprintf("pyDict %v: strictUnicode %v: protocol %d", pyDict, strictUnicode, proto)
buf.Reset()
enc := NewEncoderWithConfig(buf, &EncoderConfig{
......@@ -67,6 +70,7 @@ func fuzz(data []byte, strictUnicode bool) int {
encoded := buf.String()
dec = NewDecoderWithConfig(bytes.NewBufferString(encoded), &DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
})
obj2, err := dec.Decode()
......@@ -75,7 +79,7 @@ func fuzz(data []byte, strictUnicode bool) int {
panic(fmt.Sprintf("%s: decode back error: %s\npickle: %q", subj, err, encoded))
}
if !reflect.DeepEqual(obj, obj2) {
if !deepEqual(obj, obj2) {
panic(fmt.Sprintf("%s: decode·encode != identity:\nhave: %#v\nwant: %#v", subj, obj2, obj))
}
}
......
module github.com/kisielk/og-rek
go 1.12
go 1.18
require github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced
require golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect
github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced h1:HxlRMDx/VeRqzj3nvqX9k4tjeBcEIkoNHDJPsS389hs=
github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced/go.mod h1:p7lmI+ecoe1RTyD11SPXWsSQ3H+pJ4cp5y7vtKW4QdM=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
......@@ -189,6 +189,11 @@ type DecoderConfig struct {
// decoded into ByteString in this mode. See StrictUnicode mode
// documentation in top-level package overview for details.
StrictUnicode bool
// PyDict, when true, requests to decode Python dicts as ogórek.Dict
// instead of builtin map. See PyDict mode documentation in top-level
// package overview for details.
PyDict bool
}
// NewDecoder returns a new Decoder with the default configuration.
......@@ -1009,29 +1014,75 @@ func mapTryAssign(m map[interface{}]interface{}, key, value interface{}) (ok boo
return
}
// dictTryAssign is like mapTryAssign but for Dict.
func dictTryAssign(d Dict, key, value interface{}) (ok bool) {
defer func() {
if r := recover(); r != nil {
ok = false
}
}()
d.Set(key, value)
ok = true
return
}
func (d *Decoder) loadDict() error {
k, err := d.marker()
if err != nil {
return err
}
m := make(map[interface{}]interface{}, 0)
items := d.stack[k+1:]
if len(items) % 2 != 0 {
return fmt.Errorf("pickle: loadDict: odd # of elements")
}
var m interface{}
if d.config.PyDict {
m, err = d.loadDictDict(items)
} else {
m, err = d.loadDictMap(items)
}
if err != nil {
return err
}
d.stack = append(d.stack[:k], m)
return nil
}
func (d *Decoder) loadDictMap(items []interface{}) (map[interface{}]interface{}, error) {
m := make(map[interface{}]interface{}, len(items)/2)
for i := 0; i < len(items); i += 2 {
key := items[i]
if !mapTryAssign(m, key, items[i+1]) {
return fmt.Errorf("pickle: loadDict: invalid key type %T", key)
return nil, fmt.Errorf("pickle: loadDict: map: invalid key type %T", key)
}
}
d.stack = append(d.stack[:k], m)
return nil
return m, nil
}
func (d *Decoder) loadDictDict(items []interface{}) (Dict, error) {
m := NewDictWithSizeHint(len(items)/2)
for i := 0; i < len(items); i += 2 {
key := items[i]
if !dictTryAssign(m, key, items[i+1]) {
return Dict{}, fmt.Errorf("pickle: loadDict: Dict: invalid key type %T", key)
}
}
return m, nil
}
func (d *Decoder) loadEmptyDict() error {
m := make(map[interface{}]interface{}, 0)
var m interface{}
if d.config.PyDict {
m = NewDict()
} else {
m = make(map[interface{}]interface{}, 0)
}
d.push(m)
return nil
}
......@@ -1218,10 +1269,14 @@ func (d *Decoder) loadSetItem() error {
switch m := m.(type) {
case map[interface{}]interface{}:
if !mapTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: invalid key type %T", k)
return fmt.Errorf("pickle: loadSetItem: map: invalid key type %T", k)
}
case Dict:
if !dictTryAssign(m, k, v) {
return fmt.Errorf("pickle: loadSetItem: Dict: invalid key type %T", k)
}
default:
return fmt.Errorf("pickle: loadSetItem: expected a map, got %T", m)
return fmt.Errorf("pickle: loadSetItem: expected a map or Dict, got %T", m)
}
return nil
}
......@@ -1234,23 +1289,31 @@ func (d *Decoder) loadSetItems() error {
if k < 1 {
return errStackUnderflow
}
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
l := d.stack[k-1]
switch m := l.(type) {
case map[interface{}]interface{}:
if (len(d.stack) - (k + 1)) % 2 != 0 {
return fmt.Errorf("pickle: loadSetItems: odd # of elements")
}
for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i]
if !mapTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: invalid key type %T", key)
return fmt.Errorf("pickle: loadSetItems: map: invalid key type %T", key)
}
}
d.stack = append(d.stack[:k-1], m)
case Dict:
for i := k + 1; i < len(d.stack); i += 2 {
key := d.stack[i]
if !dictTryAssign(m, key, d.stack[i+1]) {
return fmt.Errorf("pickle: loadSetItems: Dict: invalid key type %T", key)
}
}
default:
return fmt.Errorf("pickle: loadSetItems: expected a map, got %T", m)
return fmt.Errorf("pickle: loadSetItems: expected a map or Dict, got %T", m)
}
d.stack = append(d.stack[:k-1], l)
return nil
}
......
......@@ -7,7 +7,6 @@ import (
"fmt"
"io"
"math/big"
"reflect"
"strconv"
"strings"
"testing"
......@@ -86,6 +85,9 @@ type TestEntry struct {
strictUnicodeN bool // whether to test with StrictUnicode=n while decoding/encoding
strictUnicodeY bool // whether to test with StrictUnicode=y while decoding/encoding
pyDictN bool // whether to test with PyDict=n while decoding/encoding
pyDictY bool // ----//---- PyDict=y
}
// X, I, P0, P1, P* form a language to describe decode/encode tests:
......@@ -105,7 +107,8 @@ type TestEntry struct {
// the entry is tested under both StrictUnicode=n and StrictUnicode=y modes.
func X(name string, object interface{}, picklev ...TestPickle) TestEntry {
return TestEntry{name: name, objectIn: object, objectOut: object, picklev: picklev,
strictUnicodeN: true, strictUnicodeY: true}
strictUnicodeN: true, strictUnicodeY: true,
pyDictN: true, pyDictY: true}
}
// Xuauto is syntactic sugar to prepare one TestEntry that is tested only under StrictUnicode=n mode.
......@@ -122,6 +125,38 @@ func Xustrict(name string, object interface{}, picklev ...TestPickle) TestEntry
return x
}
// Xdgo is syntactic sugar to prepare one TestEntry that is tested only under PyDict=n mode.
func Xdgo(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.pyDictY = false
return x
}
// Xdpy is syntactic sugar to prepare one TestEntry that is tested only under PyDict=y mode.
func Xdpy(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.pyDictN = false
return x
}
// Xuauto_dgo is syntactic sugar to prepare one TestEntry that is tested only
// under StrictUnicode=n ^ pyDict=n mode.
func Xuauto_dgo(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.strictUnicodeY = false
x.pyDictY = false
return x
}
// Xuauto_dpy is syntactic sugar to prepare one TestEntry that is tested only
// under StrictUnicode=n ^ pyDict=y mode.
func Xuauto_dpy(name string, object interface{}, picklev ...TestPickle) TestEntry {
x := X(name, object, picklev...)
x.strictUnicodeY = false
x.pyDictN = false
return x
}
// Xloosy is syntactic sugar to prepare one TestEntry with loosy encoding.
func Xloosy(name string, objectIn, objectOut interface{}, picklev ...TestPickle) TestEntry {
x := X(name, objectIn, picklev...)
......@@ -129,9 +164,9 @@ func Xloosy(name string, objectIn, objectOut interface{}, picklev ...TestPickle)
return x
}
// Xloosy_uauto is like Xuauto but for Xloosy.
func Xloosy_uauto(name string, objectIn, objectOut interface{}, picklev ...TestPickle) TestEntry {
x := Xuauto(name, objectIn, picklev...)
// Xloosy_uauto_dgo is like Xuauto_dgo but for Xloosy.
func Xloosy_uauto_dgo(name string, objectIn, objectOut interface{}, picklev ...TestPickle) TestEntry {
x := Xuauto_dgo(name, objectIn, picklev...)
x.objectOut = objectOut
return x
}
......@@ -416,19 +451,46 @@ var tests = []TestEntry{
// bytearray(text, encoding); GLOBAL + BINUNICODE + TUPLE + REDUCE
I("c__builtin__\nbytearray\nq\x00(X\x13\x00\x00\x00hello\n\xc3\x90\xc2\xbc\xc3\x90\xc2\xb8\xc3\x91\xc2\x80\x01q\x01X\x07\x00\x00\x00latin-1q\x02tq\x03Rq\x04.")),
// dicts in default PyDict=n mode
Xdgo("dict({})", make(map[interface{}]interface{}),
P0("(d."), // MARK + DICT
P1_("}."), // EMPTY_DICT
I("(dp0\n.")),
Xuauto_dgo("dict({'a': '1'})", map[interface{}]interface{}{"a": "1"},
P0("(S\"a\"\nS\"1\"\nd."), // MARK + STRING + DICT
P12("(U\x01aU\x011d."), // MARK + SHORT_BINSTRING + DICT
P3("(X\x01\x00\x00\x00aX\x01\x00\x00\x001d."), // MARK + BINUNICODE + DICT
P4_("(\x8c\x01a\x8c\x011d.")), // MARK + SHORT_BINUNICODE + DICT
Xuauto_dgo("dict({'a': '1', 'b': '2'})", map[interface{}]interface{}{"a": "1", "b": "2"},
// map iteration order is not stable - test only decoding
I("(S\"a\"\nS\"1\"\nS\"b\"\nS\"2\"\nd."), // P0: MARK + STRING + DICT
I("(U\x01aU\x011U\x01bU\x012d."), // P12: MARK + SHORT_BINSTRING + DICT
// P3: MARK + BINUNICODE + DICT
I("(X\x01\x00\x00\x00aX\x01\x00\x00\x001X\x01\x00\x00\x00bX\x01\x00\x00\x002d."),
I("(\x8c\x01a\x8c\x011\x8c\x01b\x8c\x012d."), // P4_: MARK + SHORT_BINUNICODE + DICT
I("(dS'a'\nS'1'\nsS'b'\nS'2'\ns."), // MARK + DICT + STRING + SETITEM
I("}(U\x01aU\x011U\x01bU\x012u."), // EMPTY_DICT + MARK + SHORT_BINSTRING + SETITEMS
I("(dp0\nS'a'\np1\nS'1'\np2\nsS'b'\np3\nS'2'\np4\ns.")),
// dicts in PyDict=y mode
X("dict({})", make(map[interface{}]interface{}),
Xdpy("dict({})", NewDict(),
P0("(d."), // MARK + DICT
P1_("}."), // EMPTY_DICT
I("(dp0\n.")),
Xuauto("dict({'a': '1'})", map[interface{}]interface{}{"a": "1"},
Xuauto_dpy("dict({'a': '1'})", NewDictWithData("a","1"),
P0("(S\"a\"\nS\"1\"\nd."), // MARK + STRING + DICT
P12("(U\x01aU\x011d."), // MARK + SHORT_BINSTRING + DICT
P3("(X\x01\x00\x00\x00aX\x01\x00\x00\x001d."), // MARK + BINUNICODE + DICT
P4_("(\x8c\x01a\x8c\x011d.")), // MARK + SHORT_BINUNICODE + DICT
Xuauto("dict({'a': '1', 'b': '2'})", map[interface{}]interface{}{"a": "1", "b": "2"},
Xuauto_dpy("dict({'a': '1', 'b': '2'})", NewDictWithData("a","1", "b","2"),
// map iteration order is not stable - test only decoding
I("(S\"a\"\nS\"1\"\nS\"b\"\nS\"2\"\nd."), // P0: MARK + STRING + DICT
I("(U\x01aU\x011U\x01bU\x012d."), // P12: MARK + SHORT_BINSTRING + DICT
......@@ -441,6 +503,21 @@ var tests = []TestEntry{
I("}(U\x01aU\x011U\x01bU\x012u."), // EMPTY_DICT + MARK + SHORT_BINSTRING + SETITEMS
I("(dp0\nS'a'\np1\nS'1'\np2\nsS'b'\np3\nS'2'\np4\ns.")),
Xdpy("dict({123L: 0})", NewDictWithData(bigInt("123"), int64(0)),
P0("(L123L\nI0\nd."), // MARK + LONG + INT + DICT
P1("(L123L\nK\x00d."), // MARK + LONG + BININT1 + DICT
I("(\x8a\x01{K\x00d.")), // MARK + LONG1 + BININT1 + DICT
Xdpy("dict(tuple(): 0)", NewDictWithData(Tuple{}, int64(0)),
P0("((tI0\nd."), // MARK + MARK + TUPLE + INT + DICT
P1_("()K\x00d.")), // MARK + EMPTY_TUPLE + BININT1 + DICT
Xdpy("dict(tuple(1,2): 0)", NewDictWithData(Tuple{int64(1), int64(2)}, int64(0)),
P0("((I1\nI2\ntI0\nd."), // MARK + MARK + INT + INT + TUPLE + INT + DICT
P1("((K\x01K\x02tK\x00d."), // MARK + MARK + BININT1 + BININT1 + TUPLE + BININT1 + DICT
P2_("(K\x01K\x02\x86K\x00d.")), // MARK + BININT1 + BININT1 + TUPLE2 + BININT1 + DICT
Xuauto("foo.bar # global", Class{Module: "foo", Name: "bar"},
P0123("cfoo\nbar\n."), // GLOBAL
P4_("\x8c\x03foo\x8c\x03bar\x93."), // SHORT_BINUNICODE + STACK_GLOBAL
......@@ -480,9 +557,9 @@ var tests = []TestEntry{
X("LONG_BINPUT", []interface{}{int64(17)},
I("(lr0000I17\na.")),
Xuauto("graphite message1", graphiteObject1, graphitePickle1),
Xuauto("graphite message2", graphiteObject2, graphitePickle2),
Xuauto("graphite message3", graphiteObject3, graphitePickle3),
Xuauto_dgo("graphite message1", graphiteObject1, graphitePickle1),
Xuauto_dgo("graphite message2", graphiteObject2, graphitePickle2),
Xuauto_dgo("graphite message3", graphiteObject3, graphitePickle3),
Xuauto("too long line", longLine, I("V" + longLine + "\n.")),
// opcodes from protocol 4
......@@ -492,7 +569,7 @@ var tests = []TestEntry{
// loosy encode: decoding back gives another object.
// the only case where ogórek encoding is loosy is for Go struct types.
Xloosy_uauto("[]ogórek.foo{\"Qux\", 4}", []foo{{"Qux", 4}},
Xloosy_uauto_dgo("[]ogórek.foo{\"Qux\", 4}", []foo{{"Qux", 4}},
[]interface{}{map[interface{}]interface{}{"Foo": "Qux", "Bar": int64(4)}},
// MARK + STRING + INT + DICT + LIST
......@@ -519,9 +596,16 @@ type foo struct {
// protocol prefix is always automatically prepended and is always concrete.
var protoPrefixTemplate = string([]byte{opProto, 0xff})
// TestDecode verifies ogórek decoder.
func TestDecode(t *testing.T) {
for _, test := range tests {
// WithEachMode runs f under all decoding/encoding modes covered by test entry.
func (test TestEntry) WithEachMode(t *testing.T, f func(t *testing.T, decConfig DecoderConfig, encConfig EncoderConfig)) {
for _, pyDict := range []bool{false, true} {
if pyDict && !test.pyDictY {
continue
}
if !pyDict && !test.pyDictN {
continue
}
for _, strictUnicode := range []bool{false, true} {
if strictUnicode && !test.strictUnicodeY {
continue
......@@ -529,8 +613,28 @@ func TestDecode(t *testing.T) {
if !strictUnicode && !test.strictUnicodeN {
continue
}
testname := fmt.Sprintf("%s/StrictUnicode=%s", test.name, yn(strictUnicode))
t.Run(fmt.Sprintf("%s/PyDict=%s/StrictUnicode=%s", test.name, yn(pyDict), yn(strictUnicode)),
func(t *testing.T) {
decConfig := DecoderConfig{
PyDict: pyDict,
StrictUnicode: strictUnicode,
}
encConfig := EncoderConfig{
// no PyDict setting for encoder
StrictUnicode: strictUnicode,
}
f(t, decConfig, encConfig)
})
}
}
}
// TestDecode verifies ogórek decoder.
func TestDecode(t *testing.T) {
for _, test := range tests {
test.WithEachMode(t, func(t *testing.T, decConfig DecoderConfig, encConfig EncoderConfig) {
for _, pickle := range test.picklev {
if pickle.err != nil {
continue
......@@ -543,32 +647,24 @@ func TestDecode(t *testing.T) {
data := string([]byte{opProto, byte(proto)}) +
pickle.data[len(protoPrefixTemplate):]
t.Run(fmt.Sprintf("%s/%q/proto=%d", testname, data, proto), func(t *testing.T) {
testDecode(t, strictUnicode, test.objectOut, data)
t.Run(fmt.Sprintf("%q/proto=%d", data, proto), func(t *testing.T) {
testDecode(t, decConfig, test.objectOut, data)
})
}
} else {
t.Run(fmt.Sprintf("%s/%q", testname, pickle.data), func(t *testing.T) {
testDecode(t, strictUnicode, test.objectOut, pickle.data)
t.Run(fmt.Sprintf("%q", pickle.data), func(t *testing.T) {
testDecode(t, decConfig, test.objectOut, pickle.data)
})
}
}
}
})
}
}
// TestEncode verifies ogórek encoder.
func TestEncode(t *testing.T) {
for _, test := range tests {
for _, strictUnicode := range []bool{false, true} {
if strictUnicode && !test.strictUnicodeY {
continue
}
if !strictUnicode && !test.strictUnicodeN {
continue
}
testname := fmt.Sprintf("%s/StrictUnicode=%s", test.name, yn(strictUnicode))
test.WithEachMode(t, func(t *testing.T, decConfig DecoderConfig, encConfig EncoderConfig) {
alreadyTested := make(map[int]bool) // protocols we tested encode with so far
for _, pickle := range test.picklev {
for _, proto := range pickle.protov {
......@@ -578,8 +674,8 @@ func TestEncode(t *testing.T) {
dataOk = string([]byte{opProto, byte(proto)}) + dataOk
}
t.Run(fmt.Sprintf("%s/proto=%d", testname, proto), func(t *testing.T) {
testEncode(t, proto, strictUnicode, test.objectIn, test.objectOut, dataOk, pickle.err)
t.Run(fmt.Sprintf("proto=%d", proto), func(t *testing.T) {
testEncode(t, proto, encConfig, decConfig, test.objectIn, test.objectOut, dataOk, pickle.err)
})
alreadyTested[proto] = true
......@@ -592,11 +688,11 @@ func TestEncode(t *testing.T) {
continue
}
t.Run(fmt.Sprintf("%s/proto=%d(roundtrip)", testname, proto), func(t *testing.T) {
testEncode(t, proto, strictUnicode, test.objectIn, test.objectOut, "", nil)
t.Run(fmt.Sprintf("proto=%d(roundtrip)", proto), func(t *testing.T) {
testEncode(t, proto, encConfig, decConfig, test.objectIn, test.objectOut, "", nil)
})
}
}
})
}
}
......@@ -604,11 +700,9 @@ func TestEncode(t *testing.T) {
//
// It also verifies decoder robustness - via feeding it various kinds of
// corrupt data derived from input.
func testDecode(t *testing.T, strictUnicode bool, object interface{}, input string) {
func testDecode(t *testing.T, decConfig DecoderConfig, object interface{}, input string) {
newDecoder := func(r io.Reader) *Decoder {
return NewDecoderWithConfig(r, &DecoderConfig{
StrictUnicode: strictUnicode,
})
return NewDecoderWithConfig(r, &decConfig)
}
// decode(input) -> expected
......@@ -619,7 +713,7 @@ func testDecode(t *testing.T, strictUnicode bool, object interface{}, input stri
t.Error(err)
}
if !reflect.DeepEqual(v, object) {
if !deepEqual(v, object) {
t.Errorf("decode:\nhave: %#v\nwant: %#v", v, object)
}
......@@ -665,18 +759,16 @@ func testDecode(t *testing.T, strictUnicode bool, object interface{}, input stri
// encode-back tests are still performed.
//
// If errOk != nil, object encoding must produce that error.
func testEncode(t *testing.T, proto int, strictUnicode bool, object, objectDecodedBack interface{}, dataOk string, errOk error) {
func testEncode(t *testing.T, proto int, encConfig EncoderConfig, decConfig DecoderConfig, object, objectDecodedBack interface{}, dataOk string, errOk error) {
newEncoder := func(w io.Writer) *Encoder {
return NewEncoderWithConfig(w, &EncoderConfig{
Protocol: proto,
StrictUnicode: strictUnicode,
})
econf := EncoderConfig{}
econf = encConfig
econf.Protocol = proto
return NewEncoderWithConfig(w, &econf)
}
newDecoder := func(r io.Reader) *Decoder {
return NewDecoderWithConfig(r, &DecoderConfig{
StrictUnicode: strictUnicode,
})
return NewDecoderWithConfig(r, &decConfig)
}
buf := &bytes.Buffer{}
......@@ -716,9 +808,9 @@ func testEncode(t *testing.T, proto int, strictUnicode bool, object, objectDecod
if err != nil {
t.Errorf("encode -> decode -> error: %s", err)
} else {
if !reflect.DeepEqual(v, objectDecodedBack) {
if !deepEqual(v, objectDecodedBack) {
what := "identity"
if !reflect.DeepEqual(object, objectDecodedBack) {
if !deepEqual(object, objectDecodedBack) {
what = "expected object"
}
t.Errorf("encode -> decode != %s\nhave: %#v\nwant: %#v", what, v, objectDecodedBack)
......@@ -742,7 +834,7 @@ func TestDecodeMultiple(t *testing.T) {
t.Errorf("step #%v: %v", i, err)
}
if !reflect.DeepEqual(obj, objOk) {
if !deepEqual(obj, objOk) {
t.Errorf("step #%v: %q ; want %q", i, obj, objOk)
}
}
......@@ -929,7 +1021,7 @@ func TestPersistentRefs(t *testing.T) {
errExpect = "pickle: handleRef: " + e.Error()
}
if !(reflect.DeepEqual(v, expected) &&
if !(deepEqual(v, expected) &&
((err == nil && errExpect == "") || err.Error() == errExpect)) {
t.Errorf("%q: decode -> %#v, %q; want %#v, %q",
tt.input, v, err, expected, errExpect)
......@@ -955,7 +1047,7 @@ func TestPersistentRefs(t *testing.T) {
continue
}
if !reflect.DeepEqual(v, tt.expected) {
if !deepEqual(v, tt.expected) {
t.Errorf("%q: expected -> encode -> decode != identity\nhave: %#v\nwant: %#v",
tt.input, v, tt.expected)
}
......@@ -1012,7 +1104,9 @@ func BenchmarkDecode(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
buf := bytes.NewBuffer(input)
dec := NewDecoder(buf)
dec := NewDecoderWithConfig(buf, &DecoderConfig{
PyDict: true, // so that decoding e.g. {(): 0} does not fail
})
j := 0
for ; ; j++ {
......
......@@ -45,7 +45,7 @@ func TestAsInt64(t *testing.T) {
}
}
if !reflect.DeepEqual(out, tt.outOK) {
if !deepEqual(out, tt.outOK) {
t.Errorf("%T %#v -> %T %#v ; want %T %#v",
tt.in, tt.in, out, out, tt.outOK, tt.outOK)
}
......@@ -100,12 +100,12 @@ func TestAsBytesString(t *testing.T) {
serrOK = Estring(tt.in)
}
if !(bout == boutOK && reflect.DeepEqual(berr, berrOK)) {
if !(bout == boutOK && deepEqual(berr, berrOK)) {
t.Errorf("%#v: AsBytes:\nhave %#v %#v\nwant %#v %#v",
tt.in, bout, berr, boutOK, berrOK)
}
if !(sout == soutOK && reflect.DeepEqual(serr, serrOK)) {
if !(sout == soutOK && deepEqual(serr, serrOK)) {
t.Errorf("%#v: AsString:\nhave %#v %#v\nwant %#v %#v",
tt.in, sout, serr, soutOK, serrOK)
}
......
//go:build !go1.21
package ogórek
import (
"math/big"
)
func bigInt_Float64(b *big.Int) (float64, big.Accuracy) {
return new(big.Float).SetInt(b).Float64()
}
//go:build go1.21
package ogórek
import (
"math/big"
)
func bigInt_Float64(b *big.Int) (float64, big.Accuracy) {
return b.Float64()
}
//go:build !go1.19
package ogórek
import (
"hash/maphash"
)
func maphash_String(seed maphash.Seed, s string) uint64 {
var h maphash.Hash
h.SetSeed(seed)
h.WriteString(s)
return h.Sum64()
}
//go:build go1.19
package ogórek
import (
"hash/maphash"
)
func maphash_String(seed maphash.Seed, s string) uint64 {
return maphash.String(seed, s)
}
package ogórek
// Utilities that complement std reflect package.
import (
"reflect"
)
// deepEqual is like reflect.DeepEqual but also supports Dict.
//
// It is needed because reflect.DeepEqual considers two Dicts not-equal because
// each Dict is made with its own seed.
//
// XXX only top-level Dict is supported currently.
// For example comparing Dict inside list with the same won't work.
func deepEqual(a, b any) bool {
da, ok := a.(Dict)
if !ok {
return reflect.DeepEqual(a, b)
}
db, ok := b.(Dict)
if !ok {
return false // Dict != non-dict
}
if da.Len() != db.Len() {
return false
}
// XXX O(n^2) because we want to compare keys exactly and so cannot use
// db.Get(ka) because Dict.Get uses general equality that would match e.g. int == int64
eq := true
da.Iter()(func(ka, va any) bool {
keq := false
db.Iter()(func(kb, vb any) bool {
// NOTE don't use reflect.Equal(ka,kb) because it does not handle e.g. big.Int
if reflect.TypeOf(ka) == reflect.TypeOf(kb) && equal(ka,kb) {
if reflect.DeepEqual(va, vb) {
keq = true
}
return false
}
return true
})
if !keq {
eq = false
return false
}
return true
})
return eq
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment