Commit 82afcb57 authored by Stefan Behnel's avatar Stefan Behnel

optimise sum([int_const for ...]) into an inlined sum(genexpr)

parent 8caa6c6a
...@@ -1339,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1339,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
""" """
if len(pos_args) not in (1,2): if len(pos_args) not in (1,2):
return node return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
ExprNodes.ComprehensionNode)):
return node return node
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None: if yield_expression is None:
return node return node
else: # ComprehensionNode
yield_stat_node = gen_expr_node.append
yield_expression = yield_stat_node.expr
try:
if not yield_expression.is_literal or not yield_expression.type.is_int:
return node
except AttributeError:
return node # in case we don't have a type yet
# special case: old Py2 backwards compatible "sum([int_const for ...])"
# can safely be unpacked into a genexpr
if len(pos_args) == 1: if len(pos_args) == 1:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
...@@ -1375,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1375,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return ExprNodes.InlinedGeneratorExpressionNode( return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = exec_code, result_node = result_ref, gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum') expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
has_local_scope = gen_expr_node.has_local_scope)
def _handle_simple_function_min(self, node, pos_args): def _handle_simple_function_min(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '<') return self._optimise_min_max(node, pos_args, '<')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment