Commit 1230a868 authored by Stefan Behnel's avatar Stefan Behnel

generalise genexpr inlining to generators with more than one yield expression

parent 34428ea0
...@@ -100,6 +100,9 @@ class _YieldNodeCollector(Visitor.TreeVisitor): ...@@ -100,6 +100,9 @@ class _YieldNodeCollector(Visitor.TreeVisitor):
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
pass pass
def visit_FuncDefNode(self, node):
pass
def _find_single_yield_expression(node): def _find_single_yield_expression(node):
collector = _YieldNodeCollector() collector = _YieldNodeCollector()
...@@ -113,6 +116,15 @@ def _find_single_yield_expression(node): ...@@ -113,6 +116,15 @@ def _find_single_yield_expression(node):
return None, None return None, None
def _find_yield_expressions(node):
collector = _YieldNodeCollector()
collector.visitchildren(node)
return [
(yield_node.arg, collector.yield_stat_nodes[yield_node])
for yield_node in collector.yield_nodes
]
class IterationTransform(Visitor.EnvTransform): class IterationTransform(Visitor.EnvTransform):
"""Transform some common for-in loop patterns into efficient C loops: """Transform some common for-in loop patterns into efficient C loops:
...@@ -1606,19 +1618,19 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1606,19 +1618,19 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
elif isinstance(arg, ExprNodes.GeneratorExpressionNode): elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
gen_expr_node = arg gen_expr_node = arg
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) yield_expressions = _find_yield_expressions(loop_node)
if yield_expression is None: if not yield_expressions:
return node return node
list_node = ExprNodes.InlinedGeneratorExpressionNode( list_node = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='sorted', node.pos, gen_expr_node, orig_func='sorted',
comprehension_type=Builtin.list_type) comprehension_type=Builtin.list_type)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
expr=yield_expression, expr=yield_expression,
target=list_node.target) target=list_node.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
elif arg.is_sequence_constructor: elif arg.is_sequence_constructor:
...@@ -1797,8 +1809,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1797,8 +1809,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) yield_expressions = _find_yield_expressions(loop_node)
if yield_expression is None: if not yield_expressions:
return node return node
result_node = ExprNodes.InlinedGeneratorExpressionNode( result_node = ExprNodes.InlinedGeneratorExpressionNode(
...@@ -1806,12 +1818,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1806,12 +1818,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
orig_func='set' if target_type is Builtin.set_type else 'list', orig_func='set' if target_type is Builtin.set_type else 'list',
comprehension_type=target_type) comprehension_type=target_type)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
expr=yield_expression, expr=yield_expression,
target=result_node.target) target=result_node.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
return result_node return result_node
def _handle_simple_function_dict(self, node, pos_args): def _handle_simple_function_dict(self, node, pos_args):
...@@ -1826,10 +1839,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1826,10 +1839,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) yield_expressions = _find_yield_expressions(loop_node)
if yield_expression is None: if not yield_expressions:
return node return node
for yield_expression, _ in yield_expressions:
if not isinstance(yield_expression, ExprNodes.TupleNode): if not isinstance(yield_expression, ExprNodes.TupleNode):
return node return node
if len(yield_expression.args) != 2: if len(yield_expression.args) != 2:
...@@ -1839,13 +1853,14 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1839,13 +1853,14 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
node.pos, gen_expr_node, orig_func='dict', node.pos, gen_expr_node, orig_func='dict',
comprehension_type=Builtin.dict_type) comprehension_type=Builtin.dict_type)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.DictComprehensionAppendNode( append_node = ExprNodes.DictComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
key_expr = yield_expression.args[0], key_expr=yield_expression.args[0],
value_expr = yield_expression.args[1], value_expr=yield_expression.args[1],
target=result_node.target) target=result_node.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
return result_node return result_node
# specific handlers for general call nodes # specific handlers for general call nodes
...@@ -3101,18 +3116,20 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3101,18 +3116,20 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
gen_expr_node = args[1] gen_expr_node = args[1]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) yield_expressions = _find_yield_expressions(loop_node)
if yield_expression is not None: if yield_expressions:
inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode( inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='list', node.pos, gen_expr_node, orig_func='list',
comprehension_type=Builtin.list_type) comprehension_type=Builtin.list_type)
for yield_expression, yield_stat_node in yield_expressions:
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
expr=yield_expression, expr=yield_expression,
target=inlined_genexpr.target) target=inlined_genexpr.target)
Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
args[1] = inlined_genexpr args[1] = inlined_genexpr
return self._substitute_method_call( return self._substitute_method_call(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment