Commit 082aee71 authored by Robert Bradshaw's avatar Robert Bradshaw

Cleanup slice iteration code.

parent 0718018a
......@@ -1667,11 +1667,10 @@ class IteratorNode(ExprNode):
def analyse_types(self, env):
self.sequence.analyse_types(env)
if isinstance(self.sequence, SliceIndexNode) and \
(self.sequence.base.type.is_array or self.sequence.base.type.is_ptr) \
or self.sequence.type.is_array and self.sequence.type.size is not None:
if (self.sequence.type.is_array or self.sequence.type.is_ptr) and \
not self.sequence.type.is_string:
# C array iteration will be transformed later on
pass
self.type = self.sequence.type
else:
self.sequence = self.sequence.coerce_to_pyobject(env)
self.is_temp = 1
......@@ -1686,6 +1685,8 @@ class IteratorNode(ExprNode):
code.funcstate.release_temp(self.counter_cname)
def generate_result_code(self, code):
if self.sequence.type.is_array or self.sequence.type.is_ptr:
raise InternalError("for in carray slice not transformed")
is_builtin_sequence = self.sequence.type is list_type or \
self.sequence.type is tuple_type
may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type
......@@ -1733,6 +1734,8 @@ class NextNode(AtomicExprNode):
def __init__(self, iterator, env):
self.pos = iterator.pos
self.iterator = iterator
if iterator.type.is_ptr or iterator.type.is_array:
self.type = iterator.type.base_type
self.is_temp = 1
def generate_result_code(self, code):
......@@ -2008,6 +2011,7 @@ class IndexNode(ExprNode):
return
is_slice = isinstance(self.index, SliceNode)
# Potentially overflowing index value.
if not is_slice and isinstance(self.index, IntNode) and Utils.long_literal(self.index.value):
self.index = self.index.coerce_to_pyobject(env)
......@@ -2092,7 +2096,9 @@ class IndexNode(ExprNode):
else:
if base_type.is_ptr or base_type.is_array:
self.type = base_type.base_type
if self.index.type.is_pyobject:
if is_slice:
self.type = base_type
elif self.index.type.is_pyobject:
self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
elif not self.index.type.is_int:
......@@ -2147,6 +2153,8 @@ class IndexNode(ExprNode):
return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
elif self.base.type is unicode_type and self.type is PyrexTypes.c_py_unicode_type:
return "PyUnicode_AS_UNICODE(%s)[%s]" % (self.base.result(), self.index.result())
elif (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
error(self.pos, "Invalid use of pointer slice")
else:
return "(%s[%s])" % (
self.base.result(), self.index.result())
......@@ -2401,7 +2409,9 @@ class SliceIndexNode(ExprNode):
base_type = self.base.type
if base_type.is_string:
self.type = bytes_type
elif base_type.is_array or base_type.is_ptr:
elif base_type.is_ptr:
self.type = base_type
elif base_type.is_array:
# we need a ptr type here instead of an array type, as
# array types can result in invalid type casts in the C
# code
......@@ -6027,13 +6037,9 @@ class CmpNode(object):
def is_ptr_contains(self):
if self.operator in ('in', 'not_in'):
iterator = self.operand2
if iterator.type.is_ptr or iterator.type.is_array:
return iterator.type.base_type is not PyrexTypes.c_char_type
if (isinstance(iterator, IndexNode) and
isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and
(iterator.base.type.is_array or iterator.base.type.is_ptr)):
return iterator.base.type.base_type is not PyrexTypes.c_char_type
container_type = self.operand2.type
return (container_type.is_ptr or container_type.is_array) \
and not container_type.is_string
def generate_operation_code(self, code, result_code,
operand1, op , operand2):
......
......@@ -4295,9 +4295,6 @@ class ForInStatNode(LoopNode, StatNode):
self.target.analyse_target_types(env)
self.iterator.analyse_expressions(env)
self.item = ExprNodes.NextNode(self.iterator, env)
if not self.target.type.assignable_from(self.item.type) and \
(self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array):
self.item.type = self.iterator.sequence.type.base_type
self.item = self.item.coerce_to(self.target.type, env)
self.body.analyse_expressions(env)
if self.else_clause:
......
......@@ -146,16 +146,13 @@ class IterationTransform(Visitor.VisitorTransform):
node, dict_obj=iterator, keys=True, values=False)
# C array (slice) iteration?
plain_iterator = unwrap_coerced_node(iterator)
if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
(plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
return self._transform_carray_iteration(node, plain_iterator)
if isinstance(plain_iterator, ExprNodes.IndexNode) and \
isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)):
iterator_base = unwrap_coerced_node(plain_iterator.base)
if iterator_base.type.is_array or iterator_base.type.is_ptr:
if False:
plain_iterator = unwrap_coerced_node(iterator)
if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
(plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
return self._transform_carray_iteration(node, plain_iterator)
if iterator.type.is_array:
if iterator.type.is_ptr or iterator.type.is_array:
return self._transform_carray_iteration(node, iterator)
if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
return self._transform_string_iteration(node, iterator)
......@@ -220,7 +217,7 @@ class IterationTransform(Visitor.VisitorTransform):
def _transform_string_iteration(self, node, slice_node):
if not node.target.type.is_int:
return node
return self._transform_carray_iteration(node, slice_node)
if slice_node.type is Builtin.unicode_type:
unpack_func = "PyUnicode_AS_UNICODE"
len_func = "PyUnicode_GET_SIZE"
......@@ -270,11 +267,13 @@ class IterationTransform(Visitor.VisitorTransform):
stop = slice_node.stop
step = None
if not stop:
if not slice_base.type.is_pyobject:
error(slice_node.pos, "C array iteration requires known end index")
return node
elif isinstance(slice_node, ExprNodes.IndexNode):
# slice_node.index must be a SliceNode
slice_base = unwrap_coerced_node(slice_node.base)
index = unwrap_coerced_node(slice_node.index)
slice_base = slice_node.base
index = slice_node.index
start = index.start
stop = index.stop
step = index.step
......@@ -285,7 +284,8 @@ class IterationTransform(Visitor.VisitorTransform):
or step.constant_result == 0 \
or step.constant_result > 0 and not stop \
or step.constant_result < 0 and not start:
error(step.pos, "C array iteration requires known step size and end index")
if not slice_base.type.is_pyobject:
error(step.pos, "C array iteration requires known step size and end index")
return node
else:
# step sign is handled internally by ForFromStatNode
......@@ -293,14 +293,20 @@ class IterationTransform(Visitor.VisitorTransform):
step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
value=abs(step.constant_result),
constant_result=abs(step.constant_result))
elif slice_node.type.is_array and slice_node.type.size is not None:
elif slice_node.type.is_array:
if slice_node.type.size is None:
error(step.pos, "C array iteration requires known end index")
return node
slice_base = slice_node
start = None
stop = ExprNodes.IntNode(
slice_node.pos, value=str(slice_node.type.size),
type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
step = None
else:
if not slice_node.type.is_pyobject:
error(slice_node.pos, "Invalid C array iteration")
return node
if start:
......
......@@ -44,7 +44,7 @@ def void_ptr_slice(py_x, L, int a, int b):
L_c[i] = <void*>L[i]
assert (x in L_c[:b]) == (py_x in L[:b])
assert (x in L_c[a:b]) == (py_x in L[a:b])
# assert (x in L_c[a:b:2]) == (py_x in L[a:b:2])
assert (x in L_c[a:b:2]) == (py_x in L[a:b:2])
finally:
free(L_c)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment