Commit 00dd1dfb authored by Stefan Behnel's avatar Stefan Behnel

optimised for-in-reversed(array) etc., including char*, bytes and unicode

parent c7e5af8b
...@@ -160,15 +160,9 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -160,15 +160,9 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_carray_iteration(node, plain_iterator) return self._transform_carray_iteration(node, plain_iterator)
if iterator.type.is_ptr or iterator.type.is_array: if iterator.type.is_ptr or iterator.type.is_array:
if reversed: return self._transform_carray_iteration(node, iterator, reversed=reversed)
# TODO: implement
return node
return self._transform_carray_iteration(node, iterator)
if iterator.type in (Builtin.bytes_type, Builtin.unicode_type): if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
if reversed: return self._transform_string_iteration(node, iterator, reversed=reversed)
# TODO: implement
return node
return self._transform_string_iteration(node, iterator)
# the rest is based on function calls # the rest is based on function calls
if not isinstance(iterator, ExprNodes.SimpleCallNode): if not isinstance(iterator, ExprNodes.SimpleCallNode):
...@@ -253,7 +247,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -253,7 +247,7 @@ class IterationTransform(Visitor.VisitorTransform):
PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
]) ])
def _transform_string_iteration(self, node, slice_node): def _transform_string_iteration(self, node, slice_node, reversed=False):
if not node.target.type.is_int: if not node.target.type.is_int:
return self._transform_carray_iteration(node, slice_node) return self._transform_carray_iteration(node, slice_node)
if slice_node.type is Builtin.unicode_type: if slice_node.type is Builtin.unicode_type:
...@@ -295,9 +289,10 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -295,9 +289,10 @@ class IterationTransform(Visitor.VisitorTransform):
stop = len_node, stop = len_node,
type = slice_base_node.type, type = slice_base_node.type,
is_temp = 1, is_temp = 1,
))) ),
reversed = reversed))
def _transform_carray_iteration(self, node, slice_node): def _transform_carray_iteration(self, node, slice_node, reversed=False):
neg_step = False neg_step = False
if isinstance(slice_node, ExprNodes.SliceIndexNode): if isinstance(slice_node, ExprNodes.SliceIndexNode):
slice_base = slice_node.base slice_base = slice_node.base
...@@ -327,10 +322,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -327,10 +322,13 @@ class IterationTransform(Visitor.VisitorTransform):
return node return node
else: else:
# step sign is handled internally by ForFromStatNode # step sign is handled internally by ForFromStatNode
neg_step = step.constant_result < 0 step_value = step.constant_result
if reversed:
step_value = -step_value
neg_step = step_value < 0
step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type, step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
value=abs(step.constant_result), value=str(abs(step_value)),
constant_result=abs(step.constant_result)) constant_result=abs(step_value))
elif slice_node.type.is_array: elif slice_node.type.is_array:
if slice_node.type.size is None: if slice_node.type.size is None:
error(step.pos, "C array iteration requires known end index") error(step.pos, "C array iteration requires known end index")
...@@ -365,6 +363,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -365,6 +363,13 @@ class IterationTransform(Visitor.VisitorTransform):
error(slice_node.pos, "C array iteration requires known step size and end index") error(slice_node.pos, "C array iteration requires known step size and end index")
return node return node
if reversed:
if not start:
start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0,
type=PyrexTypes.c_py_ssize_t_type)
# if step was provided, it was already negated above
start, stop = stop, start
ptr_type = slice_base.type ptr_type = slice_base.type
if ptr_type.is_array: if ptr_type.is_array:
ptr_type = ptr_type.element_ptr_type() ptr_type = ptr_type.element_ptr_type()
...@@ -380,6 +385,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -380,6 +385,7 @@ class IterationTransform(Visitor.VisitorTransform):
else: else:
start_ptr_node = carray_ptr start_ptr_node = carray_ptr
if stop and stop.constant_result != 0:
stop_ptr_node = ExprNodes.AddNode( stop_ptr_node = ExprNodes.AddNode(
stop.pos, stop.pos,
operand1=ExprNodes.CloneNode(carray_ptr), operand1=ExprNodes.CloneNode(carray_ptr),
...@@ -387,6 +393,8 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -387,6 +393,8 @@ class IterationTransform(Visitor.VisitorTransform):
operand2=stop, operand2=stop,
type=ptr_type type=ptr_type
).coerce_to_simple(self.current_scope) ).coerce_to_simple(self.current_scope)
else:
stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
counter = UtilNodes.TempHandle(ptr_type) counter = UtilNodes.TempHandle(ptr_type)
counter_temp = counter.ref(node.target.pos) counter_temp = counter.ref(node.target.pos)
...@@ -430,11 +438,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -430,11 +438,13 @@ class IterationTransform(Visitor.VisitorTransform):
node.pos, node.pos,
stats = [target_assign, node.body]) stats = [target_assign, node.body])
relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
for_node = Nodes.ForFromStatNode( for_node = Nodes.ForFromStatNode(
node.pos, node.pos,
bound1=start_ptr_node, relation1=neg_step and '>=' or '<=', bound1=start_ptr_node, relation1=relation1,
target=counter_temp, target=counter_temp,
relation2=neg_step and '>' or '<', bound2=stop_ptr_node, relation2=relation2, bound2=stop_ptr_node,
step=step, body=body, step=step, body=body,
else_clause=node.else_clause, else_clause=node.else_clause,
from_range=True) from_range=True)
...@@ -511,7 +521,19 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -511,7 +521,19 @@ class IterationTransform(Visitor.VisitorTransform):
node.iterator.sequence = enumerate_function.arg_tuple.args[0] node.iterator.sequence = enumerate_function.arg_tuple.args[0]
# recurse into loop to check for further optimisations # recurse into loop to check for further optimisations
return UtilNodes.LetNode(temp, self._optimise_for_loop(node)) return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
def _find_for_from_node_relations(self, neg_step_value, reversed):
if reversed:
if neg_step_value:
return '<', '<='
else:
return '>', '>='
else:
if neg_step_value:
return '>=', '>'
else:
return '<=', '<'
def _transform_range_iteration(self, node, range_function, reversed=False): def _transform_range_iteration(self, node, range_function, reversed=False):
args = range_function.arg_tuple.args args = range_function.arg_tuple.args
...@@ -542,23 +564,15 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -542,23 +564,15 @@ class IterationTransform(Visitor.VisitorTransform):
bound1 = args[0].coerce_to_integer(self.current_scope) bound1 = args[0].coerce_to_integer(self.current_scope)
bound2 = args[1].coerce_to_integer(self.current_scope) bound2 = args[1].coerce_to_integer(self.current_scope)
relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
if reversed: if reversed:
bound1, bound2 = bound2, bound1 bound1, bound2 = bound2, bound1
if step_value < 0: if step_value < 0:
step_value = -step_value step_value = -step_value
relation1 = '<'
relation2 = '<='
else:
relation1 = '>'
relation2 = '>='
else: else:
if step_value < 0: if step_value < 0:
step_value = -step_value step_value = -step_value
relation1 = '>='
relation2 = '>'
else:
relation1 = '<='
relation2 = '<'
step.value = str(step_value) step.value = str(step_value)
step.constant_result = step_value step.constant_result = step_value
......
...@@ -82,6 +82,9 @@ class MarkAssignments(CythonTransform): ...@@ -82,6 +82,9 @@ class MarkAssignments(CythonTransform):
'+', '+',
sequence.args[0], sequence.args[0],
sequence.args[2])) sequence.args[2]))
elif function.name == 'reversed' and len(sequence.args) == 1:
sequence = sequence.args[0]
if not is_special: if not is_special:
# A for-loop basically translates to subsequent calls to # A for-loop basically translates to subsequent calls to
# __getitem__(), so using an IndexNode here allows us to # __getitem__(), so using an IndexNode here allows us to
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment