Commit 0b547d5c authored by Robert Bradshaw's avatar Robert Bradshaw

C++ iterator type inference

parent 4a4f6583
...@@ -1997,6 +1997,18 @@ class IteratorNode(ExprNode): ...@@ -1997,6 +1997,18 @@ class IteratorNode(ExprNode):
PyrexTypes.py_object_type, [ PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None),
])) ]))
def infer_type(self, env):
sequence_type = self.sequence.infer_type(env)
if (sequence_type.is_array or sequence_type.is_ptr) and \
not sequence_type.is_string:
return sequence_type
elif sequence_type.is_cpp_class:
begin = sequence_type.scope.lookup("begin")
if begin is not None:
return begin.type.base_type.return_type
return py_object_type
def analyse_cpp_types(self, env): def analyse_cpp_types(self, env):
begin = self.sequence.type.scope.lookup("begin") begin = self.sequence.type.scope.lookup("begin")
end = self.sequence.type.scope.lookup("end") end = self.sequence.type.scope.lookup("end")
...@@ -2179,15 +2191,22 @@ class NextNode(AtomicExprNode): ...@@ -2179,15 +2191,22 @@ class NextNode(AtomicExprNode):
def __init__(self, iterator): def __init__(self, iterator):
self.pos = iterator.pos self.pos = iterator.pos
self.iterator = iterator self.iterator = iterator
iterator_type = iterator.type
def infer_type(self, env, iterator_type = None):
if iterator_type is None:
iterator_type = self.iterator.infer_type(env)
if iterator_type.is_ptr or iterator_type.is_array: if iterator_type.is_ptr or iterator_type.is_array:
self.type = iterator_type.base_type return iterator_type.base_type
elif iterator_type.is_cpp_class: elif iterator_type.is_cpp_class:
self.type = iterator_type.scope.lookup("operator*").type.base_type.return_type item_type = iterator_type.scope.lookup("operator*").type.base_type.return_type
if self.type.is_reference: if item_type.is_reference:
self.type = self.type.ref_base_type item_type = item_type.ref_base_type
return item_type
else: else:
self.type = py_object_type return py_object_type
def analyse_types(self, env):
self.type = self.infer_type(env, self.iterator.type)
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
......
...@@ -915,10 +915,8 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -915,10 +915,8 @@ class ControlFlowAnalysis(CythonTransform):
# naturally infer the base type of pointers, C arrays, # naturally infer the base type of pointers, C arrays,
# Python strings, etc., while correctly falling back to an # Python strings, etc., while correctly falling back to an
# object type when the base type cannot be handled. # object type when the base type cannot be handled.
self.mark_assignment(target, ExprNodes.IndexNode(
node.pos, self.mark_assignment(target, node.item)
base = sequence,
index = ExprNodes.IntNode(node.pos, value = '0')))
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
condition_block = self.flow.nextblock() condition_block = self.flow.nextblock()
......
...@@ -5371,16 +5371,17 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -5371,16 +5371,17 @@ class ForInStatNode(LoopNode, StatNode):
item = None item = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
import ExprNodes
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
if self.else_clause: if self.else_clause:
self.else_clause.analyse_declarations(env) self.else_clause.analyse_declarations(env)
self.item = ExprNodes.NextNode(self.iterator)
def analyse_expressions(self, env): def analyse_expressions(self, env):
import ExprNodes
self.target.analyse_target_types(env) self.target.analyse_target_types(env)
self.iterator.analyse_expressions(env) self.iterator.analyse_expressions(env)
self.item = ExprNodes.NextNode(self.iterator) self.item.analyse_expressions(env)
if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \ if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \
self.target.type.assignable_from(self.iterator.type): self.target.type.assignable_from(self.iterator.type):
# C array slice optimization. # C array slice optimization.
......
...@@ -26,6 +26,4 @@ def test_ptrs(): ...@@ -26,6 +26,4 @@ def test_ptrs():
v.push_back(&a) v.push_back(&a)
v.push_back(&b) v.push_back(&b)
v.push_back(&c) v.push_back(&c)
cdef double* item
return [item[0] for item in v] return [item[0] for item in v]
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment