Commit c40ff48f authored by Stefan Behnel's avatar Stefan Behnel

fix decorator lookup by avoiding (re-)assignments to the function/class name...

fix decorator lookup by avoiding (re-)assignments to the function/class name before their execution (ticket #593)
parent fc956d33
...@@ -2326,6 +2326,13 @@ class DefNode(FuncDefNode): ...@@ -2326,6 +2326,13 @@ class DefNode(FuncDefNode):
else: else:
rhs.binding = False rhs.binding = False
if self.decorators:
for decorator in self.decorators[::-1]:
rhs = ExprNodes.SimpleCallNode(
decorator.pos,
function = decorator.decorator,
args = [rhs])
self.assmt = SingleAssignmentNode(self.pos, self.assmt = SingleAssignmentNode(self.pos,
lhs = ExprNodes.NameNode(self.pos, name = self.name), lhs = ExprNodes.NameNode(self.pos, name = self.name),
rhs = rhs) rhs = rhs)
...@@ -3198,8 +3205,9 @@ class PyClassDefNode(ClassDefNode): ...@@ -3198,8 +3205,9 @@ class PyClassDefNode(ClassDefNode):
# classobj ClassNode Class object # classobj ClassNode Class object
# target NameNode Variable to assign class object to # target NameNode Variable to assign class object to
child_attrs = ["body", "dict", "metaclass", "mkw", "bases", "classobj", "target"] child_attrs = ["body", "dict", "metaclass", "mkw", "bases", "class_result", "target"]
decorators = None decorators = None
class_result = None
py3_style_class = False # Python3 style class (bases+kwargs) py3_style_class = False # Python3 style class (bases+kwargs)
def __init__(self, pos, name, bases, doc, body, decorators = None, def __init__(self, pos, name, bases, doc, body, decorators = None,
...@@ -3302,6 +3310,16 @@ class PyClassDefNode(ClassDefNode): ...@@ -3302,6 +3310,16 @@ class PyClassDefNode(ClassDefNode):
return cenv return cenv
def analyse_declarations(self, env): def analyse_declarations(self, env):
class_result = self.classobj
if self.decorators:
from ExprNodes import SimpleCallNode
for decorator in self.decorators[::-1]:
class_result = SimpleCallNode(
decorator.pos,
function = decorator.decorator,
args = [class_result])
self.class_result = class_result
self.class_result.analyse_declarations(env)
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
cenv = self.create_scope(env) cenv = self.create_scope(env)
cenv.directives = env.directives cenv.directives = env.directives
...@@ -3314,7 +3332,7 @@ class PyClassDefNode(ClassDefNode): ...@@ -3314,7 +3332,7 @@ class PyClassDefNode(ClassDefNode):
self.metaclass.analyse_expressions(env) self.metaclass.analyse_expressions(env)
self.mkw.analyse_expressions(env) self.mkw.analyse_expressions(env)
self.dict.analyse_expressions(env) self.dict.analyse_expressions(env)
self.classobj.analyse_expressions(env) self.class_result.analyse_expressions(env)
genv = env.global_scope() genv = env.global_scope()
cenv = self.scope cenv = self.scope
self.body.analyse_expressions(cenv) self.body.analyse_expressions(cenv)
...@@ -3334,9 +3352,9 @@ class PyClassDefNode(ClassDefNode): ...@@ -3334,9 +3352,9 @@ class PyClassDefNode(ClassDefNode):
self.dict.generate_evaluation_code(code) self.dict.generate_evaluation_code(code)
cenv.namespace_cname = cenv.class_obj_cname = self.dict.result() cenv.namespace_cname = cenv.class_obj_cname = self.dict.result()
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
self.classobj.generate_evaluation_code(code) self.class_result.generate_evaluation_code(code)
cenv.namespace_cname = cenv.class_obj_cname = self.classobj.result() cenv.namespace_cname = cenv.class_obj_cname = self.classobj.result()
self.target.generate_assignment_code(self.classobj, code) self.target.generate_assignment_code(self.class_result, code)
self.dict.generate_disposal_code(code) self.dict.generate_disposal_code(code)
self.dict.free_temps(code) self.dict.free_temps(code)
if self.py3_style_class: if self.py3_style_class:
...@@ -3424,6 +3442,9 @@ class CClassDefNode(ClassDefNode): ...@@ -3424,6 +3442,9 @@ class CClassDefNode(ClassDefNode):
if env.in_cinclude and not self.objstruct_name: if env.in_cinclude and not self.objstruct_name:
error(self.pos, "Object struct name specification required for " error(self.pos, "Object struct name specification required for "
"C class defined in 'extern from' block") "C class defined in 'extern from' block")
if self.decorators:
error(self.pos,
"Decorators not allowed on cdef classes (used on type '%s')" % self.class_name)
self.base_type = None self.base_type = None
# Now that module imports are cached, we need to # Now that module imports are cached, we need to
# import the modules for extern classes. # import the modules for extern classes.
......
...@@ -1193,39 +1193,23 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1193,39 +1193,23 @@ class WithTransform(CythonTransform, SkipDeclarations):
return node return node
class DecoratorTransform(CythonTransform, SkipDeclarations): class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
"""Originally, this was the only place where decorators were
transformed into the corresponding calling code. Now, this is
done directly in DefNode and PyClassDefNode to avoid reassignments
to the function/class name - except for cdef class methods. For
those, the reassignment is required as methods are originally
defined in the PyMethodDef struct.
"""
def visit_DefNode(self, func_node): def visit_DefNode(self, func_node):
self.visitchildren(func_node) scope_type = self.scope_type
if not func_node.decorators: func_node = self.visit_FuncDefNode(func_node)
if scope_type is not 'cclass' or not func_node.decorators:
return func_node return func_node
return self._handle_decorators( return self._handle_decorators(
func_node, func_node.name) func_node, func_node.name)
def visit_CClassDefNode(self, class_node):
# This doesn't currently work, so it's disabled.
#
# Problem: assignments to cdef class names do not work. They
# would require an additional check anyway, as the extension
# type must not change its C type, so decorators cannot
# replace an extension type, just alter it and return it.
self.visitchildren(class_node)
if not class_node.decorators:
return class_node
error(class_node.pos,
"Decorators not allowed on cdef classes (used on type '%s')" % class_node.class_name)
return class_node
#return self._handle_decorators(
# class_node, class_node.class_name)
def visit_ClassDefNode(self, class_node):
self.visitchildren(class_node)
if not class_node.decorators:
return class_node
return self._handle_decorators(
class_node, class_node.name)
def _handle_decorators(self, node, name): def _handle_decorators(self, node, name):
decorator_result = ExprNodes.NameNode(node.pos, name = name) decorator_result = ExprNodes.NameNode(node.pos, name = name)
for decorator in node.decorators[::-1]: for decorator in node.decorators[::-1]:
......
import unittest
from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import DecoratorTransform
class TestDecorator(TransformTest):
def test_decorator(self):
t = self.run_pipeline([DecoratorTransform(None)], u"""
def decorator(fun):
return fun
@decorator
def decorated():
pass
""")
self.assertCode(u"""
def decorator(fun):
return fun
def decorated():
pass
decorated = decorator(decorated)
""", t)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment