Commit d5103c1f authored by Mark Florisson's avatar Mark Florisson

Support defaults tuple for fused functions

parent 4bb342ce
...@@ -615,7 +615,7 @@ class ExprNode(Node): ...@@ -615,7 +615,7 @@ class ExprNode(Node):
dst_type = dst_type.base_type dst_type = dst_type.base_type
for signature in src_type.get_all_specific_function_types(): for signature in src_type.get_all_specialized_function_types():
if signature.same_as(dst_type): if signature.same_as(dst_type):
src.type = signature src.type = signature
src.entry = src.type.entry src.entry = src.type.entry
...@@ -2788,7 +2788,7 @@ class IndexNode(ExprNode): ...@@ -2788,7 +2788,7 @@ class IndexNode(ExprNode):
"Index operation makes function only partially specific") "Index operation makes function only partially specific")
else: else:
# Fully specific, find the signature with the specialized entry # Fully specific, find the signature with the specialized entry
for signature in self.base.type.get_all_specific_function_types(): for signature in self.base.type.get_all_specialized_function_types():
if type.same_as(signature): if type.same_as(signature):
self.type = signature self.type = signature
...@@ -3683,7 +3683,7 @@ class SimpleCallNode(CallNode): ...@@ -3683,7 +3683,7 @@ class SimpleCallNode(CallNode):
if overloaded_entry: if overloaded_entry:
if self.function.type.is_fused: if self.function.type.is_fused:
functypes = self.function.type.get_all_specific_function_types() functypes = self.function.type.get_all_specialized_function_types()
alternatives = [f.entry for f in functypes] alternatives = [f.entry for f in functypes]
else: else:
alternatives = overloaded_entry.all_alternatives() alternatives = overloaded_entry.all_alternatives()
...@@ -5554,6 +5554,11 @@ class DictNode(ExprNode): ...@@ -5554,6 +5554,11 @@ class DictNode(ExprNode):
obj_conversion_errors = [] obj_conversion_errors = []
@classmethod
def from_pairs(cls, pos, pairs):
return cls(pos, key_value_pairs=[
DictItemNode(pos, key=k, value=v) for k, v in pairs])
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = dict([ self.constant_result = dict([
item.constant_result for item in self.key_value_pairs]) item.constant_result for item in self.key_value_pairs])
...@@ -6102,13 +6107,20 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6102,13 +6107,20 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
is_temp = 1 is_temp = 1
specialized_cpdefs = None specialized_cpdefs = None
is_specialization = False
def analyse_types(self, env): @classmethod
if self.specialized_cpdefs: def from_defnode(cls, node, binding):
self.binding = True return cls(node.pos,
def_node=node,
pymethdef_cname=node.entry.pymethdef_cname,
binding=binding or node.specialized_cpdefs,
specialized_cpdefs=node.specialized_cpdefs,
code_object=CodeObjectNode(node))
def analyse_types(self, env):
if self.binding: if self.binding:
if self.specialized_cpdefs: if self.specialized_cpdefs or self.is_specialization:
env.use_utility_code(fused_function_utility_code) env.use_utility_code(fused_function_utility_code)
else: else:
env.use_utility_code(binding_cfunc_utility_code) env.use_utility_code(binding_cfunc_utility_code)
...@@ -6212,12 +6224,15 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6212,12 +6224,15 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
def generate_cyfunction_code(self, code): def generate_cyfunction_code(self, code):
def_node = self.def_node
if self.specialized_cpdefs: if self.specialized_cpdefs:
constructor = "__pyx_FusedFunction_NewEx" constructor = "__pyx_FusedFunction_NewEx"
def_node = self.specialized_cpdefs[0] def_node = self.specialized_cpdefs[0]
elif self.is_specialization:
constructor = "__pyx_FusedFunction_NewEx"
else: else:
constructor = "__Pyx_CyFunction_NewEx" constructor = "__Pyx_CyFunction_NewEx"
def_node = self.def_node
if self.code_object: if self.code_object:
code_object_result = self.code_object.py_result() code_object_result = self.code_object.py_result()
...@@ -6280,64 +6295,6 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6280,64 +6295,6 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
code.putln('__Pyx_CyFunction_SetDefaultsGetter(%s, %s);' % ( code.putln('__Pyx_CyFunction_SetDefaultsGetter(%s, %s);' % (
self.result(), def_node.defaults_getter.entry.pyfunc_cname)) self.result(), def_node.defaults_getter.entry.pyfunc_cname))
if self.specialized_cpdefs:
self.generate_fused_cpdef(code, code_object_result, flags)
def generate_fused_cpdef(self, code, code_object_result, flags):
"""
Generate binding function objects for all specialized cpdefs, and the
original fused one. The fused function gets a dict __signatures__
mapping the specialized signature to the specialized binding function.
In Python space, the specialized versions can be obtained by indexing
the fused function.
For unsubscripted dispatch, we also need to remember the positions of
the arguments with fused types.
"""
def goto_err(string):
string = "(%s)" % string
code.putln(code.error_goto_if_null(string % fmt_dict, self.pos))
# Set up an interpolation dict
fmt_dict = dict(
vars(Naming),
result=self.result(),
py_mod_name=self.get_py_mod_name(code),
self=self.self_result_code(),
code=code_object_result,
flags=flags,
func=code.funcstate.allocate_temp(py_object_type,
manage_ref=True),
signature=code.funcstate.allocate_temp(py_object_type,
manage_ref=True),
)
fmt_dict['sigdict'] = \
"((__pyx_FusedFunctionObject *) %(result)s)->__signatures__" % fmt_dict
# Initialize __signatures__
goto_err("%(sigdict)s = PyDict_New()")
# Now put all specialized cpdefs in __signatures__
for cpdef in self.specialized_cpdefs:
fmt_dict['signature_string'] = cpdef.specialized_signature_string
fmt_dict['pymethdef_cname'] = cpdef.entry.pymethdef_cname
goto_err('%(signature)s = PyUnicode_FromString('
'"%(signature_string)s")')
goto_err("%(func)s = __pyx_FusedFunction_NewEx("
"&%(pymethdef_cname)s, %(flags)s, %(self)s, %(py_mod_name)s, %(code)s)")
s = "PyDict_SetItem(%(sigdict)s, %(signature)s, %(func)s)"
code.put_error_if_neg(self.pos, s % fmt_dict)
code.putln("Py_DECREF(%(signature)s); %(signature)s = NULL;" % fmt_dict)
code.putln("Py_DECREF(%(func)s); %(func)s = NULL;" % fmt_dict)
code.funcstate.release_temp(fmt_dict['func'])
code.funcstate.release_temp(fmt_dict['signature'])
class InnerFunctionNode(PyCFunctionNode): class InnerFunctionNode(PyCFunctionNode):
# Special PyCFunctionNode that depends on a closure class # Special PyCFunctionNode that depends on a closure class
...@@ -9259,11 +9216,18 @@ class ProxyNode(CoercionNode): ...@@ -9259,11 +9216,18 @@ class ProxyNode(CoercionNode):
def __init__(self, arg): def __init__(self, arg):
super(ProxyNode, self).__init__(arg) super(ProxyNode, self).__init__(arg)
if hasattr(arg, 'type'): self._proxy_type()
self.type = arg.type
self.result_ctype = arg.result_ctype def analyse_expressions(self, env):
if hasattr(arg, 'entry'): self.arg.analyse_expressions(env)
self.entry = arg.entry self._proxy_type()
def _proxy_type(self):
if hasattr(self.arg, 'type'):
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
if hasattr(self.arg, 'entry'):
self.entry = self.arg.entry
def generate_result_code(self, code): def generate_result_code(self, code):
self.arg.generate_result_code(code) self.arg.generate_result_code(code)
......
...@@ -2058,7 +2058,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2058,7 +2058,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
for entry in pxd_env.cfunc_entries[:]: for entry in pxd_env.cfunc_entries[:]:
if entry.type.is_fused: if entry.type.is_fused:
# This call modifies the cfunc_entries in-place # This call modifies the cfunc_entries in-place
entry.type.get_all_specific_function_types() entry.type.get_all_specialized_function_types()
def generate_c_variable_import_code_for_module(self, module, env, code): def generate_c_variable_import_code_for_module(self, module, env, code):
# Generate import code for all exported C functions in a cimported module. # Generate import code for all exported C functions in a cimported module.
......
...@@ -2237,8 +2237,22 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2237,8 +2237,22 @@ class FusedCFuncDefNode(StatListNode):
nodes [FuncDefNode] list of copies of node with different specific types nodes [FuncDefNode] list of copies of node with different specific types
py_func DefNode the fused python function subscriptable from py_func DefNode the fused python function subscriptable from
Python space Python space
__signatures__ A DictNode mapping signature specialization strings
to PyCFunction nodes
resulting_fused_function PyCFunction for the fused DefNode that delegates
to specializations
fused_func_assignment Assignment of the fused function to the function name
defaults_tuple TupleNode of defaults (letting PyCFunctionNode build
defaults would result in many different tuples)
specialized_pycfuncs List of synthesized pycfunction nodes for the
specializations
""" """
__signatures__ = None
resulting_fused_function = None
fused_func_assignment = None
defaults_tuple = None
def __init__(self, node, env): def __init__(self, node, env):
super(FusedCFuncDefNode, self).__init__(node.pos) super(FusedCFuncDefNode, self).__init__(node.pos)
...@@ -2277,13 +2291,18 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2277,13 +2291,18 @@ class FusedCFuncDefNode(StatListNode):
# CFuncDefNodes in self.nodes # CFuncDefNodes in self.nodes
self.stats = self.nodes[:] self.stats = self.nodes[:]
if self.py_func:
self.synthesize_defnodes()
self.stats.append(self.__signatures__)
def copy_def(self, env): def copy_def(self, env):
""" """
Create a copy of the original def or lambda function for specialized Create a copy of the original def or lambda function for specialized
versions. versions.
""" """
fused_types = [arg.type for arg in self.node.args if arg.type.is_fused] fused_types = PyrexTypes.unique(
permutations = PyrexTypes.get_all_specific_permutations(fused_types) [arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_types)
if self.node.entry in env.pyfunc_entries: if self.node.entry in env.pyfunc_entries:
env.pyfunc_entries.remove(self.node.entry) env.pyfunc_entries.remove(self.node.entry)
...@@ -2314,7 +2333,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2314,7 +2333,7 @@ class FusedCFuncDefNode(StatListNode):
Create a copy of the original c(p)def function for all specialized Create a copy of the original c(p)def function for all specialized
versions. versions.
""" """
permutations = self.node.type.get_all_specific_permutations() permutations = self.node.type.get_all_specialized_permutations()
# print 'Node %s has %d specializations:' % (self.node.entry.name, # print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations)) # len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations]) # import pprint; pprint.pprint([d for cname, d in permutations])
...@@ -2386,7 +2405,6 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2386,7 +2405,6 @@ class FusedCFuncDefNode(StatListNode):
if arg.type.is_memoryviewslice: if arg.type.is_memoryviewslice:
MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype) MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype)
def create_new_local_scope(self, node, env, f2s): def create_new_local_scope(self, node, env, f2s):
""" """
Create a new local scope for the copied node and append it to Create a new local scope for the copied node and append it to
...@@ -2617,25 +2635,119 @@ def __pyx_fused_cpdef(signatures, args, kwargs): ...@@ -2617,25 +2635,119 @@ def __pyx_fused_cpdef(signatures, args, kwargs):
return py_func return py_func
def analyse_expressions(self, env):
"""
Analyse the expressions. Take care to only evaluate default arguments
once and clone the result for all specializations
"""
from ExprNodes import CloneNode, ProxyNode, TupleNode
if self.py_func:
self.__signatures__.analyse_expressions(env)
self.py_func.analyse_expressions(env)
self.resulting_fused_function.analyse_expressions(env)
self.fused_func_assignment.analyse_expressions(env)
self.defaults = defaults = []
for arg in self.node.args:
if arg.default:
arg.default.analyse_expressions(env)
defaults.append(ProxyNode(arg.default))
else:
defaults.append(None)
for node in self.stats:
node.analyse_expressions(env)
if isinstance(node, FuncDefNode):
for arg, default in zip(node.args, defaults):
if default is not None:
arg.default = CloneNode(default).coerce_to(arg.type, env)
if self.py_func:
args = [CloneNode(default) for default in defaults if default]
defaults_tuple = TupleNode(self.pos, args=args)
defaults_tuple.analyse_types(env, skip_children=True)
self.defaults_tuple = ProxyNode(defaults_tuple)
self.resulting_fused_function.arg.defaults_tuple = CloneNode(
self.defaults_tuple)
for pycfunc in self.specialized_pycfuncs:
pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
def synthesize_defnodes(self):
"""
Create the __signatures__ dict of PyCFunctionNode specializations.
"""
import ExprNodes, StringEncoding
if isinstance(self.nodes[0], CFuncDefNode):
nodes = [node.py_func for node in self.nodes]
else:
nodes = self.nodes
signatures = [
StringEncoding.EncodedString(node.specialized_signature_string)
for node in nodes]
keys = [ExprNodes.StringNode(node.pos, value=sig)
for node, sig in zip(nodes, signatures)]
values = [ExprNodes.PyCFunctionNode.from_defnode(node, True)
for node in nodes]
self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos,
zip(keys, values))
self.specialized_pycfuncs = values
for pycfuncnode in values:
pycfuncnode.is_specialization = True
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
# Ensure the indexable fused function is generated first, so we can if self.py_func:
# use its docstring self.py_func.pymethdef_required = True
# self.stats.insert(0, self.stats.pop()) self.fused_func_assignment.generate_function_definitions(env, code)
for stat in self.stats: for stat in self.stats:
# print stat.entry, stat.entry.used if isinstance(stat, FuncDefNode) and stat.entry.used:
if stat.entry.used:
code.mark_pos(stat.pos) code.mark_pos(stat.pos)
stat.generate_function_definitions(env, code) stat.generate_function_definitions(env, code)
def generate_execution_code(self, code): def generate_execution_code(self, code):
import ExprNodes
for default in self.defaults:
if default is not None:
default.generate_evaluation_code(code)
if self.py_func:
self.defaults_tuple.generate_evaluation_code(code)
for stat in self.stats: for stat in self.stats:
if stat.entry.used:
code.mark_pos(stat.pos) code.mark_pos(stat.pos)
if isinstance(stat, ExprNodes.ExprNode):
stat.generate_evaluation_code(code)
elif not isinstance(stat, FuncDefNode) or stat.entry.used:
stat.generate_execution_code(code) stat.generate_execution_code(code)
if self.__signatures__:
self.resulting_fused_function.generate_evaluation_code(code)
code.putln(
"((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" %
(self.resulting_fused_function.result(),
self.__signatures__.result()))
code.put_giveref(self.__signatures__.result())
self.fused_func_assignment.generate_execution_code(code)
# Dispose of results
self.resulting_fused_function.generate_disposal_code(code)
self.defaults_tuple.generate_disposal_code(code)
for default in self.defaults:
if default is not None:
default.generate_disposal_code(code)
def annotate(self, code): def annotate(self, code):
for stat in self.stats: for stat in self.stats:
if stat.entry.used:
stat.annotate(code) stat.annotate(code)
...@@ -3047,9 +3159,9 @@ class DefNode(FuncDefNode): ...@@ -3047,9 +3159,9 @@ class DefNode(FuncDefNode):
decorator.decorator.analyse_expressions(env) decorator.decorator.analyse_expressions(env)
def needs_assignment_synthesis(self, env, code=None): def needs_assignment_synthesis(self, env, code=None):
if self.is_wrapper: if self.is_wrapper or self.specialized_cpdefs:
return False return False
if self.specialized_cpdefs or self.is_staticmethod: if self.is_staticmethod:
return True return True
if self.no_assignment_synthesis: if self.no_assignment_synthesis:
return False return False
......
...@@ -1493,9 +1493,13 @@ if VALUE is not None: ...@@ -1493,9 +1493,13 @@ if VALUE is not None:
if node.py_func: if node.py_func:
node.stats.insert(0, node.py_func) node.stats.insert(0, node.py_func)
self.visit(node.py_func) node.py_func = self.visit(node.py_func)
if node.py_func.needs_assignment_synthesis(env): pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
node = [node, self._synthesize_assignment(node.py_func, env)] True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env)
else: else:
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
...@@ -1538,29 +1542,26 @@ if VALUE is not None: ...@@ -1538,29 +1542,26 @@ if VALUE is not None:
pymethdef_cname=node.entry.pymethdef_cname, pymethdef_cname=node.entry.pymethdef_cname,
code_object=ExprNodes.CodeObjectNode(node)) code_object=ExprNodes.CodeObjectNode(node))
else: else:
rhs = ExprNodes.PyCFunctionNode( binding = self.current_directives.get('binding')
node.pos, rhs = ExprNodes.PyCFunctionNode.from_defnode(node, binding)
def_node=node,
pymethdef_cname=node.entry.pymethdef_cname,
binding=self.current_directives.get('binding'),
specialized_cpdefs=node.specialized_cpdefs,
code_object=ExprNodes.CodeObjectNode(node))
if env.is_py_class_scope: if env.is_py_class_scope:
rhs.binding = True rhs.binding = True
node.is_cyfunction = rhs.binding node.is_cyfunction = rhs.binding
return self._create_assignment(node, rhs, env)
if node.decorators: def _create_assignment(self, def_node, rhs, env):
for decorator in node.decorators[::-1]: if def_node.decorators:
for decorator in def_node.decorators[::-1]:
rhs = ExprNodes.SimpleCallNode( rhs = ExprNodes.SimpleCallNode(
decorator.pos, decorator.pos,
function = decorator.decorator, function = decorator.decorator,
args = [rhs]) args = [rhs])
assmt = Nodes.SingleAssignmentNode( assmt = Nodes.SingleAssignmentNode(
node.pos, def_node.pos,
lhs=ExprNodes.NameNode(node.pos,name=node.name), lhs=ExprNodes.NameNode(def_node.pos, name=def_node.name),
rhs=rhs) rhs=rhs)
assmt.analyse_declarations(env) assmt.analyse_declarations(env)
return assmt return assmt
......
...@@ -2501,7 +2501,7 @@ class CFuncType(CType): ...@@ -2501,7 +2501,7 @@ class CFuncType(CType):
# All but map_with_specific_entries should be called only on functions # All but map_with_specific_entries should be called only on functions
# with fused types (and not on their corresponding specific versions). # with fused types (and not on their corresponding specific versions).
def get_all_specific_permutations(self, fused_types=None): def get_all_specialized_permutations(self, fused_types=None):
""" """
Permute all the types. For every specific instance of a fused type, we Permute all the types. For every specific instance of a fused type, we
want all other specific instances of all other fused types. want all other specific instances of all other fused types.
...@@ -2515,9 +2515,9 @@ class CFuncType(CType): ...@@ -2515,9 +2515,9 @@ class CFuncType(CType):
if fused_types is None: if fused_types is None:
fused_types = self.get_fused_types() fused_types = self.get_fused_types()
return get_all_specific_permutations(fused_types) return get_all_specialized_permutations(fused_types)
def get_all_specific_function_types(self): def get_all_specialized_function_types(self):
""" """
Get all the specific function types of this one. Get all the specific function types of this one.
""" """
...@@ -2532,7 +2532,7 @@ class CFuncType(CType): ...@@ -2532,7 +2532,7 @@ class CFuncType(CType):
cfunc_entries.remove(self.entry) cfunc_entries.remove(self.entry)
result = [] result = []
permutations = self.get_all_specific_permutations() permutations = self.get_all_specialized_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific) new_func_type = self.entry.type.specialize(fused_to_specific)
...@@ -2589,7 +2589,20 @@ def get_fused_cname(fused_cname, orig_cname): ...@@ -2589,7 +2589,20 @@ def get_fused_cname(fused_cname, orig_cname):
return StringEncoding.EncodedString('%s%s%s' % (Naming.fused_func_prefix, return StringEncoding.EncodedString('%s%s%s' % (Naming.fused_func_prefix,
fused_cname, orig_cname)) fused_cname, orig_cname))
def get_all_specific_permutations(fused_types, id="", f2s=()): def unique(somelist):
seen = set()
result = []
for obj in somelist:
if obj not in seen:
result.append(obj)
seen.add(obj)
return result
def get_all_specialized_permutations(fused_types):
return _get_all_specialized_permutations(unique(fused_types))
def _get_all_specialized_permutations(fused_types, id="", f2s=()):
fused_type, = fused_types[0].get_fused_types() fused_type, = fused_types[0].get_fused_types()
result = [] result = []
...@@ -2604,7 +2617,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()): ...@@ -2604,7 +2617,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
cname = str(newid) cname = str(newid)
if len(fused_types) > 1: if len(fused_types) > 1:
result.extend(get_all_specific_permutations( result.extend(_get_all_specialized_permutations(
fused_types[1:], cname, f2s)) fused_types[1:], cname, f2s))
else: else:
result.append((cname, f2s)) result.append((cname, f2s))
...@@ -2622,7 +2635,7 @@ def get_specialized_types(type): ...@@ -2622,7 +2635,7 @@ def get_specialized_types(type):
result = type.types result = type.types
else: else:
result = [] result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()): for cname, f2s in get_all_specialized_permutations(type.get_fused_types()):
result.append(type.specialize(f2s)) result.append(type.specialize(f2s))
return sorted(result) return sorted(result)
......
...@@ -1926,10 +1926,10 @@ class CClassScope(ClassScope): ...@@ -1926,10 +1926,10 @@ class CClassScope(ClassScope):
# If the class defined in a pxd, specific entries have not been added. # If the class defined in a pxd, specific entries have not been added.
# Ensure now that the parent (base) scope has specific entries # Ensure now that the parent (base) scope has specific entries
# Iterate over a copy as get_all_specific_function_types() will mutate # Iterate over a copy as get_all_specialized_function_types() will mutate
for base_entry in base_scope.cfunc_entries[:]: for base_entry in base_scope.cfunc_entries[:]:
if base_entry.type.is_fused: if base_entry.type.is_fused:
base_entry.type.get_all_specific_function_types() base_entry.type.get_all_specialized_function_types()
for base_entry in base_scope.cfunc_entries: for base_entry in base_scope.cfunc_entries:
cname = base_entry.cname cname = base_entry.cname
......
...@@ -591,6 +591,9 @@ __pyx_FusedFunction_descr_get(PyObject *self, PyObject *obj, PyObject *type) ...@@ -591,6 +591,9 @@ __pyx_FusedFunction_descr_get(PyObject *self, PyObject *obj, PyObject *type)
Py_XINCREF(type); Py_XINCREF(type);
meth->type = type; meth->type = type;
Py_XINCREF(func->func.defaults_tuple);
meth->func.defaults_tuple = func->func.defaults_tuple;
if (func->func.flags & __Pyx_CYFUNCTION_CLASSMETHOD) if (func->func.flags & __Pyx_CYFUNCTION_CLASSMETHOD)
obj = type; obj = type;
...@@ -600,6 +603,15 @@ __pyx_FusedFunction_descr_get(PyObject *self, PyObject *obj, PyObject *type) ...@@ -600,6 +603,15 @@ __pyx_FusedFunction_descr_get(PyObject *self, PyObject *obj, PyObject *type)
return (PyObject *) meth; return (PyObject *) meth;
} }
static PyObject *
_obj_to_str(PyObject *obj)
{
if (PyType_Check(obj))
return PyObject_GetAttrString(obj, "__name__");
else
return PyObject_Str(obj);
}
static PyObject * static PyObject *
__pyx_FusedFunction_getitem(__pyx_FusedFunctionObject *self, PyObject *idx) __pyx_FusedFunction_getitem(__pyx_FusedFunctionObject *self, PyObject *idx)
{ {
...@@ -625,11 +637,7 @@ __pyx_FusedFunction_getitem(__pyx_FusedFunctionObject *self, PyObject *idx) ...@@ -625,11 +637,7 @@ __pyx_FusedFunction_getitem(__pyx_FusedFunctionObject *self, PyObject *idx)
for (i = 0; i < n; i++) { for (i = 0; i < n; i++) {
PyObject *item = PyTuple_GET_ITEM(idx, i); PyObject *item = PyTuple_GET_ITEM(idx, i);
if (PyType_Check(item)) string = _obj_to_str(item);
string = PyObject_GetAttrString(item, "__name__");
else
string = PyObject_Str(item);
if (!string || PyList_Append(list, string) < 0) if (!string || PyList_Append(list, string) < 0)
goto __pyx_err; goto __pyx_err;
...@@ -644,7 +652,7 @@ __pyx_err: ...@@ -644,7 +652,7 @@ __pyx_err:
Py_DECREF(list); Py_DECREF(list);
Py_XDECREF(sep); Py_XDECREF(sep);
} else { } else {
signature = PyObject_Str(idx); signature = _obj_to_str(idx);
} }
if (!signature) if (!signature)
...@@ -653,14 +661,20 @@ __pyx_err: ...@@ -653,14 +661,20 @@ __pyx_err:
unbound_result_func = PyObject_GetItem(self->__signatures__, signature); unbound_result_func = PyObject_GetItem(self->__signatures__, signature);
if (unbound_result_func) { if (unbound_result_func) {
if (self->self || self->type) {
__pyx_FusedFunctionObject *unbound = (__pyx_FusedFunctionObject *) unbound_result_func; __pyx_FusedFunctionObject *unbound = (__pyx_FusedFunctionObject *) unbound_result_func;
/* Todo: move this to InitClassCell */
Py_CLEAR(unbound->func.func_classobj); Py_CLEAR(unbound->func.func_classobj);
Py_XINCREF(self->func.func_classobj); Py_XINCREF(self->func.func_classobj);
unbound->func.func_classobj = self->func.func_classobj; unbound->func.func_classobj = self->func.func_classobj;
result_func = __pyx_FusedFunction_descr_get(unbound_result_func, result_func = __pyx_FusedFunction_descr_get(unbound_result_func,
self->self, self->type); self->self, self->type);
} else {
result_func = unbound_result_func;
Py_INCREF(result_func);
}
} }
Py_DECREF(signature); Py_DECREF(signature);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# mode: run # mode: run
# tag: cyfunction # tag: cyfunction
cimport cython
import sys import sys
def get_defaults(func): def get_defaults(func):
...@@ -85,3 +86,24 @@ def test_defaults_nonliteral_func_call(f): ...@@ -85,3 +86,24 @@ def test_defaults_nonliteral_func_call(f):
return a return a
return func return func
_counter2 = 1.0
def counter2():
global _counter2
_counter2 += 1.0
return _counter2
def test_defaults_fused(cython.floating arg1, cython.floating arg2 = counter2()):
"""
>>> test_defaults_fused(1.0)
1.0 2.0
>>> test_defaults_fused(1.0, 3.0)
1.0 3.0
>>> _counter2
2.0
>>> get_defaults(test_defaults_fused)
(2.0,)
>>> get_defaults(test_defaults_fused[float])
(2.0,)
"""
print arg1, arg2
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment