Commit 1230a868 authored by Stefan Behnel's avatar Stefan Behnel

generalise genexpr inlining to generators with more than one yield expression

parent 34428ea0
......@@ -100,6 +100,9 @@ class _YieldNodeCollector(Visitor.TreeVisitor):
def visit_LambdaNode(self, node):
pass
def visit_FuncDefNode(self, node):
pass
def _find_single_yield_expression(node):
collector = _YieldNodeCollector()
......@@ -113,6 +116,15 @@ def _find_single_yield_expression(node):
return None, None
def _find_yield_expressions(node):
collector = _YieldNodeCollector()
collector.visitchildren(node)
return [
(yield_node.arg, collector.yield_stat_nodes[yield_node])
for yield_node in collector.yield_nodes
]
class IterationTransform(Visitor.EnvTransform):
"""Transform some common for-in loop patterns into efficient C loops:
......@@ -1606,20 +1618,20 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
gen_expr_node = arg
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None:
yield_expressions = _find_yield_expressions(loop_node)
if not yield_expressions:
return node
list_node = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='sorted',
comprehension_type=Builtin.list_type)
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=list_node.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=list_node.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
elif arg.is_sequence_constructor:
# sorted([a, b, c]) or sorted((a, b, c)). The result is always a list,
......@@ -1797,8 +1809,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None:
yield_expressions = _find_yield_expressions(loop_node)
if not yield_expressions:
return node
result_node = ExprNodes.InlinedGeneratorExpressionNode(
......@@ -1806,12 +1818,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
orig_func='set' if target_type is Builtin.set_type else 'list',
comprehension_type=target_type)
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=result_node.target)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=result_node.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
return result_node
def _handle_simple_function_dict(self, node, pos_args):
......@@ -1826,26 +1839,28 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None:
yield_expressions = _find_yield_expressions(loop_node)
if not yield_expressions:
return node
if not isinstance(yield_expression, ExprNodes.TupleNode):
return node
if len(yield_expression.args) != 2:
return node
for yield_expression, _ in yield_expressions:
if not isinstance(yield_expression, ExprNodes.TupleNode):
return node
if len(yield_expression.args) != 2:
return node
result_node = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='dict',
comprehension_type=Builtin.dict_type)
append_node = ExprNodes.DictComprehensionAppendNode(
yield_expression.pos,
key_expr = yield_expression.args[0],
value_expr = yield_expression.args[1],
target=result_node.target)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.DictComprehensionAppendNode(
yield_expression.pos,
key_expr=yield_expression.args[0],
value_expr=yield_expression.args[1],
target=result_node.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
return result_node
# specific handlers for general call nodes
......@@ -3101,18 +3116,20 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
gen_expr_node = args[1]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is not None:
yield_expressions = _find_yield_expressions(loop_node)
if yield_expressions:
inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='list',
comprehension_type=Builtin.list_type)
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=inlined_genexpr.target)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=inlined_genexpr.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
args[1] = inlined_genexpr
return self._substitute_method_call(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment