Commit c1e8c914 authored by Stefan Behnel's avatar Stefan Behnel

fix scoping rules for comprehensions and inlined generator expressions by...

fix scoping rules for comprehensions and inlined generator expressions by injecting a separate scope instance
parent 140a9be0
......@@ -3898,7 +3898,29 @@ class ListNode(SequenceNode):
# generate_evaluation_code which will do that.
class ComprehensionNode(ExprNode):
class ScopedExprNode(ExprNode):
# Abstract base class for ExprNodes that have their own local
# scope, such as generator expressions.
#
# expr_scope Scope the inner scope of the expression
subexprs = []
expr_scope = None
def analyse_types(self, env):
# nothing to do here, the children will be analysed separately
pass
def analyse_expressions(self, env):
# nothing to do here, the children will be analysed separately
pass
def analyse_scoped_expressions(self, env):
# this is called with the expr_scope as env
pass
class ComprehensionNode(ScopedExprNode):
subexprs = ["target"]
child_attrs = ["loop", "append"]
......@@ -3907,11 +3929,14 @@ class ComprehensionNode(ExprNode):
def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop
self.loop.analyse_declarations(env)
self.expr_scope = Symtab.GeneratorExpressionScope(env)
self.loop.analyse_declarations(self.expr_scope)
def analyse_types(self, env):
self.target.analyse_expressions(env)
self.type = self.target.type
def analyse_scoped_expressions(self, env):
self.loop.analyse_expressions(env)
def may_be_none(self):
......@@ -3980,21 +4005,25 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
code.error_goto_if(self.result(), self.pos)))
class GeneratorExpressionNode(ExprNode):
class GeneratorExpressionNode(ScopedExprNode):
# A generator expression, e.g. (i for i in range(10))
#
# Result is a generator.
#
# loop ForStatNode the for-loop, containing a YieldExprNode
subexprs = []
child_attrs = ["loop"]
type = py_object_type
def analyse_declarations(self, env):
self.loop.analyse_declarations(env)
self.expr_scope = Symtab.GeneratorExpressionScope(env)
self.loop.analyse_declarations(self.expr_scope)
def analyse_types(self, env):
self.is_temp = True
def analyse_scoped_expressions(self, env):
self.loop.analyse_expressions(env)
def may_be_none(self):
......@@ -4004,6 +4033,24 @@ class GeneratorExpressionNode(ExprNode):
self.loop.annotate(code)
class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
# An inlined generator expression for which the result is
# calculated inside of the loop.
#
# loop ForStatNode the for-loop, not containing any YieldExprNodes
# result_node ResultRefNode the reference to the result value temp
child_attrs = ["loop"]
def analyse_types(self, env):
self.type = self.result_node.type
self.is_temp = True
def generate_result_code(self, code):
self.result_node.result_code = self.result()
self.loop.generate_execution_code(code)
class SetNode(ExprNode):
# Set constructor.
......
......@@ -91,6 +91,8 @@ frame_cname = pyrex_prefix + "frame"
frame_code_cname = pyrex_prefix + "frame_code"
binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
genexpr_id_ref = 'genexpr'
line_c_macro = "__LINE__"
file_c_macro = "__FILE__"
......
......@@ -1130,7 +1130,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
loop_node = pos_args[0].loop
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
collector = self.YieldNodeCollector()
collector.visitchildren(loop_node)
......@@ -1140,14 +1141,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
yield_expression = yield_node.arg
del collector
result_ref = UtilNodes.ResultRefNode(pos=node.pos)
result_ref.type = PyrexTypes.c_bint_type
if is_any:
condition = yield_expression
else:
condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
test_node = Nodes.IfStatNode(
yield_node.pos,
else_clause = None,
......@@ -1182,7 +1181,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
return UtilNodes.TempResultFromStatNode(result_ref, loop_node)
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope)
# specific handlers for general call nodes
......
......@@ -1030,9 +1030,16 @@ property NAME:
node.analyse_declarations(self.env_stack[-1])
return node
def visit_GeneratorExpressionNode(self, node):
self.visitchildren(node)
def visit_ScopedExprNode(self, node):
node.analyse_declarations(self.env_stack[-1])
if self.seen_vars_stack:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
else:
self.seen_vars_stack.append(set())
self.env_stack.append(node.expr_scope)
self.visitchildren(node)
self.env_stack.pop()
self.seen_vars_stack.pop()
return node
def visit_TempResultFromStatNode(self, node):
......@@ -1134,6 +1141,12 @@ class AnalyseExpressionsTransform(CythonTransform):
self.visitchildren(node)
return node
def visit_ScopedExprNode(self, node):
node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node)
return node
class AlignFunctionDefinitions(CythonTransform):
"""
This class takes the signatures from a .pxd file and applies them to
......
......@@ -269,6 +269,7 @@ class Scope(object):
self.lambda_defs = []
self.control_flow = ControlFlow.LinearControlFlow()
self.return_type = None
self.id_counters = {}
def start_branching(self, pos):
self.control_flow = self.control_flow.start_branch(pos)
......@@ -298,6 +299,18 @@ class Scope(object):
return self.mangle(prefix)
#return self.parent_scope.mangle(prefix, self.name)
def next_id(self, name=None):
# Return a cname fragment that is unique for this scope.
try:
count = self.id_counters[name] + 1
except KeyError:
count = 0
self.id_counters[name] = count
if name:
return '%s%d' % (name, count)
else:
return '%d' % count
def global_scope(self):
# Return the module-level scope containing this scope.
return self.outer_scope.global_scope()
......@@ -1245,6 +1258,29 @@ class LocalScope(Scope):
entry.original_cname = entry.cname
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class GeneratorExpressionScope(LocalScope):
"""Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all
we really need is a scope that holds the loop variable(s).
"""
def __init__(self, outer_scope):
name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref)
LocalScope.__init__(self, name, outer_scope)
self.directives = outer_scope.directives
self.genexp_prefix = "%s%s" % (Naming.pyrex_prefix, name)
def mangle(self, prefix, name):
return '%s%s' % (self.genexp_prefix, LocalScope.mangle(self, prefix, name))
def declare_var(self, name, type, pos,
cname = None, visibility = 'private', is_cdef = 0):
cname = '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(Naming.var_prefix, name))
entry = self.outer_scope.declare_var(None, type, pos, cname, visibility, is_cdef)
self.entries[name] = entry
return entry
class ClosureScope(LocalScope):
is_closure_scope = True
......
......@@ -119,7 +119,7 @@ class ResultRefNode(AtomicExprNode):
subexprs = []
lhs_of_first_assignment = False
def __init__(self, expression=None, pos=None):
def __init__(self, expression=None, pos=None, type=None):
self.expression = expression
self.pos = None
if expression is not None:
......@@ -128,6 +128,8 @@ class ResultRefNode(AtomicExprNode):
self.type = expression.type
if pos is not None:
self.pos = pos
if type is not None:
self.type = type
assert self.pos is not None
def analyse_types(self, env):
......
......@@ -53,10 +53,10 @@ def all_item(x):
"""
return all(x)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def all_in_simple_gen(seq):
"""
>>> all_in_simple_gen([1,1,1])
......@@ -82,10 +82,42 @@ def all_in_simple_gen(seq):
"""
return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def all_in_simple_gen_scope(seq):
"""
>>> all_in_simple_gen_scope([1,1,1])
True
>>> all_in_simple_gen_scope([1,1,0])
False
>>> all_in_simple_gen_scope([1,0,1])
False
>>> all_in_simple_gen_scope(VerboseGetItem([1,1,1,1,1]))
0
1
2
3
4
5
True
>>> all_in_simple_gen_scope(VerboseGetItem([1,1,0,1,1]))
0
1
2
False
"""
x = 'abc'
result = all(x for x in seq)
assert x == 'abc'
return result
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
def all_in_conditional_gen(seq):
"""
>>> all_in_conditional_gen([3,6,9])
......@@ -133,10 +165,10 @@ def all_lower_case_characters(unicode ustring):
"""
return all(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def all_in_typed_gen(seq):
"""
>>> all_in_typed_gen([1,1,1])
......@@ -165,10 +197,10 @@ def all_in_typed_gen(seq):
cdef int x
return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def all_in_nested_gen(seq):
"""
>>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L)
......
......@@ -51,10 +51,10 @@ def any_item(x):
"""
return any(x)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def any_in_simple_gen(seq):
"""
>>> any_in_simple_gen([0,1,0])
......@@ -78,10 +78,40 @@ def any_in_simple_gen(seq):
"""
return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def any_in_simple_gen_scope(seq):
"""
>>> any_in_simple_gen_scope([0,1,0])
True
>>> any_in_simple_gen_scope([0,0,0])
False
>>> any_in_simple_gen_scope(VerboseGetItem([0,0,1,0,0]))
0
1
2
True
>>> any_in_simple_gen_scope(VerboseGetItem([0,0,0,0,0]))
0
1
2
3
4
5
False
"""
x = 'abc'
result = any(x for x in seq)
assert x == 'abc'
return result
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
def any_in_conditional_gen(seq):
"""
>>> any_in_conditional_gen([3,6,9])
......@@ -127,10 +157,10 @@ def any_lower_case_characters(unicode ustring):
"""
return any(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def any_in_typed_gen(seq):
"""
>>> any_in_typed_gen([0,1,0])
......@@ -157,10 +187,10 @@ def any_in_typed_gen(seq):
cdef int x
return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
"//YieldExprNode")
def any_in_nested_gen(seq):
"""
>>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L)
......
......@@ -6,7 +6,7 @@ __doc__ = u"""
[('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
>>> sorted(get_locals_items_listcomp(1,2,3, k=5))
[('args', (2, 3)), ('item', None), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
[('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
"""
def get_locals(x, *args, **kwds):
......@@ -20,7 +20,6 @@ def get_locals_items(x, *args, **kwds):
return locals().items()
def get_locals_items_listcomp(x, *args, **kwds):
# FIXME: 'item' should *not* appear in locals() !
cdef int z = 5
y = "hi"
return [ item for item in locals().items() ]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment