Commit 7eddbcf4 authored by Stefan Behnel's avatar Stefan Behnel

optimise sorted([listcomp]) by sorting in-place

parent 70c5f657
...@@ -1387,35 +1387,42 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1387,35 +1387,42 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all') expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
def _handle_simple_function_sorted(self, node, pos_args): def _handle_simple_function_sorted(self, node, pos_args):
"""Transform sorted(genexpr) into [listcomp].sort(). CPython """Transform sorted(genexpr) and sorted([listcomp]) into
just reads the iterable into a list and calls .sort() on it. [listcomp].sort(). CPython just reads the iterable into a
Expanding the iterable in a listcomp is still faster. list and calls .sort() on it. Expanding the iterable in a
listcomp is still faster and the result can be sorted in
place.
""" """
if len(pos_args) != 1: if len(pos_args) != 1:
return node return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
return node and pos_args[0].target.type is Builtin.list_type:
gen_expr_node = pos_args[0] listcomp_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = listcomp_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
if yield_expression is None: gen_expr_node = pos_args[0]
return node loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None:
return node
result_node = UtilNodes.ResultRefNode( target = ExprNodes.ListNode(node.pos, args = [])
pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False) append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, expr = yield_expression,
target = ExprNodes.CloneNode(target))
target = ExprNodes.ListNode(node.pos, args = []) Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, expr = yield_expression,
target = ExprNodes.CloneNode(target))
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) listcomp_node = ExprNodes.ComprehensionNode(
gen_expr_node.pos, loop = loop_node, target = target,
append = append_node, type = Builtin.list_type,
expr_scope = gen_expr_node.expr_scope,
has_local_scope = True)
else:
return node
listcomp_node = ExprNodes.ComprehensionNode( result_node = UtilNodes.ResultRefNode(
gen_expr_node.pos, loop = loop_node, target = target, pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
append = append_node, type = Builtin.list_type,
expr_scope = gen_expr_node.expr_scope,
has_local_scope = True)
listcomp_assign_node = Nodes.SingleAssignmentNode( listcomp_assign_node = Nodes.SingleAssignmentNode(
node.pos, lhs = result_node, rhs = listcomp_node, first = True) node.pos, lhs = result_node, rhs = listcomp_node, first = True)
......
cimport cython
def smoketest(): def smoketest():
""" """
>>> smoketest() >>> smoketest()
...@@ -76,3 +78,12 @@ def listcomp_as_condition(sequence): ...@@ -76,3 +78,12 @@ def listcomp_as_condition(sequence):
if [1 for c in sequence if c in '+-*/<=>!%&|([^~,']: if [1 for c in sequence if c in '+-*/<=>!%&|([^~,']:
return True return True
return False return False
@cython.test_fail_if_path_exists("//SimpleCallNode//ComprehensionNode")
@cython.test_assert_path_exists("//ComprehensionNode")
def sorted_listcomp(sequence):
"""
>>> sorted_listcomp([3,2,4])
[3, 4, 5]
"""
return sorted([ n+1 for n in sequence ])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment