Commit c7e5af8b authored by Stefan Behnel's avatar Stefan Behnel

start optimising for-in-reversed(): unpack reversed(range(...))

parent 8e7816d2
......@@ -141,12 +141,14 @@ class IterationTransform(Visitor.VisitorTransform):
def visit_ForInStatNode(self, node):
self.visitchildren(node)
return self._optimise_for_loop(node)
return self._optimise_for_loop(node, node.iterator.sequence)
def _optimise_for_loop(self, node):
iterator = node.iterator.sequence
def _optimise_for_loop(self, node, iterator, reversed=False):
if iterator.type is Builtin.dict_type:
# like iterating over dict.keys()
if reversed:
# (reversed) dict iteration uses arbitrary order
return node
return self._transform_dict_iteration(
node, dict_obj=iterator, keys=True, values=False)
......@@ -158,8 +160,14 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_carray_iteration(node, plain_iterator)
if iterator.type.is_ptr or iterator.type.is_array:
if reversed:
# TODO: implement
return node
return self._transform_carray_iteration(node, iterator)
if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
if reversed:
# TODO: implement
return node
return self._transform_string_iteration(node, iterator)
# the rest is based on function calls
......@@ -170,6 +178,9 @@ class IterationTransform(Visitor.VisitorTransform):
# dict iteration?
if isinstance(function, ExprNodes.AttributeNode) and \
function.obj.type == Builtin.dict_type:
if reversed:
# (reversed) dict iteration uses arbitrary order
return node
dict_obj = function.obj
method = function.attribute
......@@ -186,21 +197,42 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_dict_iteration(
node, dict_obj, keys, values)
# enumerate() ?
# enumerate/reversed ?
if iterator.self is None and function.is_name and \
function.entry and function.entry.is_builtin and \
function.name == 'enumerate':
function.entry and function.entry.is_builtin:
if function.name == 'enumerate':
if reversed:
# TODO: implement
return node
return self._transform_enumerate_iteration(node, iterator)
elif function.name == 'reversed':
if reversed:
# it is not safe to short-cut here due to evaluation rules,
# but this case is unlikely enough to just ignore it
return node
return self._transform_reversed_iteration(node, iterator)
# range() iteration?
if Options.convert_range and node.target.type.is_int:
if iterator.self is None and function.is_name and \
function.entry and function.entry.is_builtin and \
function.name in ('range', 'xrange'):
return self._transform_range_iteration(node, iterator)
return self._transform_range_iteration(node, iterator, reversed=reversed)
return node
def _transform_reversed_iteration(self, node, reversed_function):
args = reversed_function.arg_tuple.args
if len(args) == 0:
error(reversed_function.pos,
"reversed() requires an iterable argument")
return node
elif len(args) > 1:
error(reversed_function.pos,
"reversed() takes exactly 1 argument")
return node
return self._optimise_for_loop(node, args[0], reversed=True)
PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_py_unicode_ptr_type, [
PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
......@@ -481,7 +513,7 @@ class IterationTransform(Visitor.VisitorTransform):
# recurse into loop to check for further optimisations
return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
def _transform_range_iteration(self, node, range_function):
def _transform_range_iteration(self, node, range_function, reversed=False):
args = range_function.arg_tuple.args
if len(args) < 3:
step_pos = range_function.pos
......@@ -502,14 +534,6 @@ class IterationTransform(Visitor.VisitorTransform):
step = ExprNodes.IntNode(step_pos, value=str(step_value),
constant_result=step_value)
if step_value < 0:
step.value = str(-step_value)
relation1 = '>='
relation2 = '>'
else:
relation1 = '<='
relation2 = '<'
if len(args) == 1:
bound1 = ExprNodes.IntNode(range_function.pos, value='0',
constant_result=0)
......@@ -517,6 +541,27 @@ class IterationTransform(Visitor.VisitorTransform):
else:
bound1 = args[0].coerce_to_integer(self.current_scope)
bound2 = args[1].coerce_to_integer(self.current_scope)
if reversed:
bound1, bound2 = bound2, bound1
if step_value < 0:
step_value = -step_value
relation1 = '<'
relation2 = '<='
else:
relation1 = '>'
relation2 = '>='
else:
if step_value < 0:
step_value = -step_value
relation1 = '>='
relation2 = '>'
else:
relation1 = '<='
relation2 = '<'
step.value = str(step_value)
step.constant_result = step_value
step = step.coerce_to_integer(self.current_scope)
if not bound2.is_literal:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment