Commit c400bbd5 authored by Stefan Behnel's avatar Stefan Behnel

do not let set/dict comprehensions leak in Py2, only list comprehensions

parent be1cfc62
...@@ -3940,8 +3940,10 @@ class ComprehensionNode(ScopedExprNode): ...@@ -3940,8 +3940,10 @@ class ComprehensionNode(ScopedExprNode):
subexprs = ["target"] subexprs = ["target"]
child_attrs = ["loop", "append"] child_attrs = ["loop", "append"]
# different behaviour in Py2 and Py3: leak loop variables or not? # leak loop variables or not? non-leaking Py3 behaviour is
has_local_scope = False # Py2 behaviour as default # default, except for list comprehensions where the behaviour
# differs in Py2 and Py3 (see Parsing.py)
has_local_scope = True
def infer_type(self, env): def infer_type(self, env):
return self.target.infer_type(env) return self.target.infer_type(env)
......
...@@ -780,7 +780,9 @@ def p_list_maker(s): ...@@ -780,7 +780,9 @@ def p_list_maker(s):
loop = p_comp_for(s, append) loop = p_comp_for(s, append)
s.expect(']') s.expect(']')
return ExprNodes.ComprehensionNode( return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target) pos, loop=loop, append=append, target=target,
# list comprehensions leak their loop variable in Py2
has_local_scope = s.context.language_level > 2)
else: else:
if s.sy == ',': if s.sy == ',':
s.next() s.next()
......
# cython: language_level=3 # cython: language_level=3
try:
sorted
except NameError:
def sorted(seq):
seq = list(seq)
seq.sort()
return seq
def print_function(*args): def print_function(*args):
""" """
>>> print_function(1,2,3) >>> print_function(1,2,3)
...@@ -17,3 +25,33 @@ def unicode_literals(): ...@@ -17,3 +25,33 @@ def unicode_literals():
""" """
print(isinstance(ustring, unicode) or type(ustring)) print(isinstance(ustring, unicode) or type(ustring))
return ustring return ustring
def list_comp():
"""
>>> list_comp()
[0, 4, 8]
"""
x = 'abc'
result = [x*2 for x in range(5) if x % 2 == 0]
assert x == 'abc' # don't leak in Py3 code
return result
def set_comp():
"""
>>> sorted(set_comp())
[0, 4, 8]
"""
x = 'abc'
result = {x*2 for x in range(5) if x % 2 == 0}
assert x == 'abc' # don't leak
return result
def dict_comp():
"""
>>> sorted(dict_comp().items())
[(0, 0), (2, 4), (4, 8)]
"""
x = 'abc'
result = {x:x*2 for x in range(5) if x % 2 == 0}
assert x == 'abc' # don't leak
return result
...@@ -12,7 +12,7 @@ def dictcomp(): ...@@ -12,7 +12,7 @@ def dictcomp():
result = { x+2:x*2 result = { x+2:x*2
for x in range(5) for x in range(5)
if x % 2 == 0 } if x % 2 == 0 }
assert x != 'abc' assert x == 'abc' # do not leak!
return result return result
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
......
...@@ -13,9 +13,12 @@ def setcomp(): ...@@ -13,9 +13,12 @@ def setcomp():
>>> sorted(setcomp()) >>> sorted(setcomp())
[0, 4, 8] [0, 4, 8]
""" """
return { x*2 x = 'abc'
result = { x*2
for x in range(5) for x in range(5)
if x % 2 == 0 } if x % 2 == 0 }
assert x == 'abc' # do not leak
return result
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
"//GeneratorExpressionNode", "//GeneratorExpressionNode",
...@@ -30,9 +33,12 @@ def genexp_set(): ...@@ -30,9 +33,12 @@ def genexp_set():
>>> sorted(genexp_set()) >>> sorted(genexp_set())
[0, 4, 8] [0, 4, 8]
""" """
return set( x*2 x = 'abc'
for x in range(5) result = set( x*2
if x % 2 == 0 ) for x in range(5)
if x % 2 == 0 )
assert x == 'abc' # do not leak
return result
cdef class A: cdef class A:
def __repr__(self): return u"A" def __repr__(self): return u"A"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment