Commit 36f3f6d0 authored by Robert Bradshaw's avatar Robert Bradshaw

Add optional args to any cdef overridden function

parent 9ada8c82
......@@ -1567,13 +1567,10 @@ class SimpleCallNode(ExprNode):
for formal_arg, actual_arg in args[:expected_nargs]:
arg_code = actual_arg.result_as(formal_arg.type)
arg_list_code.append(arg_code)
if func_type.optional_arg_count:
if expected_nargs == actual_nargs:
if func_type.old_signature:
struct_type = func_type.old_signature.op_args
else:
struct_type = func_type.op_args
optional_args = struct_type.cast_code('NULL')
optional_args = 'NULL'
else:
optional_arg_code = [str(actual_nargs - expected_nargs)]
for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
......@@ -1582,11 +1579,10 @@ class SimpleCallNode(ExprNode):
# for formal_arg in formal_args[actual_nargs:max_nargs]:
# optional_arg_code.append(formal_arg.type.cast_code('0'))
optional_arg_struct = '{%s}' % ','.join(optional_arg_code)
optional_args = '&' + func_type.op_args.base_type.cast_code(optional_arg_struct)
if func_type.old_signature and \
func_type.old_signature.op_args != func_type.op_args:
optional_args = func_type.old_signature.op_args.cast_code(optional_args)
optional_args = PyrexTypes.c_void_ptr_type.cast_code(
'&' + func_type.op_arg_struct.base_type.cast_code(optional_arg_struct))
arg_list_code.append(optional_args)
for actual_arg in self.args[len(formal_args):]:
arg_list_code.append(actual_arg.result_code)
result = "%s(%s)" % (self.function.result_code,
......
......@@ -62,6 +62,7 @@ skip_dispatch_cname = pyrex_prefix + "skip_dispatch"
empty_tuple = pyrex_prefix + "empty_tuple"
cleanup_cname = pyrex_prefix + "module_cleanup"
optional_args_cname = pyrex_prefix + "optional_args"
no_opt_args = pyrex_prefix + "no_opt_args"
extern_c_macro = pyrex_prefix.upper() + "EXTERN_C"
......
......@@ -384,7 +384,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
calling_convention = self.base.calling_convention,
nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
if self.optional_arg_count:
func_type.op_args = PyrexTypes.c_ptr_type(self.op_args_struct.type)
func_type.op_arg_struct = PyrexTypes.c_ptr_type(self.op_args_struct.type)
return self.base.analyse(func_type, env)
......@@ -763,6 +763,7 @@ class FuncDefNode(StatNode, BlockNode):
# ----- Python version
if self.py_func:
self.py_func.generate_function_definitions(env, code)
self.generate_optarg_wrapper_function(env, code)
def put_stararg_decrefs(self, code):
pass
......@@ -782,6 +783,9 @@ class FuncDefNode(StatNode, BlockNode):
for entry in env.arg_entries:
code.put_var_incref(entry)
def generate_optarg_wrapper_function(self, env, code):
pass
def generate_execution_code(self, code):
# Evaluate and store argument default values
for arg in self.args:
......@@ -845,11 +849,7 @@ class CFuncDefNode(FuncDefNode):
if self.overridable:
import ExprNodes
arg_names = [arg.name for arg in self.type.args]
self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True)
py_func_body = ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
py_func_body = self.call_self_node()
self.py_func = DefNode(pos = self.pos,
name = self.declarator.base.name,
args = self.declarator.args,
......@@ -865,6 +865,16 @@ class CFuncDefNode(FuncDefNode):
self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
self.body = StatListNode(self.pos, stats=[self.override, self.body])
def call_self_node(self, omit_optional_args=0):
import ExprNodes
args = self.type.args
if omit_optional_args:
args = args[:len(args) - self.type.optional_arg_count]
arg_names = [arg.name for arg in args]
self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True)
return ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
def declare_arguments(self, env):
for arg in self.type.args:
......@@ -886,20 +896,22 @@ class CFuncDefNode(FuncDefNode):
if self.overridable:
self.py_func.analyse_expressions(env)
def generate_function_header(self, code, with_pymethdef):
def generate_function_header(self, code, with_pymethdef, with_opt_args = 1):
arg_decls = []
type = self.type
visibility = self.entry.visibility
for arg in type.args[:len(type.args)-type.optional_arg_count]:
arg_decls.append(arg.declaration_code())
if type.optional_arg_count:
arg_decls.append(type.op_args.declaration_code(Naming.optional_args_cname))
if type.optional_arg_count and with_opt_args:
arg_decls.append(type.op_arg_struct.declaration_code(Naming.optional_args_cname))
if type.has_varargs:
arg_decls.append("...")
if not arg_decls:
arg_decls = ["void"]
entity = type.function_header_code(self.entry.func_cname,
string.join(arg_decls, ", "))
cname = self.entry.func_cname
if not with_opt_args:
cname += Naming.no_opt_args
entity = type.function_header_code(cname, string.join(arg_decls, ", "))
if visibility == 'public':
dll_linkage = "DL_EXPORT"
else:
......@@ -973,6 +985,19 @@ class CFuncDefNode(FuncDefNode):
def caller_will_check_exceptions(self):
return self.entry.type.exception_check
def generate_optarg_wrapper_function(self, env, code):
if self.type.optional_arg_count and \
self.type.original_sig and not self.type.original_sig.optional_arg_count:
code.putln()
self.generate_function_header(code, 0, with_opt_args = 0)
if not self.return_type.is_void:
code.put('return ')
args = self.type.args
arglist = [arg.cname for arg in args[:len(args)-self.type.optional_arg_count]]
arglist.append('NULL')
code.putln('%s(%s);' % (self.entry.func_cname, ', '.join(arglist)))
code.putln('}')
class PyArgDeclNode(Node):
# Argument which must be a Python object (used
......
......@@ -584,7 +584,7 @@ class CFuncType(CType):
# with_gil boolean Acquire gil around function body
is_cfunction = 1
old_signature = None
original_sig = None
def __init__(self, return_type, args, has_varargs = 0,
exception_value = None, exception_check = 0, calling_convention = "",
......@@ -680,15 +680,9 @@ class CFuncType(CType):
return 0
if not self.same_calling_convention_as(other_type):
return 0
self.old_signature = other_type
self.original_sig = other_type.original_sig or other_type
if as_cmethod:
self.args[0] = other_type.args[0]
if self.optional_arg_count and \
self.optional_arg_count == other_type.optional_arg_count:
self.op_args = other_type.op_args
print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count
elif self.optional_arg_count:
print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count
return 1
......@@ -741,8 +735,7 @@ class CFuncType(CType):
arg_decl_list.append(
arg.type.declaration_code("", for_display, pyrex = pyrex))
if self.optional_arg_count:
arg_decl_list.append(self.op_args.declaration_code(Naming.optional_args_cname))
# arg_decl_list.append(c_void_ptr_type.declaration_code(Naming.optional_args_cname))
arg_decl_list.append(self.op_arg_struct.declaration_code(Naming.optional_args_cname))
if self.has_varargs:
arg_decl_list.append("...")
arg_decl_code = string.join(arg_decl_list, ", ")
......
......@@ -1273,14 +1273,26 @@ class CClassScope(ClassScope):
if defining and entry.func_cname:
error(pos, "'%s' already defined" % name)
#print "CClassScope.declare_cfunction: checking signature" ###
if type.compatible_signature_with(entry.type, as_cmethod = 1):
if type.same_c_signature_as(entry.type, as_cmethod = 1):
pass
elif type.compatible_signature_with(entry.type, as_cmethod = 1):
if type.optional_arg_count and not type.original_sig.optional_arg_count:
# Need to put a wrapper taking no optional arguments
# into the method table.
wrapper_func_cname = self.mangle(Naming.func_prefix, name) + Naming.no_opt_args
wrapper_func_name = name + Naming.no_opt_args
if entry.type.optional_arg_count:
old_entry = self.lookup_here(wrapper_func_name)
old_entry.func_cname = wrapper_func_cname
else:
entry.func_cname = wrapper_func_cname
entry.name = wrapper_func_name
entry = self.add_cfunction(name, type, pos, cname or name, visibility)
defining = 1
entry.type = type
elif type.same_c_signature_as(entry.type, as_cmethod = 1):
print "not compatible", name
# if type.narrower_c_signature_than(entry.type, as_cmethod = 1):
# entry.type = type
else:
print "here"
error(pos, "Signature not compatible with previous declaration")
else:
if self.defined:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment