Commit f58ddd33 authored by Stefan Behnel's avatar Stefan Behnel

allow reassignments to C++ variables while iterating over the original value in a for-loop

parent e9f4ee4b
......@@ -1973,6 +1973,7 @@ class IteratorNode(ExprNode):
type = py_object_type
iter_func_ptr = None
counter_cname = None
cpp_iterator_cname = None
reversed = False # currently only used for list/tuple types (see Optimize.py)
subexprs = ['sequence']
......@@ -2060,12 +2061,20 @@ class IteratorNode(ExprNode):
error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
self.type = error_type
return
def generate_result_code(self, code):
sequence_type = self.sequence.type
if sequence_type.is_cpp_class:
if self.sequence.is_name:
# safe: C++ won't allow you to reassign to class references
begin_func = "%s.begin" % self.sequence.result()
else:
sequence_type = PyrexTypes.c_ptr_type(sequence_type)
self.cpp_iterator_cname = code.funcstate.allocate_temp(sequence_type, manage_ref=False)
code.putln("%s = &%s;" % (self.cpp_iterator_cname, self.sequence.result()))
begin_func = "%s->begin" % self.cpp_iterator_cname
# TODO: Limit scope.
code.putln("%s = %s.begin();" % (self.result(), self.sequence.result()))
code.putln("%s = %s();" % (self.result(), begin_func))
return
if sequence_type.is_array or sequence_type.is_ptr:
raise InternalError("for in carray slice not transformed")
......@@ -2150,10 +2159,14 @@ class IteratorNode(ExprNode):
if self.reversed:
code.putln("if (%s < 0) break;" % self.counter_cname)
if sequence_type.is_cpp_class:
if self.cpp_iterator_cname:
end_func = "%s->end" % self.cpp_iterator_cname
else:
end_func = "%s.end" % self.sequence.result()
# TODO: Cache end() call?
code.putln("if (!(%s != %s.end())) break;" % (
code.putln("if (!(%s != %s())) break;" % (
self.result(),
self.sequence.result()));
end_func))
code.putln("%s = *%s;" % (
result_name,
self.result()))
......@@ -2195,6 +2208,8 @@ class IteratorNode(ExprNode):
if self.iter_func_ptr:
code.funcstate.release_temp(self.iter_func_ptr)
self.iter_func_ptr = None
if self.cpp_iterator_cname:
code.funcstate.release_temp(self.cpp_iterator_cname)
ExprNode.free_temps(self, code)
......
......@@ -63,3 +63,32 @@ def test_iteration_over_heap_vector(L):
return [ i for i in deref(vint) ]
finally:
del vint
def test_iteration_in_generator(vector[int] vint):
"""
>>> list( test_iteration_in_generator([1,2]) )
[1, 2]
"""
for i in vint:
yield i
def test_iteration_in_generator_reassigned():
"""
>>> list( test_iteration_in_generator_reassigned() )
[1]
"""
cdef vector[int] *vint = new vector[int]()
cdef vector[int] *orig_vint = vint
vint.push_back(1)
reassign = True
try:
for i in deref(vint):
yield i
if reassign:
reassign = False
vint = new vector[int]()
vint.push_back(2)
finally:
del orig_vint
if vint is not orig_vint:
del vint
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment