Commit c0e500d4 authored by Robert Bradshaw's avatar Robert Bradshaw

Fix cpdef cross-module dispatching bug, greatly simplify (also simplify optional args)

parent 6ed229d1
...@@ -1887,6 +1887,9 @@ class SimpleCallNode(CallNode): ...@@ -1887,6 +1887,9 @@ class SimpleCallNode(CallNode):
arg_code = actual_arg.result_as(formal_arg.type) arg_code = actual_arg.result_as(formal_arg.type)
arg_list_code.append(arg_code) arg_list_code.append(arg_code)
if func_type.is_overridable:
arg_list_code.append(str(int(self.wrapper_call or self.function.entry.is_unbound_cmethod)))
if func_type.optional_arg_count: if func_type.optional_arg_count:
if expected_nargs == actual_nargs: if expected_nargs == actual_nargs:
optional_args = 'NULL' optional_args = 'NULL'
...@@ -1898,9 +1901,9 @@ class SimpleCallNode(CallNode): ...@@ -1898,9 +1901,9 @@ class SimpleCallNode(CallNode):
arg_list_code.append(actual_arg.result_code) arg_list_code.append(actual_arg.result_code)
result = "%s(%s)" % (self.function.result_code, result = "%s(%s)" % (self.function.result_code,
join(arg_list_code, ", ")) join(arg_list_code, ", "))
if self.wrapper_call or \ # if self.wrapper_call or \
self.function.entry.is_unbound_cmethod and self.function.entry.type.is_overridable: # self.function.entry.is_unbound_cmethod and self.function.entry.type.is_overridable:
result = "(%s = 1, %s)" % (Naming.skip_dispatch_cname, result) # result = "(%s = 1, %s)" % (Naming.skip_dispatch_cname, result)
return result return result
def generate_result_code(self, code): def generate_result_code(self, code):
......
...@@ -76,7 +76,6 @@ print_function_kwargs = pyrex_prefix + "print_kwargs" ...@@ -76,7 +76,6 @@ print_function_kwargs = pyrex_prefix + "print_kwargs"
cleanup_cname = pyrex_prefix + "module_cleanup" cleanup_cname = pyrex_prefix + "module_cleanup"
pymoduledef_cname = pyrex_prefix + "moduledef" pymoduledef_cname = pyrex_prefix + "moduledef"
optional_args_cname = pyrex_prefix + "optional_args" optional_args_cname = pyrex_prefix + "optional_args"
no_opt_args = pyrex_prefix + "no_opt_args"
import_star = pyrex_prefix + "import_star" import_star = pyrex_prefix + "import_star"
import_star_set = pyrex_prefix + "import_star_set" import_star_set = pyrex_prefix + "import_star_set"
cur_scope_cname = pyrex_prefix + "cur_scope" cur_scope_cname = pyrex_prefix + "cur_scope"
......
...@@ -1027,7 +1027,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1027,7 +1027,7 @@ class FuncDefNode(StatNode, BlockNode):
code.exit_cfunc_scope() code.exit_cfunc_scope()
if self.py_func: if self.py_func:
self.py_func.generate_function_definitions(env, code) self.py_func.generate_function_definitions(env, code)
self.generate_optarg_wrapper_function(env, code) self.generate_wrapper_functions(code)
def declare_argument(self, env, arg): def declare_argument(self, env, arg):
if arg.type.is_void: if arg.type.is_void:
...@@ -1036,7 +1036,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1036,7 +1036,8 @@ class FuncDefNode(StatNode, BlockNode):
error(arg.pos, error(arg.pos,
"Argument type '%s' is incomplete" % arg.type) "Argument type '%s' is incomplete" % arg.type)
return env.declare_arg(arg.name, arg.type, arg.pos) return env.declare_arg(arg.name, arg.type, arg.pos)
def generate_optarg_wrapper_function(self, env, code):
def generate_wrapper_functions(self, code):
pass pass
def generate_execution_code(self, code): def generate_execution_code(self, code):
...@@ -1093,6 +1094,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -1093,6 +1094,7 @@ class CFuncDefNode(FuncDefNode):
# with_gil boolean Acquire GIL around body # with_gil boolean Acquire GIL around body
# type CFuncType # type CFuncType
# py_func wrapper for calling from Python # py_func wrapper for calling from Python
# overridable whether or not this is a cpdef function
child_attrs = ["base_type", "declarator", "body", "py_func"] child_attrs = ["base_type", "declarator", "body", "py_func"]
...@@ -1188,21 +1190,22 @@ class CFuncDefNode(FuncDefNode): ...@@ -1188,21 +1190,22 @@ class CFuncDefNode(FuncDefNode):
if self.overridable: if self.overridable:
self.py_func.analyse_expressions(env) self.py_func.analyse_expressions(env)
def generate_function_header(self, code, with_pymethdef, with_opt_args = 1): def generate_function_header(self, code, with_pymethdef, with_opt_args = 1, with_dispatch = 1, cname = None):
arg_decls = [] arg_decls = []
type = self.type type = self.type
visibility = self.entry.visibility visibility = self.entry.visibility
for arg in type.args[:len(type.args)-type.optional_arg_count]: for arg in type.args[:len(type.args)-type.optional_arg_count]:
arg_decls.append(arg.declaration_code()) arg_decls.append(arg.declaration_code())
if with_dispatch and self.overridable:
arg_decls.append(PyrexTypes.c_int_type.declaration_code(Naming.skip_dispatch_cname))
if type.optional_arg_count and with_opt_args: if type.optional_arg_count and with_opt_args:
arg_decls.append(type.op_arg_struct.declaration_code(Naming.optional_args_cname)) arg_decls.append(type.op_arg_struct.declaration_code(Naming.optional_args_cname))
if type.has_varargs: if type.has_varargs:
arg_decls.append("...") arg_decls.append("...")
if not arg_decls: if not arg_decls:
arg_decls = ["void"] arg_decls = ["void"]
if cname is None:
cname = self.entry.func_cname cname = self.entry.func_cname
if not with_opt_args:
cname += Naming.no_opt_args
entity = type.function_header_code(cname, string.join(arg_decls, ", ")) entity = type.function_header_code(cname, string.join(arg_decls, ", "))
if visibility == 'public': if visibility == 'public':
dll_linkage = "DL_EXPORT" dll_linkage = "DL_EXPORT"
...@@ -1280,15 +1283,33 @@ class CFuncDefNode(FuncDefNode): ...@@ -1280,15 +1283,33 @@ class CFuncDefNode(FuncDefNode):
def caller_will_check_exceptions(self): def caller_will_check_exceptions(self):
return self.entry.type.exception_check return self.entry.type.exception_check
def generate_optarg_wrapper_function(self, env, code): def generate_wrapper_functions(self, code):
if self.type.optional_arg_count and \ # If the C signature of a function has changed, we need to generate
self.type.original_sig and not self.type.original_sig.optional_arg_count: # wrappers to put in the slots here.
k = 0
entry = self.entry
func_type = entry.type
while entry.prev_entry is not None:
k += 1
entry = entry.prev_entry
entry.func_cname = "%s%swrap_%s" % (self.entry.func_cname, Naming.pyrex_prefix, k)
code.putln() code.putln()
self.generate_function_header(code, 0, with_opt_args = 0) self.generate_function_header(code,
0,
with_dispatch = entry.type.is_overridable,
with_opt_args = entry.type.optional_arg_count,
cname = entry.func_cname)
if not self.return_type.is_void: if not self.return_type.is_void:
code.put('return ') code.put('return ')
args = self.type.args args = self.type.args
arglist = [arg.cname for arg in args[:len(args)-self.type.optional_arg_count]] arglist = [arg.cname for arg in args[:len(args)-self.type.optional_arg_count]]
if entry.type.is_overridable:
arglist.append(Naming.skip_dispatch_cname)
elif func_type.is_overridable:
arglist.append('0')
if entry.type.optional_arg_count:
arglist.append(Naming.optional_args_cname)
elif func_type.optional_arg_count:
arglist.append('NULL') arglist.append('NULL')
code.putln('%s(%s);' % (self.entry.func_cname, ', '.join(arglist))) code.putln('%s(%s);' % (self.entry.func_cname, ', '.join(arglist)))
code.putln('}') code.putln('}')
...@@ -2070,7 +2091,7 @@ class OverrideCheckNode(StatNode): ...@@ -2070,7 +2091,7 @@ class OverrideCheckNode(StatNode):
else: else:
self_arg = "((PyObject *)%s)" % self.args[0].cname self_arg = "((PyObject *)%s)" % self.args[0].cname
code.putln("/* Check if called by wrapper */") code.putln("/* Check if called by wrapper */")
code.putln("if (unlikely(%s)) %s = 0;" % (Naming.skip_dispatch_cname, Naming.skip_dispatch_cname)) code.putln("if (unlikely(%s)) ;" % Naming.skip_dispatch_cname)
code.putln("/* Check if overriden in Python */") code.putln("/* Check if overriden in Python */")
if self.py_func.is_module_scope: if self.py_func.is_module_scope:
code.putln("else {") code.putln("else {")
......
...@@ -732,7 +732,7 @@ class CFuncType(CType): ...@@ -732,7 +732,7 @@ class CFuncType(CType):
return 1 return 1
if not other_type.is_cfunction: if not other_type.is_cfunction:
return 0 return 0
if not self.is_overridable and other_type.is_overridable: if self.is_overridable != other_type.is_overridable:
return 0 return 0
nargs = len(self.args) nargs = len(self.args)
if nargs != len(other_type.args): if nargs != len(other_type.args):
...@@ -846,6 +846,8 @@ class CFuncType(CType): ...@@ -846,6 +846,8 @@ class CFuncType(CType):
for arg in self.args[:len(self.args)-self.optional_arg_count]: for arg in self.args[:len(self.args)-self.optional_arg_count]:
arg_decl_list.append( arg_decl_list.append(
arg.type.declaration_code("", for_display, pyrex = pyrex)) arg.type.declaration_code("", for_display, pyrex = pyrex))
if self.is_overridable:
arg_decl_list.append("int %s" % Naming.skip_dispatch_cname)
if self.optional_arg_count: if self.optional_arg_count:
arg_decl_list.append(self.op_arg_struct.declaration_code(Naming.optional_args_cname)) arg_decl_list.append(self.op_arg_struct.declaration_code(Naming.optional_args_cname))
if self.has_varargs: if self.has_varargs:
......
...@@ -138,6 +138,7 @@ class Entry: ...@@ -138,6 +138,7 @@ class Entry:
utility_code = None utility_code = None
is_overridable = 0 is_overridable = 0
buffer_aux = None buffer_aux = None
prev_entry = None
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -277,7 +278,7 @@ class Scope: ...@@ -277,7 +278,7 @@ class Scope:
if name and dict.has_key(name): if name and dict.has_key(name):
if visibility == 'extern': if visibility == 'extern':
warning(pos, "'%s' redeclared " % name, 0) warning(pos, "'%s' redeclared " % name, 0)
else: elif visibility != 'ignore':
error(pos, "'%s' redeclared " % name) error(pos, "'%s' redeclared " % name)
entry = Entry(name, cname, type, pos = pos) entry = Entry(name, cname, type, pos = pos)
entry.in_cinclude = self.in_cinclude entry.in_cinclude = self.in_cinclude
...@@ -1411,22 +1412,8 @@ class CClassScope(ClassScope): ...@@ -1411,22 +1412,8 @@ class CClassScope(ClassScope):
if type.same_c_signature_as(entry.type, as_cmethod = 1) and type.nogil == entry.type.nogil: if type.same_c_signature_as(entry.type, as_cmethod = 1) and type.nogil == entry.type.nogil:
pass pass
elif type.compatible_signature_with(entry.type, as_cmethod = 1) and type.nogil == entry.type.nogil: elif type.compatible_signature_with(entry.type, as_cmethod = 1) and type.nogil == entry.type.nogil:
if type.optional_arg_count and not type.original_sig.optional_arg_count: entry = self.add_cfunction(name, type, pos, cname or name, visibility='ignore')
# Need to put a wrapper taking no optional arguments
# into the method table.
wrapper_func_cname = self.mangle(Naming.func_prefix, name) + Naming.no_opt_args
wrapper_func_name = name + Naming.no_opt_args
if entry.type.optional_arg_count:
old_entry = self.lookup_here(wrapper_func_name)
old_entry.func_cname = wrapper_func_cname
else:
entry.func_cname = wrapper_func_cname
entry.name = wrapper_func_name
entry = self.add_cfunction(name, type, pos, cname or name, visibility)
defining = 1 defining = 1
entry.type = type
# if type.narrower_c_signature_than(entry.type, as_cmethod = 1):
# entry.type = type
else: else:
error(pos, "Signature not compatible with previous declaration") error(pos, "Signature not compatible with previous declaration")
error(entry.pos, "Previous declaration is here") error(entry.pos, "Previous declaration is here")
...@@ -1442,8 +1429,10 @@ class CClassScope(ClassScope): ...@@ -1442,8 +1429,10 @@ class CClassScope(ClassScope):
def add_cfunction(self, name, type, pos, cname, visibility): def add_cfunction(self, name, type, pos, cname, visibility):
# Add a cfunction entry without giving it a func_cname. # Add a cfunction entry without giving it a func_cname.
prev_entry = self.lookup_here(name)
entry = ClassScope.add_cfunction(self, name, type, pos, cname, visibility) entry = ClassScope.add_cfunction(self, name, type, pos, cname, visibility)
entry.is_cmethod = 1 entry.is_cmethod = 1
entry.prev_entry = prev_entry
return entry return entry
def declare_property(self, name, doc, pos): def declare_property(self, name, doc, pos):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment