Commit 66204a91 authored by Robert Bradshaw's avatar Robert Bradshaw

Better type checking for C++ iterators.

parent 0b547d5c
......@@ -2010,33 +2010,46 @@ class IteratorNode(ExprNode):
return py_object_type
def analyse_cpp_types(self, env):
begin = self.sequence.type.scope.lookup("begin")
end = self.sequence.type.scope.lookup("end")
if begin is None:
sequence_type = self.sequence.type
if sequence_type.is_ptr:
sequence_type = sequence_type.base_type
begin = sequence_type.scope.lookup("begin")
end = sequence_type.scope.lookup("end")
if (begin is None
or not begin.type.is_ptr
or not begin.type.base_type.is_cfunction
or begin.type.base_type.args):
error(self.pos, "missing begin() on %s" % self.sequence.type)
self.type = error_type
return
if end is None:
if (end is None
or not end.type.is_ptr
or not end.type.base_type.is_cfunction
or end.type.base_type.args):
error(self.pos, "missing end() on %s" % self.sequence.type)
self.type = error_type
return
iter_type = begin.type.base_type.return_type
if iter_type.is_cpp_class:
# TODO(robertwb): Check argument types.
if iter_type.scope.lookup("operator!=") is None:
if env.lookup_operator_for_types(
self.pos,
"!=",
[iter_type, end.type.base_type.return_type]) is None:
error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
if iter_type.scope.lookup("operator++") is None:
if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None:
error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
if iter_type.scope.lookup("operator*") is None:
if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None:
error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
self.type = iter_type
elif iter_type.is_ptr:
if not (iter_type == end.type.base_type.return_type):
error(self.pos, "incompatible types for begin() and end()")
self.type = iter_type
else:
error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
......@@ -2133,7 +2146,7 @@ class IteratorNode(ExprNode):
code.putln("if (%s < 0) break;" % self.counter_cname)
if sequence_type.is_cpp_class:
# TODO: Cache end() call?
code.putln("if (%s == %s.end()) break;" % (
code.putln("if (!(%s != %s.end())) break;" % (
self.result(),
self.sequence.result()));
code.putln("%s = *%s;" % (
......@@ -2198,7 +2211,7 @@ class NextNode(AtomicExprNode):
if iterator_type.is_ptr or iterator_type.is_array:
return iterator_type.base_type
elif iterator_type.is_cpp_class:
item_type = iterator_type.scope.lookup("operator*").type.base_type.return_type
item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.base_type.return_type
if item_type.is_reference:
item_type = item_type.ref_base_type
return item_type
......
......@@ -758,6 +758,13 @@ class Scope(object):
return None
return PyrexTypes.best_match(operands, function.all_alternatives())
def lookup_operator_for_types(self, pos, operator, types):
from Nodes import Node
class FakeOperand(Node):
pass
operands = [FakeOperand(pos, type=type) for type in types]
return self.lookup_operator(operator, operands)
def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code)
......
......@@ -2,6 +2,12 @@
from libcpp.vector cimport vector
cdef extern from "cpp_iterators_simple.h":
cdef cppclass DoublePointerIter:
DoublePointerIter(double* start, int len)
double* begin()
double* end()
def test_vector(py_v):
"""
>>> test_vector([1, 2, 3])
......@@ -27,3 +33,18 @@ def test_ptrs():
v.push_back(&b)
v.push_back(&c)
return [item[0] for item in v]
def test_custom():
"""
>>> test_custom()
[1.0, 2.0, 3.0]
"""
cdef double* values = [1, 2, 3]
cdef DoublePointerIter* iter
try:
iter = new DoublePointerIter(values, 3)
# TODO: It'd be nice to automatically dereference this in a way that
# would not conflict with the pointer slicing iteration.
return [x for x in iter[0]]
finally:
del iter
class DoublePointerIter {
public:
DoublePointerIter(double* start, int len) : start_(start), len_(len) { }
double* begin() { return start_; }
double* end() { return start_ + len_; }
private:
double* start_;
int len_;
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment