Commit 10c8bc1a authored by Stefan Behnel's avatar Stefan Behnel

rewrite dict.setdefault() optimisation to fix double hashing in Py3

parent de44bcf5
...@@ -2294,6 +2294,7 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform): ...@@ -2294,6 +2294,7 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
]) ])
def _handle_simple_method_dict_setdefault(self, node, args, is_unbound_method): def _handle_simple_method_dict_setdefault(self, node, args, is_unbound_method):
...@@ -2304,12 +2305,22 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform): ...@@ -2304,12 +2305,22 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
elif len(args) != 3: elif len(args) != 3:
self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3") self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
return node return node
key_type = args[1].type
if key_type.is_builtin_type:
is_safe_type = int(key_type.name in
'str bytes unicode float int long bool')
elif key_type is PyrexTypes.py_object_type:
is_safe_type = -1 # don't know
else:
is_safe_type = 0 # definitely not
args.append(ExprNodes.IntNode(
node.pos, value=is_safe_type, constant_result=is_safe_type))
return self._substitute_method_call( return self._substitute_method_call(
node, "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type, node, "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
'setdefault', is_unbound_method, args, 'setdefault', is_unbound_method, args,
may_return_none = True, may_return_none=True,
utility_code = load_c_utility('dict_setdefault')) utility_code=load_c_utility('dict_setdefault'))
### unicode type methods ### unicode type methods
......
...@@ -315,25 +315,26 @@ static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObjec ...@@ -315,25 +315,26 @@ static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObjec
/////////////// dict_setdefault.proto /////////////// /////////////// dict_setdefault.proto ///////////////
static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value); /*proto*/ static CYTHON_INLINE PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value, int is_safe_type); /*proto*/
/////////////// dict_setdefault /////////////// /////////////// dict_setdefault ///////////////
static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value) { static CYTHON_INLINE PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *default_value, int is_safe_type) {
PyObject* value; PyObject* value;
if (is_safe_type == 1 || (is_safe_type == -1 &&
(PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)))) {
/* these presumably have repeatably safe and fast hash functions */
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
value = PyDict_GetItemWithError(d, key); value = PyDict_GetItemWithError(d, key);
if (unlikely(!value)) { if (unlikely(!value)) {
if (unlikely(PyErr_Occurred())) if (unlikely(PyErr_Occurred()))
return NULL; return NULL;
if (unlikely(PyDict_SetItem(d, key, default_value) == -1)) if (unlikely(PyDict_SetItem(d, key, default_value) == -1))
return NULL; return NULL;
value = default_value; value = default_value;
} }
Py_INCREF(value); Py_INCREF(value);
#else #else
if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
/* these presumably have safe hash functions */
value = PyDict_GetItem(d, key); value = PyDict_GetItem(d, key);
if (unlikely(!value)) { if (unlikely(!value)) {
if (unlikely(PyDict_SetItem(d, key, default_value) == -1)) if (unlikely(PyDict_SetItem(d, key, default_value) == -1))
...@@ -341,10 +342,10 @@ static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *d ...@@ -341,10 +342,10 @@ static PyObject *__Pyx_PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *d
value = default_value; value = default_value;
} }
Py_INCREF(value); Py_INCREF(value);
#endif
} else { } else {
value = PyObject_CallMethodObjArgs(d, PYIDENT("setdefault"), key, default_value, NULL); value = PyObject_CallMethodObjArgs(d, PYIDENT("setdefault"), key, default_value, NULL);
} }
#endif
return value; return value;
} }
......
...@@ -11,6 +11,17 @@ class Hashable(object): ...@@ -11,6 +11,17 @@ class Hashable(object):
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, Hashable) return isinstance(other, Hashable)
class CountedHashable(object):
def __init__(self):
self.hash_count = 0
self.eq_count = 0
def __hash__(self):
self.hash_count += 1
return 42
def __eq__(self, other):
self.eq_count += 1
return id(self) == id(other)
@cython.test_fail_if_path_exists('//AttributeNode') @cython.test_fail_if_path_exists('//AttributeNode')
@cython.test_assert_path_exists('//PythonCapiCallNode') @cython.test_assert_path_exists('//PythonCapiCallNode')
@cython.locals(d=dict) @cython.locals(d=dict)
...@@ -36,6 +47,17 @@ def setdefault1(d, key): ...@@ -36,6 +47,17 @@ def setdefault1(d, key):
>>> len(d) >>> len(d)
2 2
>>> d[Hashable()] >>> d[Hashable()]
>>> hashed1 = CountedHashable()
>>> y = {hashed1: 5}
>>> hashed2 = CountedHashable()
>>> setdefault1(y, hashed2)
>>> hashed1.hash_count
1
>>> hashed2.hash_count
1
>>> hashed1.eq_count + hashed2.eq_count
1
""" """
return d.setdefault(key) return d.setdefault(key)
...@@ -72,5 +94,17 @@ def setdefault2(d, key, value): ...@@ -72,5 +94,17 @@ def setdefault2(d, key, value):
3 3
>>> d[Hashable()] >>> d[Hashable()]
55 55
>>> hashed1 = CountedHashable()
>>> y = {hashed1: 5}
>>> hashed2 = CountedHashable()
>>> setdefault2(y, hashed2, [])
[]
>>> hashed1.hash_count
1
>>> hashed2.hash_count
1
>>> hashed1.eq_count + hashed2.eq_count
1
""" """
return d.setdefault(key, value) return d.setdefault(key, value)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment