Commit 837f5498 authored by Stefan Behnel's avatar Stefan Behnel

Optimise set iteration using _PySet_NextEntry() in CPython.

Closes #2048.
parent 9286f3d4
......@@ -33,6 +33,9 @@ Features added
* ``dict.pop()`` is optimised.
Original patch by Antoine Pitrou. (Github issue #2047)
* Iteration over sets and frozensets is optimised.
(Github issue #2048)
* Calls to builtin methods that are not specifically optimised into C-API calls
now use a cache that avoids repeated lookups of the underlying C function.
(Github issue #2054)
......
......@@ -6443,6 +6443,66 @@ class DictIterationNextNode(Node):
var.release(code)
class SetIterationNextNode(Node):
# Helper node for calling _PySet_NextEntry() inside of a WhileStatNode
# and checking the set size for changes. Created in Optimize.py.
child_attrs = ['set_obj', 'expected_size', 'pos_index_var',
'coerced_value_var', 'value_target', 'is_set_flag']
coerced_value_var = value_ref = None
def __init__(self, set_obj, expected_size, pos_index_var, value_target, is_set_flag):
Node.__init__(
self, set_obj.pos,
set_obj=set_obj,
expected_size=expected_size,
pos_index_var=pos_index_var,
value_target=value_target,
is_set_flag=is_set_flag,
is_temp=True,
type=PyrexTypes.c_bint_type)
def analyse_expressions(self, env):
from . import ExprNodes
self.set_obj = self.set_obj.analyse_types(env)
self.expected_size = self.expected_size.analyse_types(env)
self.pos_index_var = self.pos_index_var.analyse_types(env)
self.value_target = self.value_target.analyse_target_types(env)
self.value_ref = ExprNodes.TempNode(self.value_target.pos, type=PyrexTypes.py_object_type)
self.coerced_value_var = self.value_ref.coerce_to(self.value_target.type, env)
self.is_set_flag = self.is_set_flag.analyse_types(env)
return self
def generate_function_definitions(self, env, code):
self.set_obj.generate_function_definitions(env, code)
def generate_execution_code(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("set_iter", "Optimize.c"))
self.set_obj.generate_evaluation_code(code)
value_ref = self.value_ref
value_ref.allocate(code)
result_temp = code.funcstate.allocate_temp(PyrexTypes.c_int_type, False)
code.putln("%s = __Pyx_set_iter_next(%s, %s, &%s, &%s, %s);" % (
result_temp,
self.set_obj.py_result(),
self.expected_size.result(),
self.pos_index_var.result(),
value_ref.result(),
self.is_set_flag.result()
))
code.putln("if (unlikely(%s == 0)) break;" % result_temp)
code.putln(code.error_goto_if("%s == -1" % result_temp, self.pos))
code.funcstate.release_temp(result_temp)
# evaluate all coercions before the assignments
code.put_gotref(value_ref.result())
self.coerced_value_var.generate_evaluation_code(code)
self.value_target.generate_assignment_code(self.coerced_value_var, code)
value_ref.release(code)
def ForStatNode(pos, **kw):
if 'iterator' in kw:
if kw['iterator'].is_async:
......
......@@ -193,10 +193,15 @@ class IterationTransform(Visitor.EnvTransform):
if annotation.is_subscript:
annotation = annotation.base # container base type
# FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
if annotation.is_name and annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
annotation_type = Builtin.dict_type
elif annotation.is_name and annotation.name == 'Dict':
annotation_type = Builtin.dict_type
if annotation.is_name:
if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
annotation_type = Builtin.dict_type
elif annotation.name == 'Dict':
annotation_type = Builtin.dict_type
if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
annotation_type = Builtin.set_type
elif annotation.name in ('Set', 'FrozenSet'):
annotation_type = Builtin.set_type
if Builtin.dict_type in (iterator.type, annotation_type):
# like iterating over dict.keys()
......@@ -206,6 +211,13 @@ class IterationTransform(Visitor.EnvTransform):
return self._transform_dict_iteration(
node, dict_obj=iterator, method=None, keys=True, values=False)
if (Builtin.set_type in (iterator.type, annotation_type) or
Builtin.frozenset_type in (iterator.type, annotation_type)):
if reversed:
# CPython raises an error here: not a sequence
return node
return self._transform_set_iteration(node, iterator)
# C array (slice) iteration?
if iterator.type.is_ptr or iterator.type.is_array:
return self._transform_carray_iteration(node, iterator, reversed=reversed)
......@@ -968,6 +980,85 @@ class IterationTransform(Visitor.EnvTransform):
PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None),
])
PySet_Iterator_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("set", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("is_set", PyrexTypes.c_int_type, None),
PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
PyrexTypes.CFuncTypeArg("p_is_set", PyrexTypes.c_int_ptr_type, None),
])
def _transform_set_iteration(self, node, set_obj):
temps = []
temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
temps.append(temp)
set_temp = temp.ref(set_obj.pos)
temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
temps.append(temp)
pos_temp = temp.ref(node.pos)
if isinstance(node.body, Nodes.StatListNode):
body = node.body
else:
body = Nodes.StatListNode(pos = node.body.pos,
stats = [node.body])
# keep original length to guard against set modification
set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
temps.append(set_len_temp)
set_len_temp_addr = ExprNodes.AmpersandNode(
node.pos, operand=set_len_temp.ref(set_obj.pos),
type=PyrexTypes.c_ptr_type(set_len_temp.type))
temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
temps.append(temp)
is_set_temp = temp.ref(node.pos)
is_set_temp_addr = ExprNodes.AmpersandNode(
node.pos, operand=is_set_temp,
type=PyrexTypes.c_ptr_type(temp.type))
value_target = node.target
iter_next_node = Nodes.SetIterationNextNode(
set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp)
iter_next_node = iter_next_node.analyse_expressions(self.current_env())
body.stats[0:0] = [iter_next_node]
def flag_node(value):
value = value and 1 or 0
return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
result_code = [
Nodes.SingleAssignmentNode(
node.pos,
lhs=pos_temp,
rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)),
Nodes.SingleAssignmentNode(
set_obj.pos,
lhs=set_temp,
rhs=ExprNodes.PythonCapiCallNode(
set_obj.pos,
"__Pyx_set_iterator",
self.PySet_Iterator_func_type,
utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"),
args=[set_obj, flag_node(set_obj.type is Builtin.set_type),
set_len_temp_addr, is_set_temp_addr,
],
is_temp=True,
)),
Nodes.WhileStatNode(
node.pos,
condition=None,
body=body,
else_clause=node.else_clause,
)
]
return UtilNodes.TempsBlockNode(
node.pos, temps=temps,
body=Nodes.StatListNode(
node.pos,
stats = result_code
))
class SwitchTransform(Visitor.EnvTransform):
"""
......
......@@ -421,6 +421,70 @@ static CYTHON_INLINE int __Pyx_dict_iter_next(
}
/////////////// set_iter.proto ///////////////
static CYTHON_INLINE PyObject* __Pyx_set_iterator(PyObject* iterable, int is_set,
Py_ssize_t* p_orig_length, int* p_source_is_set); /*proto*/
static CYTHON_INLINE int __Pyx_set_iter_next(
PyObject* iter_obj, Py_ssize_t orig_length,
Py_ssize_t* ppos, PyObject **value,
int source_is_set); /*proto*/
/////////////// set_iter ///////////////
//@requires: ObjectHandling.c::IterFinish
static CYTHON_INLINE PyObject* __Pyx_set_iterator(PyObject* iterable, int is_set,
Py_ssize_t* p_orig_length, int* p_source_is_set) {
#if CYTHON_COMPILING_IN_CPYTHON
is_set = is_set || likely(PySet_CheckExact(iterable) || PyFrozenSet_CheckExact(iterable));
*p_source_is_set = is_set;
if (unlikely(!is_set))
return PyObject_GetIter(iterable);
*p_orig_length = PySet_Size(iterable);
Py_INCREF(iterable);
return iterable;
#else
(void)is_set;
*p_source_is_set = 0;
*p_orig_length = 0;
return PyObject_GetIter(iterable);
#endif
}
static CYTHON_INLINE int __Pyx_set_iter_next(
PyObject* iter_obj, Py_ssize_t orig_length,
Py_ssize_t* ppos, PyObject **value,
int source_is_set) {
if (!CYTHON_COMPILING_IN_CPYTHON || unlikely(!source_is_set)) {
*value = PyIter_Next(iter_obj);
if (unlikely(!*value)) {
return __Pyx_IterFinish();
}
(void)orig_length;
(void)ppos;
return 0;
}
#if CYTHON_COMPILING_IN_CPYTHON
if (unlikely(PySet_GET_SIZE(iter_obj) != orig_length)) {
PyErr_SetString(
PyExc_RuntimeError,
"set changed size during iteration");
return -1;
}
{
Py_hash_t hash;
int ret = _PySet_NextEntry(iter_obj, ppos, value, &hash);
// CPython does not raise errors here, only if !isinstance(iter_obj, set/frozenset)
assert (ret != -1);
if (likely(ret)) {
Py_INCREF(*value);
return 1;
}
return 0;
}
#endif
}
/////////////// py_set_discard_unhashable ///////////////
static int __Pyx_PySet_DiscardUnhashable(PyObject *set, PyObject *key) {
......
# mode: run
# tag: set
cimport cython
@cython.test_assert_path_exists(
"//SetIterationNextNode",
)
def set_iter_comp(set s):
"""
>>> s = set([1, 2, 3])
>>> sorted(set_iter_comp(s))
[1, 2, 3]
"""
return [x for x in s]
@cython.test_assert_path_exists(
"//SetIterationNextNode",
)
def set_iter_comp_typed(set s):
"""
>>> s = set([1, 2, 3])
>>> sorted(set_iter_comp(s))
[1, 2, 3]
"""
cdef int x
return [x for x in s]
@cython.test_assert_path_exists(
"//SetIterationNextNode",
)
def frozenset_iter_comp(frozenset s):
"""
>>> s = frozenset([1, 2, 3])
>>> sorted(frozenset_iter_comp(s))
[1, 2, 3]
"""
return [x for x in s]
@cython.test_assert_path_exists(
"//SetIterationNextNode",
)
def set_iter_comp_frozenset(set s):
"""
>>> s = set([1, 2, 3])
>>> sorted(set_iter_comp(s))
[1, 2, 3]
"""
return [x for x in frozenset(s)]
@cython.test_assert_path_exists(
"//SetIterationNextNode",
)
def set_iter_modify(set s, int value):
"""
>>> s = set([1, 2, 3])
>>> sorted(set_iter_modify(s, 1))
[1, 2, 3]
>>> sorted(set_iter_modify(s, 2))
[1, 2, 3]
>>> sorted(set_iter_modify(s, 3))
[1, 2, 3]
>>> sorted(set_iter_modify(s, 4))
Traceback (most recent call last):
RuntimeError: set changed size during iteration
"""
for x in s:
s.add(value)
return s
@cython.test_fail_if_path_exists(
"//SimpleCallNode//NameNode[@name = 'enumerate']",
)
@cython.test_assert_path_exists(
"//AddNode",
"//SetIterationNextNode",
)
def set_iter_enumerate(set s):
"""
>>> s = set(['a', 'b', 'c'])
>>> numbers, values = set_iter_enumerate(s)
>>> sorted(numbers)
[0, 1, 2]
>>> sorted(values)
['a', 'b', 'c']
"""
cdef int i
numbers = []
values = []
for i, x in enumerate(s):
numbers.append(i)
values.append(x)
return numbers, values
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment