Commit db1e13e8 authored by Stefan Behnel's avatar Stefan Behnel

make AnalyseDeclarationsTransform inherit from EnvTransform to fix...

make AnalyseDeclarationsTransform inherit from EnvTransform to fix inconsistencies in scope tracking
parent 325da76f
......@@ -34,10 +34,10 @@ cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs
#class WithTransform(CythonTransform, SkipDeclarations):
#class DecoratorTransform(CythonTransform, SkipDeclarations):
#class AnalyseDeclarationsTransform(CythonTransform):
#class AnalyseDeclarationsTransform(EnvTransform):
cdef class AnalyseExpressionsTransform(CythonTransform):
cdef list env_stack
pass
cdef class ExpandInplaceOperators(EnvTransform):
pass
......
......@@ -1349,7 +1349,7 @@ class ForwardDeclareTypes(CythonTransform):
return node
class AnalyseDeclarationsTransform(CythonTransform):
class AnalyseDeclarationsTransform(EnvTransform):
basic_property = TreeFragment(u"""
property NAME:
......@@ -1398,11 +1398,12 @@ if VALUE is not None:
in_lambda = 0
def __call__(self, root):
self.env_stack = [root.scope]
# needed to determine if a cdef var is declared after it's used.
self.seen_vars_stack = []
self.fused_error_funcs = set()
return super(AnalyseDeclarationsTransform, self).__call__(root)
super_class = super(AnalyseDeclarationsTransform, self)
self._super_visit_FuncDefNode = super_class.visit_FuncDefNode
return super_class.__call__(root)
def visit_NameNode(self, node):
self.seen_vars_stack[-1].add(node.name)
......@@ -1410,24 +1411,18 @@ if VALUE is not None:
def visit_ModuleNode(self, node):
self.seen_vars_stack.append(set())
node.analyse_declarations(self.env_stack[-1])
node.analyse_declarations(self.current_env())
self.visitchildren(node)
self.seen_vars_stack.pop()
return node
def visit_LambdaNode(self, node):
self.in_lambda += 1
node.analyse_declarations(self.env_stack[-1])
node.analyse_declarations(self.current_env())
self.visitchildren(node)
self.in_lambda -= 1
return node
def visit_ClassDefNode(self, node):
self.env_stack.append(node.scope)
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_CClassDefNode(self, node):
node = self.visit_ClassDefNode(node)
if node.scope and node.scope.implemented:
......@@ -1548,7 +1543,7 @@ if VALUE is not None:
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
env = self.env_stack[-1]
env = self.current_env()
self.seen_vars_stack.append(set())
lenv = node.local_scope
......@@ -1567,23 +1562,23 @@ if VALUE is not None:
else:
node.body.analyse_declarations(lenv)
self._handle_nogil_cleanup(lenv, node)
self.env_stack.append(lenv)
self.visitchildren(node)
self.env_stack.pop()
self._super_visit_FuncDefNode(node)
self.seen_vars_stack.pop()
return node
def visit_DefNode(self, node):
node = self.visit_FuncDefNode(node)
env = self.env_stack[-1]
env = self.current_env()
if (not isinstance(node, Nodes.DefNode) or
node.fused_py_func or node.is_generator_body or
not node.needs_assignment_synthesis(env)):
return node
return [node, self._synthesize_assignment(node, env)]
def visit_GeneratorBodyDefNode(self, node):
return self.visit_FuncDefNode(node)
def _synthesize_assignment(self, node, env):
# Synthesize assignment node and put it right after defnode
genv = env
......@@ -1622,15 +1617,15 @@ if VALUE is not None:
return assmt
def visit_ScopedExprNode(self, node):
env = self.env_stack[-1]
env = self.current_env()
node.analyse_declarations(env)
# the node may or may not have a local scope
if node.has_local_scope:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
self.env_stack.append(node.expr_scope)
self.enter_scope(node, node.expr_scope)
node.analyse_scoped_declarations(node.expr_scope)
self.visitchildren(node)
self.env_stack.pop()
self.exit_scope()
self.seen_vars_stack.pop()
else:
node.analyse_scoped_declarations(env)
......@@ -1639,7 +1634,7 @@ if VALUE is not None:
def visit_TempResultFromStatNode(self, node):
self.visitchildren(node)
node.analyse_declarations(self.env_stack[-1])
node.analyse_declarations(self.current_env())
return node
def visit_CppClassNode(self, node):
......@@ -1804,18 +1799,15 @@ if VALUE is not None:
class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node):
self.env_stack = [node.scope]
node.scope.infer_types()
node.body.analyse_expressions(node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_ScopedExprNode(self, node):
......
......@@ -201,7 +201,7 @@ class Entry(object):
self.defining_entry = self
def __repr__(self):
return "%s(name=%s, type=%s)" % (type(self).__name__, self.name, self.type)
return "%s(<%x>, name=%s, type=%s)" % (type(self).__name__, id(self), self.name, self.type)
def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name)
......
......@@ -321,7 +321,8 @@ class EnvTransform(CythonTransform):
This transformation keeps a stack of the environments.
"""
def __call__(self, root):
self.env_stack = [(root, root.scope)]
self.env_stack = []
self.enter_scope(root, root.scope)
return super(EnvTransform, self).__call__(root)
def current_env(self):
......@@ -333,10 +334,16 @@ class EnvTransform(CythonTransform):
def global_scope(self):
return self.current_env().global_scope()
def enter_scope(self, node, scope):
self.env_stack.append((node, scope))
def exit_scope(self):
self.env_stack.pop()
def visit_FuncDefNode(self, node):
self.env_stack.append((node, node.local_scope))
self.enter_scope(node, node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
self.exit_scope()
return node
def visit_GeneratorBodyDefNode(self, node):
......@@ -344,22 +351,22 @@ class EnvTransform(CythonTransform):
return node
def visit_ClassDefNode(self, node):
self.env_stack.append((node, node.scope))
self.enter_scope(node, node.scope)
self.visitchildren(node)
self.env_stack.pop()
self.exit_scope()
return node
def visit_CStructOrUnionDefNode(self, node):
self.env_stack.append((node, node.scope))
self.enter_scope(node, node.scope)
self.visitchildren(node)
self.env_stack.pop()
self.exit_scope()
return node
def visit_ScopedExprNode(self, node):
if node.expr_scope:
self.env_stack.append((node, node.expr_scope))
self.enter_scope(node, node.expr_scope)
self.visitchildren(node)
self.env_stack.pop()
self.exit_scope()
else:
self.visitchildren(node)
return node
......@@ -369,9 +376,9 @@ class EnvTransform(CythonTransform):
if node.default:
attrs = [ attr for attr in node.child_attrs if attr != 'default' ]
self.visitchildren(node, attrs)
self.env_stack.append((node, self.current_env().outer_scope))
self.enter_scope(node, self.current_env().outer_scope)
self.visitchildren(node, ('default',))
self.env_stack.pop()
self.exit_scope()
else:
self.visitchildren(node)
return node
......
# mode: run
# tag: closures
cimport cython
@cython.test_fail_if_path_exists(
'//NameNode[@entry.in_closure = True]',
'//NameNode[@entry.from_closure = True]')
def test_func_default():
"""
>>> func = test_func_default()
>>> func()
1
>>> func(2)
2
"""
def default():
return 1
def func(arg=default()):
return arg
return func
# cython: binding=True
# mode: run
# tag: cyfunction
# tag: cyfunction, closures
cimport cython
import sys
......@@ -131,3 +131,59 @@ def test_dynamic_defaults_fused():
for i, f in enumerate(funcs):
print "i", i, "func result", f(1.0), "defaults", get_defaults(f)
@cython.test_fail_if_path_exists(
'//NameNode[@entry.in_closure = True]',
'//NameNode[@entry.from_closure = True]')
def test_func_default_inlined():
"""
Make sure we don't accidentally generate a closure.
>>> func = test_func_default_inlined()
>>> func()
1
>>> func(2)
2
"""
def default():
return 1
def func(arg=default()):
return arg
return func
@cython.test_fail_if_path_exists(
'//NameNode[@entry.in_closure = True]',
'//NameNode[@entry.from_closure = True]')
def test_func_default_scope():
"""
Test that the default value expression is evaluated in the outer scope.
>>> func = test_func_default_scope()
3
>>> func()
[0, 1, 2, 3]
>>> func(2)
2
"""
i = -1
def func(arg=[ i for i in range(4) ]):
return arg
print i # list comps leak in Py2 mode => i == 3
return func
def test_func_default_scope_local():
"""
>>> func = test_func_default_scope_local()
-1
>>> func()
[0, 1, 2, 3]
>>> func(2)
2
"""
i = -1
def func(arg=list(i for i in range(4))):
return arg
print i # genexprs don't leak
return func
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment