Commit a8305590 authored by Mark Florisson's avatar Mark Florisson

Support optional fused-typed arguments in cdef functions

parent 430ec7cf
...@@ -572,22 +572,6 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -572,22 +572,6 @@ class CFuncDeclaratorNode(CDeclaratorNode):
elif self.optional_arg_count: elif self.optional_arg_count:
error(self.pos, "Non-default argument follows default argument") error(self.pos, "Non-default argument follows default argument")
if self.optional_arg_count:
scope = StructOrUnionScope()
arg_count_member = '%sn' % Naming.pyrex_prefix
scope.declare_var(arg_count_member, PyrexTypes.c_int_type, self.pos)
for arg in func_type_args[len(func_type_args)-self.optional_arg_count:]:
scope.declare_var(arg.name, arg.type, arg.pos, allow_pyobject = 1)
struct_cname = env.mangle(Naming.opt_arg_prefix, self.base.name)
self.op_args_struct = env.global_scope().declare_struct_or_union(name = struct_cname,
kind = 'struct',
scope = scope,
typedef_flag = 0,
pos = self.pos,
cname = struct_cname)
self.op_args_struct.defined_in_pxd = 1
self.op_args_struct.used = 1
exc_val = None exc_val = None
exc_check = 0 exc_check = 0
if self.exception_check == '+': if self.exception_check == '+':
...@@ -629,8 +613,19 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -629,8 +613,19 @@ class CFuncDeclaratorNode(CDeclaratorNode):
exception_value = exc_val, exception_check = exc_check, exception_value = exc_val, exception_check = exc_check,
calling_convention = self.base.calling_convention, calling_convention = self.base.calling_convention,
nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
if self.optional_arg_count: if self.optional_arg_count:
func_type.op_arg_struct = PyrexTypes.c_ptr_type(self.op_args_struct.type) if func_type.is_fused:
# This is a bit of a hack... When we need to create specialized CFuncTypes
# on the fly because the cdef is defined in a pxd, we need to declare the specialized optional arg
# struct
def declare_opt_arg_struct(func_type, fused_cname):
self.declare_optional_arg_struct(func_type, env, fused_cname)
func_type.declare_opt_arg_struct = declare_opt_arg_struct
else:
self.declare_optional_arg_struct(func_type, env)
callspec = env.directives['callspec'] callspec = env.directives['callspec']
if callspec: if callspec:
current = func_type.calling_convention current = func_type.calling_convention
...@@ -640,6 +635,38 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -640,6 +635,38 @@ class CFuncDeclaratorNode(CDeclaratorNode):
func_type.calling_convention = callspec func_type.calling_convention = callspec
return self.base.analyse(func_type, env) return self.base.analyse(func_type, env)
def declare_optional_arg_struct(self, func_type, env, fused_cname=None):
"""
Declares the optional argument struct (the struct used to hold the
values for optional arguments). For fused cdef functions, this is
deferred as analyse_declarations is called only once (on the fused
cdef function).
"""
scope = StructOrUnionScope()
arg_count_member = '%sn' % Naming.pyrex_prefix
scope.declare_var(arg_count_member, PyrexTypes.c_int_type, self.pos)
for arg in func_type.args[len(func_type.args)-self.optional_arg_count:]:
scope.declare_var(arg.name, arg.type, arg.pos, allow_pyobject = 1)
struct_cname = env.mangle(Naming.opt_arg_prefix, self.base.name)
if fused_cname is not None:
struct_cname = PyrexTypes.get_fused_cname(fused_cname, struct_cname)
op_args_struct = env.global_scope().declare_struct_or_union(
name = struct_cname,
kind = 'struct',
scope = scope,
typedef_flag = 0,
pos = self.pos,
cname = struct_cname)
op_args_struct.defined_in_pxd = 1
op_args_struct.used = 1
func_type.op_arg_struct = PyrexTypes.c_ptr_type(op_args_struct.type)
class CArgDeclNode(Node): class CArgDeclNode(Node):
# Item in a function declaration argument list. # Item in a function declaration argument list.
...@@ -2020,6 +2047,9 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2020,6 +2047,9 @@ class FusedCFuncDefNode(StatListNode):
if node.return_type.is_fused: if node.return_type.is_fused:
assert not n.return_type.is_fused assert not n.return_type.is_fused
if n.cfunc_declarator.optional_arg_count:
assert n.type.op_arg_struct
assert node.type.is_fused assert node.type.is_fused
node.entry.fused_cfunction = self node.entry.fused_cfunction = self
...@@ -2044,6 +2074,9 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2044,6 +2074,9 @@ class FusedCFuncDefNode(StatListNode):
copied_node.entry.type = newtype copied_node.entry.type = newtype
newtype.entry = copied_node.entry newtype.entry = copied_node.entry
self.node.cfunc_declarator.declare_optional_arg_struct(
newtype, env, fused_cname=cname)
copied_node.return_type = newtype.return_type copied_node.return_type = newtype.return_type
copied_node.create_local_scope(env) copied_node.create_local_scope(env)
copied_node.local_scope.fused_to_specific = fused_to_specific copied_node.local_scope.fused_to_specific = fused_to_specific
......
...@@ -1762,9 +1762,11 @@ class CFuncType(CType): ...@@ -1762,9 +1762,11 @@ class CFuncType(CType):
# nogil boolean Can be called without gil # nogil boolean Can be called without gil
# with_gil boolean Acquire gil around function body # with_gil boolean Acquire gil around function body
# templates [string] or None # templates [string] or None
# cached_specialized_types [CFuncType] cached specialized versions of the CFuncType if defined in a pxd
is_cfunction = 1 is_cfunction = 1
original_sig = None original_sig = None
cached_specialized_types = None
subtypes = ['return_type', 'args'] subtypes = ['return_type', 'args']
...@@ -1991,6 +1993,7 @@ class CFuncType(CType): ...@@ -1991,6 +1993,7 @@ class CFuncType(CType):
new_templates = None new_templates = None
else: else:
new_templates = [v.specialize(values) for v in self.templates] new_templates = [v.specialize(values) for v in self.templates]
return CFuncType(self.return_type.specialize(values), return CFuncType(self.return_type.specialize(values),
[arg.specialize(values) for arg in self.args], [arg.specialize(values) for arg in self.args],
has_varargs = 0, has_varargs = 0,
...@@ -2032,11 +2035,20 @@ class CFuncType(CType): ...@@ -2032,11 +2035,20 @@ class CFuncType(CType):
""" """
assert self.is_fused assert self.is_fused
if self.entry.fused_cfunction:
return [n.type for n in self.entry.fused_cfunction.nodes]
elif self.cached_specialized_types is not None:
return self.cached_specialized_types
result = [] result = []
permutations = self.get_all_specific_permutations() permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific) new_func_type = self.entry.type.specialize(fused_to_specific)
if self.optional_arg_count:
# Remember, this method is set by CFuncDeclaratorNode
self.declare_opt_arg_struct(new_func_type, cname)
new_entry = copy.deepcopy(self.entry) new_entry = copy.deepcopy(self.entry)
new_entry.cname = self.get_specific_cname(cname) new_entry.cname = self.get_specific_cname(cname)
...@@ -2045,6 +2057,8 @@ class CFuncType(CType): ...@@ -2045,6 +2057,8 @@ class CFuncType(CType):
result.append(new_func_type) result.append(new_func_type)
self.cached_specialized_types = result
return result return result
def get_specific_cname(self, fused_cname): def get_specific_cname(self, fused_cname):
...@@ -2053,11 +2067,16 @@ class CFuncType(CType): ...@@ -2053,11 +2067,16 @@ class CFuncType(CType):
for the corresponding function with specific types. for the corresponding function with specific types.
""" """
assert self.is_fused assert self.is_fused
return '%s%s%s' % (Naming.fused_func_prefix, return get_fused_cname(fused_cname, self.entry.func_cname)
fused_cname,
self.entry.func_cname)
def get_fused_cname(fused_cname, orig_cname):
"""
Given the fused cname id and an original cname, return a specialized cname
"""
return '%s%s%s' % (Naming.fused_func_prefix, fused_cname, orig_cname)
def map_with_specific_entries(entry, func, *args, **kwargs): def map_with_specific_entries(entry, func, *args, **kwargs):
""" """
Call func for every specific function instance. If this is not a Call func for every specific function instance. If this is not a
...@@ -2089,7 +2108,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()): ...@@ -2089,7 +2108,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
if id: if id:
cname = '%s_%s' % (id, newid) cname = '%s_%s' % (id, newid)
else: else:
cname = newid cname = str(newid)
if len(fused_types) > 1: if len(fused_types) > 1:
result.extend(get_all_specific_permutations( result.extend(get_all_specific_permutations(
......
...@@ -13,7 +13,7 @@ ctypedef cython.fused_type(int, long) integral ...@@ -13,7 +13,7 @@ ctypedef cython.fused_type(int, long) integral
ctypedef cython.fused_type(int, long, float, double, string_t) fused_type1 ctypedef cython.fused_type(int, long, float, double, string_t) fused_type1
ctypedef cython.fused_type(string_t) fused_type2 ctypedef cython.fused_type(string_t) fused_type2
ctypedef fused_type1 *composed_t ctypedef fused_type1 *composed_t
ctypedef cython.fused_type(int, long, float, double) other_t ctypedef cython.fused_type(int, double) other_t
ctypedef double *p_double ctypedef double *p_double
ctypedef int *p_int ctypedef int *p_int
...@@ -164,8 +164,8 @@ def test_specializations(): ...@@ -164,8 +164,8 @@ def test_specializations():
# print test_specialize[double](1.1, somedouble_p, otherdouble_p) # print test_specialize[double](1.1, somedouble_p, otherdouble_p)
# print # print
#cdef opt_args(integral x, floating y = 4.0): cdef opt_args(integral x, floating y = 4.0):
# print x, y print x, y
def test_opt_args(): def test_opt_args():
""" """
...@@ -176,8 +176,8 @@ def test_opt_args(): ...@@ -176,8 +176,8 @@ def test_opt_args():
3 4.0 3 4.0
3 4.0 3 4.0
""" """
#opt_args[int, float](3) opt_args[int, float](3)
#opt_args[int, double](3) opt_args[int, double](3)
#opt_args[int, float](3, 4.0) opt_args[int, float](3, 4.0)
#opt_args[int, double](3, 4.0) opt_args[int, double](3, 4.0)
...@@ -37,6 +37,7 @@ ctypedef cython.fused_type(str, unicode, bytes) builtin_t ...@@ -37,6 +37,7 @@ ctypedef cython.fused_type(str, unicode, bytes) builtin_t
cdef object_t add_simple(object_t obj, simple_t simple) cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple) cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
cdef public_optional_args(object_t obj, simple_t simple = *)
######## header.h ######## ######## header.h ########
...@@ -54,6 +55,9 @@ cdef object_t add_simple(object_t obj, simple_t simple): ...@@ -54,6 +55,9 @@ cdef object_t add_simple(object_t obj, simple_t simple):
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple): cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
return obj.a + simple return obj.a + simple
cdef public_optional_args(object_t obj, simple_t simple = 6):
return obj.a, simple
######## b.pyx ######## ######## b.pyx ########
from a cimport * from a cimport *
...@@ -82,3 +86,9 @@ assert f(mystruct, 5).a == 10 ...@@ -82,3 +86,9 @@ assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t, int] f = add_simple[mystruct_t, int]
assert f(mystruct, 5).a == 10 assert f(mystruct, 5).a == 10
assert public_optional_args(mystruct, 5) == (5, 5)
assert public_optional_args[mystruct_t, int](mystruct) == (5, 6)
assert public_optional_args[mystruct_t, float](mystruct) == (5, 6.0)
assert public_optional_args[mystruct_t, float](mystruct, 7.0) == (5, 7.0)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment