Commit 10c8bc1a authored by Stefan Behnel's avatar Stefan Behnel

rewrite dict.setdefault() optimisation to fix double hashing in Py3

parent de44bcf5
......@@ -2294,6 +2294,7 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
])
def _handle_simple_method_dict_setdefault(self, node, args, is_unbound_method):
......@@ -2304,12 +2305,22 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
elif len(args) != 3:
self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
return node
key_type = args[1].type
if key_type.is_builtin_type:
is_safe_type = int(key_type.name in
'str bytes unicode float int long bool')
elif key_type is PyrexTypes.py_object_type:
is_safe_type = -1 # don't know
else:
is_safe_type = 0 # definitely not
args.append(ExprNodes.IntNode(
node.pos, value=is_safe_type, constant_result=is_safe_type))
return self._substitute_method_call(
node, "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
'setdefault', is_unbound_method, args,
may_return_none = True,
utility_code = load_c_utility('dict_setdefault'))
may_return_none=True,
utility_code=load_c_utility('dict_setdefault'))
### unicode type methods
......
......@@ -315,25 +315,26 @@ static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObjec
/////////////// dict_setdefault.proto ///////////////
static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value); /*proto*/
static CYTHON_INLINE PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value, int is_safe_type); /*proto*/
/////////////// dict_setdefault ///////////////
static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value) {
static CYTHON_INLINE PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value, int is_safe_type) {
PyObject* value;
if (is_safe_type == 1 || (is_safe_type == -1 &&
(PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)))) {
/* these presumably have repeatably safe and fast hash functions */
#if PY_MAJOR_VERSION >= 3
value = PyDict_GetItemWithError(d, key);
if (unlikely(!value)) {
if (unlikely(PyErr_Occurred()))
return NULL;
if (unlikely(PyDict_SetItem(d, key, default_value) == -1))
return NULL;
value = default_value;
}
Py_INCREF(value);
value = PyDict_GetItemWithError(d, key);
if (unlikely(!value)) {
if (unlikely(PyErr_Occurred()))
return NULL;
if (unlikely(PyDict_SetItem(d, key, default_value) == -1))
return NULL;
value = default_value;
}
Py_INCREF(value);
#else
if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
/* these presumably have safe hash functions */
value = PyDict_GetItem(d, key);
if (unlikely(!value)) {
if (unlikely(PyDict_SetItem(d, key, default_value) == -1))
......@@ -341,10 +342,10 @@ static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *d
value = default_value;
}
Py_INCREF(value);
#endif
} else {
value = PyObject_CallMethodObjArgs(d, PYIDENT("setdefault"), key, default_value, NULL);
}
#endif
return value;
}
......
......@@ -11,6 +11,17 @@ class Hashable(object):
def __eq__(self, other):
return isinstance(other, Hashable)
class CountedHashable(object):
def __init__(self):
self.hash_count = 0
self.eq_count = 0
def __hash__(self):
self.hash_count += 1
return 42
def __eq__(self, other):
self.eq_count += 1
return id(self) == id(other)
@cython.test_fail_if_path_exists('//AttributeNode')
@cython.test_assert_path_exists('//PythonCapiCallNode')
@cython.locals(d=dict)
......@@ -36,6 +47,17 @@ def setdefault1(d, key):
>>> len(d)
2
>>> d[Hashable()]
>>> hashed1 = CountedHashable()
>>> y = {hashed1: 5}
>>> hashed2 = CountedHashable()
>>> setdefault1(y, hashed2)
>>> hashed1.hash_count
1
>>> hashed2.hash_count
1
>>> hashed1.eq_count + hashed2.eq_count
1
"""
return d.setdefault(key)
......@@ -72,5 +94,17 @@ def setdefault2(d, key, value):
3
>>> d[Hashable()]
55
>>> hashed1 = CountedHashable()
>>> y = {hashed1: 5}
>>> hashed2 = CountedHashable()
>>> setdefault2(y, hashed2, [])
[]
>>> hashed1.hash_count
1
>>> hashed2.hash_count
1
>>> hashed1.eq_count + hashed2.eq_count
1
"""
return d.setdefault(key, value)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment