Commit d5103c1f authored by Mark Florisson's avatar Mark Florisson

Support defaults tuple for fused functions

parent 4bb342ce
......@@ -615,7 +615,7 @@ class ExprNode(Node):
dst_type = dst_type.base_type
for signature in src_type.get_all_specific_function_types():
for signature in src_type.get_all_specialized_function_types():
if signature.same_as(dst_type):
src.type = signature
src.entry = src.type.entry
......@@ -2788,7 +2788,7 @@ class IndexNode(ExprNode):
"Index operation makes function only partially specific")
else:
# Fully specific, find the signature with the specialized entry
for signature in self.base.type.get_all_specific_function_types():
for signature in self.base.type.get_all_specialized_function_types():
if type.same_as(signature):
self.type = signature
......@@ -3683,7 +3683,7 @@ class SimpleCallNode(CallNode):
if overloaded_entry:
if self.function.type.is_fused:
functypes = self.function.type.get_all_specific_function_types()
functypes = self.function.type.get_all_specialized_function_types()
alternatives = [f.entry for f in functypes]
else:
alternatives = overloaded_entry.all_alternatives()
......@@ -5554,6 +5554,11 @@ class DictNode(ExprNode):
obj_conversion_errors = []
@classmethod
def from_pairs(cls, pos, pairs):
return cls(pos, key_value_pairs=[
DictItemNode(pos, key=k, value=v) for k, v in pairs])
def calculate_constant_result(self):
self.constant_result = dict([
item.constant_result for item in self.key_value_pairs])
......@@ -6102,13 +6107,20 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
is_temp = 1
specialized_cpdefs = None
is_specialization = False
def analyse_types(self, env):
if self.specialized_cpdefs:
self.binding = True
@classmethod
def from_defnode(cls, node, binding):
return cls(node.pos,
def_node=node,
pymethdef_cname=node.entry.pymethdef_cname,
binding=binding or node.specialized_cpdefs,
specialized_cpdefs=node.specialized_cpdefs,
code_object=CodeObjectNode(node))
def analyse_types(self, env):
if self.binding:
if self.specialized_cpdefs:
if self.specialized_cpdefs or self.is_specialization:
env.use_utility_code(fused_function_utility_code)
else:
env.use_utility_code(binding_cfunc_utility_code)
......@@ -6212,12 +6224,15 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
code.put_gotref(self.py_result())
def generate_cyfunction_code(self, code):
def_node = self.def_node
if self.specialized_cpdefs:
constructor = "__pyx_FusedFunction_NewEx"
def_node = self.specialized_cpdefs[0]
elif self.is_specialization:
constructor = "__pyx_FusedFunction_NewEx"
else:
constructor = "__Pyx_CyFunction_NewEx"
def_node = self.def_node
if self.code_object:
code_object_result = self.code_object.py_result()
......@@ -6280,64 +6295,6 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
code.putln('__Pyx_CyFunction_SetDefaultsGetter(%s, %s);' % (
self.result(), def_node.defaults_getter.entry.pyfunc_cname))
if self.specialized_cpdefs:
self.generate_fused_cpdef(code, code_object_result, flags)
def generate_fused_cpdef(self, code, code_object_result, flags):
"""
Generate binding function objects for all specialized cpdefs, and the
original fused one. The fused function gets a dict __signatures__
mapping the specialized signature to the specialized binding function.
In Python space, the specialized versions can be obtained by indexing
the fused function.
For unsubscripted dispatch, we also need to remember the positions of
the arguments with fused types.
"""
def goto_err(string):
string = "(%s)" % string
code.putln(code.error_goto_if_null(string % fmt_dict, self.pos))
# Set up an interpolation dict
fmt_dict = dict(
vars(Naming),
result=self.result(),
py_mod_name=self.get_py_mod_name(code),
self=self.self_result_code(),
code=code_object_result,
flags=flags,
func=code.funcstate.allocate_temp(py_object_type,
manage_ref=True),
signature=code.funcstate.allocate_temp(py_object_type,
manage_ref=True),
)
fmt_dict['sigdict'] = \
"((__pyx_FusedFunctionObject *) %(result)s)->__signatures__" % fmt_dict
# Initialize __signatures__
goto_err("%(sigdict)s = PyDict_New()")
# Now put all specialized cpdefs in __signatures__
for cpdef in self.specialized_cpdefs:
fmt_dict['signature_string'] = cpdef.specialized_signature_string
fmt_dict['pymethdef_cname'] = cpdef.entry.pymethdef_cname
goto_err('%(signature)s = PyUnicode_FromString('
'"%(signature_string)s")')
goto_err("%(func)s = __pyx_FusedFunction_NewEx("
"&%(pymethdef_cname)s, %(flags)s, %(self)s, %(py_mod_name)s, %(code)s)")
s = "PyDict_SetItem(%(sigdict)s, %(signature)s, %(func)s)"
code.put_error_if_neg(self.pos, s % fmt_dict)
code.putln("Py_DECREF(%(signature)s); %(signature)s = NULL;" % fmt_dict)
code.putln("Py_DECREF(%(func)s); %(func)s = NULL;" % fmt_dict)
code.funcstate.release_temp(fmt_dict['func'])
code.funcstate.release_temp(fmt_dict['signature'])
class InnerFunctionNode(PyCFunctionNode):
# Special PyCFunctionNode that depends on a closure class
......@@ -9259,11 +9216,18 @@ class ProxyNode(CoercionNode):
def __init__(self, arg):
super(ProxyNode, self).__init__(arg)
if hasattr(arg, 'type'):
self.type = arg.type
self.result_ctype = arg.result_ctype
if hasattr(arg, 'entry'):
self.entry = arg.entry
self._proxy_type()
def analyse_expressions(self, env):
self.arg.analyse_expressions(env)
self._proxy_type()
def _proxy_type(self):
if hasattr(self.arg, 'type'):
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
if hasattr(self.arg, 'entry'):
self.entry = self.arg.entry
def generate_result_code(self, code):
self.arg.generate_result_code(code)
......
......@@ -2058,7 +2058,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
for entry in pxd_env.cfunc_entries[:]:
if entry.type.is_fused:
# This call modifies the cfunc_entries in-place
entry.type.get_all_specific_function_types()
entry.type.get_all_specialized_function_types()
def generate_c_variable_import_code_for_module(self, module, env, code):
# Generate import code for all exported C functions in a cimported module.
......
......@@ -2237,8 +2237,22 @@ class FusedCFuncDefNode(StatListNode):
nodes [FuncDefNode] list of copies of node with different specific types
py_func DefNode the fused python function subscriptable from
Python space
__signatures__ A DictNode mapping signature specialization strings
to PyCFunction nodes
resulting_fused_function PyCFunction for the fused DefNode that delegates
to specializations
fused_func_assignment Assignment of the fused function to the function name
defaults_tuple TupleNode of defaults (letting PyCFunctionNode build
defaults would result in many different tuples)
specialized_pycfuncs List of synthesized pycfunction nodes for the
specializations
"""
__signatures__ = None
resulting_fused_function = None
fused_func_assignment = None
defaults_tuple = None
def __init__(self, node, env):
super(FusedCFuncDefNode, self).__init__(node.pos)
......@@ -2277,13 +2291,18 @@ class FusedCFuncDefNode(StatListNode):
# CFuncDefNodes in self.nodes
self.stats = self.nodes[:]
if self.py_func:
self.synthesize_defnodes()
self.stats.append(self.__signatures__)
def copy_def(self, env):
"""
Create a copy of the original def or lambda function for specialized
versions.
"""
fused_types = [arg.type for arg in self.node.args if arg.type.is_fused]
permutations = PyrexTypes.get_all_specific_permutations(fused_types)
fused_types = PyrexTypes.unique(
[arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_types)
if self.node.entry in env.pyfunc_entries:
env.pyfunc_entries.remove(self.node.entry)
......@@ -2314,7 +2333,7 @@ class FusedCFuncDefNode(StatListNode):
Create a copy of the original c(p)def function for all specialized
versions.
"""
permutations = self.node.type.get_all_specific_permutations()
permutations = self.node.type.get_all_specialized_permutations()
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations])
......@@ -2386,7 +2405,6 @@ class FusedCFuncDefNode(StatListNode):
if arg.type.is_memoryviewslice:
MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype)
def create_new_local_scope(self, node, env, f2s):
"""
Create a new local scope for the copied node and append it to
......@@ -2617,26 +2635,120 @@ def __pyx_fused_cpdef(signatures, args, kwargs):
return py_func
def analyse_expressions(self, env):
"""
Analyse the expressions. Take care to only evaluate default arguments
once and clone the result for all specializations
"""
from ExprNodes import CloneNode, ProxyNode, TupleNode
if self.py_func:
self.__signatures__.analyse_expressions(env)
self.py_func.analyse_expressions(env)
self.resulting_fused_function.analyse_expressions(env)
self.fused_func_assignment.analyse_expressions(env)
self.defaults = defaults = []
for arg in self.node.args:
if arg.default:
arg.default.analyse_expressions(env)
defaults.append(ProxyNode(arg.default))
else:
defaults.append(None)
for node in self.stats:
node.analyse_expressions(env)
if isinstance(node, FuncDefNode):
for arg, default in zip(node.args, defaults):
if default is not None:
arg.default = CloneNode(default).coerce_to(arg.type, env)
if self.py_func:
args = [CloneNode(default) for default in defaults if default]
defaults_tuple = TupleNode(self.pos, args=args)
defaults_tuple.analyse_types(env, skip_children=True)
self.defaults_tuple = ProxyNode(defaults_tuple)
self.resulting_fused_function.arg.defaults_tuple = CloneNode(
self.defaults_tuple)
for pycfunc in self.specialized_pycfuncs:
pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
def synthesize_defnodes(self):
"""
Create the __signatures__ dict of PyCFunctionNode specializations.
"""
import ExprNodes, StringEncoding
if isinstance(self.nodes[0], CFuncDefNode):
nodes = [node.py_func for node in self.nodes]
else:
nodes = self.nodes
signatures = [
StringEncoding.EncodedString(node.specialized_signature_string)
for node in nodes]
keys = [ExprNodes.StringNode(node.pos, value=sig)
for node, sig in zip(nodes, signatures)]
values = [ExprNodes.PyCFunctionNode.from_defnode(node, True)
for node in nodes]
self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos,
zip(keys, values))
self.specialized_pycfuncs = values
for pycfuncnode in values:
pycfuncnode.is_specialization = True
def generate_function_definitions(self, env, code):
# Ensure the indexable fused function is generated first, so we can
# use its docstring
# self.stats.insert(0, self.stats.pop())
if self.py_func:
self.py_func.pymethdef_required = True
self.fused_func_assignment.generate_function_definitions(env, code)
for stat in self.stats:
# print stat.entry, stat.entry.used
if stat.entry.used:
if isinstance(stat, FuncDefNode) and stat.entry.used:
code.mark_pos(stat.pos)
stat.generate_function_definitions(env, code)
def generate_execution_code(self, code):
import ExprNodes
for default in self.defaults:
if default is not None:
default.generate_evaluation_code(code)
if self.py_func:
self.defaults_tuple.generate_evaluation_code(code)
for stat in self.stats:
if stat.entry.used:
code.mark_pos(stat.pos)
code.mark_pos(stat.pos)
if isinstance(stat, ExprNodes.ExprNode):
stat.generate_evaluation_code(code)
elif not isinstance(stat, FuncDefNode) or stat.entry.used:
stat.generate_execution_code(code)
if self.__signatures__:
self.resulting_fused_function.generate_evaluation_code(code)
code.putln(
"((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" %
(self.resulting_fused_function.result(),
self.__signatures__.result()))
code.put_giveref(self.__signatures__.result())
self.fused_func_assignment.generate_execution_code(code)
# Dispose of results
self.resulting_fused_function.generate_disposal_code(code)
self.defaults_tuple.generate_disposal_code(code)
for default in self.defaults:
if default is not None:
default.generate_disposal_code(code)
def annotate(self, code):
for stat in self.stats:
if stat.entry.used:
stat.annotate(code)
stat.annotate(code)
class PyArgDeclNode(Node):
......@@ -3047,9 +3159,9 @@ class DefNode(FuncDefNode):
decorator.decorator.analyse_expressions(env)
def needs_assignment_synthesis(self, env, code=None):
if self.is_wrapper:
if self.is_wrapper or self.specialized_cpdefs:
return False
if self.specialized_cpdefs or self.is_staticmethod:
if self.is_staticmethod:
return True
if self.no_assignment_synthesis:
return False
......
......@@ -1493,9 +1493,13 @@ if VALUE is not None:
if node.py_func:
node.stats.insert(0, node.py_func)
self.visit(node.py_func)
if node.py_func.needs_assignment_synthesis(env):
node = [node, self._synthesize_assignment(node.py_func, env)]
node.py_func = self.visit(node.py_func)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env)
else:
node.body.analyse_declarations(lenv)
......@@ -1538,29 +1542,26 @@ if VALUE is not None:
pymethdef_cname=node.entry.pymethdef_cname,
code_object=ExprNodes.CodeObjectNode(node))
else:
rhs = ExprNodes.PyCFunctionNode(
node.pos,
def_node=node,
pymethdef_cname=node.entry.pymethdef_cname,
binding=self.current_directives.get('binding'),
specialized_cpdefs=node.specialized_cpdefs,
code_object=ExprNodes.CodeObjectNode(node))
binding = self.current_directives.get('binding')
rhs = ExprNodes.PyCFunctionNode.from_defnode(node, binding)
if env.is_py_class_scope:
rhs.binding = True
node.is_cyfunction = rhs.binding
return self._create_assignment(node, rhs, env)
if node.decorators:
for decorator in node.decorators[::-1]:
def _create_assignment(self, def_node, rhs, env):
if def_node.decorators:
for decorator in def_node.decorators[::-1]:
rhs = ExprNodes.SimpleCallNode(
decorator.pos,
function = decorator.decorator,
args = [rhs])
assmt = Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos,name=node.name),
def_node.pos,
lhs=ExprNodes.NameNode(def_node.pos, name=def_node.name),
rhs=rhs)
assmt.analyse_declarations(env)
return assmt
......
......@@ -2501,7 +2501,7 @@ class CFuncType(CType):
# All but map_with_specific_entries should be called only on functions
# with fused types (and not on their corresponding specific versions).
def get_all_specific_permutations(self, fused_types=None):
def get_all_specialized_permutations(self, fused_types=None):
"""
Permute all the types. For every specific instance of a fused type, we
want all other specific instances of all other fused types.
......@@ -2515,9 +2515,9 @@ class CFuncType(CType):
if fused_types is None:
fused_types = self.get_fused_types()
return get_all_specific_permutations(fused_types)
return get_all_specialized_permutations(fused_types)
def get_all_specific_function_types(self):
def get_all_specialized_function_types(self):
"""
Get all the specific function types of this one.
"""
......@@ -2532,7 +2532,7 @@ class CFuncType(CType):
cfunc_entries.remove(self.entry)
result = []
permutations = self.get_all_specific_permutations()
permutations = self.get_all_specialized_permutations()
for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific)
......@@ -2589,7 +2589,20 @@ def get_fused_cname(fused_cname, orig_cname):
return StringEncoding.EncodedString('%s%s%s' % (Naming.fused_func_prefix,
fused_cname, orig_cname))
def get_all_specific_permutations(fused_types, id="", f2s=()):
def unique(somelist):
seen = set()
result = []
for obj in somelist:
if obj not in seen:
result.append(obj)
seen.add(obj)
return result
def get_all_specialized_permutations(fused_types):
return _get_all_specialized_permutations(unique(fused_types))
def _get_all_specialized_permutations(fused_types, id="", f2s=()):
fused_type, = fused_types[0].get_fused_types()
result = []
......@@ -2604,7 +2617,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
cname = str(newid)
if len(fused_types) > 1:
result.extend(get_all_specific_permutations(
result.extend(_get_all_specialized_permutations(
fused_types[1:], cname, f2s))
else:
result.append((cname, f2s))
......@@ -2622,7 +2635,7 @@ def get_specialized_types(type):
result = type.types
else:
result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
for cname, f2s in get_all_specialized_permutations(type.get_fused_types()):
result.append(type.specialize(f2s))
return sorted(result)
......
......@@ -1926,10 +1926,10 @@ class CClassScope(ClassScope):
# If the class defined in a pxd, specific entries have not been added.
# Ensure now that the parent (base) scope has specific entries
# Iterate over a copy as get_all_specific_function_types() will mutate
# Iterate over a copy as get_all_specialized_function_types() will mutate
for base_entry in base_scope.cfunc_entries[:]:
if base_entry.type.is_fused:
base_entry.type.get_all_specific_function_types()
base_entry.type.get_all_specialized_function_types()
for base_entry in base_scope.cfunc_entries:
cname = base_entry.cname
......
......@@ -591,6 +591,9 @@ __pyx_FusedFunction_descr_get(PyObject *self, PyObject *obj, PyObject *type)
Py_XINCREF(type);
meth->type = type;
Py_XINCREF(func->func.defaults_tuple);
meth->func.defaults_tuple = func->func.defaults_tuple;
if (func->func.flags & __Pyx_CYFUNCTION_CLASSMETHOD)
obj = type;
......@@ -600,6 +603,15 @@ __pyx_FusedFunction_descr_get(PyObject *self, PyObject *obj, PyObject *type)
return (PyObject *) meth;
}
static PyObject *
_obj_to_str(PyObject *obj)
{
if (PyType_Check(obj))
return PyObject_GetAttrString(obj, "__name__");
else
return PyObject_Str(obj);
}
static PyObject *
__pyx_FusedFunction_getitem(__pyx_FusedFunctionObject *self, PyObject *idx)
{
......@@ -625,11 +637,7 @@ __pyx_FusedFunction_getitem(__pyx_FusedFunctionObject *self, PyObject *idx)
for (i = 0; i < n; i++) {
PyObject *item = PyTuple_GET_ITEM(idx, i);
if (PyType_Check(item))
string = PyObject_GetAttrString(item, "__name__");
else
string = PyObject_Str(item);
string = _obj_to_str(item);
if (!string || PyList_Append(list, string) < 0)
goto __pyx_err;
......@@ -644,7 +652,7 @@ __pyx_err:
Py_DECREF(list);
Py_XDECREF(sep);
} else {
signature = PyObject_Str(idx);
signature = _obj_to_str(idx);
}
if (!signature)
......@@ -653,14 +661,20 @@ __pyx_err:
unbound_result_func = PyObject_GetItem(self->__signatures__, signature);
if (unbound_result_func) {
__pyx_FusedFunctionObject *unbound = (__pyx_FusedFunctionObject *) unbound_result_func;
Py_CLEAR(unbound->func.func_classobj);
Py_XINCREF(self->func.func_classobj);
unbound->func.func_classobj = self->func.func_classobj;
result_func = __pyx_FusedFunction_descr_get(unbound_result_func,
self->self, self->type);
if (self->self || self->type) {
__pyx_FusedFunctionObject *unbound = (__pyx_FusedFunctionObject *) unbound_result_func;
/* Todo: move this to InitClassCell */
Py_CLEAR(unbound->func.func_classobj);
Py_XINCREF(self->func.func_classobj);
unbound->func.func_classobj = self->func.func_classobj;
result_func = __pyx_FusedFunction_descr_get(unbound_result_func,
self->self, self->type);
} else {
result_func = unbound_result_func;
Py_INCREF(result_func);
}
}
Py_DECREF(signature);
......
......@@ -2,6 +2,7 @@
# mode: run
# tag: cyfunction
cimport cython
import sys
def get_defaults(func):
......@@ -85,3 +86,24 @@ def test_defaults_nonliteral_func_call(f):
return a
return func
_counter2 = 1.0
def counter2():
global _counter2
_counter2 += 1.0
return _counter2
def test_defaults_fused(cython.floating arg1, cython.floating arg2 = counter2()):
"""
>>> test_defaults_fused(1.0)
1.0 2.0
>>> test_defaults_fused(1.0, 3.0)
1.0 3.0
>>> _counter2
2.0
>>> get_defaults(test_defaults_fused)
(2.0,)
>>> get_defaults(test_defaults_fused[float])
(2.0,)
"""
print arg1, arg2
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment