Commit 19cdcd6d authored by Robert Bradshaw's avatar Robert Bradshaw

Add @cython.returns(type) decorator.

parent 9ec01981
...@@ -1280,7 +1280,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1280,7 +1280,8 @@ class FuncDefNode(StatNode, BlockNode):
# needs_closure boolean Whether or not this function has inner functions/classes/yield # needs_closure boolean Whether or not this function has inner functions/classes/yield
# needs_outer_scope boolean Whether or not this function requires outer scope # needs_outer_scope boolean Whether or not this function requires outer scope
# pymethdef_required boolean Force Python method struct generation # pymethdef_required boolean Force Python method struct generation
# directive_locals { string : NameNode } locals defined by cython.locals(...) # directive_locals { string : ExprNode } locals defined by cython.locals(...)
# directive_returns [ExprNode] type defined by cython.returns(...)
# star_arg PyArgDeclNode or None * argument # star_arg PyArgDeclNode or None * argument
# starstar_arg PyArgDeclNode or None ** argument # starstar_arg PyArgDeclNode or None ** argument
...@@ -1872,6 +1873,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -1872,6 +1873,7 @@ class CFuncDefNode(FuncDefNode):
inline_in_pxd = False inline_in_pxd = False
decorators = None decorators = None
directive_locals = None directive_locals = None
directive_returns = None
override = None override = None
def unqualified_name(self): def unqualified_name(self):
...@@ -1881,7 +1883,13 @@ class CFuncDefNode(FuncDefNode): ...@@ -1881,7 +1883,13 @@ class CFuncDefNode(FuncDefNode):
if self.directive_locals is None: if self.directive_locals is None:
self.directive_locals = {} self.directive_locals = {}
self.directive_locals.update(env.directives['locals']) self.directive_locals.update(env.directives['locals'])
base_type = self.base_type.analyse(env) if self.directive_returns is not None:
base_type = self.directive_returns.analyse_as_type(env)
if base_type is None:
error(self.directive_returns.pos, "Not a type")
base_type = PyrexTypes.error_type
else:
base_type = self.base_type.analyse(env)
# The 2 here is because we need both function and argument names. # The 2 here is because we need both function and argument names.
if isinstance(self.declarator, CFuncDeclaratorNode): if isinstance(self.declarator, CFuncDeclaratorNode):
name_declarator, type = self.declarator.analyse(base_type, env, name_declarator, type = self.declarator.analyse(base_type, env,
...@@ -2664,7 +2672,7 @@ class DefNode(FuncDefNode): ...@@ -2664,7 +2672,7 @@ class DefNode(FuncDefNode):
self.num_required_kw_args = rk self.num_required_kw_args = rk
self.num_required_args = r self.num_required_args = r
def as_cfunction(self, cfunc=None, scope=None, overridable=True): def as_cfunction(self, cfunc=None, scope=None, overridable=True, returns=None):
if self.star_arg: if self.star_arg:
error(self.star_arg.pos, "cdef function cannot have star argument") error(self.star_arg.pos, "cdef function cannot have star argument")
if self.starstar_arg: if self.starstar_arg:
...@@ -2724,7 +2732,8 @@ class DefNode(FuncDefNode): ...@@ -2724,7 +2732,8 @@ class DefNode(FuncDefNode):
nogil = cfunc_type.nogil, nogil = cfunc_type.nogil,
visibility = 'private', visibility = 'private',
api = False, api = False,
directive_locals = getattr(cfunc, 'directive_locals', {})) directive_locals = getattr(cfunc, 'directive_locals', {}),
directive_returns = returns)
def is_cdef_func_compatible(self): def is_cdef_func_compatible(self):
"""Determines if the function's signature is compatible with a """Determines if the function's signature is compatible with a
......
...@@ -136,6 +136,7 @@ directive_types = { ...@@ -136,6 +136,7 @@ directive_types = {
'cfunc' : None, # decorators do not take directive value 'cfunc' : None, # decorators do not take directive value
'ccall' : None, 'ccall' : None,
'cclass' : None, 'cclass' : None,
'returns' : type,
} }
for key, val in directive_defaults.items(): for key, val in directive_defaults.items():
......
...@@ -867,6 +867,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -867,6 +867,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
raise PostParseError(pos, raise PostParseError(pos,
'The %s directive takes one compile-time string argument' % optname) 'The %s directive takes one compile-time string argument' % optname)
return (optname, str(args[0].value)) return (optname, str(args[0].value))
elif directivetype is type:
if kwds is not None or len(args) != 1:
raise PostParseError(pos,
'The %s directive takes one type argument' % optname)
return (optname, args[0])
elif directivetype is dict: elif directivetype is dict:
if len(args) != 0: if len(args) != 0:
raise PostParseError(pos, raise PostParseError(pos,
...@@ -1835,7 +1840,7 @@ class AdjustDefByDirectives(CythonTransform, SkipDeclarations): ...@@ -1835,7 +1840,7 @@ class AdjustDefByDirectives(CythonTransform, SkipDeclarations):
if self.in_py_class: if self.in_py_class:
error(node.pos, "cfunc directive is not allowed here") error(node.pos, "cfunc directive is not allowed here")
else: else:
node = node.as_cfunction(overridable=False) node = node.as_cfunction(overridable=False, returns=self.directives.get('returns'))
return self.visit(node) return self.visit(node)
self.visitchildren(node) self.visitchildren(node)
return node return node
......
...@@ -28,6 +28,8 @@ class _EmptyDecoratorAndManager(object): ...@@ -28,6 +28,8 @@ class _EmptyDecoratorAndManager(object):
cclass = ccall = cfunc = _EmptyDecoratorAndManager() cclass = ccall = cfunc = _EmptyDecoratorAndManager()
returns = lambda type_arg: _EmptyDecoratorAndManager()
final = internal = _empty_decorator final = internal = _empty_decorator
def inline(f, *args, **kwds): def inline(f, *args, **kwds):
......
...@@ -110,3 +110,16 @@ def test_ccall_method(x): ...@@ -110,3 +110,16 @@ def test_ccall_method(x):
1 1
""" """
return x.meth() return x.meth()
@cython.cfunc
@cython.returns(p_int)
@cython.locals(xptr=p_int)
def typed_return(xptr):
return xptr
def test_typed_return():
"""
>>> test_typed_return()
"""
x = cython.declare(int, 5)
assert typed_return(cython.address(x)) == cython.address(x)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment