Commit 822ac414 authored by scoder's avatar scoder Committed by GitHub

Merge pull request #2044 from pitrou/optimize_set_discard_remove

Issue #2042: add optimized versions of `x in set`, set.remove() and set.discard()
parents f17ece06 17b50ea0
...@@ -328,7 +328,10 @@ builtin_types_table = [ ...@@ -328,7 +328,10 @@ builtin_types_table = [
("set", "PySet_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"), ("set", "PySet_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"),
BuiltinMethod("clear", "T", "r", "PySet_Clear"), BuiltinMethod("clear", "T", "r", "PySet_Clear"),
# discard() and remove() have a special treatment for unhashable values # discard() and remove() have a special treatment for unhashable values
# BuiltinMethod("discard", "TO", "r", "PySet_Discard"), BuiltinMethod("discard", "TO", "r", "__Pyx_PySet_Discard",
utility_code=UtilityCode.load("py_set_discard", "Optimize.c")),
BuiltinMethod("remove", "TO", "r", "__Pyx_PySet_Remove",
utility_code=UtilityCode.load("py_set_remove", "Optimize.c")),
# update is actually variadic (see Github issue #1645) # update is actually variadic (see Github issue #1645)
# BuiltinMethod("update", "TO", "r", "__Pyx_PySet_Update", # BuiltinMethod("update", "TO", "r", "__Pyx_PySet_Update",
# utility_code=UtilityCode.load_cached("PySet_Update", "Builtins.c")), # utility_code=UtilityCode.load_cached("PySet_Update", "Builtins.c")),
......
...@@ -12197,6 +12197,11 @@ class CmpNode(object): ...@@ -12197,6 +12197,11 @@ class CmpNode(object):
self.special_bool_cmp_utility_code = UtilityCode.load_cached("PyDictContains", "ObjectHandling.c") self.special_bool_cmp_utility_code = UtilityCode.load_cached("PyDictContains", "ObjectHandling.c")
self.special_bool_cmp_function = "__Pyx_PyDict_ContainsTF" self.special_bool_cmp_function = "__Pyx_PyDict_ContainsTF"
return True return True
elif self.operand2.type is Builtin.set_type:
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
self.special_bool_cmp_utility_code = UtilityCode.load_cached("PySetContains", "ObjectHandling.c")
self.special_bool_cmp_function = "__Pyx_PySet_ContainsTF"
return True
elif self.operand2.type is Builtin.unicode_type: elif self.operand2.type is Builtin.unicode_type:
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable") self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
self.special_bool_cmp_utility_code = UtilityCode.load_cached("PyUnicodeContains", "StringTools.c") self.special_bool_cmp_utility_code = UtilityCode.load_cached("PyUnicodeContains", "StringTools.c")
......
...@@ -1035,6 +1035,36 @@ static CYTHON_INLINE int __Pyx_PyDict_ContainsTF(PyObject* item, PyObject* dict, ...@@ -1035,6 +1035,36 @@ static CYTHON_INLINE int __Pyx_PyDict_ContainsTF(PyObject* item, PyObject* dict,
return unlikely(result < 0) ? result : (result == (eq == Py_EQ)); return unlikely(result < 0) ? result : (result == (eq == Py_EQ));
} }
/////////////// PySetContains.proto ///////////////
static CYTHON_INLINE int __Pyx_PySet_ContainsTF(PyObject* key, PyObject* set, int eq); /* proto */
/////////////// PySetContains ///////////////
static int __Pyx_PySet_ContainsUnhashable(PyObject *set, PyObject *key) {
int result = -1;
if (PySet_Check(key) && PyErr_ExceptionMatches(PyExc_TypeError)) {
/* Convert key to frozenset */
PyObject *tmpkey;
PyErr_Clear();
tmpkey = PyFrozenSet_New(key);
if (tmpkey != NULL) {
result = PySet_Contains(set, tmpkey);
Py_DECREF(tmpkey);
}
}
return result;
}
static CYTHON_INLINE int __Pyx_PySet_ContainsTF(PyObject* key, PyObject* set, int eq) {
int result = PySet_Contains(set, key);
if (unlikely(result < 0)) {
result = __Pyx_PySet_ContainsUnhashable(set, key);
}
return unlikely(result < 0) ? result : (result == (eq == Py_EQ));
}
/////////////// PySequenceContains.proto /////////////// /////////////// PySequenceContains.proto ///////////////
static CYTHON_INLINE int __Pyx_PySequence_ContainsTF(PyObject* item, PyObject* seq, int eq) { static CYTHON_INLINE int __Pyx_PySequence_ContainsTF(PyObject* item, PyObject* seq, int eq) {
......
...@@ -397,6 +397,76 @@ static CYTHON_INLINE int __Pyx_dict_iter_next( ...@@ -397,6 +397,76 @@ static CYTHON_INLINE int __Pyx_dict_iter_next(
} }
/////////////// py_set_discard_unhashable.proto ///////////////
static int __Pyx_PySet_DiscardUnhashable(PyObject *set, PyObject *key); /* proto */
/////////////// py_set_discard_unhashable ///////////////
static int __Pyx_PySet_DiscardUnhashable(PyObject *set, PyObject *key) {
PyObject *tmpkey;
int rv;
if (!PySet_Check(key) || !PyErr_ExceptionMatches(PyExc_TypeError))
return -1;
PyErr_Clear();
tmpkey = PyFrozenSet_New(key);
if (tmpkey == NULL)
return -1;
rv = PySet_Discard(set, tmpkey);
Py_DECREF(tmpkey);
return rv;
}
/////////////// py_set_discard.proto ///////////////
static CYTHON_INLINE int __Pyx_PySet_Discard(PyObject *set, PyObject *key); /*proto*/
/////////////// py_set_discard ///////////////
//@requires: py_set_discard_unhashable
static CYTHON_INLINE int __Pyx_PySet_Discard(PyObject *set, PyObject *key) {
int rv;
rv = PySet_Discard(set, key);
/* Convert *key* to frozenset if necessary */
if (unlikely(rv < 0)) {
return __Pyx_PySet_DiscardUnhashable(set, key);
}
return rv;
}
/////////////// py_set_remove.proto ///////////////
static CYTHON_INLINE int __Pyx_PySet_Remove(PyObject *set, PyObject *key); /*proto*/
/////////////// py_set_remove ///////////////
//@requires: py_set_discard_unhashable
static CYTHON_INLINE int __Pyx_PySet_Remove(PyObject *set, PyObject *key) {
int rv;
rv = PySet_Discard(set, key);
/* Convert *key* to frozenset if necessary */
if (unlikely(rv < 0)) {
rv = __Pyx_PySet_DiscardUnhashable(set, key);
}
if (rv == 0) {
/* Not found */
PyObject *tup;
tup = PyTuple_Pack(1, key);
if (!tup)
return -1;
PyErr_SetObject(PyExc_KeyError, tup);
Py_DECREF(tup);
return -1;
}
return 0;
}
/////////////// unicode_iter.proto /////////////// /////////////// unicode_iter.proto ///////////////
static CYTHON_INLINE int __Pyx_init_unicode_iteration( static CYTHON_INLINE int __Pyx_init_unicode_iteration(
......
...@@ -65,6 +65,32 @@ def test_set_add(): ...@@ -65,6 +65,32 @@ def test_set_add():
return s1 return s1
def test_set_contains(v):
"""
>>> test_set_contains(1)
True
>>> test_set_contains(2)
False
>>> test_set_contains(frozenset([1, 2, 3]))
True
>>> test_set_contains(frozenset([1, 2]))
False
>>> test_set_contains(set([1, 2, 3]))
True
>>> test_set_contains(set([1, 2]))
False
>>> try: test_set_contains([1, 2])
... except TypeError: pass
... else: print("NOT RAISED!")
"""
cdef set s1
s1 = set()
s1.add(1)
s1.add('a')
s1.add(frozenset([1, 2, 3]))
return v in s1
def test_set_update(v=None): def test_set_update(v=None):
""" """
>>> type(test_set_update()) is set >>> type(test_set_update()) is set
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment