Commit dae3f8f1 authored by Stefan Behnel's avatar Stefan Behnel

fix optimised iteration over sliced C arrays with given step size

parent 67ebd9f7
...@@ -30,6 +30,11 @@ class FakePythonEnv(object): ...@@ -30,6 +30,11 @@ class FakePythonEnv(object):
"A fake environment for creating type test nodes etc." "A fake environment for creating type test nodes etc."
nogil = False nogil = False
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
if isinstance(node, coercion_nodes):
return node.arg
return node
def unwrap_node(node): def unwrap_node(node):
while isinstance(node, UtilNodes.ResultRefNode): while isinstance(node, UtilNodes.ResultRefNode):
node = node.expression node = node.expression
...@@ -90,19 +95,18 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -90,19 +95,18 @@ class IterationTransform(Visitor.VisitorTransform):
node, dict_obj=iterator, keys=True, values=False) node, dict_obj=iterator, keys=True, values=False)
# C array (slice) iteration? # C array (slice) iteration?
plain_iterator = iterator plain_iterator = unwrap_coerced_node(iterator)
if isinstance(iterator, ExprNodes.CoerceToPyTypeNode):
plain_iterator = iterator.arg
if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \ if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
(plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr): (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
return self._transform_carray_iteration(node, plain_iterator) return self._transform_carray_iteration(node, plain_iterator)
elif isinstance(plain_iterator, ExprNodes.IndexNode) and \ if isinstance(plain_iterator, ExprNodes.IndexNode) and \
isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)) and \ isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)):
(plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr): iterator_base = unwrap_coerced_node(plain_iterator.base)
if iterator_base.type.is_array or iterator_base.type.is_ptr:
return self._transform_carray_iteration(node, plain_iterator) return self._transform_carray_iteration(node, plain_iterator)
elif iterator.type.is_array: if iterator.type.is_array:
return self._transform_carray_iteration(node, iterator) return self._transform_carray_iteration(node, iterator)
elif iterator.type in (Builtin.bytes_type, Builtin.unicode_type): if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
return self._transform_string_iteration(node, iterator) return self._transform_string_iteration(node, iterator)
# the rest is based on function calls # the rest is based on function calls
...@@ -218,10 +222,8 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -218,10 +222,8 @@ class IterationTransform(Visitor.VisitorTransform):
return node return node
elif isinstance(slice_node, ExprNodes.IndexNode): elif isinstance(slice_node, ExprNodes.IndexNode):
# slice_node.index must be a SliceNode # slice_node.index must be a SliceNode
slice_base = slice_node.base slice_base = unwrap_coerced_node(slice_node.base)
index = slice_node.index index = unwrap_coerced_node(slice_node.index)
if isinstance(index, ExprNodes.CoerceFromPyTypeNode):
index = index.arg
start = index.start start = index.start
stop = index.stop stop = index.stop
step = index.step step = index.step
...@@ -260,6 +262,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -260,6 +262,13 @@ class IterationTransform(Visitor.VisitorTransform):
stop = None stop = None
else: else:
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope) stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
if stop is None:
if neg_step:
stop = ExprNodes.IntNode(
slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
else:
error(slice_node.pos, "C array iteration requires known step size and end index")
return node
ptr_type = slice_base.type ptr_type = slice_base.type
if ptr_type.is_array: if ptr_type.is_array:
......
...@@ -48,23 +48,40 @@ def slice_charptr_for_loop_c(): ...@@ -48,23 +48,40 @@ def slice_charptr_for_loop_c():
def slice_charptr_for_loop_c_step(): def slice_charptr_for_loop_c_step():
""" """
>>> slice_charptr_for_loop_c_step() >>> slice_charptr_for_loop_c_step()
['p', 't', 'q', 'C', 'B'] Acba
['p', 't', 'q', 'C', 'B'] ['A', 'c', 'b', 'a']
Acba
['A', 'c', 'b', 'a']
bA
['b', 'A'] ['b', 'A']
acB
['a', 'c', 'B'] ['a', 'c', 'B']
acB
['a', 'c', 'B'] ['a', 'c', 'B']
<BLANKLINE>
[] []
ptqC
['p', 't', 'q', 'C'] ['p', 't', 'q', 'C']
pq
['p', 'q'] ['p', 'q']
""" """
cdef unicode ustring = cstring.decode('ASCII')
cdef char c cdef char c
print [ chr(c) for c in cstring[:3:-1] ] print ustring[3::-1]
print [ chr(c) for c in cstring[None:3:-1] ] print [ chr(c) for c in cstring[3::-1] ]
print ustring[3:None:-1]
print [ chr(c) for c in cstring[3:None:-1] ]
print ustring[1:5:2]
print [ chr(c) for c in cstring[1:5:2] ] print [ chr(c) for c in cstring[1:5:2] ]
print ustring[:5:2]
print [ chr(c) for c in cstring[:5:2] ] print [ chr(c) for c in cstring[:5:2] ]
print ustring[None:5:2]
print [ chr(c) for c in cstring[None:5:2] ] print [ chr(c) for c in cstring[None:5:2] ]
print ustring[4:9:-1]
print [ chr(c) for c in cstring[4:9:-1] ] print [ chr(c) for c in cstring[4:9:-1] ]
print ustring[8:4:-1]
print [ chr(c) for c in cstring[8:4:-1] ] print [ chr(c) for c in cstring[8:4:-1] ]
print ustring[8:4:-2]
print [ chr(c) for c in cstring[8:4:-2] ] print [ chr(c) for c in cstring[8:4:-2] ]
@cython.test_assert_path_exists("//ForFromStatNode", @cython.test_assert_path_exists("//ForFromStatNode",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment