Commit 36f3f6d0 authored by Robert Bradshaw's avatar Robert Bradshaw

Add optional args to any cdef overridden function

parent 9ada8c82
...@@ -1567,13 +1567,10 @@ class SimpleCallNode(ExprNode): ...@@ -1567,13 +1567,10 @@ class SimpleCallNode(ExprNode):
for formal_arg, actual_arg in args[:expected_nargs]: for formal_arg, actual_arg in args[:expected_nargs]:
arg_code = actual_arg.result_as(formal_arg.type) arg_code = actual_arg.result_as(formal_arg.type)
arg_list_code.append(arg_code) arg_list_code.append(arg_code)
if func_type.optional_arg_count: if func_type.optional_arg_count:
if expected_nargs == actual_nargs: if expected_nargs == actual_nargs:
if func_type.old_signature: optional_args = 'NULL'
struct_type = func_type.old_signature.op_args
else:
struct_type = func_type.op_args
optional_args = struct_type.cast_code('NULL')
else: else:
optional_arg_code = [str(actual_nargs - expected_nargs)] optional_arg_code = [str(actual_nargs - expected_nargs)]
for formal_arg, actual_arg in args[expected_nargs:actual_nargs]: for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
...@@ -1582,11 +1579,10 @@ class SimpleCallNode(ExprNode): ...@@ -1582,11 +1579,10 @@ class SimpleCallNode(ExprNode):
# for formal_arg in formal_args[actual_nargs:max_nargs]: # for formal_arg in formal_args[actual_nargs:max_nargs]:
# optional_arg_code.append(formal_arg.type.cast_code('0')) # optional_arg_code.append(formal_arg.type.cast_code('0'))
optional_arg_struct = '{%s}' % ','.join(optional_arg_code) optional_arg_struct = '{%s}' % ','.join(optional_arg_code)
optional_args = '&' + func_type.op_args.base_type.cast_code(optional_arg_struct) optional_args = PyrexTypes.c_void_ptr_type.cast_code(
if func_type.old_signature and \ '&' + func_type.op_arg_struct.base_type.cast_code(optional_arg_struct))
func_type.old_signature.op_args != func_type.op_args:
optional_args = func_type.old_signature.op_args.cast_code(optional_args)
arg_list_code.append(optional_args) arg_list_code.append(optional_args)
for actual_arg in self.args[len(formal_args):]: for actual_arg in self.args[len(formal_args):]:
arg_list_code.append(actual_arg.result_code) arg_list_code.append(actual_arg.result_code)
result = "%s(%s)" % (self.function.result_code, result = "%s(%s)" % (self.function.result_code,
......
...@@ -62,6 +62,7 @@ skip_dispatch_cname = pyrex_prefix + "skip_dispatch" ...@@ -62,6 +62,7 @@ skip_dispatch_cname = pyrex_prefix + "skip_dispatch"
empty_tuple = pyrex_prefix + "empty_tuple" empty_tuple = pyrex_prefix + "empty_tuple"
cleanup_cname = pyrex_prefix + "module_cleanup" cleanup_cname = pyrex_prefix + "module_cleanup"
optional_args_cname = pyrex_prefix + "optional_args" optional_args_cname = pyrex_prefix + "optional_args"
no_opt_args = pyrex_prefix + "no_opt_args"
extern_c_macro = pyrex_prefix.upper() + "EXTERN_C" extern_c_macro = pyrex_prefix.upper() + "EXTERN_C"
......
...@@ -384,7 +384,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -384,7 +384,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
calling_convention = self.base.calling_convention, calling_convention = self.base.calling_convention,
nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
if self.optional_arg_count: if self.optional_arg_count:
func_type.op_args = PyrexTypes.c_ptr_type(self.op_args_struct.type) func_type.op_arg_struct = PyrexTypes.c_ptr_type(self.op_args_struct.type)
return self.base.analyse(func_type, env) return self.base.analyse(func_type, env)
...@@ -763,6 +763,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -763,6 +763,7 @@ class FuncDefNode(StatNode, BlockNode):
# ----- Python version # ----- Python version
if self.py_func: if self.py_func:
self.py_func.generate_function_definitions(env, code) self.py_func.generate_function_definitions(env, code)
self.generate_optarg_wrapper_function(env, code)
def put_stararg_decrefs(self, code): def put_stararg_decrefs(self, code):
pass pass
...@@ -782,6 +783,9 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -782,6 +783,9 @@ class FuncDefNode(StatNode, BlockNode):
for entry in env.arg_entries: for entry in env.arg_entries:
code.put_var_incref(entry) code.put_var_incref(entry)
def generate_optarg_wrapper_function(self, env, code):
pass
def generate_execution_code(self, code): def generate_execution_code(self, code):
# Evaluate and store argument default values # Evaluate and store argument default values
for arg in self.args: for arg in self.args:
...@@ -845,11 +849,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -845,11 +849,7 @@ class CFuncDefNode(FuncDefNode):
if self.overridable: if self.overridable:
import ExprNodes import ExprNodes
arg_names = [arg.name for arg in self.type.args] py_func_body = self.call_self_node()
self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True)
py_func_body = ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
self.py_func = DefNode(pos = self.pos, self.py_func = DefNode(pos = self.pos,
name = self.declarator.base.name, name = self.declarator.base.name,
args = self.declarator.args, args = self.declarator.args,
...@@ -865,6 +865,16 @@ class CFuncDefNode(FuncDefNode): ...@@ -865,6 +865,16 @@ class CFuncDefNode(FuncDefNode):
self.override = OverrideCheckNode(self.pos, py_func = self.py_func) self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
self.body = StatListNode(self.pos, stats=[self.override, self.body]) self.body = StatListNode(self.pos, stats=[self.override, self.body])
def call_self_node(self, omit_optional_args=0):
import ExprNodes
args = self.type.args
if omit_optional_args:
args = args[:len(args) - self.type.optional_arg_count]
arg_names = [arg.name for arg in args]
self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True)
return ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
def declare_arguments(self, env): def declare_arguments(self, env):
for arg in self.type.args: for arg in self.type.args:
...@@ -886,20 +896,22 @@ class CFuncDefNode(FuncDefNode): ...@@ -886,20 +896,22 @@ class CFuncDefNode(FuncDefNode):
if self.overridable: if self.overridable:
self.py_func.analyse_expressions(env) self.py_func.analyse_expressions(env)
def generate_function_header(self, code, with_pymethdef): def generate_function_header(self, code, with_pymethdef, with_opt_args = 1):
arg_decls = [] arg_decls = []
type = self.type type = self.type
visibility = self.entry.visibility visibility = self.entry.visibility
for arg in type.args[:len(type.args)-type.optional_arg_count]: for arg in type.args[:len(type.args)-type.optional_arg_count]:
arg_decls.append(arg.declaration_code()) arg_decls.append(arg.declaration_code())
if type.optional_arg_count: if type.optional_arg_count and with_opt_args:
arg_decls.append(type.op_args.declaration_code(Naming.optional_args_cname)) arg_decls.append(type.op_arg_struct.declaration_code(Naming.optional_args_cname))
if type.has_varargs: if type.has_varargs:
arg_decls.append("...") arg_decls.append("...")
if not arg_decls: if not arg_decls:
arg_decls = ["void"] arg_decls = ["void"]
entity = type.function_header_code(self.entry.func_cname, cname = self.entry.func_cname
string.join(arg_decls, ", ")) if not with_opt_args:
cname += Naming.no_opt_args
entity = type.function_header_code(cname, string.join(arg_decls, ", "))
if visibility == 'public': if visibility == 'public':
dll_linkage = "DL_EXPORT" dll_linkage = "DL_EXPORT"
else: else:
...@@ -973,6 +985,19 @@ class CFuncDefNode(FuncDefNode): ...@@ -973,6 +985,19 @@ class CFuncDefNode(FuncDefNode):
def caller_will_check_exceptions(self): def caller_will_check_exceptions(self):
return self.entry.type.exception_check return self.entry.type.exception_check
def generate_optarg_wrapper_function(self, env, code):
if self.type.optional_arg_count and \
self.type.original_sig and not self.type.original_sig.optional_arg_count:
code.putln()
self.generate_function_header(code, 0, with_opt_args = 0)
if not self.return_type.is_void:
code.put('return ')
args = self.type.args
arglist = [arg.cname for arg in args[:len(args)-self.type.optional_arg_count]]
arglist.append('NULL')
code.putln('%s(%s);' % (self.entry.func_cname, ', '.join(arglist)))
code.putln('}')
class PyArgDeclNode(Node): class PyArgDeclNode(Node):
# Argument which must be a Python object (used # Argument which must be a Python object (used
......
...@@ -584,7 +584,7 @@ class CFuncType(CType): ...@@ -584,7 +584,7 @@ class CFuncType(CType):
# with_gil boolean Acquire gil around function body # with_gil boolean Acquire gil around function body
is_cfunction = 1 is_cfunction = 1
old_signature = None original_sig = None
def __init__(self, return_type, args, has_varargs = 0, def __init__(self, return_type, args, has_varargs = 0,
exception_value = None, exception_check = 0, calling_convention = "", exception_value = None, exception_check = 0, calling_convention = "",
...@@ -680,15 +680,9 @@ class CFuncType(CType): ...@@ -680,15 +680,9 @@ class CFuncType(CType):
return 0 return 0
if not self.same_calling_convention_as(other_type): if not self.same_calling_convention_as(other_type):
return 0 return 0
self.old_signature = other_type self.original_sig = other_type.original_sig or other_type
if as_cmethod: if as_cmethod:
self.args[0] = other_type.args[0] self.args[0] = other_type.args[0]
if self.optional_arg_count and \
self.optional_arg_count == other_type.optional_arg_count:
self.op_args = other_type.op_args
print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count
elif self.optional_arg_count:
print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count
return 1 return 1
...@@ -741,8 +735,7 @@ class CFuncType(CType): ...@@ -741,8 +735,7 @@ class CFuncType(CType):
arg_decl_list.append( arg_decl_list.append(
arg.type.declaration_code("", for_display, pyrex = pyrex)) arg.type.declaration_code("", for_display, pyrex = pyrex))
if self.optional_arg_count: if self.optional_arg_count:
arg_decl_list.append(self.op_args.declaration_code(Naming.optional_args_cname)) arg_decl_list.append(self.op_arg_struct.declaration_code(Naming.optional_args_cname))
# arg_decl_list.append(c_void_ptr_type.declaration_code(Naming.optional_args_cname))
if self.has_varargs: if self.has_varargs:
arg_decl_list.append("...") arg_decl_list.append("...")
arg_decl_code = string.join(arg_decl_list, ", ") arg_decl_code = string.join(arg_decl_list, ", ")
......
...@@ -1273,14 +1273,26 @@ class CClassScope(ClassScope): ...@@ -1273,14 +1273,26 @@ class CClassScope(ClassScope):
if defining and entry.func_cname: if defining and entry.func_cname:
error(pos, "'%s' already defined" % name) error(pos, "'%s' already defined" % name)
#print "CClassScope.declare_cfunction: checking signature" ### #print "CClassScope.declare_cfunction: checking signature" ###
if type.compatible_signature_with(entry.type, as_cmethod = 1): if type.same_c_signature_as(entry.type, as_cmethod = 1):
pass
elif type.compatible_signature_with(entry.type, as_cmethod = 1):
if type.optional_arg_count and not type.original_sig.optional_arg_count:
# Need to put a wrapper taking no optional arguments
# into the method table.
wrapper_func_cname = self.mangle(Naming.func_prefix, name) + Naming.no_opt_args
wrapper_func_name = name + Naming.no_opt_args
if entry.type.optional_arg_count:
old_entry = self.lookup_here(wrapper_func_name)
old_entry.func_cname = wrapper_func_cname
else:
entry.func_cname = wrapper_func_cname
entry.name = wrapper_func_name
entry = self.add_cfunction(name, type, pos, cname or name, visibility)
defining = 1
entry.type = type entry.type = type
elif type.same_c_signature_as(entry.type, as_cmethod = 1):
print "not compatible", name
# if type.narrower_c_signature_than(entry.type, as_cmethod = 1): # if type.narrower_c_signature_than(entry.type, as_cmethod = 1):
# entry.type = type # entry.type = type
else: else:
print "here"
error(pos, "Signature not compatible with previous declaration") error(pos, "Signature not compatible with previous declaration")
else: else:
if self.defined: if self.defined:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment