Commit d75c7b77 authored by Stefan Behnel's avatar Stefan Behnel

only look at yield statements and not arbitrary yield expressions when...

only look at yield statements and not arbitrary yield expressions when inlining generator expressions
parent 51e31a55
...@@ -105,24 +105,24 @@ class _YieldNodeCollector(Visitor.TreeVisitor): ...@@ -105,24 +105,24 @@ class _YieldNodeCollector(Visitor.TreeVisitor):
def _find_single_yield_expression(node): def _find_single_yield_expression(node):
collector = _YieldNodeCollector() yield_statements = _find_yield_statements(node)
collector.visitchildren(node) if len(yield_statements) != 1:
if len(collector.yield_nodes) != 1:
return None, None
yield_node = collector.yield_nodes[0]
try:
return yield_node.arg, collector.yield_stat_nodes[yield_node]
except KeyError:
return None, None return None, None
return yield_statements[0]
def _find_yield_expressions(node): def _find_yield_statements(node):
collector = _YieldNodeCollector() collector = _YieldNodeCollector()
collector.visitchildren(node) collector.visitchildren(node)
return [ try:
(yield_node.arg, collector.yield_stat_nodes[yield_node]) yield_statements = [
for yield_node in collector.yield_nodes (yield_node.arg, collector.yield_stat_nodes[yield_node])
] for yield_node in collector.yield_nodes
]
except KeyError:
# found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield')
yield_statements = []
return yield_statements
class IterationTransform(Visitor.EnvTransform): class IterationTransform(Visitor.EnvTransform):
...@@ -1618,15 +1618,15 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1618,15 +1618,15 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
elif isinstance(arg, ExprNodes.GeneratorExpressionNode): elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
gen_expr_node = arg gen_expr_node = arg
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expressions = _find_yield_expressions(loop_node) yield_statements = _find_yield_statements(loop_node)
if not yield_expressions: if not yield_statements:
return node return node
list_node = ExprNodes.InlinedGeneratorExpressionNode( list_node = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='sorted', node.pos, gen_expr_node, orig_func='sorted',
comprehension_type=Builtin.list_type) comprehension_type=Builtin.list_type)
for yield_expression, yield_stat_node in yield_expressions: for yield_expression, yield_stat_node in yield_statements:
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
expr=yield_expression, expr=yield_expression,
...@@ -1809,8 +1809,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1809,8 +1809,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expressions = _find_yield_expressions(loop_node) yield_statements = _find_yield_statements(loop_node)
if not yield_expressions: if not yield_statements:
return node return node
result_node = ExprNodes.InlinedGeneratorExpressionNode( result_node = ExprNodes.InlinedGeneratorExpressionNode(
...@@ -1818,7 +1818,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1818,7 +1818,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
orig_func='set' if target_type is Builtin.set_type else 'list', orig_func='set' if target_type is Builtin.set_type else 'list',
comprehension_type=target_type) comprehension_type=target_type)
for yield_expression, yield_stat_node in yield_expressions: for yield_expression, yield_stat_node in yield_statements:
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
expr=yield_expression, expr=yield_expression,
...@@ -1839,11 +1839,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1839,11 +1839,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expressions = _find_yield_expressions(loop_node) yield_statements = _find_yield_statements(loop_node)
if not yield_expressions: if not yield_statements:
return node return node
for yield_expression, _ in yield_expressions: for yield_expression, _ in yield_statements:
if not isinstance(yield_expression, ExprNodes.TupleNode): if not isinstance(yield_expression, ExprNodes.TupleNode):
return node return node
if len(yield_expression.args) != 2: if len(yield_expression.args) != 2:
...@@ -1853,7 +1853,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1853,7 +1853,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
node.pos, gen_expr_node, orig_func='dict', node.pos, gen_expr_node, orig_func='dict',
comprehension_type=Builtin.dict_type) comprehension_type=Builtin.dict_type)
for yield_expression, yield_stat_node in yield_expressions: for yield_expression, yield_stat_node in yield_statements:
append_node = ExprNodes.DictComprehensionAppendNode( append_node = ExprNodes.DictComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
key_expr=yield_expression.args[0], key_expr=yield_expression.args[0],
...@@ -3116,13 +3116,13 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3116,13 +3116,13 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
gen_expr_node = args[1] gen_expr_node = args[1]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expressions = _find_yield_expressions(loop_node) yield_statements = _find_yield_statements(loop_node)
if yield_expressions: if yield_statements:
inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode( inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='list', node.pos, gen_expr_node, orig_func='list',
comprehension_type=Builtin.list_type) comprehension_type=Builtin.list_type)
for yield_expression, yield_stat_node in yield_expressions: for yield_expression, yield_stat_node in yield_statements:
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
expr=yield_expression, expr=yield_expression,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment