Commit 19cdcd6d authored by Robert Bradshaw's avatar Robert Bradshaw

Add @cython.returns(type) decorator.

parent 9ec01981
......@@ -1280,7 +1280,8 @@ class FuncDefNode(StatNode, BlockNode):
# needs_closure boolean Whether or not this function has inner functions/classes/yield
# needs_outer_scope boolean Whether or not this function requires outer scope
# pymethdef_required boolean Force Python method struct generation
# directive_locals { string : NameNode } locals defined by cython.locals(...)
# directive_locals { string : ExprNode } locals defined by cython.locals(...)
# directive_returns [ExprNode] type defined by cython.returns(...)
# star_arg PyArgDeclNode or None * argument
# starstar_arg PyArgDeclNode or None ** argument
......@@ -1872,6 +1873,7 @@ class CFuncDefNode(FuncDefNode):
inline_in_pxd = False
decorators = None
directive_locals = None
directive_returns = None
override = None
def unqualified_name(self):
......@@ -1881,6 +1883,12 @@ class CFuncDefNode(FuncDefNode):
if self.directive_locals is None:
self.directive_locals = {}
self.directive_locals.update(env.directives['locals'])
if self.directive_returns is not None:
base_type = self.directive_returns.analyse_as_type(env)
if base_type is None:
error(self.directive_returns.pos, "Not a type")
base_type = PyrexTypes.error_type
else:
base_type = self.base_type.analyse(env)
# The 2 here is because we need both function and argument names.
if isinstance(self.declarator, CFuncDeclaratorNode):
......@@ -2664,7 +2672,7 @@ class DefNode(FuncDefNode):
self.num_required_kw_args = rk
self.num_required_args = r
def as_cfunction(self, cfunc=None, scope=None, overridable=True):
def as_cfunction(self, cfunc=None, scope=None, overridable=True, returns=None):
if self.star_arg:
error(self.star_arg.pos, "cdef function cannot have star argument")
if self.starstar_arg:
......@@ -2724,7 +2732,8 @@ class DefNode(FuncDefNode):
nogil = cfunc_type.nogil,
visibility = 'private',
api = False,
directive_locals = getattr(cfunc, 'directive_locals', {}))
directive_locals = getattr(cfunc, 'directive_locals', {}),
directive_returns = returns)
def is_cdef_func_compatible(self):
"""Determines if the function's signature is compatible with a
......
......@@ -136,6 +136,7 @@ directive_types = {
'cfunc' : None, # decorators do not take directive value
'ccall' : None,
'cclass' : None,
'returns' : type,
}
for key, val in directive_defaults.items():
......
......@@ -867,6 +867,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
raise PostParseError(pos,
'The %s directive takes one compile-time string argument' % optname)
return (optname, str(args[0].value))
elif directivetype is type:
if kwds is not None or len(args) != 1:
raise PostParseError(pos,
'The %s directive takes one type argument' % optname)
return (optname, args[0])
elif directivetype is dict:
if len(args) != 0:
raise PostParseError(pos,
......@@ -1835,7 +1840,7 @@ class AdjustDefByDirectives(CythonTransform, SkipDeclarations):
if self.in_py_class:
error(node.pos, "cfunc directive is not allowed here")
else:
node = node.as_cfunction(overridable=False)
node = node.as_cfunction(overridable=False, returns=self.directives.get('returns'))
return self.visit(node)
self.visitchildren(node)
return node
......
......@@ -28,6 +28,8 @@ class _EmptyDecoratorAndManager(object):
cclass = ccall = cfunc = _EmptyDecoratorAndManager()
returns = lambda type_arg: _EmptyDecoratorAndManager()
final = internal = _empty_decorator
def inline(f, *args, **kwds):
......
......@@ -110,3 +110,16 @@ def test_ccall_method(x):
1
"""
return x.meth()
@cython.cfunc
@cython.returns(p_int)
@cython.locals(xptr=p_int)
def typed_return(xptr):
return xptr
def test_typed_return():
"""
>>> test_typed_return()
"""
x = cython.declare(int, 5)
assert typed_return(cython.address(x)) == cython.address(x)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment