Commit 331758a4 authored by Robert Bradshaw's avatar Robert Bradshaw

Fix optional cdef arguments for c++, possible optimization when not all args are used.

parent 20383d1f
...@@ -1680,6 +1680,7 @@ class SimpleCallNode(CallNode): ...@@ -1680,6 +1680,7 @@ class SimpleCallNode(CallNode):
# self ExprNode or None used internally # self ExprNode or None used internally
# coerced_self ExprNode or None used internally # coerced_self ExprNode or None used internally
# wrapper_call bool used internally # wrapper_call bool used internally
# has_optional_args bool used internally
subexprs = ['self', 'coerced_self', 'function', 'args', 'arg_tuple'] subexprs = ['self', 'coerced_self', 'function', 'args', 'arg_tuple']
...@@ -1687,6 +1688,7 @@ class SimpleCallNode(CallNode): ...@@ -1687,6 +1688,7 @@ class SimpleCallNode(CallNode):
coerced_self = None coerced_self = None
arg_tuple = None arg_tuple = None
wrapper_call = False wrapper_call = False
has_optional_args = False
def compile_time_value(self, denv): def compile_time_value(self, denv):
function = self.function.compile_time_value(denv) function = self.function.compile_time_value(denv)
...@@ -1773,6 +1775,11 @@ class SimpleCallNode(CallNode): ...@@ -1773,6 +1775,11 @@ class SimpleCallNode(CallNode):
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
if func_type.optional_arg_count and expected_nargs != actual_nargs:
self.has_optional_args = 1
self.is_temp = 1
self.opt_arg_struct = env.allocate_temp(func_type.op_arg_struct.base_type)
env.release_temp(self.opt_arg_struct)
# Coerce arguments # Coerce arguments
for i in range(min(max_nargs, actual_nargs)): for i in range(min(max_nargs, actual_nargs)):
formal_type = func_type.args[i].type formal_type = func_type.args[i].type
...@@ -1818,15 +1825,7 @@ class SimpleCallNode(CallNode): ...@@ -1818,15 +1825,7 @@ class SimpleCallNode(CallNode):
if expected_nargs == actual_nargs: if expected_nargs == actual_nargs:
optional_args = 'NULL' optional_args = 'NULL'
else: else:
optional_arg_code = [str(actual_nargs - expected_nargs)] optional_args = "&%s" % self.opt_arg_struct
for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
arg_code = actual_arg.result_as(formal_arg.type)
optional_arg_code.append(arg_code)
# for formal_arg in formal_args[actual_nargs:max_nargs]:
# optional_arg_code.append(formal_arg.type.cast_code('0'))
optional_arg_struct = '{%s}' % ','.join(optional_arg_code)
optional_args = PyrexTypes.c_void_ptr_type.cast_code(
'&' + func_type.op_arg_struct.base_type.cast_code(optional_arg_struct))
arg_list_code.append(optional_args) arg_list_code.append(optional_args)
for actual_arg in self.args[len(formal_args):]: for actual_arg in self.args[len(formal_args):]:
...@@ -1849,6 +1848,19 @@ class SimpleCallNode(CallNode): ...@@ -1849,6 +1848,19 @@ class SimpleCallNode(CallNode):
arg_code, arg_code,
code.error_goto_if_null(self.result_code, self.pos))) code.error_goto_if_null(self.result_code, self.pos)))
elif func_type.is_cfunction: elif func_type.is_cfunction:
if self.has_optional_args:
actual_nargs = len(self.args)
expected_nargs = len(func_type.args) - func_type.optional_arg_count
code.putln("%s.%s = %s;" % (
self.opt_arg_struct,
Naming.pyrex_prefix + "n",
len(self.args) - expected_nargs))
args = zip(func_type.args, self.args)
for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
code.putln("%s.%s = %s;" % (
self.opt_arg_struct,
formal_arg.name,
actual_arg.result_as(formal_arg.type)))
exc_checks = [] exc_checks = []
if self.type.is_pyobject: if self.type.is_pyobject:
exc_checks.append("!%s" % self.result_code) exc_checks.append("!%s" % self.result_code)
...@@ -1883,12 +1895,12 @@ class SimpleCallNode(CallNode): ...@@ -1883,12 +1895,12 @@ class SimpleCallNode(CallNode):
rhs, rhs,
raise_py_exception, raise_py_exception,
code.error_goto(self.pos))) code.error_goto(self.pos)))
return else:
code.putln( if exc_checks:
"%s%s; %s" % ( goto_error = code.error_goto_if(" && ".join(exc_checks), self.pos)
lhs, else:
rhs, goto_error = ""
code.error_goto_if(" && ".join(exc_checks), self.pos))) code.putln("%s%s; %s" % (lhs, rhs, goto_error))
class GeneralCallNode(CallNode): class GeneralCallNode(CallNode):
# General Python function call, including keyword, # General Python function call, including keyword,
......
...@@ -7,7 +7,6 @@ from Cython.Distutils import build_ext ...@@ -7,7 +7,6 @@ from Cython.Distutils import build_ext
ext_modules=[ ext_modules=[
Extension("primes", ["primes.pyx"]), Extension("primes", ["primes.pyx"]),
Extension("spam", ["spam.pyx"]), Extension("spam", ["spam.pyx"]),
# Extension("optargs", ["optargs.pyx"], language = "c++"),
] ]
for file in glob.glob("*.pyx"): for file in glob.glob("*.pyx"):
......
...@@ -2,6 +2,11 @@ __doc__ = u""" ...@@ -2,6 +2,11 @@ __doc__ = u"""
>>> call2() >>> call2()
>>> call3() >>> call3()
>>> call4() >>> call4()
>>> test_foo()
2
3
7
26
""" """
# the calls: # the calls:
...@@ -19,3 +24,13 @@ def call4(): ...@@ -19,3 +24,13 @@ def call4():
cdef b(a, b, c=1, d=2): cdef b(a, b, c=1, d=2):
pass pass
cdef int foo(int a, int b=1, int c=1):
return a+b*c
def test_foo():
print foo(1)
print foo(1, 2)
print foo(1, 2, 3)
print foo(1, foo(2, 3), foo(4))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment