Commit 39570c69 authored by Stefan Behnel's avatar Stefan Behnel

implement sorted(genexp) as [listcomp].sort()

parent 134af5de
...@@ -1280,6 +1280,53 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1280,6 +1280,53 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node.pos, loop = loop_node, result_node = result_ref, gen_expr_node.pos, loop = loop_node, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all') expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
def _handle_simple_function_sorted(self, node, pos_args):
"""Transform sorted(genexpr) into [listcomp].sort(). CPython
just reads the iterable into a list and calls .sort() on it.
Expanding the iterable in a listcomp is still faster.
"""
if len(pos_args) != 1:
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None:
return node
result_node = UtilNodes.ResultRefNode(
pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
target = ExprNodes.ListNode(node.pos, args = [])
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, expr = yield_expression,
target = ExprNodes.CloneNode(target))
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
listcomp_node = ExprNodes.ComprehensionNode(
gen_expr_node.pos, loop = loop_node, target = target,
append = append_node, type = Builtin.list_type,
expr_scope = gen_expr_node.expr_scope,
has_local_scope = True)
listcomp_assign_node = Nodes.SingleAssignmentNode(
node.pos, lhs = result_node, rhs = listcomp_node, first = True)
sort_method = ExprNodes.AttributeNode(
node.pos, obj = result_node, attribute = EncodedString('sort'),
# entry ? type ?
needs_none_check = False)
sort_node = Nodes.ExprStatNode(
node.pos, expr = ExprNodes.SimpleCallNode(
node.pos, function = sort_method, args = []))
sort_node.analyse_declarations(self.current_env())
return UtilNodes.TempResultFromStatNode(
result_node,
Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
def _handle_simple_function_sum(self, node, pos_args): def _handle_simple_function_sum(self, node, pos_args):
"""Transform sum(genexpr) into an equivalent inlined aggregation loop. """Transform sum(genexpr) into an equivalent inlined aggregation loop.
""" """
......
...@@ -120,9 +120,10 @@ class ResultRefNode(AtomicExprNode): ...@@ -120,9 +120,10 @@ class ResultRefNode(AtomicExprNode):
subexprs = [] subexprs = []
lhs_of_first_assignment = False lhs_of_first_assignment = False
def __init__(self, expression=None, pos=None, type=None): def __init__(self, expression=None, pos=None, type=None, may_hold_none=True):
self.expression = expression self.expression = expression
self.pos = None self.pos = None
self.may_hold_none = may_hold_none
if expression is not None: if expression is not None:
self.pos = expression.pos self.pos = expression.pos
if hasattr(expression, "type"): if hasattr(expression, "type"):
...@@ -141,6 +142,11 @@ class ResultRefNode(AtomicExprNode): ...@@ -141,6 +142,11 @@ class ResultRefNode(AtomicExprNode):
if self.expression is not None: if self.expression is not None:
return self.expression.infer_type(env) return self.expression.infer_type(env)
def may_be_none(self):
if not self.type.is_pyobject:
return False
return self.may_hold_none
def _DISABLED_may_be_none(self): def _DISABLED_may_be_none(self):
# not sure if this is safe - the expression may not be the # not sure if this is safe - the expression may not be the
# only value that gets assigned # only value that gets assigned
......
...@@ -62,6 +62,8 @@ VER_DEP_MODULES = { ...@@ -62,6 +62,8 @@ VER_DEP_MODULES = {
]), ]),
(2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258' (2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258'
]), ]),
(2,3) : (operator.le, lambda x: x in ['run.builtin_sorted'
]),
(2,6) : (operator.lt, lambda x: x in ['run.print_function', (2,6) : (operator.lt, lambda x: x in ['run.print_function',
'run.cython3', 'run.cython3',
]), ]),
......
cimport cython
@cython.test_fail_if_path_exists("//GeneratorExpressionNode",
"//ComprehensionNode//NoneCheckNode")
@cython.test_assert_path_exists("//ComprehensionNode")
def sorted_genexp():
"""
>>> sorted_genexp()
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
"""
return sorted(i*i for i in range(10,0,-1))
@cython.test_assert_path_exists("//SimpleCallNode//SimpleCallNode")
def sorted_list():
"""
>>> sorted_list()
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
"""
return sorted(list(range(10,0,-1)))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment