Commit 082aee71 authored by Robert Bradshaw's avatar Robert Bradshaw

Cleanup slice iteration code.

parent 0718018a
...@@ -1667,11 +1667,10 @@ class IteratorNode(ExprNode): ...@@ -1667,11 +1667,10 @@ class IteratorNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.sequence.analyse_types(env) self.sequence.analyse_types(env)
if isinstance(self.sequence, SliceIndexNode) and \ if (self.sequence.type.is_array or self.sequence.type.is_ptr) and \
(self.sequence.base.type.is_array or self.sequence.base.type.is_ptr) \ not self.sequence.type.is_string:
or self.sequence.type.is_array and self.sequence.type.size is not None:
# C array iteration will be transformed later on # C array iteration will be transformed later on
pass self.type = self.sequence.type
else: else:
self.sequence = self.sequence.coerce_to_pyobject(env) self.sequence = self.sequence.coerce_to_pyobject(env)
self.is_temp = 1 self.is_temp = 1
...@@ -1686,6 +1685,8 @@ class IteratorNode(ExprNode): ...@@ -1686,6 +1685,8 @@ class IteratorNode(ExprNode):
code.funcstate.release_temp(self.counter_cname) code.funcstate.release_temp(self.counter_cname)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.sequence.type.is_array or self.sequence.type.is_ptr:
raise InternalError("for in carray slice not transformed")
is_builtin_sequence = self.sequence.type is list_type or \ is_builtin_sequence = self.sequence.type is list_type or \
self.sequence.type is tuple_type self.sequence.type is tuple_type
may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type
...@@ -1733,6 +1734,8 @@ class NextNode(AtomicExprNode): ...@@ -1733,6 +1734,8 @@ class NextNode(AtomicExprNode):
def __init__(self, iterator, env): def __init__(self, iterator, env):
self.pos = iterator.pos self.pos = iterator.pos
self.iterator = iterator self.iterator = iterator
if iterator.type.is_ptr or iterator.type.is_array:
self.type = iterator.type.base_type
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -2008,6 +2011,7 @@ class IndexNode(ExprNode): ...@@ -2008,6 +2011,7 @@ class IndexNode(ExprNode):
return return
is_slice = isinstance(self.index, SliceNode) is_slice = isinstance(self.index, SliceNode)
# Potentially overflowing index value.
if not is_slice and isinstance(self.index, IntNode) and Utils.long_literal(self.index.value): if not is_slice and isinstance(self.index, IntNode) and Utils.long_literal(self.index.value):
self.index = self.index.coerce_to_pyobject(env) self.index = self.index.coerce_to_pyobject(env)
...@@ -2092,7 +2096,9 @@ class IndexNode(ExprNode): ...@@ -2092,7 +2096,9 @@ class IndexNode(ExprNode):
else: else:
if base_type.is_ptr or base_type.is_array: if base_type.is_ptr or base_type.is_array:
self.type = base_type.base_type self.type = base_type.base_type
if self.index.type.is_pyobject: if is_slice:
self.type = base_type
elif self.index.type.is_pyobject:
self.index = self.index.coerce_to( self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env) PyrexTypes.c_py_ssize_t_type, env)
elif not self.index.type.is_int: elif not self.index.type.is_int:
...@@ -2147,6 +2153,8 @@ class IndexNode(ExprNode): ...@@ -2147,6 +2153,8 @@ class IndexNode(ExprNode):
return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result()) return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
elif self.base.type is unicode_type and self.type is PyrexTypes.c_py_unicode_type: elif self.base.type is unicode_type and self.type is PyrexTypes.c_py_unicode_type:
return "PyUnicode_AS_UNICODE(%s)[%s]" % (self.base.result(), self.index.result()) return "PyUnicode_AS_UNICODE(%s)[%s]" % (self.base.result(), self.index.result())
elif (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
error(self.pos, "Invalid use of pointer slice")
else: else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result(), self.index.result()) self.base.result(), self.index.result())
...@@ -2401,7 +2409,9 @@ class SliceIndexNode(ExprNode): ...@@ -2401,7 +2409,9 @@ class SliceIndexNode(ExprNode):
base_type = self.base.type base_type = self.base.type
if base_type.is_string: if base_type.is_string:
self.type = bytes_type self.type = bytes_type
elif base_type.is_array or base_type.is_ptr: elif base_type.is_ptr:
self.type = base_type
elif base_type.is_array:
# we need a ptr type here instead of an array type, as # we need a ptr type here instead of an array type, as
# array types can result in invalid type casts in the C # array types can result in invalid type casts in the C
# code # code
...@@ -6027,13 +6037,9 @@ class CmpNode(object): ...@@ -6027,13 +6037,9 @@ class CmpNode(object):
def is_ptr_contains(self): def is_ptr_contains(self):
if self.operator in ('in', 'not_in'): if self.operator in ('in', 'not_in'):
iterator = self.operand2 container_type = self.operand2.type
if iterator.type.is_ptr or iterator.type.is_array: return (container_type.is_ptr or container_type.is_array) \
return iterator.type.base_type is not PyrexTypes.c_char_type and not container_type.is_string
if (isinstance(iterator, IndexNode) and
isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and
(iterator.base.type.is_array or iterator.base.type.is_ptr)):
return iterator.base.type.base_type is not PyrexTypes.c_char_type
def generate_operation_code(self, code, result_code, def generate_operation_code(self, code, result_code,
operand1, op , operand2): operand1, op , operand2):
......
...@@ -4295,9 +4295,6 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -4295,9 +4295,6 @@ class ForInStatNode(LoopNode, StatNode):
self.target.analyse_target_types(env) self.target.analyse_target_types(env)
self.iterator.analyse_expressions(env) self.iterator.analyse_expressions(env)
self.item = ExprNodes.NextNode(self.iterator, env) self.item = ExprNodes.NextNode(self.iterator, env)
if not self.target.type.assignable_from(self.item.type) and \
(self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array):
self.item.type = self.iterator.sequence.type.base_type
self.item = self.item.coerce_to(self.target.type, env) self.item = self.item.coerce_to(self.target.type, env)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
if self.else_clause: if self.else_clause:
......
...@@ -146,16 +146,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -146,16 +146,13 @@ class IterationTransform(Visitor.VisitorTransform):
node, dict_obj=iterator, keys=True, values=False) node, dict_obj=iterator, keys=True, values=False)
# C array (slice) iteration? # C array (slice) iteration?
plain_iterator = unwrap_coerced_node(iterator) if False:
if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \ plain_iterator = unwrap_coerced_node(iterator)
(plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr): if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
return self._transform_carray_iteration(node, plain_iterator) (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
if isinstance(plain_iterator, ExprNodes.IndexNode) and \
isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)):
iterator_base = unwrap_coerced_node(plain_iterator.base)
if iterator_base.type.is_array or iterator_base.type.is_ptr:
return self._transform_carray_iteration(node, plain_iterator) return self._transform_carray_iteration(node, plain_iterator)
if iterator.type.is_array:
if iterator.type.is_ptr or iterator.type.is_array:
return self._transform_carray_iteration(node, iterator) return self._transform_carray_iteration(node, iterator)
if iterator.type in (Builtin.bytes_type, Builtin.unicode_type): if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
return self._transform_string_iteration(node, iterator) return self._transform_string_iteration(node, iterator)
...@@ -220,7 +217,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -220,7 +217,7 @@ class IterationTransform(Visitor.VisitorTransform):
def _transform_string_iteration(self, node, slice_node): def _transform_string_iteration(self, node, slice_node):
if not node.target.type.is_int: if not node.target.type.is_int:
return node return self._transform_carray_iteration(node, slice_node)
if slice_node.type is Builtin.unicode_type: if slice_node.type is Builtin.unicode_type:
unpack_func = "PyUnicode_AS_UNICODE" unpack_func = "PyUnicode_AS_UNICODE"
len_func = "PyUnicode_GET_SIZE" len_func = "PyUnicode_GET_SIZE"
...@@ -270,11 +267,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -270,11 +267,13 @@ class IterationTransform(Visitor.VisitorTransform):
stop = slice_node.stop stop = slice_node.stop
step = None step = None
if not stop: if not stop:
if not slice_base.type.is_pyobject:
error(slice_node.pos, "C array iteration requires known end index")
return node return node
elif isinstance(slice_node, ExprNodes.IndexNode): elif isinstance(slice_node, ExprNodes.IndexNode):
# slice_node.index must be a SliceNode # slice_node.index must be a SliceNode
slice_base = unwrap_coerced_node(slice_node.base) slice_base = slice_node.base
index = unwrap_coerced_node(slice_node.index) index = slice_node.index
start = index.start start = index.start
stop = index.stop stop = index.stop
step = index.step step = index.step
...@@ -285,7 +284,8 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -285,7 +284,8 @@ class IterationTransform(Visitor.VisitorTransform):
or step.constant_result == 0 \ or step.constant_result == 0 \
or step.constant_result > 0 and not stop \ or step.constant_result > 0 and not stop \
or step.constant_result < 0 and not start: or step.constant_result < 0 and not start:
error(step.pos, "C array iteration requires known step size and end index") if not slice_base.type.is_pyobject:
error(step.pos, "C array iteration requires known step size and end index")
return node return node
else: else:
# step sign is handled internally by ForFromStatNode # step sign is handled internally by ForFromStatNode
...@@ -293,14 +293,20 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -293,14 +293,20 @@ class IterationTransform(Visitor.VisitorTransform):
step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type, step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
value=abs(step.constant_result), value=abs(step.constant_result),
constant_result=abs(step.constant_result)) constant_result=abs(step.constant_result))
elif slice_node.type.is_array and slice_node.type.size is not None: elif slice_node.type.is_array:
if slice_node.type.size is None:
error(step.pos, "C array iteration requires known end index")
return node
slice_base = slice_node slice_base = slice_node
start = None start = None
stop = ExprNodes.IntNode( stop = ExprNodes.IntNode(
slice_node.pos, value=str(slice_node.type.size), slice_node.pos, value=str(slice_node.type.size),
type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size) type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
step = None step = None
else: else:
if not slice_node.type.is_pyobject:
error(slice_node.pos, "Invalid C array iteration")
return node return node
if start: if start:
......
...@@ -44,7 +44,7 @@ def void_ptr_slice(py_x, L, int a, int b): ...@@ -44,7 +44,7 @@ def void_ptr_slice(py_x, L, int a, int b):
L_c[i] = <void*>L[i] L_c[i] = <void*>L[i]
assert (x in L_c[:b]) == (py_x in L[:b]) assert (x in L_c[:b]) == (py_x in L[:b])
assert (x in L_c[a:b]) == (py_x in L[a:b]) assert (x in L_c[a:b]) == (py_x in L[a:b])
# assert (x in L_c[a:b:2]) == (py_x in L[a:b:2]) assert (x in L_c[a:b:2]) == (py_x in L[a:b:2])
finally: finally:
free(L_c) free(L_c)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment