Commit 15e5567f authored by Robert Bradshaw's avatar Robert Bradshaw

in and not in operators for C arrays and sliced pointers

parent a4e2de9d
...@@ -2011,7 +2011,8 @@ class IndexNode(ExprNode): ...@@ -2011,7 +2011,8 @@ class IndexNode(ExprNode):
# Handle the case where base is a literal char* (and we expect a string, not an int) # Handle the case where base is a literal char* (and we expect a string, not an int)
if isinstance(self.base, BytesNode) or is_slice: if isinstance(self.base, BytesNode) or is_slice:
self.base = self.base.coerce_to_pyobject(env) if not (self.base.type.is_ptr or self.base.type.is_array):
self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
...@@ -2092,7 +2093,7 @@ class IndexNode(ExprNode): ...@@ -2092,7 +2093,7 @@ class IndexNode(ExprNode):
if self.index.type.is_pyobject: if self.index.type.is_pyobject:
self.index = self.index.coerce_to( self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env) PyrexTypes.c_py_ssize_t_type, env)
if not self.index.type.is_int: elif not self.index.type.is_int:
error(self.pos, error(self.pos,
"Invalid index type '%s'" % "Invalid index type '%s'" %
self.index.type) self.index.type)
...@@ -5995,10 +5996,11 @@ class CmpNode(object): ...@@ -5995,10 +5996,11 @@ class CmpNode(object):
(op, operand1.type, operand2.type)) (op, operand1.type, operand2.type))
def is_python_comparison(self): def is_python_comparison(self):
return not self.is_c_string_contains() and ( return (not self.is_ptr_contains()
self.has_python_operands() and not self.is_c_string_contains()
or (self.cascade and self.cascade.is_python_comparison()) and (self.has_python_operands()
or self.operator in ('in', 'not_in')) or (self.cascade and self.cascade.is_python_comparison())
or self.operator in ('in', 'not_in')))
def coerce_operands_to(self, dst_type, env): def coerce_operands_to(self, dst_type, env):
operand2 = self.operand2 operand2 = self.operand2
...@@ -6010,7 +6012,8 @@ class CmpNode(object): ...@@ -6010,7 +6012,8 @@ class CmpNode(object):
def is_python_result(self): def is_python_result(self):
return ((self.has_python_operands() and return ((self.has_python_operands() and
self.operator not in ('is', 'is_not', 'in', 'not_in') and self.operator not in ('is', 'is_not', 'in', 'not_in') and
not self.is_c_string_contains()) not self.is_c_string_contains() and
not self.is_ptr_contains())
or (self.cascade and self.cascade.is_python_result())) or (self.cascade and self.cascade.is_python_result()))
def is_c_string_contains(self): def is_c_string_contains(self):
...@@ -6019,6 +6022,16 @@ class CmpNode(object): ...@@ -6019,6 +6022,16 @@ class CmpNode(object):
and (self.operand2.type.is_string or self.operand2.type is bytes_type)) or and (self.operand2.type.is_string or self.operand2.type is bytes_type)) or
(self.operand1.type is PyrexTypes.c_py_unicode_type (self.operand1.type is PyrexTypes.c_py_unicode_type
and self.operand2.type is unicode_type)) and self.operand2.type is unicode_type))
def is_ptr_contains(self):
if self.operator in ('in', 'not_in'):
iterator = self.operand2
if iterator.type.is_ptr or iterator.type.is_array:
return iterator.type.base_type is not PyrexTypes.c_char_type
if (isinstance(iterator, IndexNode) and
isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and
(iterator.base.type.is_array or iterator.base.type.is_ptr)):
return iterator.base.type.base_type is not PyrexTypes.c_char_type
def generate_operation_code(self, code, result_code, def generate_operation_code(self, code, result_code,
operand1, op , operand2): operand1, op , operand2):
...@@ -6214,6 +6227,12 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -6214,6 +6227,12 @@ class PrimaryCmpNode(ExprNode, CmpNode):
env.use_utility_code(char_in_bytes_utility_code) env.use_utility_code(char_in_bytes_utility_code)
self.operand2 = self.operand2.as_none_safe_node( self.operand2 = self.operand2.as_none_safe_node(
"argument of type 'NoneType' is not iterable") "argument of type 'NoneType' is not iterable")
elif self.is_ptr_contains():
if self.cascade:
error(self.pos, "Cascading comparison not yet supported for 'val in sliced pointer'.")
self.type = PyrexTypes.c_bint_type
# Will be transformed by IterationTransform
return
else: else:
common_type = py_object_type common_type = py_object_type
self.is_pycmp = True self.is_pycmp = True
......
...@@ -4295,6 +4295,9 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -4295,6 +4295,9 @@ class ForInStatNode(LoopNode, StatNode):
self.target.analyse_target_types(env) self.target.analyse_target_types(env)
self.iterator.analyse_expressions(env) self.iterator.analyse_expressions(env)
self.item = ExprNodes.NextNode(self.iterator, env) self.item = ExprNodes.NextNode(self.iterator, env)
if not self.target.type.assignable_from(self.item.type) and \
(self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array):
self.item.type = self.iterator.sequence.type.base_type
self.item = self.item.coerce_to(self.target.type, env) self.item = self.item.coerce_to(self.target.type, env)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
if self.else_clause: if self.else_clause:
......
...@@ -77,11 +77,62 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -77,11 +77,62 @@ class IterationTransform(Visitor.VisitorTransform):
self.visitchildren(node) self.visitchildren(node)
self.current_scope = oldscope self.current_scope = oldscope
return node return node
def visit_PrimaryCmpNode(self, node):
if node.is_ptr_contains():
# for t in operand2:
# if operand1 == t:
# res = True
# break
# else:
# res = False
pos = node.pos
res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
res = res_handle.ref(pos)
result_ref = UtilNodes.ResultRefNode(node)
if isinstance(node.operand2, ExprNodes.IndexNode):
base_type = node.operand2.base.type.base_type
else:
base_type = node.operand2.type.base_type
target_handle = UtilNodes.TempHandle(base_type)
target = target_handle.ref(pos)
cmp_node = ExprNodes.PrimaryCmpNode(
pos, operator=u'==', operand1=node.operand1, operand2=target)
if_body = Nodes.StatListNode(
pos,
stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
Nodes.BreakStatNode(pos)])
if_node = Nodes.IfStatNode(
pos,
if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
else_clause=None)
for_loop = UtilNodes.TempsBlockNode(
pos,
temps = [target_handle],
body = Nodes.ForInStatNode(
pos,
target=target,
iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
body=if_node,
else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
for_loop.analyse_expressions(self.current_scope)
for_loop = self(for_loop)
new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
if node.operator == 'not_in':
new_node = ExprNodes.NotNode(pos, operand=new_node)
return new_node
else:
self.visitchildren(node)
return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
return self._optimise_for_loop(node) return self._optimise_for_loop(node)
def _optimise_for_loop(self, node): def _optimise_for_loop(self, node):
iterator = node.iterator.sequence iterator = node.iterator.sequence
if iterator.type is Builtin.dict_type: if iterator.type is Builtin.dict_type:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment