Commit bea0d457 authored by Stefan Behnel's avatar Stefan Behnel

enumerate optimisation (#316)

parent bd1645a8
......@@ -99,6 +99,13 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_dict_iteration(
node, dict_obj, keys, values)
# enumerate() ?
if iterator.self is None and \
isinstance(function, ExprNodes.NameNode) and \
function.entry.is_builtin and \
function.name == 'enumerate':
return self._transform_enumerate_iteration(node, iterator)
# range() iteration?
if Options.convert_range and node.target.type.is_int:
if iterator.self is None and \
......@@ -109,6 +116,81 @@ class IterationTransform(Visitor.VisitorTransform):
return node
def _transform_enumerate_iteration(self, node, enumerate_function):
args = enumerate_function.arg_tuple.args
if len(args) == 0:
error(enumerate_function.pos,
"enumerate() requires an iterable argument")
return node
elif len(args) > 1:
error(enumerate_function.pos,
"enumerate() takes at most 1 argument")
return node
if not node.target.is_sequence_constructor:
# leave this untouched for now
return node
targets = node.target.args
if len(targets) != 2:
# leave this untouched for now
return node
if not isinstance(targets[0], ExprNodes.NameNode):
# leave this untouched for now
return node
enumerate_target, iterable_target = targets
counter_type = enumerate_target.type
if not counter_type.is_pyobject and not counter_type.is_int:
# nothing we can do here, I guess
return node
temp = UtilNodes.TempHandle(counter_type)
init_val = ExprNodes.IntNode(enumerate_function.pos, value='0',
type=counter_type)
inc_expression = ExprNodes.AddNode(
enumerate_function.pos,
operand1 = temp.ref(enumerate_target.pos),
operand2 = ExprNodes.IntNode(node.pos, value='1',
type=counter_type),
operator = '+',
type = counter_type,
is_temp = counter_type.is_pyobject
)
enumerate_assignment_in_loop = Nodes.SingleAssignmentNode(
pos = enumerate_target.pos,
lhs = enumerate_target,
rhs = temp.ref(enumerate_target.pos))
inc_statement = Nodes.SingleAssignmentNode(
pos = enumerate_target.pos,
lhs = temp.ref(enumerate_target.pos),
rhs = inc_expression)
node.body.stats.insert(0, enumerate_assignment_in_loop)
node.body.stats.insert(1, inc_statement)
node.target = iterable_target
node.iterator.sequence = enumerate_function.arg_tuple.args[0]
# recurse into loop to check for further optimisations
node = self.visit_ForInStatNode(node)
statements = [
Nodes.SingleAssignmentNode(
pos = enumerate_target.pos,
lhs = temp.ref(enumerate_target.pos),
rhs = init_val),
node
]
return UtilNodes.TempsBlockNode(
node.pos, temps=[temp],
body=Nodes.StatListNode(
node.pos,
stats = statements
))
def _transform_range_iteration(self, node, range_function):
args = range_function.arg_tuple.args
if len(args) < 3:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment