Commit 2e3a306d authored by Stefan Behnel's avatar Stefan Behnel

implement two-args enumerate() with counter start value

parent 8982a94c
......@@ -527,9 +527,9 @@ class IterationTransform(Visitor.VisitorTransform):
error(enumerate_function.pos,
"enumerate() requires an iterable argument")
return node
elif len(args) > 1:
elif len(args) > 2:
error(enumerate_function.pos,
"enumerate() takes at most 1 argument")
"enumerate() takes at most 2 arguments")
return node
if not node.target.is_sequence_constructor:
......@@ -550,10 +550,15 @@ class IterationTransform(Visitor.VisitorTransform):
# nothing we can do here, I guess
return node
temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
value='0',
type=counter_type,
constant_result=0))
if len(args) == 2:
start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_scope)
else:
start = ExprNodes.IntNode(enumerate_function.pos,
value='0',
type=counter_type,
constant_result=0)
temp = UtilNodes.LetRefNode(start)
inc_expression = ExprNodes.AddNode(
enumerate_function.pos,
operand1 = temp,
......@@ -586,7 +591,7 @@ class IterationTransform(Visitor.VisitorTransform):
node.target = iterable_target
node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
node.iterator.sequence = enumerate_function.arg_tuple.args[0]
node.iterator.sequence = args[0]
# recurse into loop to check for further optimisations
return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
......
......@@ -14,6 +14,18 @@ def go_py_enumerate():
for i,k in enumerate(range(1,5)):
print i, k
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def go_py_enumerate_start():
"""
>>> go_py_enumerate_start()
5 1
6 2
7 3
8 4
"""
for i,k in enumerate(range(1,5), 5):
print i, k
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def go_c_enumerate():
"""
......@@ -136,6 +148,18 @@ def multi_enumerate():
for a,(b,(c,d)) in enumerate(enumerate(enumerate(range(1,5)))):
print a,b,c,d
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def multi_enumerate_start():
"""
>>> multi_enumerate_start()
0 2 0 1
1 3 1 2
2 4 2 3
3 5 3 4
"""
for a,(b,(c,d)) in enumerate(enumerate(enumerate(range(1,5)), 2)):
print a,b,c,d
@cython.test_fail_if_path_exists("//SimpleCallNode")
def multi_c_enumerate():
"""
......@@ -160,3 +184,15 @@ def convert_target_enumerate(L):
cdef int a,b
for a, b in enumerate(L):
print a,b
@cython.test_fail_if_path_exists("//SimpleCallNode")
def convert_target_enumerate_start(L, int n):
"""
>>> convert_target_enumerate_start([2,3,5], 3)
3 2
4 3
5 5
"""
cdef int a,b
for a, b in enumerate(L, n):
print a,b
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment