Commit cf05c0a3 authored by Stefan Behnel's avatar Stefan Behnel

decorator support (partly by Fabrizio Milo)

parent e8ae53c9
...@@ -274,6 +274,16 @@ class CodeWriter(TreeVisitor): ...@@ -274,6 +274,16 @@ class CodeWriter(TreeVisitor):
self.visit(node.body) self.visit(node.body)
self.dedent() self.dedent()
def visit_ReturnStatNode(self, node):
self.startline("return ")
self.visit(node.value)
self.endline()
def visit_DecoratorNode(self, node):
self.startline("@")
self.visit(node.decorator)
self.endline()
def visit_ReraiseStatNode(self, node): def visit_ReraiseStatNode(self, node):
self.line("raise") self.line("raise")
......
...@@ -65,6 +65,7 @@ def make_lexicon(): ...@@ -65,6 +65,7 @@ def make_lexicon():
escapeseq = Str("\\") + (two_oct | three_oct | two_hex | escapeseq = Str("\\") + (two_oct | three_oct | two_hex |
Str('u') + four_hex | Str('x') + two_hex | AnyChar) Str('u') + four_hex | Str('x') + two_hex | AnyChar)
deco = Str("@")
bra = Any("([{") bra = Any("([{")
ket = Any(")]}") ket = Any(")]}")
punct = Any(":,;+-*/|&<>=.%`~^?") punct = Any(":,;+-*/|&<>=.%`~^?")
...@@ -82,6 +83,7 @@ def make_lexicon(): ...@@ -82,6 +83,7 @@ def make_lexicon():
(longconst, 'LONG'), (longconst, 'LONG'),
(fltconst, 'FLOAT'), (fltconst, 'FLOAT'),
(imagconst, 'IMAG'), (imagconst, 'IMAG'),
(deco, 'DECORATOR'),
(punct | diphthong, TEXT), (punct | diphthong, TEXT),
(bra, Method('open_bracket_action')), (bra, Method('open_bracket_action')),
......
...@@ -356,7 +356,7 @@ def create_generate_code(context, options, result): ...@@ -356,7 +356,7 @@ def create_generate_code(context, options, result):
def create_default_pipeline(context, options, result): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Buffer import BufferTransform from Buffer import BufferTransform
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
...@@ -365,6 +365,7 @@ def create_default_pipeline(context, options, result): ...@@ -365,6 +365,7 @@ def create_default_pipeline(context, options, result):
NormalizeTree(context), NormalizeTree(context),
PostParse(context), PostParse(context),
WithTransform(context), WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
......
...@@ -1235,13 +1235,19 @@ class PyArgDeclNode(Node): ...@@ -1235,13 +1235,19 @@ class PyArgDeclNode(Node):
# entry Symtab.Entry # entry Symtab.Entry
child_attrs = [] child_attrs = []
pass
class DecoratorNode(Node):
# A decorator
#
# decorator NameNode or CallNode
child_attrs = ['decorator']
class DefNode(FuncDefNode): class DefNode(FuncDefNode):
# A Python function definition. # A Python function definition.
# #
# name string the Python name of the function # name string the Python name of the function
# decorators [DecoratorNode] list of decorators
# args [CArgDeclNode] formal arguments # args [CArgDeclNode] formal arguments
# star_arg PyArgDeclNode or None * argument # star_arg PyArgDeclNode or None * argument
# starstar_arg PyArgDeclNode or None ** argument # starstar_arg PyArgDeclNode or None ** argument
...@@ -1253,13 +1259,14 @@ class DefNode(FuncDefNode): ...@@ -1253,13 +1259,14 @@ class DefNode(FuncDefNode):
# #
# assmt AssignmentNode Function construction/assignment # assmt AssignmentNode Function construction/assignment
child_attrs = ["args", "star_arg", "starstar_arg", "body"] child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
assmt = None assmt = None
num_kwonly_args = 0 num_kwonly_args = 0
num_required_kw_args = 0 num_required_kw_args = 0
reqd_kw_flags_cname = "0" reqd_kw_flags_cname = "0"
is_wrapper = 0 is_wrapper = 0
decorators = None
def __init__(self, pos, **kwds): def __init__(self, pos, **kwds):
FuncDefNode.__init__(self, pos, **kwds) FuncDefNode.__init__(self, pos, **kwds)
......
...@@ -217,6 +217,26 @@ class WithTransform(CythonTransform): ...@@ -217,6 +217,26 @@ class WithTransform(CythonTransform):
return result.stats return result.stats
class DecoratorTransform(CythonTransform):
def visit_FuncDefNode(self, func_node):
if not func_node.decorators:
return func_node
decorator_result = NameNode(func_node.pos, name = func_node.name)
for decorator in func_node.decorators[::-1]:
decorator_result = SimpleCallNode(
decorator.pos,
function = decorator.decorator,
args = [decorator_result])
func_name_node = NameNode(func_node.pos, name = func_node.name)
reassignment = SingleAssignmentNode(
func_node.pos,
lhs = func_name_node,
rhs = decorator_result)
return [func_node, reassignment]
class AnalyseDeclarationsTransform(CythonTransform): class AnalyseDeclarationsTransform(CythonTransform):
def __call__(self, root): def __call__(self, root):
......
...@@ -1372,6 +1372,14 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1372,6 +1372,14 @@ def p_statement(s, ctx, first_statement = 0):
return p_DEF_statement(s) return p_DEF_statement(s)
elif s.sy == 'IF': elif s.sy == 'IF':
return p_IF_statement(s, ctx) return p_IF_statement(s, ctx)
elif s.sy == 'DECORATOR':
if ctx.level not in ('module', 'class', 'c_class', 'property'):
s.error('decorator not allowed here')
s.level = ctx.level
decorators = p_decorators(s)
if s.sy != 'def':
s.error("Decorators can only be followed by functions ")
return p_def_statement(s, decorators)
else: else:
overridable = 0 overridable = 0
if s.sy == 'cdef': if s.sy == 'cdef':
...@@ -2103,7 +2111,21 @@ def p_ctypedef_statement(s, ctx): ...@@ -2103,7 +2111,21 @@ def p_ctypedef_statement(s, ctx):
declarator = declarator, visibility = visibility, declarator = declarator, visibility = visibility,
in_pxd = ctx.level == 'module_pxd') in_pxd = ctx.level == 'module_pxd')
def p_def_statement(s): def p_decorators(s):
decorators = []
while s.sy == 'DECORATOR':
pos = s.position()
s.next()
decorator = ExprNodes.NameNode(
pos, name = Utils.EncodedString(
p_dotted_name(s, as_allowed=0)[2] ))
if s.sy == '(':
decorator = p_call(s, decorator)
decorators.append(Nodes.DecoratorNode(pos, decorator=decorator))
s.expect_newline("Expected a newline after decorator")
return decorators
def p_def_statement(s, decorators=None):
# s.sy == 'def' # s.sy == 'def'
pos = s.position() pos = s.position()
s.next() s.next()
...@@ -2132,7 +2154,7 @@ def p_def_statement(s): ...@@ -2132,7 +2154,7 @@ def p_def_statement(s):
doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1) doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1)
return Nodes.DefNode(pos, name = name, args = args, return Nodes.DefNode(pos, name = name, args = args,
star_arg = star_arg, starstar_arg = starstar_arg, star_arg = star_arg, starstar_arg = starstar_arg,
doc = doc, body = body) doc = doc, body = body, decorators = decorators)
def p_py_arg_decl(s): def p_py_arg_decl(s):
pos = s.position() pos = s.position()
......
import unittest
from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import DecoratorTransform
class TestDecorator(TransformTest):
def test_decorator(self):
t = self.run_pipeline([DecoratorTransform(None)], u"""
def decorator(fun):
return fun
@decorator
def decorated():
pass
""")
self.assertCode(u"""
def decorator(fun):
return fun
def decorated():
pass
decorated = decorator(decorated)
""", t)
if __name__ == '__main__':
unittest.main()
__doc__ = u"""
>>> f(1,2)
4
>>> f.HERE
1
>>> g(1,2)
5
>>> g.HERE
5
>>> h(1,2)
6
>>> h.HERE
1
"""
class wrap:
def __init__(self, func):
self.func = func
self.HERE = 1
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def decorate(func):
try:
func.HERE += 1
except AttributeError:
func = wrap(func)
return func
def decorate2(a,b):
return decorate
@decorate
def f(a,b):
return a+b+1
@decorate
@decorate
@decorate
@decorate
@decorate
def g(a,b):
return a+b+2
@decorate2(1,2)
def h(a,b):
return a+b+3
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment