Commit 082156ee authored by Stefan Behnel's avatar Stefan Behnel

enumerate fixes: single-statement bodies, avoid redundant deep recursion during loop optimisation

parent 131e416d
......@@ -72,6 +72,9 @@ class IterationTransform(Visitor.VisitorTransform):
def visit_ForInStatNode(self, node):
self.visitchildren(node)
return self._optimise_for_loop(node)
def _optimise_for_loop(self, node):
iterator = node.iterator.sequence
if iterator.type is Builtin.dict_type:
# like iterating over dict.keys()
......@@ -158,23 +161,30 @@ class IterationTransform(Visitor.VisitorTransform):
is_temp = counter_type.is_pyobject
)
enumerate_assignment_in_loop = Nodes.SingleAssignmentNode(
pos = enumerate_target.pos,
lhs = enumerate_target,
rhs = temp.ref(enumerate_target.pos))
loop_body = [
Nodes.SingleAssignmentNode(
pos = enumerate_target.pos,
lhs = enumerate_target,
rhs = temp.ref(enumerate_target.pos)),
Nodes.SingleAssignmentNode(
pos = enumerate_target.pos,
lhs = temp.ref(enumerate_target.pos),
rhs = inc_expression)
]
inc_statement = Nodes.SingleAssignmentNode(
pos = enumerate_target.pos,
lhs = temp.ref(enumerate_target.pos),
rhs = inc_expression)
if isinstance(node.body, Nodes.StatListNode):
node.body.stats = loop_body + node.body.stats
else:
loop_body.append(node.body)
node.body = Nodes.StatListNode(
node.body.pos,
stats = loop_body)
node.body.stats.insert(0, enumerate_assignment_in_loop)
node.body.stats.insert(1, inc_statement)
node.target = iterable_target
node.iterator.sequence = enumerate_function.arg_tuple.args[0]
# recurse into loop to check for further optimisations
node = self.visit_ForInStatNode(node)
node = self._optimise_for_loop(node)
statements = [
Nodes.SingleAssignmentNode(
......
......@@ -53,6 +53,14 @@ __doc__ = u"""
3 4
:: 3 4
>>> py_enumerate_dict({})
:: 55 99
>>> py_enumerate_dict(dict(a=1, b=2, c=3))
0 a
1 c
2 b
:: 2 b
"""
def go_py_enumerate():
......@@ -69,6 +77,13 @@ def go_c_enumerate_step():
for i,k in enumerate(range(1,7,2)):
print i, k
def py_enumerate_dict(dict d):
cdef int i = 55
k = 99
for i,k in enumerate(d):
print i, k
print u"::", i, k
def py_enumerate_break(*t):
i,k = 55,99
for i,k in enumerate(t):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment