Commit 7eddbcf4 authored by Stefan Behnel's avatar Stefan Behnel

optimise sorted([listcomp]) by sorting in-place

parent 70c5f657
......@@ -1387,23 +1387,25 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
def _handle_simple_function_sorted(self, node, pos_args):
"""Transform sorted(genexpr) into [listcomp].sort(). CPython
just reads the iterable into a list and calls .sort() on it.
Expanding the iterable in a listcomp is still faster.
"""Transform sorted(genexpr) and sorted([listcomp]) into
[listcomp].sort(). CPython just reads the iterable into a
list and calls .sort() on it. Expanding the iterable in a
listcomp is still faster and the result can be sorted in
place.
"""
if len(pos_args) != 1:
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
and pos_args[0].target.type is Builtin.list_type:
listcomp_node = pos_args[0]
loop_node = listcomp_node.loop
elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None:
return node
result_node = UtilNodes.ResultRefNode(
pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
target = ExprNodes.ListNode(node.pos, args = [])
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, expr = yield_expression,
......@@ -1416,6 +1418,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
append = append_node, type = Builtin.list_type,
expr_scope = gen_expr_node.expr_scope,
has_local_scope = True)
else:
return node
result_node = UtilNodes.ResultRefNode(
pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
listcomp_assign_node = Nodes.SingleAssignmentNode(
node.pos, lhs = result_node, rhs = listcomp_node, first = True)
......
cimport cython
def smoketest():
"""
>>> smoketest()
......@@ -76,3 +78,12 @@ def listcomp_as_condition(sequence):
if [1 for c in sequence if c in '+-*/<=>!%&|([^~,']:
return True
return False
@cython.test_fail_if_path_exists("//SimpleCallNode//ComprehensionNode")
@cython.test_assert_path_exists("//ComprehensionNode")
def sorted_listcomp(sequence):
"""
>>> sorted_listcomp([3,2,4])
[3, 4, 5]
"""
return sorted([ n+1 for n in sequence ])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment