Commit 2d177dda authored by Robert Bradshaw's avatar Robert Bradshaw

Support C++ iterators in for..in loops.

parent 3575912c
......@@ -1982,6 +1982,8 @@ class IteratorNode(ExprNode):
not self.sequence.type.is_string:
# C array iteration will be transformed later on
self.type = self.sequence.type
elif self.sequence.type.is_cpp_class:
self.analyse_cpp_types(env)
else:
self.sequence = self.sequence.coerce_to_pyobject(env)
if self.sequence.type is list_type or \
......@@ -1995,9 +1997,46 @@ class IteratorNode(ExprNode):
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None),
]))
def analyse_cpp_types(self, env):
begin = self.sequence.type.scope.lookup("begin")
end = self.sequence.type.scope.lookup("end")
if begin is None:
error(self.pos, "missing begin() on %s" % self.sequence.type)
self.type = error_type
return
if end is None:
error(self.pos, "missing end() on %s" % self.sequence.type)
self.type = error_type
return
iter_type = begin.type.base_type.return_type
if iter_type.is_cpp_class:
# TODO(robertwb): Check argument types.
if iter_type.scope.lookup("operator!=") is None:
error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
if iter_type.scope.lookup("operator++") is None:
error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
if iter_type.scope.lookup("operator*") is None:
error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
self.type = iter_type
elif iter_type.is_ptr:
self.type = iter_type
else:
error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
self.type = error_type
return
def generate_result_code(self, code):
sequence_type = self.sequence.type
if sequence_type.is_cpp_class:
# TODO: Limit scope.
code.putln("%s = %s.begin();" % (self.result(), self.sequence.result()))
return
if sequence_type.is_array or sequence_type.is_ptr:
raise InternalError("for in carray slice not transformed")
is_builtin_sequence = sequence_type is list_type or \
......@@ -2080,7 +2119,17 @@ class IteratorNode(ExprNode):
sequence_type = self.sequence.type
if self.reversed:
code.putln("if (%s < 0) break;" % self.counter_cname)
if sequence_type is list_type:
if sequence_type.is_cpp_class:
# TODO: Cache end() call?
code.putln("if (%s == %s.end()) break;" % (
self.result(),
self.sequence.result()));
code.putln("%s = *%s;" % (
result_name,
self.result()))
code.putln("++%s;" % self.result())
return
elif sequence_type is list_type:
self.generate_next_sequence_item('List', result_name, code)
return
elif sequence_type is tuple_type:
......@@ -2127,13 +2176,16 @@ class NextNode(AtomicExprNode):
#
# iterator IteratorNode
type = py_object_type
def __init__(self, iterator):
self.pos = iterator.pos
self.iterator = iterator
if iterator.type.is_ptr or iterator.type.is_array:
self.type = iterator.type.base_type
iterator_type = iterator.type
if iterator_type.is_ptr or iterator_type.is_array:
self.type = iterator_type.base_type
elif iterator_type.is_cpp_class:
self.type = iterator_type.scope.lookup("operator*").type.base_type.return_type
else:
self.type = py_object_type
self.is_temp = 1
def generate_result_code(self, code):
......@@ -2459,6 +2511,18 @@ class IndexNode(ExprNode):
elif base_type.is_ptr or base_type.is_array:
return base_type.base_type
if base_type.is_cpp_class:
class FakeOperand:
def __init__(self, **kwds):
self.__dict__.update(kwds)
operands = [
FakeOperand(pos=self.pos, type=base_type),
FakeOperand(pos=self.pos, type=index_type),
]
index_func = env.lookup_operator('[]', operands)
if index_func is not None:
return index_func.type.base_type.return_type
# may be slicing or indexing, we don't know
if base_type in (unicode_type, str_type):
# these types always returns their own type on Python indexing/slicing
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment