Commit 23b8ea6d authored by Mark Florisson's avatar Mark Florisson

Support decorators for fused functions

parent 5a0effd0
...@@ -40,6 +40,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -40,6 +40,7 @@ class FusedCFuncDefNode(StatListNode):
resulting_fused_function = None resulting_fused_function = None
fused_func_assignment = None fused_func_assignment = None
defaults_tuple = None defaults_tuple = None
decorators = None
def __init__(self, node, env): def __init__(self, node, env):
super(FusedCFuncDefNode, self).__init__(node.pos) super(FusedCFuncDefNode, self).__init__(node.pos)
...@@ -49,6 +50,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -49,6 +50,7 @@ class FusedCFuncDefNode(StatListNode):
is_def = isinstance(self.node, DefNode) is_def = isinstance(self.node, DefNode)
if is_def: if is_def:
# self.node.decorators = []
self.copy_def(env) self.copy_def(env)
else: else:
self.copy_cdef(env) self.copy_cdef(env)
...@@ -91,6 +93,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -91,6 +93,8 @@ class FusedCFuncDefNode(StatListNode):
fused_to_specific) fused_to_specific)
copied_node.analyse_declarations(env) copied_node.analyse_declarations(env)
# copied_node.is_staticmethod = self.node.is_staticmethod
# copied_node.is_classmethod = self.node.is_classmethod
self.create_new_local_scope(copied_node, env, fused_to_specific) self.create_new_local_scope(copied_node, env, fused_to_specific)
self.specialize_copied_def(copied_node, cname, self.node.entry, self.specialize_copied_def(copied_node, cname, self.node.entry,
fused_to_specific, fused_compound_types) fused_to_specific, fused_compound_types)
......
...@@ -2524,7 +2524,8 @@ class DefNode(FuncDefNode): ...@@ -2524,7 +2524,8 @@ class DefNode(FuncDefNode):
sig.is_staticmethod = True sig.is_staticmethod = True
sig.has_generic_args = True sig.has_generic_args = True
if self.is_classmethod and self.has_fused_arguments and env.is_c_class_scope: if ((self.is_classmethod or self.is_staticmethod) and
self.has_fused_arguments and env.is_c_class_scope):
del self.decorator_indirection.stats[:] del self.decorator_indirection.stats[:]
for i in range(min(nfixed, len(self.args))): for i in range(min(nfixed, len(self.args))):
......
...@@ -1242,12 +1242,12 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations): ...@@ -1242,12 +1242,12 @@ class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
func_node = self.visit_FuncDefNode(func_node) func_node = self.visit_FuncDefNode(func_node)
if scope_type != 'cclass' or not func_node.decorators: if scope_type != 'cclass' or not func_node.decorators:
return func_node return func_node
return self._handle_decorators( return self.handle_decorators(func_node, func_node.decorators,
func_node, func_node.name) func_node.name)
def _handle_decorators(self, node, name): def handle_decorators(self, node, decorators, name):
decorator_result = ExprNodes.NameNode(node.pos, name = name) decorator_result = ExprNodes.NameNode(node.pos, name = name)
for decorator in node.decorators[::-1]: for decorator in decorators[::-1]:
decorator_result = ExprNodes.SimpleCallNode( decorator_result = ExprNodes.SimpleCallNode(
decorator.pos, decorator.pos,
function = decorator.decorator, function = decorator.decorator,
...@@ -1441,37 +1441,55 @@ if VALUE is not None: ...@@ -1441,37 +1441,55 @@ if VALUE is not None:
node.body.stats += stats node.body.stats += stats
return node return node
def visit_FuncDefNode(self, node): def _handle_fused_def_decorators(self, old_decorators, env, node):
""" """
Analyse a function and its body, as that hasn't happend yet. Also Create function calls to the decorators and reassignments to
analyse the directive_locals set by @cython.locals(). Then, if we are the function.
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
""" """
env = self.env_stack[-1] # Delete staticmethod and classmethod decorators, this is
# handled directly by the fused function object.
decorators = []
for decorator in old_decorators:
func = decorator.decorator
if (not func.is_name or
func.name not in ('staticmethod', 'classmethod') or
env.lookup_here(func.name)):
# not a static or classmethod
decorators.append(decorator)
if decorators:
transform = DecoratorTransform(self.context)
def_node = node.node
_, reassignments = transform.handle_decorators(
def_node, decorators, def_node.name)
reassignments.analyse_declarations(env)
node = [node, reassignments]
return node
def _handle_def(self, decorators, env, node):
"Handle def or cpdef fused functions"
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func)
node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
# Create assignment node for our def function
node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env)
self.seen_vars_stack.append(set()) if decorators:
lenv = node.local_scope node = self._handle_fused_def_decorators(decorators, env, node)
node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items(): return node
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
if type:
lenv.declare_var(var, type, type_node.pos)
else:
error(type_node.pos, "Not a type")
if node.is_generator and node.has_fused_arguments: def _create_fused_function(self, env, node):
node.has_fused_arguments = False "Create a fused function for a DefNode with fused arguments"
error(node.pos, "Fused generators not supported") from Cython.Compiler import FusedNode
node.gbody = Nodes.StatListNode(node.pos,
stats=[],
body=Nodes.PassStatNode(node.pos))
if node.has_fused_arguments:
if self.fused_function or self.in_lambda: if self.fused_function or self.in_lambda:
if self.fused_function not in self.fused_error_funcs: if self.fused_function not in self.fused_error_funcs:
if self.in_lambda: if self.in_lambda:
...@@ -1488,29 +1506,18 @@ if VALUE is not None: ...@@ -1488,29 +1506,18 @@ if VALUE is not None:
return node return node
from Cython.Compiler import FusedNode decorators = getattr(node, 'decorators', None)
node = FusedNode.FusedCFuncDefNode(node, env) node = FusedNode.FusedCFuncDefNode(node, env)
self.fused_function = node self.fused_function = node
self.visitchildren(node) self.visitchildren(node)
self.fused_function = None self.fused_function = None
if node.py_func: if node.py_func:
# Create PyCFunction nodes for each specialization node = self._handle_def(decorators, env, node)
node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func)
node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
# Create assignment node for our def function return node
node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env)
else:
node.body.analyse_declarations(lenv)
def _handle_nogil_cleanup(self, lenv, node):
"Handle cleanup for 'with gil' blocks in nogil functions."
if lenv.nogil and lenv.has_with_gil_block: if lenv.nogil and lenv.has_with_gil_block:
# Acquire the GIL for cleanup in 'nogil' functions, by wrapping # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
# the entire function body in try/finally. # the entire function body in try/finally.
...@@ -1518,9 +1525,47 @@ if VALUE is not None: ...@@ -1518,9 +1525,47 @@ if VALUE is not None:
# Nodes.FuncDefNode.generate_function_definitions() # Nodes.FuncDefNode.generate_function_definitions()
node.body = Nodes.NogilTryFinallyStatNode( node.body = Nodes.NogilTryFinallyStatNode(
node.body.pos, node.body.pos,
body = node.body, body=node.body,
finally_clause = Nodes.EnsureGILNode(node.body.pos), finally_clause=Nodes.EnsureGILNode(node.body.pos))
)
def _handle_fused(self, node):
if node.is_generator and node.has_fused_arguments:
node.has_fused_arguments = False
error(node.pos, "Fused generators not supported")
node.gbody = Nodes.StatListNode(node.pos,
stats=[],
body=Nodes.PassStatNode(node.pos))
return node.has_fused_arguments
def visit_FuncDefNode(self, node):
"""
Analyse a function and its body, as that hasn't happend yet. Also
analyse the directive_locals set by @cython.locals(). Then, if we are
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
env = self.env_stack[-1]
self.seen_vars_stack.append(set())
lenv = node.local_scope
node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
if type:
lenv.declare_var(var, type, type_node.pos)
else:
error(type_node.pos, "Not a type")
if self._handle_fused(node):
node = self._create_fused_function(env, node)
else:
node.body.analyse_declarations(lenv)
self._handle_nogil_cleanup(lenv, node)
self.env_stack.append(lenv) self.env_stack.append(lenv)
self.visitchildren(node) self.visitchildren(node)
......
...@@ -304,3 +304,20 @@ def test_code_object(cython.floating dummy = 2.0): ...@@ -304,3 +304,20 @@ def test_code_object(cython.floating dummy = 2.0):
>>> getcode(test_code_object) is getcode(test_code_object[float]) >>> getcode(test_code_object) is getcode(test_code_object[float])
True True
""" """
def create_dec(value):
def dec(f):
if not hasattr(f, 'order'):
f.order = []
f.order.append(value)
return f
return dec
@create_dec(1)
@create_dec(2)
@create_dec(3)
def test_decorators(cython.floating arg):
"""
>>> test_decorators.order
[3, 2, 1]
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment