Commit b9bfa498 authored by Stefan Behnel's avatar Stefan Behnel

fix compiler crash for generator expressions with a constant False condition

--HG--
extra : transplant_source : C%C7%2Ak%C4%89%DA%C1f%85%86%0D%9E%7F_%B4%17%D2t%40
parent b640de61
......@@ -3117,11 +3117,30 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
break
else:
assert condition_result == False
# prevent killing generators, but simplify them as much as possible
yield_expr = self._find_genexpr_yield(if_clause.body)
if yield_expr is not None:
if_clause.condition = ExprNodes.BoolNode(if_clause.condition.pos, value=False)
yield_expr.arg = ExprNodes.NoneNode(yield_expr.arg.pos)
if_clauses.append(if_clause)
else:
# False clauses outside of generators can safely be deleted
pass
if not if_clauses:
return node.else_clause
node.if_clauses = if_clauses
return node
def _find_genexpr_yield(self, node):
body_node_types = (Nodes.ForInStatNode, Nodes.IfStatNode)
while isinstance(node, body_node_types):
node = node.body
if isinstance(node, Nodes.ExprStatNode):
node = node.expr
if isinstance(node, ExprNodes.YieldExprNode):
return node
return None
# in the future, other nodes can have their own handler method here
# that can replace them with a constant result node
......
......@@ -21,6 +21,16 @@ def genexpr_if():
assert x == 'abc' # don't leak
return result
def genexpr_if_false():
"""
>>> genexpr_if_false()
[]
"""
x = 'abc'
result = list( x*2 for x in range(5) if False )
assert x == 'abc' # don't leak
return result
def genexpr_with_lambda():
"""
>>> genexpr_with_lambda()
......
# mode: run
# cython: language_level=3
"""
Adapted from CPython's test_grammar.py
"""
def genexpr_simple():
"""
>>> sum([ x**2 for x in range(10) ])
285
>>> sum(genexpr_simple())
285
"""
return (x**2 for x in range(10))
def genexpr_conditional():
"""
>>> sum([ x*x for x in range(10) if x%2 ])
165
>>> sum(genexpr_conditional())
165
"""
return (x*x for x in range(10) if x%2)
def genexpr_nested2():
"""
>>> sum([x for x in range(10)])
45
>>> sum(genexpr_nested2())
45
"""
return (x for x in (y for y in range(10)))
def genexpr_nested3():
"""
>>> sum([x for x in range(10)])
45
>>> sum(genexpr_nested3())
45
"""
return (x for x in (y for y in (z for z in range(10))))
def genexpr_nested_listcomp():
"""
>>> sum([x for x in range(10)])
45
>>> sum(genexpr_nested_listcomp())
45
"""
return (x for x in [y for y in (z for z in range(10))])
def genexpr_nested_conditional():
"""
>>> sum([ x for x in [y for y in [z for z in range(10) if True]] if True ])
45
>>> sum(genexpr_nested_conditional())
45
"""
return (x for x in (y for y in (z for z in range(10) if True)) if True)
def genexpr_nested2_conditional_empty():
"""
>>> sum(genexpr_nested2_conditional_empty())
0
"""
return (y for y in (z for z in range(10) if True) if False)
def genexpr_nested3_conditional_empty():
"""
>>> sum(genexpr_nested3_conditional_empty())
0
"""
return (x for x in (y for y in (z for z in range(10) if True) if False) if True)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment