Commit 87c3455c authored by Robert Bradshaw's avatar Robert Bradshaw

Optional arguments in cpdef functions

parent 68592200
...@@ -1514,13 +1514,19 @@ class SimpleCallNode(ExprNode): ...@@ -1514,13 +1514,19 @@ class SimpleCallNode(ExprNode):
self.result_code = "<error>" self.result_code = "<error>"
return return
# Check no. of args # Check no. of args
expected_nargs = len(func_type.args) max_nargs = len(func_type.args)
expected_nargs = max_nargs - func_type.optional_arg_count
actual_nargs = len(self.args) actual_nargs = len(self.args)
if actual_nargs < expected_nargs \ if actual_nargs < expected_nargs \
or (not func_type.has_varargs and actual_nargs > expected_nargs): or (not func_type.has_varargs and actual_nargs > max_nargs):
expected_str = str(expected_nargs) expected_str = str(expected_nargs)
if func_type.has_varargs: if func_type.has_varargs:
expected_str = "at least " + expected_str expected_str = "at least " + expected_str
elif func_type.optional_arg_count:
if actual_nargs > max_nargs:
expected_str = "at least " + expected_str
else:
expected_str = "at most " + str(max_nargs)
error(self.pos, error(self.pos,
"Call with wrong number of arguments (expected %s, got %s)" "Call with wrong number of arguments (expected %s, got %s)"
% (expected_str, actual_nargs)) % (expected_str, actual_nargs))
...@@ -1529,10 +1535,10 @@ class SimpleCallNode(ExprNode): ...@@ -1529,10 +1535,10 @@ class SimpleCallNode(ExprNode):
self.result_code = "<error>" self.result_code = "<error>"
return return
# Coerce arguments # Coerce arguments
for i in range(expected_nargs): for i in range(min(max_nargs, actual_nargs)):
formal_type = func_type.args[i].type formal_type = func_type.args[i].type
self.args[i] = self.args[i].coerce_to(formal_type, env) self.args[i] = self.args[i].coerce_to(formal_type, env)
for i in range(expected_nargs, actual_nargs): for i in range(max_nargs, actual_nargs):
if self.args[i].type.is_pyobject: if self.args[i].type.is_pyobject:
error(self.args[i].pos, error(self.args[i].pos,
"Python object cannot be passed as a varargs parameter") "Python object cannot be passed as a varargs parameter")
...@@ -1558,10 +1564,14 @@ class SimpleCallNode(ExprNode): ...@@ -1558,10 +1564,14 @@ class SimpleCallNode(ExprNode):
zip(formal_args, self.args): zip(formal_args, self.args):
arg_code = actual_arg.result_as(formal_arg.type) arg_code = actual_arg.result_as(formal_arg.type)
arg_list_code.append(arg_code) arg_list_code.append(arg_code)
if func_type.optional_arg_count:
for formal_arg in formal_args[len(self.args):]:
arg_list_code.append(formal_arg.type.cast_code('0'))
arg_list_code.append(str(max(0, len(formal_args) - len(self.args))))
for actual_arg in self.args[len(formal_args):]: for actual_arg in self.args[len(formal_args):]:
arg_list_code.append(actual_arg.result_code) arg_list_code.append(actual_arg.result_code)
result = "%s(%s)" % (self.function.result_code, result = "%s(%s)" % (self.function.result_code,
join(arg_list_code, ",")) join(arg_list_code, ", "))
if self.wrapper_call or \ if self.wrapper_call or \
self.function.entry.is_unbound_cmethod and self.function.entry.type.is_overridable: self.function.entry.is_unbound_cmethod and self.function.entry.type.is_overridable:
result = "(%s = 1, %s)" % (Naming.skip_dispatch_cname, result) result = "(%s = 1, %s)" % (Naming.skip_dispatch_cname, result)
...@@ -3539,6 +3549,9 @@ class PyTypeTestNode(CoercionNode): ...@@ -3539,6 +3549,9 @@ class PyTypeTestNode(CoercionNode):
self.result_ctype = arg.ctype() self.result_ctype = arg.ctype()
env.use_utility_code(type_test_utility_code) env.use_utility_code(type_test_utility_code)
def analyse_types(self, env):
pass
def result_in_temp(self): def result_in_temp(self):
return self.arg.result_in_temp() return self.arg.result_in_temp()
...@@ -3552,7 +3565,7 @@ class PyTypeTestNode(CoercionNode): ...@@ -3552,7 +3565,7 @@ class PyTypeTestNode(CoercionNode):
if self.type.typeobj_is_available(): if self.type.typeobj_is_available():
code.putln( code.putln(
"if (!__Pyx_TypeTest(%s, %s)) %s" % ( "if (!__Pyx_TypeTest(%s, %s)) %s" % (
self.arg.py_result(), self.arg.py_result(),
self.type.typeptr_cname, self.type.typeptr_cname,
code.error_goto(self.pos))) code.error_goto(self.pos)))
else: else:
......
...@@ -60,6 +60,7 @@ gilstate_cname = pyrex_prefix + "state" ...@@ -60,6 +60,7 @@ gilstate_cname = pyrex_prefix + "state"
skip_dispatch_cname = pyrex_prefix + "skip_dispatch" skip_dispatch_cname = pyrex_prefix + "skip_dispatch"
empty_tuple = pyrex_prefix + "empty_tuple" empty_tuple = pyrex_prefix + "empty_tuple"
cleanup_cname = pyrex_prefix + "module_cleanup" cleanup_cname = pyrex_prefix + "module_cleanup"
optional_count_cname = pyrex_prefix + "optional_arg_count"
extern_c_macro = pyrex_prefix.upper() + "EXTERN_C" extern_c_macro = pyrex_prefix.upper() + "EXTERN_C"
......
...@@ -316,6 +316,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -316,6 +316,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
# with_gil boolean Acquire gil around function body # with_gil boolean Acquire gil around function body
overridable = 0 overridable = 0
optional_arg_count = 0
def analyse(self, return_type, env): def analyse(self, return_type, env):
func_type_args = [] func_type_args = []
...@@ -337,7 +338,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -337,7 +338,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
func_type_args.append( func_type_args.append(
PyrexTypes.CFuncTypeArg(name, type, arg_node.pos)) PyrexTypes.CFuncTypeArg(name, type, arg_node.pos))
if arg_node.default: if arg_node.default:
error(arg_node.pos, "C function argument cannot have default value") self.optional_arg_count += 1
exc_val = None exc_val = None
exc_check = 0 exc_check = 0
if return_type.is_pyobject \ if return_type.is_pyobject \
...@@ -363,6 +364,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -363,6 +364,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
"Function cannot return a function") "Function cannot return a function")
func_type = PyrexTypes.CFuncType( func_type = PyrexTypes.CFuncType(
return_type, func_type_args, self.has_varargs, return_type, func_type_args, self.has_varargs,
optional_arg_count = self.optional_arg_count,
exception_value = exc_val, exception_check = exc_check, exception_value = exc_val, exception_check = exc_check,
calling_convention = self.base.calling_convention, calling_convention = self.base.calling_convention,
nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
...@@ -609,10 +611,24 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -609,10 +611,24 @@ class FuncDefNode(StatNode, BlockNode):
# entry Symtab.Entry # entry Symtab.Entry
py_func = None py_func = None
assmt = None
def analyse_default_values(self, env):
genv = env.global_scope()
for arg in self.args:
if arg.default:
if arg.is_generic:
if not hasattr(arg, 'default_entry'):
arg.default.analyse_types(genv)
arg.default = arg.default.coerce_to(arg.type, genv)
arg.default.allocate_temps(genv)
arg.default_entry = genv.add_default_value(arg.type)
arg.default_entry.used = 1
else:
error(arg.pos,
"This argument cannot have a default value")
arg.default = None
def analyse_expressions(self, env):
pass
def need_gil_acquisition(self, lenv): def need_gil_acquisition(self, lenv):
return 0 return 0
...@@ -749,7 +765,24 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -749,7 +765,24 @@ class FuncDefNode(StatNode, BlockNode):
code.put_var_incref(entry) code.put_var_incref(entry)
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass # Evaluate and store argument default values
for arg in self.args:
default = arg.default
if default:
default.generate_evaluation_code(code)
default.make_owned_reference(code)
code.putln(
"%s = %s;" % (
arg.default_entry.cname,
default.result_as(arg.default_entry.type)))
if default.is_temp and default.type.is_pyobject:
code.putln(
"%s = 0;" %
default.result_code)
# For Python class methods, create and store function object
if self.assmt:
self.assmt.generate_execution_code(code)
class CFuncDefNode(FuncDefNode): class CFuncDefNode(FuncDefNode):
...@@ -826,12 +859,20 @@ class CFuncDefNode(FuncDefNode): ...@@ -826,12 +859,20 @@ class CFuncDefNode(FuncDefNode):
error(self.pos, "Function declared nogil has Python locals or temporaries") error(self.pos, "Function declared nogil has Python locals or temporaries")
return with_gil return with_gil
def analyse_expressions(self, env):
self.args = self.declarator.args
self.analyse_default_values(env)
if self.overridable:
self.py_func.analyse_expressions(env)
def generate_function_header(self, code, with_pymethdef): def generate_function_header(self, code, with_pymethdef):
arg_decls = [] arg_decls = []
type = self.type type = self.type
visibility = self.entry.visibility visibility = self.entry.visibility
for arg in type.args: for arg in type.args:
arg_decls.append(arg.declaration_code()) arg_decls.append(arg.declaration_code())
if type.optional_arg_count:
arg_decls.append("int %s" % Naming.optional_count_cname)
if type.has_varargs: if type.has_varargs:
arg_decls.append("...") arg_decls.append("...")
if not arg_decls: if not arg_decls:
...@@ -861,7 +902,17 @@ class CFuncDefNode(FuncDefNode): ...@@ -861,7 +902,17 @@ class CFuncDefNode(FuncDefNode):
pass pass
def generate_argument_parsing_code(self, code): def generate_argument_parsing_code(self, code):
pass rev_args = zip(self.declarator.args, self.type.args)
rev_args.reverse()
i = 0
for darg, targ in rev_args:
if darg.default:
code.putln('if (%s > %s) {' % (Naming.optional_count_cname, i))
code.putln('%s = %s;' % (targ.cname, darg.default_entry.cname))
i += 1
for _ in range(i):
code.putln('}')
code.putln('/* defaults */')
def generate_argument_conversion_code(self, code): def generate_argument_conversion_code(self, code):
pass pass
...@@ -1102,21 +1153,6 @@ class DefNode(FuncDefNode): ...@@ -1102,21 +1153,6 @@ class DefNode(FuncDefNode):
if env.is_py_class_scope: if env.is_py_class_scope:
self.synthesize_assignment_node(env) self.synthesize_assignment_node(env)
def analyse_default_values(self, env):
genv = env.global_scope()
for arg in self.args:
if arg.default:
if arg.is_generic:
arg.default.analyse_types(genv)
arg.default = arg.default.coerce_to(arg.type, genv)
arg.default.allocate_temps(genv)
arg.default_entry = genv.add_default_value(arg.type)
arg.default_entry.used = 1
else:
error(arg.pos,
"This argument cannot have a default value")
arg.default = None
def synthesize_assignment_node(self, env): def synthesize_assignment_node(self, env):
import ExprNodes import ExprNodes
self.assmt = SingleAssignmentNode(self.pos, self.assmt = SingleAssignmentNode(self.pos,
...@@ -1485,25 +1521,6 @@ class DefNode(FuncDefNode): ...@@ -1485,25 +1521,6 @@ class DefNode(FuncDefNode):
error(arg.pos, "Cannot test type of extern C class " error(arg.pos, "Cannot test type of extern C class "
"without type object name specification") "without type object name specification")
def generate_execution_code(self, code):
# Evaluate and store argument default values
for arg in self.args:
default = arg.default
if default:
default.generate_evaluation_code(code)
default.make_owned_reference(code)
code.putln(
"%s = %s;" % (
arg.default_entry.cname,
default.result_as(arg.default_entry.type)))
if default.is_temp and default.type.is_pyobject:
code.putln(
"%s = 0;" %
default.result_code)
# For Python class methods, create and store function object
if self.assmt:
self.assmt.generate_execution_code(code)
def error_value(self): def error_value(self):
return self.entry.signature.error_value return self.entry.signature.error_value
......
...@@ -1710,8 +1710,8 @@ def p_api(s): ...@@ -1710,8 +1710,8 @@ def p_api(s):
def p_cdef_statement(s, level, visibility = 'private', api = 0, def p_cdef_statement(s, level, visibility = 'private', api = 0,
overridable = False): overridable = False):
pos = s.position() pos = s.position()
if overridable and level not in ('c_class', 'c_class_pxd'): # if overridable and level not in ('c_class', 'c_class_pxd'):
error(pos, "Overridable cdef function not allowed here") # error(pos, "Overridable cdef function not allowed here")
visibility = p_visibility(s, visibility) visibility = p_visibility(s, visibility)
api = api or p_api(s) api = api or p_api(s)
if api: if api:
......
...@@ -587,10 +587,11 @@ class CFuncType(CType): ...@@ -587,10 +587,11 @@ class CFuncType(CType):
def __init__(self, return_type, args, has_varargs = 0, def __init__(self, return_type, args, has_varargs = 0,
exception_value = None, exception_check = 0, calling_convention = "", exception_value = None, exception_check = 0, calling_convention = "",
nogil = 0, with_gil = 0, is_overridable = 0): nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0):
self.return_type = return_type self.return_type = return_type
self.args = args self.args = args
self.has_varargs = has_varargs self.has_varargs = has_varargs
self.optional_arg_count = optional_arg_count
self.exception_value = exception_value self.exception_value = exception_value
self.exception_check = exception_check self.exception_check = exception_check
self.calling_convention = calling_convention self.calling_convention = calling_convention
...@@ -639,6 +640,8 @@ class CFuncType(CType): ...@@ -639,6 +640,8 @@ class CFuncType(CType):
return 0 return 0
if self.has_varargs <> other_type.has_varargs: if self.has_varargs <> other_type.has_varargs:
return 0 return 0
if self.optional_arg_count <> other_type.optional_arg_count:
return 0
if not self.return_type.same_as(other_type.return_type): if not self.return_type.same_as(other_type.return_type):
return 0 return 0
if not self.same_calling_convention_as(other_type): if not self.same_calling_convention_as(other_type):
...@@ -664,6 +667,8 @@ class CFuncType(CType): ...@@ -664,6 +667,8 @@ class CFuncType(CType):
or not self.args[i].type.same_as(other_type.args[i].type) or not self.args[i].type.same_as(other_type.args[i].type)
if self.has_varargs <> other_type.has_varargs: if self.has_varargs <> other_type.has_varargs:
return 0 return 0
if self.optional_arg_count <> other_type.optional_arg_count:
return 0
if not self.return_type.subtype_of_resolved_type(other_type.return_type): if not self.return_type.subtype_of_resolved_type(other_type.return_type):
return 0 return 0
return 1 return 1
...@@ -691,6 +696,8 @@ class CFuncType(CType): ...@@ -691,6 +696,8 @@ class CFuncType(CType):
for arg in self.args: for arg in self.args:
arg_decl_list.append( arg_decl_list.append(
arg.type.declaration_code("", for_display, pyrex = pyrex)) arg.type.declaration_code("", for_display, pyrex = pyrex))
if self.optional_arg_count:
arg_decl_list.append("int %s" % Naming.optional_count_cname)
if self.has_varargs: if self.has_varargs:
arg_decl_list.append("...") arg_decl_list.append("...")
arg_decl_code = string.join(arg_decl_list, ",") arg_decl_code = string.join(arg_decl_list, ",")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment