Commit db1e13e8 authored by Stefan Behnel's avatar Stefan Behnel

make AnalyseDeclarationsTransform inherit from EnvTransform to fix...

make AnalyseDeclarationsTransform inherit from EnvTransform to fix inconsistencies in scope tracking
parent 325da76f
...@@ -34,10 +34,10 @@ cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs ...@@ -34,10 +34,10 @@ cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs
#class WithTransform(CythonTransform, SkipDeclarations): #class WithTransform(CythonTransform, SkipDeclarations):
#class DecoratorTransform(CythonTransform, SkipDeclarations): #class DecoratorTransform(CythonTransform, SkipDeclarations):
#class AnalyseDeclarationsTransform(CythonTransform): #class AnalyseDeclarationsTransform(EnvTransform):
cdef class AnalyseExpressionsTransform(CythonTransform): cdef class AnalyseExpressionsTransform(CythonTransform):
cdef list env_stack pass
cdef class ExpandInplaceOperators(EnvTransform): cdef class ExpandInplaceOperators(EnvTransform):
pass pass
......
...@@ -1349,7 +1349,7 @@ class ForwardDeclareTypes(CythonTransform): ...@@ -1349,7 +1349,7 @@ class ForwardDeclareTypes(CythonTransform):
return node return node
class AnalyseDeclarationsTransform(CythonTransform): class AnalyseDeclarationsTransform(EnvTransform):
basic_property = TreeFragment(u""" basic_property = TreeFragment(u"""
property NAME: property NAME:
...@@ -1398,11 +1398,12 @@ if VALUE is not None: ...@@ -1398,11 +1398,12 @@ if VALUE is not None:
in_lambda = 0 in_lambda = 0
def __call__(self, root): def __call__(self, root):
self.env_stack = [root.scope]
# needed to determine if a cdef var is declared after it's used. # needed to determine if a cdef var is declared after it's used.
self.seen_vars_stack = [] self.seen_vars_stack = []
self.fused_error_funcs = set() self.fused_error_funcs = set()
return super(AnalyseDeclarationsTransform, self).__call__(root) super_class = super(AnalyseDeclarationsTransform, self)
self._super_visit_FuncDefNode = super_class.visit_FuncDefNode
return super_class.__call__(root)
def visit_NameNode(self, node): def visit_NameNode(self, node):
self.seen_vars_stack[-1].add(node.name) self.seen_vars_stack[-1].add(node.name)
...@@ -1410,24 +1411,18 @@ if VALUE is not None: ...@@ -1410,24 +1411,18 @@ if VALUE is not None:
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.seen_vars_stack.append(set()) self.seen_vars_stack.append(set())
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.current_env())
self.visitchildren(node) self.visitchildren(node)
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
return node return node
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
self.in_lambda += 1 self.in_lambda += 1
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.current_env())
self.visitchildren(node) self.visitchildren(node)
self.in_lambda -= 1 self.in_lambda -= 1
return node return node
def visit_ClassDefNode(self, node):
self.env_stack.append(node.scope)
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_CClassDefNode(self, node): def visit_CClassDefNode(self, node):
node = self.visit_ClassDefNode(node) node = self.visit_ClassDefNode(node)
if node.scope and node.scope.implemented: if node.scope and node.scope.implemented:
...@@ -1548,7 +1543,7 @@ if VALUE is not None: ...@@ -1548,7 +1543,7 @@ if VALUE is not None:
analyse its children (which are in turn normal functions). If we're a analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function. normal function, just analyse the body of the function.
""" """
env = self.env_stack[-1] env = self.current_env()
self.seen_vars_stack.append(set()) self.seen_vars_stack.append(set())
lenv = node.local_scope lenv = node.local_scope
...@@ -1567,23 +1562,23 @@ if VALUE is not None: ...@@ -1567,23 +1562,23 @@ if VALUE is not None:
else: else:
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
self._handle_nogil_cleanup(lenv, node) self._handle_nogil_cleanup(lenv, node)
self._super_visit_FuncDefNode(node)
self.env_stack.append(lenv)
self.visitchildren(node)
self.env_stack.pop()
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
return node return node
def visit_DefNode(self, node): def visit_DefNode(self, node):
node = self.visit_FuncDefNode(node) node = self.visit_FuncDefNode(node)
env = self.env_stack[-1] env = self.current_env()
if (not isinstance(node, Nodes.DefNode) or if (not isinstance(node, Nodes.DefNode) or
node.fused_py_func or node.is_generator_body or node.fused_py_func or node.is_generator_body or
not node.needs_assignment_synthesis(env)): not node.needs_assignment_synthesis(env)):
return node return node
return [node, self._synthesize_assignment(node, env)] return [node, self._synthesize_assignment(node, env)]
def visit_GeneratorBodyDefNode(self, node):
return self.visit_FuncDefNode(node)
def _synthesize_assignment(self, node, env): def _synthesize_assignment(self, node, env):
# Synthesize assignment node and put it right after defnode # Synthesize assignment node and put it right after defnode
genv = env genv = env
...@@ -1622,15 +1617,15 @@ if VALUE is not None: ...@@ -1622,15 +1617,15 @@ if VALUE is not None:
return assmt return assmt
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
env = self.env_stack[-1] env = self.current_env()
node.analyse_declarations(env) node.analyse_declarations(env)
# the node may or may not have a local scope # the node may or may not have a local scope
if node.has_local_scope: if node.has_local_scope:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1])) self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
self.env_stack.append(node.expr_scope) self.enter_scope(node, node.expr_scope)
node.analyse_scoped_declarations(node.expr_scope) node.analyse_scoped_declarations(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.exit_scope()
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
else: else:
node.analyse_scoped_declarations(env) node.analyse_scoped_declarations(env)
...@@ -1639,7 +1634,7 @@ if VALUE is not None: ...@@ -1639,7 +1634,7 @@ if VALUE is not None:
def visit_TempResultFromStatNode(self, node): def visit_TempResultFromStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.current_env())
return node return node
def visit_CppClassNode(self, node): def visit_CppClassNode(self, node):
...@@ -1804,18 +1799,15 @@ if VALUE is not None: ...@@ -1804,18 +1799,15 @@ if VALUE is not None:
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.env_stack = [node.scope]
node.scope.infer_types() node.scope.infer_types()
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
node.local_scope.infer_types() node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope) node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop()
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
......
...@@ -201,7 +201,7 @@ class Entry(object): ...@@ -201,7 +201,7 @@ class Entry(object):
self.defining_entry = self self.defining_entry = self
def __repr__(self): def __repr__(self):
return "%s(name=%s, type=%s)" % (type(self).__name__, self.name, self.type) return "%s(<%x>, name=%s, type=%s)" % (type(self).__name__, id(self), self.name, self.type)
def redeclared(self, pos): def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
......
...@@ -321,7 +321,8 @@ class EnvTransform(CythonTransform): ...@@ -321,7 +321,8 @@ class EnvTransform(CythonTransform):
This transformation keeps a stack of the environments. This transformation keeps a stack of the environments.
""" """
def __call__(self, root): def __call__(self, root):
self.env_stack = [(root, root.scope)] self.env_stack = []
self.enter_scope(root, root.scope)
return super(EnvTransform, self).__call__(root) return super(EnvTransform, self).__call__(root)
def current_env(self): def current_env(self):
...@@ -333,10 +334,16 @@ class EnvTransform(CythonTransform): ...@@ -333,10 +334,16 @@ class EnvTransform(CythonTransform):
def global_scope(self): def global_scope(self):
return self.current_env().global_scope() return self.current_env().global_scope()
def enter_scope(self, node, scope):
self.env_stack.append((node, scope))
def exit_scope(self):
self.env_stack.pop()
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.env_stack.append((node, node.local_scope)) self.enter_scope(node, node.local_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.exit_scope()
return node return node
def visit_GeneratorBodyDefNode(self, node): def visit_GeneratorBodyDefNode(self, node):
...@@ -344,22 +351,22 @@ class EnvTransform(CythonTransform): ...@@ -344,22 +351,22 @@ class EnvTransform(CythonTransform):
return node return node
def visit_ClassDefNode(self, node): def visit_ClassDefNode(self, node):
self.env_stack.append((node, node.scope)) self.enter_scope(node, node.scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.exit_scope()
return node return node
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
self.env_stack.append((node, node.scope)) self.enter_scope(node, node.scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.exit_scope()
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
if node.expr_scope: if node.expr_scope:
self.env_stack.append((node, node.expr_scope)) self.enter_scope(node, node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.exit_scope()
else: else:
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -369,9 +376,9 @@ class EnvTransform(CythonTransform): ...@@ -369,9 +376,9 @@ class EnvTransform(CythonTransform):
if node.default: if node.default:
attrs = [ attr for attr in node.child_attrs if attr != 'default' ] attrs = [ attr for attr in node.child_attrs if attr != 'default' ]
self.visitchildren(node, attrs) self.visitchildren(node, attrs)
self.env_stack.append((node, self.current_env().outer_scope)) self.enter_scope(node, self.current_env().outer_scope)
self.visitchildren(node, ('default',)) self.visitchildren(node, ('default',))
self.env_stack.pop() self.exit_scope()
else: else:
self.visitchildren(node) self.visitchildren(node)
return node return node
......
# mode: run
# tag: closures
cimport cython
@cython.test_fail_if_path_exists(
'//NameNode[@entry.in_closure = True]',
'//NameNode[@entry.from_closure = True]')
def test_func_default():
"""
>>> func = test_func_default()
>>> func()
1
>>> func(2)
2
"""
def default():
return 1
def func(arg=default()):
return arg
return func
# cython: binding=True # cython: binding=True
# mode: run # mode: run
# tag: cyfunction # tag: cyfunction, closures
cimport cython cimport cython
import sys import sys
...@@ -131,3 +131,59 @@ def test_dynamic_defaults_fused(): ...@@ -131,3 +131,59 @@ def test_dynamic_defaults_fused():
for i, f in enumerate(funcs): for i, f in enumerate(funcs):
print "i", i, "func result", f(1.0), "defaults", get_defaults(f) print "i", i, "func result", f(1.0), "defaults", get_defaults(f)
@cython.test_fail_if_path_exists(
'//NameNode[@entry.in_closure = True]',
'//NameNode[@entry.from_closure = True]')
def test_func_default_inlined():
"""
Make sure we don't accidentally generate a closure.
>>> func = test_func_default_inlined()
>>> func()
1
>>> func(2)
2
"""
def default():
return 1
def func(arg=default()):
return arg
return func
@cython.test_fail_if_path_exists(
'//NameNode[@entry.in_closure = True]',
'//NameNode[@entry.from_closure = True]')
def test_func_default_scope():
"""
Test that the default value expression is evaluated in the outer scope.
>>> func = test_func_default_scope()
3
>>> func()
[0, 1, 2, 3]
>>> func(2)
2
"""
i = -1
def func(arg=[ i for i in range(4) ]):
return arg
print i # list comps leak in Py2 mode => i == 3
return func
def test_func_default_scope_local():
"""
>>> func = test_func_default_scope_local()
-1
>>> func()
[0, 1, 2, 3]
>>> func(2)
2
"""
i = -1
def func(arg=list(i for i in range(4))):
return arg
print i # genexprs don't leak
return func
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment