Commit c1e8c914 authored by Stefan Behnel's avatar Stefan Behnel

fix scoping rules for comprehensions and inlined generator expressions by...

fix scoping rules for comprehensions and inlined generator expressions by injecting a separate scope instance
parent 140a9be0
...@@ -3898,7 +3898,29 @@ class ListNode(SequenceNode): ...@@ -3898,7 +3898,29 @@ class ListNode(SequenceNode):
# generate_evaluation_code which will do that. # generate_evaluation_code which will do that.
class ComprehensionNode(ExprNode): class ScopedExprNode(ExprNode):
# Abstract base class for ExprNodes that have their own local
# scope, such as generator expressions.
#
# expr_scope Scope the inner scope of the expression
subexprs = []
expr_scope = None
def analyse_types(self, env):
# nothing to do here, the children will be analysed separately
pass
def analyse_expressions(self, env):
# nothing to do here, the children will be analysed separately
pass
def analyse_scoped_expressions(self, env):
# this is called with the expr_scope as env
pass
class ComprehensionNode(ScopedExprNode):
subexprs = ["target"] subexprs = ["target"]
child_attrs = ["loop", "append"] child_attrs = ["loop", "append"]
...@@ -3907,11 +3929,14 @@ class ComprehensionNode(ExprNode): ...@@ -3907,11 +3929,14 @@ class ComprehensionNode(ExprNode):
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop self.append.target = self # this is used in the PyList_Append of the inner loop
self.loop.analyse_declarations(env) self.expr_scope = Symtab.GeneratorExpressionScope(env)
self.loop.analyse_declarations(self.expr_scope)
def analyse_types(self, env): def analyse_types(self, env):
self.target.analyse_expressions(env) self.target.analyse_expressions(env)
self.type = self.target.type self.type = self.target.type
def analyse_scoped_expressions(self, env):
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
def may_be_none(self): def may_be_none(self):
...@@ -3980,21 +4005,25 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): ...@@ -3980,21 +4005,25 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
code.error_goto_if(self.result(), self.pos))) code.error_goto_if(self.result(), self.pos)))
class GeneratorExpressionNode(ExprNode): class GeneratorExpressionNode(ScopedExprNode):
# A generator expression, e.g. (i for i in range(10)) # A generator expression, e.g. (i for i in range(10))
# #
# Result is a generator. # Result is a generator.
# #
# loop ForStatNode the for-loop, containing a YieldExprNode # loop ForStatNode the for-loop, containing a YieldExprNode
subexprs = []
child_attrs = ["loop"] child_attrs = ["loop"]
type = py_object_type type = py_object_type
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.loop.analyse_declarations(env) self.expr_scope = Symtab.GeneratorExpressionScope(env)
self.loop.analyse_declarations(self.expr_scope)
def analyse_types(self, env): def analyse_types(self, env):
self.is_temp = True
def analyse_scoped_expressions(self, env):
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
def may_be_none(self): def may_be_none(self):
...@@ -4004,6 +4033,24 @@ class GeneratorExpressionNode(ExprNode): ...@@ -4004,6 +4033,24 @@ class GeneratorExpressionNode(ExprNode):
self.loop.annotate(code) self.loop.annotate(code)
class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
# An inlined generator expression for which the result is
# calculated inside of the loop.
#
# loop ForStatNode the for-loop, not containing any YieldExprNodes
# result_node ResultRefNode the reference to the result value temp
child_attrs = ["loop"]
def analyse_types(self, env):
self.type = self.result_node.type
self.is_temp = True
def generate_result_code(self, code):
self.result_node.result_code = self.result()
self.loop.generate_execution_code(code)
class SetNode(ExprNode): class SetNode(ExprNode):
# Set constructor. # Set constructor.
......
...@@ -91,6 +91,8 @@ frame_cname = pyrex_prefix + "frame" ...@@ -91,6 +91,8 @@ frame_cname = pyrex_prefix + "frame"
frame_code_cname = pyrex_prefix + "frame_code" frame_code_cname = pyrex_prefix + "frame_code"
binding_cfunc = pyrex_prefix + "binding_PyCFunctionType" binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
genexpr_id_ref = 'genexpr'
line_c_macro = "__LINE__" line_c_macro = "__LINE__"
file_c_macro = "__FILE__" file_c_macro = "__FILE__"
......
...@@ -1130,7 +1130,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1130,7 +1130,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return node return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node return node
loop_node = pos_args[0].loop gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
collector = self.YieldNodeCollector() collector = self.YieldNodeCollector()
collector.visitchildren(loop_node) collector.visitchildren(loop_node)
...@@ -1140,14 +1141,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1140,14 +1141,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
yield_expression = yield_node.arg yield_expression = yield_node.arg
del collector del collector
result_ref = UtilNodes.ResultRefNode(pos=node.pos)
result_ref.type = PyrexTypes.c_bint_type
if is_any: if is_any:
condition = yield_expression condition = yield_expression
else: else:
condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression) condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
test_node = Nodes.IfStatNode( test_node = Nodes.IfStatNode(
yield_node.pos, yield_node.pos,
else_clause = None, else_clause = None,
...@@ -1182,7 +1181,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1182,7 +1181,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node) Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
return UtilNodes.TempResultFromStatNode(result_ref, loop_node) return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope)
# specific handlers for general call nodes # specific handlers for general call nodes
......
...@@ -1030,9 +1030,16 @@ property NAME: ...@@ -1030,9 +1030,16 @@ property NAME:
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.env_stack[-1])
return node return node
def visit_GeneratorExpressionNode(self, node): def visit_ScopedExprNode(self, node):
self.visitchildren(node)
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.env_stack[-1])
if self.seen_vars_stack:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
else:
self.seen_vars_stack.append(set())
self.env_stack.append(node.expr_scope)
self.visitchildren(node)
self.env_stack.pop()
self.seen_vars_stack.pop()
return node return node
def visit_TempResultFromStatNode(self, node): def visit_TempResultFromStatNode(self, node):
...@@ -1133,6 +1140,12 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1133,6 +1140,12 @@ class AnalyseExpressionsTransform(CythonTransform):
node.body.analyse_expressions(node.local_scope) node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ScopedExprNode(self, node):
node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node)
return node
class AlignFunctionDefinitions(CythonTransform): class AlignFunctionDefinitions(CythonTransform):
""" """
......
...@@ -269,7 +269,8 @@ class Scope(object): ...@@ -269,7 +269,8 @@ class Scope(object):
self.lambda_defs = [] self.lambda_defs = []
self.control_flow = ControlFlow.LinearControlFlow() self.control_flow = ControlFlow.LinearControlFlow()
self.return_type = None self.return_type = None
self.id_counters = {}
def start_branching(self, pos): def start_branching(self, pos):
self.control_flow = self.control_flow.start_branch(pos) self.control_flow = self.control_flow.start_branch(pos)
...@@ -297,7 +298,19 @@ class Scope(object): ...@@ -297,7 +298,19 @@ class Scope(object):
prefix = "%s%s_" % (Naming.pyrex_prefix, name) prefix = "%s%s_" % (Naming.pyrex_prefix, name)
return self.mangle(prefix) return self.mangle(prefix)
#return self.parent_scope.mangle(prefix, self.name) #return self.parent_scope.mangle(prefix, self.name)
def next_id(self, name=None):
# Return a cname fragment that is unique for this scope.
try:
count = self.id_counters[name] + 1
except KeyError:
count = 0
self.id_counters[name] = count
if name:
return '%s%d' % (name, count)
else:
return '%d' % count
def global_scope(self): def global_scope(self):
# Return the module-level scope containing this scope. # Return the module-level scope containing this scope.
return self.outer_scope.global_scope() return self.outer_scope.global_scope()
...@@ -1244,7 +1257,30 @@ class LocalScope(Scope): ...@@ -1244,7 +1257,30 @@ class LocalScope(Scope):
elif entry.in_closure: elif entry.in_closure:
entry.original_cname = entry.cname entry.original_cname = entry.cname
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname) entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class GeneratorExpressionScope(LocalScope):
"""Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all
we really need is a scope that holds the loop variable(s).
"""
def __init__(self, outer_scope):
name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref)
LocalScope.__init__(self, name, outer_scope)
self.directives = outer_scope.directives
self.genexp_prefix = "%s%s" % (Naming.pyrex_prefix, name)
def mangle(self, prefix, name):
return '%s%s' % (self.genexp_prefix, LocalScope.mangle(self, prefix, name))
def declare_var(self, name, type, pos,
cname = None, visibility = 'private', is_cdef = 0):
cname = '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(Naming.var_prefix, name))
entry = self.outer_scope.declare_var(None, type, pos, cname, visibility, is_cdef)
self.entries[name] = entry
return entry
class ClosureScope(LocalScope): class ClosureScope(LocalScope):
is_closure_scope = True is_closure_scope = True
......
...@@ -119,7 +119,7 @@ class ResultRefNode(AtomicExprNode): ...@@ -119,7 +119,7 @@ class ResultRefNode(AtomicExprNode):
subexprs = [] subexprs = []
lhs_of_first_assignment = False lhs_of_first_assignment = False
def __init__(self, expression=None, pos=None): def __init__(self, expression=None, pos=None, type=None):
self.expression = expression self.expression = expression
self.pos = None self.pos = None
if expression is not None: if expression is not None:
...@@ -128,6 +128,8 @@ class ResultRefNode(AtomicExprNode): ...@@ -128,6 +128,8 @@ class ResultRefNode(AtomicExprNode):
self.type = expression.type self.type = expression.type
if pos is not None: if pos is not None:
self.pos = pos self.pos = pos
if type is not None:
self.type = type
assert self.pos is not None assert self.pos is not None
def analyse_types(self, env): def analyse_types(self, env):
......
...@@ -53,10 +53,10 @@ def all_item(x): ...@@ -53,10 +53,10 @@ def all_item(x):
""" """
return all(x) return all(x)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode")
def all_in_simple_gen(seq): def all_in_simple_gen(seq):
""" """
>>> all_in_simple_gen([1,1,1]) >>> all_in_simple_gen([1,1,1])
...@@ -82,10 +82,42 @@ def all_in_simple_gen(seq): ...@@ -82,10 +82,42 @@ def all_in_simple_gen(seq):
""" """
return all(x for x in seq) return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode") def all_in_simple_gen_scope(seq):
"""
>>> all_in_simple_gen_scope([1,1,1])
True
>>> all_in_simple_gen_scope([1,1,0])
False
>>> all_in_simple_gen_scope([1,0,1])
False
>>> all_in_simple_gen_scope(VerboseGetItem([1,1,1,1,1]))
0
1
2
3
4
5
True
>>> all_in_simple_gen_scope(VerboseGetItem([1,1,0,1,1]))
0
1
2
False
"""
x = 'abc'
result = all(x for x in seq)
assert x == 'abc'
return result
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
def all_in_conditional_gen(seq): def all_in_conditional_gen(seq):
""" """
>>> all_in_conditional_gen([3,6,9]) >>> all_in_conditional_gen([3,6,9])
...@@ -133,10 +165,10 @@ def all_lower_case_characters(unicode ustring): ...@@ -133,10 +165,10 @@ def all_lower_case_characters(unicode ustring):
""" """
return all(uchar.islower() for uchar in ustring) return all(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode")
def all_in_typed_gen(seq): def all_in_typed_gen(seq):
""" """
>>> all_in_typed_gen([1,1,1]) >>> all_in_typed_gen([1,1,1])
...@@ -165,10 +197,10 @@ def all_in_typed_gen(seq): ...@@ -165,10 +197,10 @@ def all_in_typed_gen(seq):
cdef int x cdef int x
return all(x for x in seq) return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode")
def all_in_nested_gen(seq): def all_in_nested_gen(seq):
""" """
>>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L) >>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L)
......
...@@ -51,10 +51,10 @@ def any_item(x): ...@@ -51,10 +51,10 @@ def any_item(x):
""" """
return any(x) return any(x)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode")
def any_in_simple_gen(seq): def any_in_simple_gen(seq):
""" """
>>> any_in_simple_gen([0,1,0]) >>> any_in_simple_gen([0,1,0])
...@@ -78,10 +78,40 @@ def any_in_simple_gen(seq): ...@@ -78,10 +78,40 @@ def any_in_simple_gen(seq):
""" """
return any(x for x in seq) return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode") def any_in_simple_gen_scope(seq):
"""
>>> any_in_simple_gen_scope([0,1,0])
True
>>> any_in_simple_gen_scope([0,0,0])
False
>>> any_in_simple_gen_scope(VerboseGetItem([0,0,1,0,0]))
0
1
2
True
>>> any_in_simple_gen_scope(VerboseGetItem([0,0,0,0,0]))
0
1
2
3
4
5
False
"""
x = 'abc'
result = any(x for x in seq)
assert x == 'abc'
return result
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
def any_in_conditional_gen(seq): def any_in_conditional_gen(seq):
""" """
>>> any_in_conditional_gen([3,6,9]) >>> any_in_conditional_gen([3,6,9])
...@@ -127,10 +157,10 @@ def any_lower_case_characters(unicode ustring): ...@@ -127,10 +157,10 @@ def any_lower_case_characters(unicode ustring):
""" """
return any(uchar.islower() for uchar in ustring) return any(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode")
def any_in_typed_gen(seq): def any_in_typed_gen(seq):
""" """
>>> any_in_typed_gen([0,1,0]) >>> any_in_typed_gen([0,1,0])
...@@ -157,10 +187,10 @@ def any_in_typed_gen(seq): ...@@ -157,10 +187,10 @@ def any_in_typed_gen(seq):
cdef int x cdef int x
return any(x for x in seq) return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode") @cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode", @cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode", "//YieldExprNode")
"//GeneratorExpressionNode")
def any_in_nested_gen(seq): def any_in_nested_gen(seq):
""" """
>>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L) >>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L)
......
...@@ -6,7 +6,7 @@ __doc__ = u""" ...@@ -6,7 +6,7 @@ __doc__ = u"""
[('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)] [('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
>>> sorted(get_locals_items_listcomp(1,2,3, k=5)) >>> sorted(get_locals_items_listcomp(1,2,3, k=5))
[('args', (2, 3)), ('item', None), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)] [('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
""" """
def get_locals(x, *args, **kwds): def get_locals(x, *args, **kwds):
...@@ -20,7 +20,6 @@ def get_locals_items(x, *args, **kwds): ...@@ -20,7 +20,6 @@ def get_locals_items(x, *args, **kwds):
return locals().items() return locals().items()
def get_locals_items_listcomp(x, *args, **kwds): def get_locals_items_listcomp(x, *args, **kwds):
# FIXME: 'item' should *not* appear in locals() !
cdef int z = 5 cdef int z = 5
y = "hi" y = "hi"
return [ item for item in locals().items() ] return [ item for item in locals().items() ]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment