Commit a4c3ab1c authored by Robert Bradshaw's avatar Robert Bradshaw

Default cdef args via struct

parent 5882c337
...@@ -1560,14 +1560,25 @@ class SimpleCallNode(ExprNode): ...@@ -1560,14 +1560,25 @@ class SimpleCallNode(ExprNode):
return "<error>" return "<error>"
formal_args = func_type.args formal_args = func_type.args
arg_list_code = [] arg_list_code = []
for (formal_arg, actual_arg) in \ args = zip(formal_args, self.args)
zip(formal_args, self.args): max_nargs = len(func_type.args)
expected_nargs = max_nargs - func_type.optional_arg_count
actual_nargs = len(self.args)
for formal_arg, actual_arg in args[:expected_nargs]:
arg_code = actual_arg.result_as(formal_arg.type) arg_code = actual_arg.result_as(formal_arg.type)
arg_list_code.append(arg_code) arg_list_code.append(arg_code)
if func_type.optional_arg_count: if func_type.optional_arg_count:
for formal_arg in formal_args[len(self.args):]: if expected_nargs == actual_nargs:
arg_list_code.append(formal_arg.type.cast_code('0')) arg_list_code.append(func_type.op_args.cast_code('NULL'))
arg_list_code.append(str(max(0, len(formal_args) - len(self.args)))) else:
optional_arg_code = [str(actual_nargs - expected_nargs)]
for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
arg_code = actual_arg.result_as(formal_arg.type)
optional_arg_code.append(arg_code)
# for formal_arg in formal_args[actual_nargs:max_nargs]:
# optional_arg_code.append(formal_arg.type.cast_code('0'))
optional_arg_struct = '{%s}' % ','.join(optional_arg_code)
arg_list_code.append('&' + func_type.op_args.base_type.cast_code(optional_arg_struct))
for actual_arg in self.args[len(formal_args):]: for actual_arg in self.args[len(formal_args):]:
arg_list_code.append(actual_arg.result_code) arg_list_code.append(actual_arg.result_code)
result = "%s(%s)" % (self.function.result_code, result = "%s(%s)" % (self.function.result_code,
......
...@@ -32,6 +32,7 @@ var_prefix = pyrex_prefix + "v_" ...@@ -32,6 +32,7 @@ var_prefix = pyrex_prefix + "v_"
vtable_prefix = pyrex_prefix + "vtable_" vtable_prefix = pyrex_prefix + "vtable_"
vtabptr_prefix = pyrex_prefix + "vtabptr_" vtabptr_prefix = pyrex_prefix + "vtabptr_"
vtabstruct_prefix = pyrex_prefix + "vtabstruct_" vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
opt_arg_prefix = pyrex_prefix + "opt_args_"
args_cname = pyrex_prefix + "args" args_cname = pyrex_prefix + "args"
kwdlist_cname = pyrex_prefix + "argnames" kwdlist_cname = pyrex_prefix + "argnames"
...@@ -60,7 +61,7 @@ gilstate_cname = pyrex_prefix + "state" ...@@ -60,7 +61,7 @@ gilstate_cname = pyrex_prefix + "state"
skip_dispatch_cname = pyrex_prefix + "skip_dispatch" skip_dispatch_cname = pyrex_prefix + "skip_dispatch"
empty_tuple = pyrex_prefix + "empty_tuple" empty_tuple = pyrex_prefix + "empty_tuple"
cleanup_cname = pyrex_prefix + "module_cleanup" cleanup_cname = pyrex_prefix + "module_cleanup"
optional_count_cname = pyrex_prefix + "optional_arg_count" optional_args_cname = pyrex_prefix + "optional_args"
extern_c_macro = pyrex_prefix.upper() + "EXTERN_C" extern_c_macro = pyrex_prefix.upper() + "EXTERN_C"
......
...@@ -339,6 +339,21 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -339,6 +339,21 @@ class CFuncDeclaratorNode(CDeclaratorNode):
PyrexTypes.CFuncTypeArg(name, type, arg_node.pos)) PyrexTypes.CFuncTypeArg(name, type, arg_node.pos))
if arg_node.default: if arg_node.default:
self.optional_arg_count += 1 self.optional_arg_count += 1
if self.optional_arg_count:
scope = StructOrUnionScope()
scope.declare_var('n', PyrexTypes.c_int_type, self.pos)
for arg in func_type_args[len(func_type_args)-self.optional_arg_count:]:
scope.declare_var(arg.name, arg.type, arg.pos, allow_pyobject = 1)
struct_cname = Naming.opt_arg_prefix + self.base.name
self.op_args_struct = env.global_scope().declare_struct_or_union(name = struct_cname,
kind = 'struct',
scope = scope,
typedef_flag = 0,
pos = self.pos,
cname = struct_cname)
self.op_args_struct.used = 1
exc_val = None exc_val = None
exc_check = 0 exc_check = 0
if return_type.is_pyobject \ if return_type.is_pyobject \
...@@ -368,6 +383,8 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -368,6 +383,8 @@ class CFuncDeclaratorNode(CDeclaratorNode):
exception_value = exc_val, exception_check = exc_check, exception_value = exc_val, exception_check = exc_check,
calling_convention = self.base.calling_convention, calling_convention = self.base.calling_convention,
nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
if self.optional_arg_count:
func_type.op_args = PyrexTypes.c_ptr_type(self.op_args_struct.type)
return self.base.analyse(func_type, env) return self.base.analyse(func_type, env)
...@@ -383,6 +400,7 @@ class CArgDeclNode(Node): ...@@ -383,6 +400,7 @@ class CArgDeclNode(Node):
# is_kw_only boolean Is a keyword-only argument # is_kw_only boolean Is a keyword-only argument
is_self_arg = 0 is_self_arg = 0
is_generic = 1
def analyse(self, env): def analyse(self, env):
#print "CArgDeclNode.analyse: is_self_arg =", self.is_self_arg ### #print "CArgDeclNode.analyse: is_self_arg =", self.is_self_arg ###
...@@ -813,6 +831,9 @@ class CFuncDefNode(FuncDefNode): ...@@ -813,6 +831,9 @@ class CFuncDefNode(FuncDefNode):
# from the base type of an extension type. # from the base type of an extension type.
self.type = type self.type = type
type.is_overridable = self.overridable type.is_overridable = self.overridable
for formal_arg, type_arg in zip(self.declarator.args, type.args):
formal_arg.type = type_arg.type
formal_arg.cname = type_arg.cname
name = name_declarator.name name = name_declarator.name
cname = name_declarator.cname cname = name_declarator.cname
self.entry = env.declare_cfunction( self.entry = env.declare_cfunction(
...@@ -869,16 +890,16 @@ class CFuncDefNode(FuncDefNode): ...@@ -869,16 +890,16 @@ class CFuncDefNode(FuncDefNode):
arg_decls = [] arg_decls = []
type = self.type type = self.type
visibility = self.entry.visibility visibility = self.entry.visibility
for arg in type.args: for arg in type.args[:len(type.args)-type.optional_arg_count]:
arg_decls.append(arg.declaration_code()) arg_decls.append(arg.declaration_code())
if type.optional_arg_count: if type.optional_arg_count:
arg_decls.append("int %s" % Naming.optional_count_cname) arg_decls.append(type.op_args.declaration_code(Naming.optional_args_cname))
if type.has_varargs: if type.has_varargs:
arg_decls.append("...") arg_decls.append("...")
if not arg_decls: if not arg_decls:
arg_decls = ["void"] arg_decls = ["void"]
entity = type.function_header_code(self.entry.func_cname, entity = type.function_header_code(self.entry.func_cname,
string.join(arg_decls, ",")) string.join(arg_decls, ", "))
if visibility == 'public': if visibility == 'public':
dll_linkage = "DL_EXPORT" dll_linkage = "DL_EXPORT"
else: else:
...@@ -895,24 +916,26 @@ class CFuncDefNode(FuncDefNode): ...@@ -895,24 +916,26 @@ class CFuncDefNode(FuncDefNode):
header)) header))
def generate_argument_declarations(self, env, code): def generate_argument_declarations(self, env, code):
# Arguments already declared in function header for arg in self.declarator.args:
pass if arg.default:
code.putln('%s = %s;' % (arg.type.declaration_code(arg.cname), arg.default_entry.cname))
def generate_keyword_list(self, code): def generate_keyword_list(self, code):
pass pass
def generate_argument_parsing_code(self, code): def generate_argument_parsing_code(self, code):
rev_args = zip(self.declarator.args, self.type.args) rev_args = self.declarator.args
rev_args.reverse()
i = 0 i = 0
for darg, targ in rev_args: if self.type.optional_arg_count:
if darg.default: code.putln('if (%s) {' % Naming.optional_args_cname)
code.putln('if (%s > %s) {' % (Naming.optional_count_cname, i)) for arg in rev_args:
code.putln('%s = %s;' % (targ.cname, darg.default_entry.cname)) if arg.default:
code.putln('if (%s->n > %s) {' % (Naming.optional_args_cname, i))
code.putln('%s = %s->%s;' % (arg.cname, Naming.optional_args_cname, arg.declarator.name))
i += 1 i += 1
for _ in range(i): for _ in range(self.type.optional_arg_count):
code.putln('}')
code.putln('}') code.putln('}')
code.putln('/* defaults */')
def generate_argument_conversion_code(self, code): def generate_argument_conversion_code(self, code):
pass pass
......
...@@ -693,14 +693,14 @@ class CFuncType(CType): ...@@ -693,14 +693,14 @@ class CFuncType(CType):
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
arg_decl_list = [] arg_decl_list = []
for arg in self.args: for arg in self.args[:len(self.args)-self.optional_arg_count]:
arg_decl_list.append( arg_decl_list.append(
arg.type.declaration_code("", for_display, pyrex = pyrex)) arg.type.declaration_code("", for_display, pyrex = pyrex))
if self.optional_arg_count: if self.optional_arg_count:
arg_decl_list.append("int %s" % Naming.optional_count_cname) arg_decl_list.append(self.op_args.declaration_code(Naming.optional_args_cname))
if self.has_varargs: if self.has_varargs:
arg_decl_list.append("...") arg_decl_list.append("...")
arg_decl_code = string.join(arg_decl_list, ",") arg_decl_code = string.join(arg_decl_list, ", ")
if not arg_decl_code and not pyrex: if not arg_decl_code and not pyrex:
arg_decl_code = "void" arg_decl_code = "void"
exc_clause = "" exc_clause = ""
......
...@@ -1058,14 +1058,14 @@ class StructOrUnionScope(Scope): ...@@ -1058,14 +1058,14 @@ class StructOrUnionScope(Scope):
Scope.__init__(self, "?", None, None) Scope.__init__(self, "?", None, None)
def declare_var(self, name, type, pos, def declare_var(self, name, type, pos,
cname = None, visibility = 'private', is_cdef = 0): cname = None, visibility = 'private', is_cdef = 0, allow_pyobject = 0):
# Add an entry for an attribute. # Add an entry for an attribute.
if not cname: if not cname:
cname = name cname = name
entry = self.declare(name, cname, type, pos) entry = self.declare(name, cname, type, pos)
entry.is_variable = 1 entry.is_variable = 1
self.var_entries.append(entry) self.var_entries.append(entry)
if type.is_pyobject: if type.is_pyobject and not allow_pyobject:
error(pos, error(pos,
"C struct/union member cannot be a Python object") "C struct/union member cannot be a Python object")
if visibility <> 'private': if visibility <> 'private':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment