Commit 52981fec authored by Vitja Makarov's avatar Vitja Makarov

Support for dynamic default arguments, fix #674

parent 5bac3a3f
...@@ -5918,6 +5918,9 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -5918,6 +5918,9 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
code_object = None code_object = None
binding = False binding = False
def_node = None def_node = None
defaults = None
defaults_struct = None
defaults_pyobjects = 0
type = py_object_type type = py_object_type
is_temp = 1 is_temp = 1
...@@ -5933,10 +5936,48 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -5933,10 +5936,48 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
env.use_utility_code(fused_function_utility_code) env.use_utility_code(fused_function_utility_code)
else: else:
env.use_utility_code(binding_cfunc_utility_code) env.use_utility_code(binding_cfunc_utility_code)
self.analyse_default_args(env)
#TODO(craig,haoyu) This should be moved to a better place #TODO(craig,haoyu) This should be moved to a better place
self.set_mod_name(env) self.set_mod_name(env)
def analyse_default_args(self, env):
"""
Handle non-literal function's default arguments.
"""
nonliteral_objects = []
nonliteral_other = []
for arg in self.def_node.args:
if arg.default and not arg.default.is_literal:
arg.is_dynamic = True
if arg.type.is_pyobject:
nonliteral_objects.append(arg)
else:
nonliteral_other.append(arg)
if nonliteral_objects or nonliteral_objects:
module_scope = env.global_scope()
cname = module_scope.next_id(Naming.defaults_struct_prefix)
scope = Symtab.StructOrUnionScope(cname)
self.defaults = []
for arg in nonliteral_objects:
entry = scope.declare_var(arg.name, arg.type, None,
Naming.arg_prefix + arg.name,
allow_pyobject=True)
self.defaults.append((arg, entry))
for arg in nonliteral_other:
entry = scope.declare_var(arg.name, arg.type, None,
Naming.arg_prefix + arg.name,
allow_pyobject=False)
self.defaults.append((arg, entry))
entry = module_scope.declare_struct_or_union(
None, 'struct', scope, 1, None, cname=cname)
self.defaults_struct = scope
self.defaults_pyobjects = len(nonliteral_objects)
for arg, entry in self.defaults:
arg.default_value = '%s->%s' % (
Naming.dynamic_args_cname, entry.cname)
self.def_node.defaults_struct = self.defaults_struct.name
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -6018,6 +6059,17 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6018,6 +6059,17 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
self.result())) self.result()))
code.put_giveref(self.py_result()) code.put_giveref(self.py_result())
if self.defaults:
code.putln(
'if (!__Pyx_CyFunction_InitDefaults(%s, sizeof(%s), %d)) %s' % (
self.result(), self.defaults_struct.name,
self.defaults_pyobjects, code.error_goto(self.pos)))
defaults = '__Pyx_CyFunction_Defaults(%s, %s)' % (
self.defaults_struct.name, self.result())
for arg, entry in self.defaults:
arg.generate_assignment_code(code, target='%s->%s' % (
defaults, entry.cname))
if self.specialized_cpdefs: if self.specialized_cpdefs:
self.generate_fused_cpdef(code, code_object_result, flags) self.generate_fused_cpdef(code, code_object_result, flags)
......
...@@ -49,6 +49,8 @@ closure_scope_prefix = pyrex_prefix + "scope_" ...@@ -49,6 +49,8 @@ closure_scope_prefix = pyrex_prefix + "scope_"
closure_class_prefix = pyrex_prefix + "scope_struct_" closure_class_prefix = pyrex_prefix + "scope_struct_"
lambda_func_prefix = pyrex_prefix + "lambda_" lambda_func_prefix = pyrex_prefix + "lambda_"
module_is_main = pyrex_prefix + "module_is_main_" module_is_main = pyrex_prefix + "module_is_main_"
defaults_struct_prefix = pyrex_prefix + "defaults"
dynamic_args_cname = pyrex_prefix + "dynamic_args"
args_cname = pyrex_prefix + "args" args_cname = pyrex_prefix + "args"
generator_cname = pyrex_prefix + "generator" generator_cname = pyrex_prefix + "generator"
......
...@@ -677,6 +677,7 @@ class CArgDeclNode(Node): ...@@ -677,6 +677,7 @@ class CArgDeclNode(Node):
# is_self_arg boolean Is the "self" arg of an extension type method # is_self_arg boolean Is the "self" arg of an extension type method
# is_type_arg boolean Is the "class" arg of an extension type classmethod # is_type_arg boolean Is the "class" arg of an extension type classmethod
# is_kw_only boolean Is a keyword-only argument # is_kw_only boolean Is a keyword-only argument
# is_dynamic boolean Non-literal arg stored inside CyFunction
child_attrs = ["base_type", "declarator", "default"] child_attrs = ["base_type", "declarator", "default"]
...@@ -690,6 +691,7 @@ class CArgDeclNode(Node): ...@@ -690,6 +691,7 @@ class CArgDeclNode(Node):
name_declarator = None name_declarator = None
default_value = None default_value = None
annotation = None annotation = None
is_dynamic = 0
def analyse(self, env, nonempty = 0, is_self_arg = False): def analyse(self, env, nonempty = 0, is_self_arg = False):
if is_self_arg: if is_self_arg:
...@@ -738,6 +740,21 @@ class CArgDeclNode(Node): ...@@ -738,6 +740,21 @@ class CArgDeclNode(Node):
if self.default: if self.default:
self.default.annotate(code) self.default.annotate(code)
def generate_assignment_code(self, code, target=None):
default = self.default
if default is None or default.is_literal:
return
if target is None:
target = self.calculate_default_value_code(code)
default.generate_evaluation_code(code)
default.make_owned_reference(code)
result = default.result_as(self.type)
code.putln("%s = %s;" % (target, result))
if self.type.is_pyobject:
code.put_giveref(default.result())
default.generate_post_assignment_code(code)
default.free_temps(code)
class CBaseTypeNode(Node): class CBaseTypeNode(Node):
# Abstract base class for C base type nodes. # Abstract base class for C base type nodes.
...@@ -1800,20 +1817,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1800,20 +1817,8 @@ class FuncDefNode(StatNode, BlockNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
# Evaluate and store argument default values # Evaluate and store argument default values
for arg in self.args: for arg in self.args:
default = arg.default if not arg.is_dynamic:
if default: arg.generate_assignment_code(code)
if not default.is_literal:
default.generate_evaluation_code(code)
default.make_owned_reference(code)
result = default.result_as(arg.type)
code.putln(
"%s = %s;" % (
arg.calculate_default_value_code(code),
result))
if arg.type.is_pyobject:
code.put_giveref(default.result())
default.generate_post_assignment_code(code)
default.free_temps(code)
# For Python class methods, create and store function object # For Python class methods, create and store function object
if self.assmt: if self.assmt:
self.assmt.generate_execution_code(code) self.assmt.generate_execution_code(code)
...@@ -2645,6 +2650,7 @@ class DefNode(FuncDefNode): ...@@ -2645,6 +2650,7 @@ class DefNode(FuncDefNode):
self_in_stararg = 0 self_in_stararg = 0
py_cfunc_node = None py_cfunc_node = None
requires_classobj = False requires_classobj = False
defaults_struct = None # Dynamic kwrds structure name
doc = None doc = None
fused_py_func = False fused_py_func = False
...@@ -3512,6 +3518,11 @@ class DefNode(FuncDefNode): ...@@ -3512,6 +3518,11 @@ class DefNode(FuncDefNode):
code.putln("PyObject* values[%d] = {%s};" % ( code.putln("PyObject* values[%d] = {%s};" % (
max_args, ','.join('0'*max_args))) max_args, ','.join('0'*max_args)))
if self.defaults_struct:
code.putln('%s *%s = __Pyx_CyFunction_Defaults(%s, %s);' % (
self.defaults_struct, Naming.dynamic_args_cname,
self.defaults_struct, Naming.self_cname))
# assign borrowed Python default values to the values array, # assign borrowed Python default values to the values array,
# so that they can be overwritten by received arguments below # so that they can be overwritten by received arguments below
for i, arg in enumerate(args): for i, arg in enumerate(args):
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
#define __Pyx_CyFunction_GetClassObj(f) \ #define __Pyx_CyFunction_GetClassObj(f) \
(((__pyx_CyFunctionObject *) (f))->func_classobj) (((__pyx_CyFunctionObject *) (f))->func_classobj)
#define __Pyx_CyFunction_Defaults(type, f) \
((type *)(((__pyx_CyFunctionObject *) (f))->defaults))
typedef struct { typedef struct {
PyCFunctionObject func; PyCFunctionObject func;
...@@ -24,6 +27,10 @@ typedef struct { ...@@ -24,6 +27,10 @@ typedef struct {
PyObject *func_code; PyObject *func_code;
PyObject *func_closure; PyObject *func_closure;
PyObject *func_classobj; /* No-args super() class cell */ PyObject *func_classobj; /* No-args super() class cell */
/* Dynamic default args*/
void *defaults;
int defaults_pyobjects;
} __pyx_CyFunctionObject; } __pyx_CyFunctionObject;
static PyTypeObject *__pyx_CyFunctionType = 0; static PyTypeObject *__pyx_CyFunctionType = 0;
...@@ -36,6 +43,11 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *, ...@@ -36,6 +43,11 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *,
PyObject *self, PyObject *module, PyObject *self, PyObject *module,
PyObject* code); PyObject* code);
static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m,
size_t size,
int pyobjects);
static int __Pyx_CyFunction_init(void); static int __Pyx_CyFunction_init(void);
//////////////////// CythonFunction //////////////////// //////////////////// CythonFunction ////////////////////
...@@ -246,6 +258,9 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *type, PyMethodDef *ml, int f ...@@ -246,6 +258,9 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *type, PyMethodDef *ml, int f
op->func_classobj = NULL; op->func_classobj = NULL;
Py_XINCREF(code); Py_XINCREF(code);
op->func_code = code; op->func_code = code;
/* Dynamic Default args */
op->defaults_pyobjects = 0;
op->defaults = NULL;
PyObject_GC_Track(op); PyObject_GC_Track(op);
return (PyObject *) op; return (PyObject *) op;
} }
...@@ -260,6 +275,18 @@ __Pyx_CyFunction_clear(__pyx_CyFunctionObject *m) ...@@ -260,6 +275,18 @@ __Pyx_CyFunction_clear(__pyx_CyFunctionObject *m)
Py_CLEAR(m->func_doc); Py_CLEAR(m->func_doc);
Py_CLEAR(m->func_code); Py_CLEAR(m->func_code);
Py_CLEAR(m->func_classobj); Py_CLEAR(m->func_classobj);
if (m->defaults) {
PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m);
int i;
for (i = 0; i < m->defaults_pyobjects; i++)
Py_XDECREF(pydefaults[i]);
PyMem_Free(m->defaults);
m->defaults = NULL;
}
return 0; return 0;
} }
...@@ -281,6 +308,15 @@ static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit, ...@@ -281,6 +308,15 @@ static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit,
Py_VISIT(m->func_doc); Py_VISIT(m->func_doc);
Py_VISIT(m->func_code); Py_VISIT(m->func_code);
Py_VISIT(m->func_classobj); Py_VISIT(m->func_classobj);
if (m->defaults) {
PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m);
int i;
for (i = 0; i < m->defaults_pyobjects; i++)
Py_VISIT(pydefaults[i]);
}
return 0; return 0;
} }
...@@ -384,6 +420,17 @@ static int __Pyx_CyFunction_init(void) ...@@ -384,6 +420,17 @@ static int __Pyx_CyFunction_init(void)
return 0; return 0;
} }
void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects)
{
__pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func;
m->defaults = PyMem_Malloc(size);
if (!m->defaults)
return PyErr_NoMemory();
memset(m->defaults, 0, sizeof(size));
m->defaults_pyobjects = pyobjects;
return m->defaults;
}
//////////////////// CyFunctionClassCell.proto //////////////////// //////////////////// CyFunctionClassCell.proto ////////////////////
static CYTHON_INLINE void __Pyx_CyFunction_InitClassCell(PyObject *cyfunctions, static CYTHON_INLINE void __Pyx_CyFunction_InitClassCell(PyObject *cyfunctions,
PyObject *classobj); PyObject *classobj);
......
...@@ -18,7 +18,6 @@ temp_sideeffects_T654 ...@@ -18,7 +18,6 @@ temp_sideeffects_T654
class_scope_T671 class_scope_T671
slice2_T636 slice2_T636
builtin_subtype_methods_T653 builtin_subtype_methods_T653
default_args_T674
# CPython regression tests that don't current work: # CPython regression tests that don't current work:
pyregr.test_threadsignals pyregr.test_threadsignals
......
# mode: run
# ticket: 674
cdef class Foo:
cdef str name
def __init__(self, name):
self.name = name
def __repr__(self):
return '<%s>' % self.name
def test_exttype_args(a, b, c):
"""
>>> f1 = test_exttype_args([1, 2, 3], 123, Foo('Foo'))
>>> f2 = test_exttype_args([0], 0, Foo('Bar'))
>>> f1()
([1, 2, 3], 123, <Foo>)
>>> f2()
([0], 0, <Bar>)
"""
def inner(a=a, int b=b, Foo c=c):
return a, b, c
return inner
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment