Commit 1f04fc75 authored by Stefan Behnel's avatar Stefan Behnel

slightly faster for-in iteration for arbitrary iterables and generators by...

slightly faster for-in iteration for arbitrary iterables and generators by caching 'tp_iternext' function pointer
parent fd236aa1
...@@ -1789,6 +1789,7 @@ class IteratorNode(ExprNode): ...@@ -1789,6 +1789,7 @@ class IteratorNode(ExprNode):
# sequence ExprNode # sequence ExprNode
type = py_object_type type = py_object_type
iter_func_ptr = None
subexprs = ['sequence'] subexprs = ['sequence']
...@@ -1814,6 +1815,11 @@ class IteratorNode(ExprNode): ...@@ -1814,6 +1815,11 @@ class IteratorNode(ExprNode):
def release_counter_temp(self, code): def release_counter_temp(self, code):
code.funcstate.release_temp(self.counter_cname) code.funcstate.release_temp(self.counter_cname)
_func_iternext_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None),
]))
def generate_result_code(self, code): def generate_result_code(self, code):
if self.sequence.type.is_array or self.sequence.type.is_ptr: if self.sequence.type.is_array or self.sequence.type.is_ptr:
raise InternalError("for in carray slice not transformed") raise InternalError("for in carray slice not transformed")
...@@ -1841,9 +1847,17 @@ class IteratorNode(ExprNode): ...@@ -1841,9 +1847,17 @@ class IteratorNode(ExprNode):
self.sequence.py_result(), self.sequence.py_result(),
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
self.iter_func_ptr = code.funcstate.allocate_temp(self._func_iternext_type, manage_ref=False)
code.putln("%s = Py_TYPE(%s)->tp_iternext;" % (self.iter_func_ptr, self.py_result()))
if may_be_a_sequence: if may_be_a_sequence:
code.putln("}") code.putln("}")
def free_temps(self, code):
if self.iter_func_ptr:
code.funcstate.release_temp(self.iter_func_ptr)
self.iter_func_ptr = None
ExprNode.free_temps(self, code)
class NextNode(AtomicExprNode): class NextNode(AtomicExprNode):
# Used as part of for statement implementation. # Used as part of for statement implementation.
...@@ -1851,7 +1865,7 @@ class NextNode(AtomicExprNode): ...@@ -1851,7 +1865,7 @@ class NextNode(AtomicExprNode):
# Created during analyse_types phase. # Created during analyse_types phase.
# The iterator is not owned by this node. # The iterator is not owned by this node.
# #
# iterator ExprNode # iterator IteratorNode
type = py_object_type type = py_object_type
...@@ -1896,12 +1910,26 @@ class NextNode(AtomicExprNode): ...@@ -1896,12 +1910,26 @@ class NextNode(AtomicExprNode):
if len(type_checks) == 1: if len(type_checks) == 1:
return return
code.putln("{") code.putln("{")
if self.iterator.iter_func_ptr:
code.putln(
"%s = %s(%s);" % (
self.result(),
self.iterator.iter_func_ptr,
self.iterator.py_result()))
code.putln("if (unlikely(!%s)) {" % self.result())
code.putln("if (PyErr_Occurred()) {")
code.putln("if (likely(PyErr_ExceptionMatches(PyExc_StopIteration))) PyErr_Clear();")
code.putln("else %s" % code.error_goto(self.pos))
code.putln("}")
code.putln("break;")
code.putln("}")
else:
code.putln( code.putln(
"%s = PyIter_Next(%s);" % ( "%s = PyIter_Next(%s);" % (
self.result(), self.result(),
self.iterator.py_result())) self.iterator.py_result()))
code.putln( code.putln(
"if (!%s) {" % "if (unlikely(!%s)) {" %
self.result()) self.result())
code.putln(code.error_goto_if_PyErr(self.pos)) code.putln(code.error_goto_if_PyErr(self.pos))
code.putln("break;") code.putln("break;")
......
# mode: run
# tag: forin
import sys
def for_in_pyiter(it):
"""
>>> for_in_pyiter(Iterable(5))
[0, 1, 2, 3, 4]
"""
l = []
for item in it:
l.append(item)
return l
def for_in_list():
"""
>>> for_in_pyiter([1,2,3,4,5])
[1, 2, 3, 4, 5]
"""
class Iterable(object):
"""
>>> for_in_pyiter(Iterable(5))
[0, 1, 2, 3, 4]
"""
def __init__(self, N):
self.N = N
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i < self.N:
i = self.i
self.i += 1
return i
raise StopIteration
next = __next__
if sys.version_info[0] >= 3:
class NextReplacingIterable(object):
def __init__(self):
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i > 5:
raise StopIteration
self.i += 1
self.__next__ = self.next2
return 1
def next2(self):
self.__next__ = self.next3
return 2
def next3(self):
del self.__next__
raise StopIteration
else:
class NextReplacingIterable(object):
def __init__(self):
self.i = 0
def __iter__(self):
return self
def next(self):
if self.i > 5:
raise StopIteration
self.i += 1
self.next = self.next2
return 1
def next2(self):
self.next = self.next3
return 2
def next3(self):
del self.next
raise StopIteration
def for_in_next_replacing_iter():
"""
>>> for_in_pyiter(NextReplacingIterable())
[1, 1, 1, 1, 1, 1]
"""
def for_in_gen(N):
"""
>>> for_in_pyiter(for_in_gen(10))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
for i in xrange(N):
yield i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment