Commit 2f20794c authored by Mark Florisson's avatar Mark Florisson

Support fused cpdef functions with Python indexing

parent 1472e87b
...@@ -16,6 +16,13 @@ class CythonScope(ModuleScope): ...@@ -16,6 +16,13 @@ class CythonScope(ModuleScope):
defining = 1, defining = 1,
cname='<error>') cname='<error>')
for fused_type in (cy_integral_type, cy_floating_type, cy_numeric_type):
entry = self.declare_typedef(fused_type.name,
fused_type,
None,
cname='<error>')
entry.in_cinclude = True
def lookup_type(self, name): def lookup_type(self, name):
# This function should go away when types are all first-level objects. # This function should go away when types are all first-level objects.
type = parse_basic_type(name) type = parse_basic_type(name)
......
...@@ -1431,6 +1431,12 @@ class NameNode(AtomicExprNode): ...@@ -1431,6 +1431,12 @@ class NameNode(AtomicExprNode):
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.analyse_entry(env) self.analyse_entry(env)
if (not self.is_lvalue() and self.entry.is_cfunction and
self.entry.fused_cfunction and self.entry.as_variable):
self.entry = self.entry.as_variable
self.type = self.entry.type
if not self.is_lvalue(): if not self.is_lvalue():
error(self.pos, "Assignment to non-lvalue '%s'" error(self.pos, "Assignment to non-lvalue '%s'"
% self.name) % self.name)
...@@ -2079,10 +2085,15 @@ class IndexNode(ExprNode): ...@@ -2079,10 +2085,15 @@ class IndexNode(ExprNode):
# indices is used on buffer access, index on non-buffer access. # indices is used on buffer access, index on non-buffer access.
# The former contains a clean list of index parameters, the # The former contains a clean list of index parameters, the
# latter whatever Python object is needed for index access. # latter whatever Python object is needed for index access.
#
# is_fused_index boolean Whether the index is used to specialize a
# c(p)def function
subexprs = ['base', 'index', 'indices'] subexprs = ['base', 'index', 'indices']
indices = None indices = None
is_fused_index = False
def __init__(self, pos, index, *args, **kw): def __init__(self, pos, index, *args, **kw):
ExprNode.__init__(self, pos, index=index, *args, **kw) ExprNode.__init__(self, pos, index=index, *args, **kw)
self._index = index self._index = index
...@@ -2335,6 +2346,8 @@ class IndexNode(ExprNode): ...@@ -2335,6 +2346,8 @@ class IndexNode(ExprNode):
""" """
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.is_fused_index = True
base_type = self.base.type base_type = self.base.type
specific_types = [] specific_types = []
positions = [] positions = []
...@@ -2345,10 +2358,23 @@ class IndexNode(ExprNode): ...@@ -2345,10 +2358,23 @@ class IndexNode(ExprNode):
elif isinstance(self.index, TupleNode): elif isinstance(self.index, TupleNode):
for arg in self.index.args: for arg in self.index.args:
positions.append(arg.pos) positions.append(arg.pos)
specific_types.append(arg.analyse_as_type(env)) specific_type = arg.analyse_as_type(env)
specific_types.append(specific_type)
else: else:
return error(self.pos, "Can only index fused functions with types") return error(self.pos, "Can only index fused functions with types")
if not Utils.all(specific_types):
self.index.analyse_types(env)
if not self.base.entry.as_variable:
error(self.pos, "cdef function must be indexed with types")
else:
# A cpdef function indexed with Python objects
self.entry = self.base.entry.as_variable
self.type = self.entry.type
return
fused_types = base_type.get_fused_types() fused_types = base_type.get_fused_types()
if len(specific_types) > len(fused_types): if len(specific_types) > len(fused_types):
return error(self.pos, "Too many types specified") return error(self.pos, "Too many types specified")
...@@ -5145,7 +5171,13 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -5145,7 +5171,13 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
type = py_object_type type = py_object_type
is_temp = 1 is_temp = 1
specialized_cpdefs = None
fused_args_positions = None
def analyse_types(self, env): def analyse_types(self, env):
if self.specialized_cpdefs:
self.binding = True
if self.binding: if self.binding:
env.use_utility_code(binding_cfunc_utility_code) env.use_utility_code(binding_cfunc_utility_code)
...@@ -5180,6 +5212,76 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -5180,6 +5212,76 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
if self.specialized_cpdefs:
self.generate_fused_cpdef(code)
def generate_fused_cpdef(self, code):
"""
Generate binding function objects for all specialized cpdefs, and the
original fused one. The fused function gets a dict __signatures__
mapping the specialized signature to the specialized binding function.
In Python space, the specialized versions can be obtained by indexing
the fused function.
For unsubscripted dispatch, we also need to remember the positions of
the arguments with fused types.
"""
def goto_err(string):
string = "(%s)" % string
code.putln(code.error_goto_if_null(string % fmt_dict, self.pos))
# Set up an interpolation dict
fmt_dict = dict(
vars(Naming),
result=self.result(),
py_mod_name=self.get_py_mod_name(code),
self=self.self_result_code(),
func=code.funcstate.allocate_temp(py_object_type,
manage_ref=True),
signature=code.funcstate.allocate_temp(py_object_type,
manage_ref=True),
)
fmt_dict['sigdict'] = \
"((%(binding_cfunc)s_object *) %(result)s)->__signatures__" % fmt_dict
# Initialize __signatures__ and set __PYX_FUSED_ARGS_POSITIONS
goto_err("%(sigdict)s = PyDict_New()")
assert self.fused_args_positions
pos_str = ','.join(map(str, self.fused_args_positions))
string_const = code.globalstate.new_string_const(pos_str, pos_str)
fmt_dict["pos_const_cname"] = string_const.cname
goto_err("%(signature)s = PyUnicode_FromString(%(pos_const_cname)s)")
code.put_error_if_neg(self.pos,
'PyDict_SetItemString(%(sigdict)s, '
'"__PYX_FUSED_ARGS_POSITIONS", '
'%(signature)s)' % fmt_dict)
code.putln("Py_DECREF(%(signature)s); %(signature)s = NULL;" % fmt_dict)
# Now put all specialized cpdefs in __signatures__
for cpdef in self.specialized_cpdefs:
fmt_dict['signature_string'] = cpdef.specialized_signature_string
fmt_dict['pymethdef_cname'] = cpdef.entry.pymethdef_cname
goto_err('%(signature)s = PyUnicode_FromString('
'"%(signature_string)s")')
goto_err("%(func)s = %(binding_cfunc)s_NewEx("
"&%(pymethdef_cname)s, %(self)s, %(py_mod_name)s)")
s = "PyDict_SetItem(%(sigdict)s, %(signature)s, %(func)s)"
code.put_error_if_neg(self.pos, s % fmt_dict)
code.putln("Py_DECREF(%(signature)s); %(signature)s = NULL;" % fmt_dict)
code.putln("Py_DECREF(%(func)s); %(func)s = NULL;" % fmt_dict)
code.funcstate.release_temp(fmt_dict['func'])
code.funcstate.release_temp(fmt_dict['signature'])
class InnerFunctionNode(PyCFunctionNode): class InnerFunctionNode(PyCFunctionNode):
# Special PyCFunctionNode that depends on a closure class # Special PyCFunctionNode that depends on a closure class
# #
...@@ -8655,65 +8757,325 @@ proto=""" ...@@ -8655,65 +8757,325 @@ proto="""
(((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x))) (((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x)))
""") """)
binding_cfunc_utility_code = UtilityCode( binding_cfunc_utility_code = UtilityCode(
proto=""" proto="""
#define %(binding_cfunc)s_USED 1 #define %(binding_cfunc)s_USED 1
#include <structmember.h>
#define __PYX_O %(binding_cfunc)s_object
#define __PYX_T %(binding_cfunc)s_type
#define __PYX_TP %(binding_cfunc)s
#define __PYX_M(MEMB) %(binding_cfunc)s_##MEMB
typedef struct { typedef struct {
PyCFunctionObject func; PyCFunctionObject func;
} %(binding_cfunc)s_object; PyObject *__signatures__;
PyObject *type;
PyObject *__dict__;
} __PYX_O;
static PyTypeObject %(binding_cfunc)s_type; /* Binding PyCFunction Prototypes */
static PyTypeObject *%(binding_cfunc)s = NULL; static PyObject *__PYX_M(NewEx)(PyMethodDef *ml, PyObject *self, PyObject *module); /* proto */
static PyObject *%(binding_cfunc)s_NewEx(PyMethodDef *ml, PyObject *self, PyObject *module); /* proto */
#define %(binding_cfunc)s_New(ml, self) %(binding_cfunc)s_NewEx(ml, self, NULL) #define %(binding_cfunc)s_New(ml, self) %(binding_cfunc)s_NewEx(ml, self, NULL)
static int %(binding_cfunc)s_init(void); /* proto */ static int __PYX_M(init)(void);
static void __PYX_M(dealloc)(__PYX_O *m);
static int __PYX_M(traverse)(__PYX_O *m, visitproc visit, void *arg);
static PyObject *__PYX_M(descr_get)(PyObject *func, PyObject *obj, PyObject *type);
static PyObject *__PYX_M(getitem)(__PYX_O *m, PyObject *idx);
static PyObject *__PYX_M(call)(PyObject *func, PyObject *args, PyObject *kw);
static PyObject *__PYX_M(get__name__)(__PYX_O *func, void *closure);
static int __PYX_M(set__name__)(__PYX_O *func, PyObject *value, void *closure);
static PyGetSetDef __PYX_M(getsets)[] = {
{"__name__", (getter) __PYX_M(get__name__), (setter) __PYX_M(set__name__), NULL},
{NULL},
};
static PyMemberDef __PYX_M(members)[] = {
{"__signatures__",
T_OBJECT,
offsetof(__PYX_O, __signatures__),
PY_WRITE_RESTRICTED},
{"__dict__", T_OBJECT, offsetof(__PYX_O, __dict__), 0},
};
static PyMappingMethods __PYX_M(mapping_methods) = {
0, /*mp_length*/
(binaryfunc) __PYX_M(getitem), /*mp_subscript*/
0, /*mp_ass_subscript*/
};
static PyTypeObject __PYX_T = {
PyVarObject_HEAD_INIT(0, 0)
__Pyx_NAMESTR("cython_function_or_method"), /*tp_name*/
sizeof(__PYX_O), /*tp_basicsize*/
0, /*tp_itemsize*/
(destructor) __PYX_M(dealloc), /*tp_dealloc*/
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
#if PY_MAJOR_VERSION < 3
0, /*tp_compare*/
#else
0, /*reserved*/
#endif
0, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
&__PYX_M(mapping_methods), /*tp_as_mapping*/
0, /*tp_hash*/
__PYX_M(call), /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags*/
0, /*tp_doc*/
(traverseproc) __PYX_M(traverse), /*tp_traverse*/
0, /*tp_clear*/
0, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
0, /*tp_iter*/
0, /*tp_iternext*/
0, /*tp_methods*/
__PYX_M(members), /*tp_members*/
__PYX_M(getsets), /*tp_getset*/
&PyCFunction_Type, /*tp_base*/
0, /*tp_dict*/
__PYX_M(descr_get), /*tp_descr_get*/
0, /*tp_descr_set*/
offsetof(__PYX_O, __dict__), /*tp_dictoffset*/
0, /*tp_init*/
0, /*tp_alloc*/
0, /*tp_new*/
0, /*tp_free*/
0, /*tp_is_gc*/
0, /*tp_bases*/
0, /*tp_mro*/
0, /*tp_cache*/
0, /*tp_subclasses*/
0, /*tp_weaklist*/
0, /*tp_del*/
#if PY_VERSION_HEX >= 0x02060000
0, /*tp_version_tag*/
#endif
};
static PyTypeObject *__PYX_TP = NULL;
""" % Naming.__dict__, """ % Naming.__dict__,
impl=""" impl="""
static PyObject *__PYX_M(NewWithDict)(PyMethodDef *ml, PyObject *self, PyObject *module, PyObject *dict) {
static PyObject *%(binding_cfunc)s_NewEx(PyMethodDef *ml, PyObject *self, PyObject *module) { __PYX_O *op = PyObject_GC_New(__PYX_O, __PYX_TP);
%(binding_cfunc)s_object *op = PyObject_GC_New(%(binding_cfunc)s_object, %(binding_cfunc)s);
if (op == NULL) if (op == NULL)
return NULL; return NULL;
op->func.m_ml = ml; op->func.m_ml = ml;
Py_XINCREF(self); Py_XINCREF(self);
op->func.m_self = self; op->func.m_self = self;
Py_XINCREF(module); Py_XINCREF(module);
op->func.m_module = module; op->func.m_module = module;
Py_XINCREF(dict);
op->__dict__ = dict;
op->__signatures__ = NULL;
op->type = NULL;
PyObject_GC_Track(op); PyObject_GC_Track(op);
return (PyObject *)op; return (PyObject *)op;
} }
static void %(binding_cfunc)s_dealloc(%(binding_cfunc)s_object *m) { static PyObject *__PYX_M(NewEx)(PyMethodDef *ml, PyObject *self, PyObject *module) {
PyObject *dict = PyDict_New();
PyObject *result;
if (!dict)
return NULL;
result = __PYX_M(NewWithDict)(ml, self, module, dict);
Py_DECREF(dict);
return result;
}
static void __PYX_M(dealloc)(__PYX_O *m) {
PyObject_GC_UnTrack(m); PyObject_GC_UnTrack(m);
Py_XDECREF(m->func.m_self); Py_XDECREF(m->func.m_self);
Py_XDECREF(m->func.m_module); Py_XDECREF(m->func.m_module);
Py_XDECREF(m->__signatures__);
Py_XDECREF(m->__dict__);
Py_XDECREF(m->type);
PyObject_GC_Del(m); PyObject_GC_Del(m);
} }
static PyObject *%(binding_cfunc)s_descr_get(PyObject *func, PyObject *obj, PyObject *type) { static int __PYX_M(traverse)(__PYX_O *m, visitproc visit, void *arg) {
Py_VISIT(m->func.m_self);
Py_VISIT(m->func.m_module);
Py_VISIT(m->__signatures__);
Py_VISIT(m->__dict__);
return 0;
}
static PyObject *__PYX_M(get__name__)(__PYX_O *func, void *closure) {
PyObject *result = PyDict_GetItemString(func->__dict__, "__name__");
if (result) {
/* Borrowed reference! */
Py_INCREF(result);
return result;
}
return PyUnicode_FromString(func->func.m_ml->ml_name);
}
static int __PYX_M(set__name__)(__PYX_O *func, PyObject *value, void *closure) {
return PyDict_SetItemString(func->__dict__, "__name__", value);
}
/*
Note: PyMethod_New() will create a bound or unbound method that does not take
PyCFunctionObject into account, it will not accept an additional
'self' argument in the unbound case, or will take one less argument in the
bound method case.
*/
static PyObject *__PYX_M(descr_get)(PyObject *op, PyObject *obj, PyObject *type) {
__PYX_O *func = (__PYX_O *) op;
if (func->func.m_self) {
/* Do not allow rebinding */
Py_INCREF(op);
return op;
}
if (obj == Py_None) if (obj == Py_None)
obj = NULL; obj = NULL;
return PyMethod_New(func, obj, type);
if (1 || func->__signatures__) {
/* Fused bound or unbound method */
__PYX_O *meth = (__PYX_O *) __PYX_M(NewWithDict)(func->func.m_ml,
obj,
func->func.m_module,
func->__dict__);
meth->__signatures__ = func->__signatures__;
Py_XINCREF(meth->__signatures__);
meth->type = type;
Py_XINCREF(type);
return (PyObject *) meth;
} else {
PyObject *meth = PyDescr_NewMethod((PyTypeObject *) type, func->func.m_ml);
PyObject *self = obj;
if (self == NULL)
self = Py_None;
if (meth == NULL)
return NULL;
return PyObject_CallMethod(meth, "__get__", "OO", self, type);
}
} }
static int %(binding_cfunc)s_init(void) { static PyObject *__PYX_M(getitem)(__PYX_O *m, PyObject *idx) {
%(binding_cfunc)s_type = PyCFunction_Type; PyObject *signature = NULL;
%(binding_cfunc)s_type.tp_name = __Pyx_NAMESTR("cython_binding_builtin_function_or_method"); PyObject *unbound_result_func;
%(binding_cfunc)s_type.tp_dealloc = (destructor)%(binding_cfunc)s_dealloc; PyObject *result_func = NULL;
%(binding_cfunc)s_type.tp_descr_get = %(binding_cfunc)s_descr_get; PyObject *type = NULL;
if (PyType_Ready(&%(binding_cfunc)s_type) < 0) {
if (m->__signatures__ == NULL) {
PyErr_SetString(PyExc_TypeError, "Function is not fused");
return NULL;
}
if (!(signature = PyObject_Str(idx)))
return NULL;
unbound_result_func = PyObject_GetItem(m->__signatures__, signature);
if (unbound_result_func) {
if (m->func.m_self)
type = (PyObject *) m->func.m_self->ob_type;
result_func = __PYX_M(descr_get)(unbound_result_func, m->func.m_self, m->type);
}
Py_DECREF(signature);
Py_XDECREF(unbound_result_func);
return result_func;
}
static PyObject *__PYX_M(call)(PyObject *func, PyObject *args, PyObject *kw) {
__PYX_O *binding_func = (__PYX_O *) func;
PyObject *dtype = binding_func->type;
Py_ssize_t argc;
PyObject *new_func = NULL;
PyObject *result;
if (binding_func->__signatures__) {
PyObject *module = PyImport_ImportModule("cython");
if (!module)
return NULL;
new_func = PyObject_CallMethod(module, "_specialized_from_args", "OOO",
binding_func->__signatures__, args, kw);
Py_DECREF(module);
if (!new_func)
return NULL;
func = new_func;
}
if (dtype && !binding_func->func.m_self) {
/* Unbound method call, make sure that the first argument is acceptable
as 'self' */
PyObject *self;
argc = PyTuple_GET_SIZE(args);
if (argc < 1) {
PyErr_Format(PyExc_TypeError, "Need at least one argument, 0 given.");
return NULL;
}
self = PyTuple_GET_ITEM(args, 0);
if (!PyObject_IsInstance(self, dtype)) {
PyErr_Format(PyExc_TypeError,
"First argument should be of type %%s, got %%s.",
((PyTypeObject *) dtype)->tp_name,
self->ob_type->tp_name);
return NULL;
}
args = PyTuple_GetSlice(args, 1, argc);
if (args == NULL) {
return NULL;
}
func = new_func = PyCFunction_NewEx(binding_func->func.m_ml, self, dtype);
}
result = PyCFunction_Call(func, args, kw);
Py_XDECREF(new_func);
return result;
}
static int __PYX_M(init)(void) {
if (PyType_Ready(&__PYX_T) < 0) {
return -1; return -1;
} }
%(binding_cfunc)s = &%(binding_cfunc)s_type; __PYX_TP = &__PYX_T;
return 0; return 0;
} }
#undef __PYX_O
#undef __PYX_T
#undef __PYX_TP
#undef __PYX_M
""" % Naming.__dict__) """ % Naming.__dict__)
generator_utility_code = UtilityCode( generator_utility_code = UtilityCode(
proto=""" proto="""
static PyObject *__Pyx_Generator_Next(PyObject *self); static PyObject *__Pyx_Generator_Next(PyObject *self);
......
...@@ -1823,8 +1823,12 @@ class CFuncDefNode(FuncDefNode): ...@@ -1823,8 +1823,12 @@ class CFuncDefNode(FuncDefNode):
# An error will be produced in the cdef function # An error will be produced in the cdef function
self.overridable = False self.overridable = False
self.declare_cpdef_wrapper(env)
self.create_local_scope(env)
def declare_cpdef_wrapper(self, env):
if self.overridable: if self.overridable:
import ExprNodes name = self.entry.name
py_func_body = self.call_self_node(is_module_scope = env.is_module_scope) py_func_body = self.call_self_node(is_module_scope = env.is_module_scope)
self.py_func = DefNode(pos = self.pos, self.py_func = DefNode(pos = self.pos,
name = self.entry.name, name = self.entry.name,
...@@ -1842,7 +1846,6 @@ class CFuncDefNode(FuncDefNode): ...@@ -1842,7 +1846,6 @@ class CFuncDefNode(FuncDefNode):
if not env.is_module_scope or Options.lookup_module_cpdef: if not env.is_module_scope or Options.lookup_module_cpdef:
self.override = OverrideCheckNode(self.pos, py_func = self.py_func) self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
self.body = StatListNode(self.pos, stats=[self.override, self.body]) self.body = StatListNode(self.pos, stats=[self.override, self.body])
self.create_local_scope(env)
def _validate_type_visibility(self, type, pos, env): def _validate_type_visibility(self, type, pos, env):
""" """
...@@ -2025,16 +2028,15 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2025,16 +2028,15 @@ class FusedCFuncDefNode(StatListNode):
Then when a function lookup occurs (to e.g. call it), the call can be Then when a function lookup occurs (to e.g. call it), the call can be
dispatched to the right function. dispatched to the right function.
node FuncDefNode the original function node FuncDefNode the original function
nodes [FuncDefNode] list of copies of node with different specific types nodes [FuncDefNode] list of copies of node with different specific types
py_func DefNode the original python function (in case of a cpdef)
""" """
child_attrs = ['nodes']
def __init__(self, node, env): def __init__(self, node, env):
super(FusedCFuncDefNode, self).__init__(node.pos) super(FusedCFuncDefNode, self).__init__(node.pos)
self.nodes = self.stats = [] self.nodes = []
self.node = node self.node = node
self.copy_cdefs(env) self.copy_cdefs(env)
...@@ -2055,6 +2057,10 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2055,6 +2057,10 @@ class FusedCFuncDefNode(StatListNode):
node.entry.fused_cfunction = self node.entry.fused_cfunction = self
self.stats = self.nodes[:]
if self.py_func:
self.stats.append(self.py_func)
def copy_cdefs(self, env): def copy_cdefs(self, env):
""" """
Gives a list of fused types and the parent environment, make copies Gives a list of fused types and the parent environment, make copies
...@@ -2067,7 +2073,16 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2067,7 +2073,16 @@ class FusedCFuncDefNode(StatListNode):
# len(permutations)) # len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations]) # import pprint; pprint.pprint([d for cname, d in permutations])
env.cfunc_entries.remove(self.node.entry) if self.node.entry in env.cfunc_entries:
env.cfunc_entries.remove(self.node.entry)
# Prevent copying of the python function
self.py_func = self.node.py_func
self.node.py_func = None
if self.py_func:
env.pyfunc_entries.remove(self.py_func.entry)
fused_types = self.node.type.get_fused_types()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
...@@ -2104,6 +2119,25 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2104,6 +2119,25 @@ class FusedCFuncDefNode(StatListNode):
type.specialize_entry(entry, cname) type.specialize_entry(entry, cname)
env.cfunc_entries.append(entry) env.cfunc_entries.append(entry)
# If a cpdef, declare all specialized cpdefs
copied_node.declare_cpdef_wrapper(env)
if copied_node.py_func:
env.pyfunc_entries.remove(copied_node.py_func.entry)
type_strings = [str(fused_to_specific[fused_type])
for fused_type in fused_types]
if len(type_strings) == 1:
sigstring = type_strings[0]
else:
sigstring = '(%s)' % ', '.join(type_strings)
copied_node.py_func.specialized_signature_string = sigstring
copied_node.py_func.fused_py_func = self.py_func
e = copied_node.py_func.entry
e.pymethdef_cname = PyrexTypes.get_fused_cname(
cname, e.pymethdef_cname)
num_errors = Errors.num_errors num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks( transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope) copied_node.local_scope)
...@@ -2112,6 +2146,18 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2112,6 +2146,18 @@ class FusedCFuncDefNode(StatListNode):
if Errors.num_errors > num_errors: if Errors.num_errors > num_errors:
break break
if self.py_func:
self.py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
self.py_func.fused_args_positions = [
i for i, arg in enumerate(self.node.type.args)
if arg.is_fused]
from Cython.Compiler import TreeFragment
fragment = TreeFragment.TreeFragment(u"""
raise ValueError("Index the function to get a specialized version")
""", level='function')
self.py_func.body = fragment.substitute()
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
for stat in self.stats: for stat in self.stats:
# print stat.entry, stat.entry.used # print stat.entry, stat.entry.used
...@@ -2167,6 +2213,12 @@ class DefNode(FuncDefNode): ...@@ -2167,6 +2213,12 @@ class DefNode(FuncDefNode):
# when the def statement is inside a Python class definition. # when the def statement is inside a Python class definition.
# #
# assmt AssignmentNode Function construction/assignment # assmt AssignmentNode Function construction/assignment
#
# fused_py_func DefNode The original fused cpdef DefNode
# (in case this is a specialization)
# specialized_cpdefs [DefNode] list of specialized cpdef DefNodes
# fused_args_positions [int] list of the positions of the
# arguments with fused types
child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"] child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
...@@ -2186,6 +2238,10 @@ class DefNode(FuncDefNode): ...@@ -2186,6 +2238,10 @@ class DefNode(FuncDefNode):
starstar_arg = None starstar_arg = None
doc = None doc = None
fused_py_func = False
specialized_cpdefs = None
fused_args_positions = None
def __init__(self, pos, **kwds): def __init__(self, pos, **kwds):
FuncDefNode.__init__(self, pos, **kwds) FuncDefNode.__init__(self, pos, **kwds)
k = rk = r = 0 k = rk = r = 0
...@@ -2506,11 +2562,22 @@ class DefNode(FuncDefNode): ...@@ -2506,11 +2562,22 @@ class DefNode(FuncDefNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.local_scope.directives = env.directives self.local_scope.directives = env.directives
self.analyse_default_values(env) self.analyse_default_values(env)
if self.specialized_cpdefs:
for arg in self.args + self.local_scope.arg_entries:
arg.needs_conversion = False
arg.type = py_object_type
self.local_scope.entries.clear()
del self.local_scope.var_entries[:]
if self.needs_assignment_synthesis(env): if self.needs_assignment_synthesis(env):
# Shouldn't we be doing this at the module level too? # Shouldn't we be doing this at the module level too?
self.synthesize_assignment_node(env) self.synthesize_assignment_node(env)
def needs_assignment_synthesis(self, env, code=None): def needs_assignment_synthesis(self, env, code=None):
if self.specialized_cpdefs:
return True
if self.no_assignment_synthesis: if self.no_assignment_synthesis:
return False return False
# Should enable for module level as well, that will require more testing... # Should enable for module level as well, that will require more testing...
...@@ -2534,7 +2601,11 @@ class DefNode(FuncDefNode): ...@@ -2534,7 +2601,11 @@ class DefNode(FuncDefNode):
self.pos, pymethdef_cname = self.entry.pymethdef_cname) self.pos, pymethdef_cname = self.entry.pymethdef_cname)
else: else:
rhs = ExprNodes.PyCFunctionNode( rhs = ExprNodes.PyCFunctionNode(
self.pos, pymethdef_cname = self.entry.pymethdef_cname, binding = env.directives['binding']) self.pos,
pymethdef_cname = self.entry.pymethdef_cname,
binding = env.directives['binding'],
specialized_cpdefs = self.specialized_cpdefs,
fused_args_positions = self.fused_args_positions)
if env.is_py_class_scope: if env.is_py_class_scope:
if not self.is_staticmethod and not self.is_classmethod: if not self.is_staticmethod and not self.is_classmethod:
...@@ -2573,8 +2644,15 @@ class DefNode(FuncDefNode): ...@@ -2573,8 +2644,15 @@ class DefNode(FuncDefNode):
if mf: mf += " " if mf: mf += " "
header = "static %s%s(%s)" % (mf, dc, arg_code) header = "static %s%s(%s)" % (mf, dc, arg_code)
code.putln("%s; /*proto*/" % header) code.putln("%s; /*proto*/" % header)
if proto_only: if proto_only:
if self.fused_py_func:
# If we are the specialized version of the cpdef, we still
# want the prototype for the "fused cpdef", in case we're
# checking to see if our method was overridden in Python
self.fused_py_func.generate_function_header(code, with_pymethdef, proto_only=True)
return return
if (Options.docstrings and self.entry.doc and if (Options.docstrings and self.entry.doc and
not self.entry.scope.is_property_scope and not self.entry.scope.is_property_scope and
(not self.entry.is_special or self.entry.wrapperbase_cname)): (not self.entry.is_special or self.entry.wrapperbase_cname)):
...@@ -2588,7 +2666,7 @@ class DefNode(FuncDefNode): ...@@ -2588,7 +2666,7 @@ class DefNode(FuncDefNode):
if self.entry.is_special: if self.entry.is_special:
code.putln( code.putln(
"struct wrapperbase %s;" % self.entry.wrapperbase_cname) "struct wrapperbase %s;" % self.entry.wrapperbase_cname)
if with_pymethdef: if with_pymethdef or self.fused_py_func:
code.put( code.put(
"static PyMethodDef %s = " % "static PyMethodDef %s = " %
self.entry.pymethdef_cname) self.entry.pymethdef_cname)
......
...@@ -1318,20 +1318,21 @@ if VALUE is not None: ...@@ -1318,20 +1318,21 @@ if VALUE is not None:
# --------------------------------------- # ---------------------------------------
return property return property
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(EnvTransform):
nested_index_node = False
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.env_stack = [node.scope]
node.scope.infer_types() node.scope.infer_types()
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
node.local_scope.infer_types() node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope) node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop()
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
...@@ -1347,14 +1348,23 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1347,14 +1348,23 @@ class AnalyseExpressionsTransform(CythonTransform):
argument types with the Attribute- or NameNode referring to the argument types with the Attribute- or NameNode referring to the
function. We then need to copy over the specialization properties to function. We then need to copy over the specialization properties to
the attribute or name node. the attribute or name node.
Because the indexing might be a Python indexing operation on a fused
function, or (usually) a Cython indexing operation, we need to
re-analyse the types.
""" """
self.visit_Node(node) self.visit_Node(node)
type = node.type type = node.type
if type.is_cfunction and node.base.type.is_fused: if node.is_fused_index:
node.base.type = node.type if node.type is PyrexTypes.error_type:
node.base.entry = node.type.entry node.type = PyrexTypes.error_type
node = node.base else:
node.base.type = node.type
node.base.entry = getattr(node, 'entry', None) or node.type.entry
node = node.base
node.analyse_types(self.env_stack[-1])
return node return node
......
...@@ -666,15 +666,15 @@ class FusedType(PyrexType): ...@@ -666,15 +666,15 @@ class FusedType(PyrexType):
See http://wiki.cython.org/enhancements/fusedtypes See http://wiki.cython.org/enhancements/fusedtypes
types [CSimpleBaseTypeNode] is the list of types to be fused types [PyrexType] is the list of types to be fused
name str the name of the ctypedef name str the name of the ctypedef
""" """
is_fused = 1 is_fused = 1
name = None
def __init__(self, types): def __init__(self, types, name=None):
self.types = types self.types = types
self.name = name
def declaration_code(self, entity_code, for_display = 0, def declaration_code(self, entity_code, for_display = 0,
dll_linkage = None, pyrex = 0): dll_linkage = None, pyrex = 0):
...@@ -2079,7 +2079,8 @@ class CFuncType(CType): ...@@ -2079,7 +2079,8 @@ class CFuncType(CType):
if entry.is_cmethod: if entry.is_cmethod:
entry.cname = entry.name entry.cname = entry.name
if entry.is_inherited: if entry.is_inherited:
entry.cname = "%s.%s" % (Naming.obj_base_cname, entry.cname) entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else: else:
entry.cname = get_fused_cname(cname, entry.cname) entry.cname = get_fused_cname(cname, entry.cname)
...@@ -2092,7 +2093,8 @@ def get_fused_cname(fused_cname, orig_cname): ...@@ -2092,7 +2093,8 @@ def get_fused_cname(fused_cname, orig_cname):
Given the fused cname id and an original cname, return a specialized cname Given the fused cname id and an original cname, return a specialized cname
""" """
assert fused_cname and orig_cname assert fused_cname and orig_cname
return '%s%s%s' % (Naming.fused_func_prefix, fused_cname, orig_cname) return StringEncoding.EncodedString('%s%s%s' % (Naming.fused_func_prefix,
fused_cname, orig_cname))
def get_all_specific_permutations(fused_types, id="", f2s=()): def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0] fused_type = fused_types[0]
...@@ -2631,6 +2633,17 @@ c_size_t_ptr_type = CPtrType(c_size_t_type) ...@@ -2631,6 +2633,17 @@ c_size_t_ptr_type = CPtrType(c_size_t_type)
c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer") c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type) c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
# Not sure whether the unsigned versions and 'long long' should be in there
# long long requires C99 and might be slow, and would always get preferred
# when specialization happens through calling and not indexing
cy_integral_type = FusedType([c_int_type, c_long_type], name="integral")
# Omitting long double as it might be slow
cy_floating_type = FusedType([c_float_type, c_double_type], name="floating")
cy_numeric_type = FusedType([c_long_type,
c_double_type,
c_double_complex_type], name="numeric")
error_type = ErrorType() error_type = ErrorType()
unspecified_type = UnspecifiedType() unspecified_type = UnspecifiedType()
......
...@@ -135,6 +135,9 @@ class PointerType(CythonType): ...@@ -135,6 +135,9 @@ class PointerType(CythonType):
else: else:
return not self._items and not value._items return not self._items and not value._items
def __repr__(self):
return "%s *" % (self._basetype,)
class ArrayType(PointerType): class ArrayType(PointerType):
def __init__(self): def __init__(self):
...@@ -218,13 +221,17 @@ def union(**members): ...@@ -218,13 +221,17 @@ def union(**members):
class typedef(CythonType): class typedef(CythonType):
def __init__(self, type): def __init__(self, type, name=None):
self._basetype = type self._basetype = type
self.name = name
def __call__(self, *arg): def __call__(self, *arg):
value = cast(self._basetype, *arg) value = cast(self._basetype, *arg)
return value return value
def __repr__(self):
return self.name or str(self._basetype)
class _FusedType(CythonType): class _FusedType(CythonType):
def __call__(self, type, value): def __call__(self, type, value):
...@@ -235,6 +242,7 @@ def fused_type(*args): ...@@ -235,6 +242,7 @@ def fused_type(*args):
if not args: if not args:
raise TypeError("Expected at least one type as argument") raise TypeError("Expected at least one type as argument")
# Find the numeric type with biggest rank if all types are numeric
rank = -1 rank = -1
for type in args: for type in args:
if type not in (py_int, py_long, py_float, py_complex): if type not in (py_int, py_long, py_float, py_complex):
...@@ -251,13 +259,18 @@ def fused_type(*args): ...@@ -251,13 +259,18 @@ def fused_type(*args):
return _FusedType() return _FusedType()
py_int = int def _specialized_from_args(signatures, args, kwargs):
"Perhaps this should be implemented in a TreeFragment in Cython code"
raise Exception("yet to be implemented")
py_int = typedef(int, "int")
try: try:
py_long = long py_long = typedef(long, "long")
except NameError: # Py3 except NameError: # Py3
py_long = int py_long = typedef(int, "long")
py_float = float py_float = typedef(float, "float")
py_complex = complex py_complex = typedef(complex, "complex")
try: try:
...@@ -278,28 +291,39 @@ float_types = ['longdouble', 'double', 'float'] ...@@ -278,28 +291,39 @@ float_types = ['longdouble', 'double', 'float']
complex_types = ['longdoublecomplex', 'doublecomplex', 'floatcomplex', 'complex'] complex_types = ['longdoublecomplex', 'doublecomplex', 'floatcomplex', 'complex']
other_types = ['bint', 'void'] other_types = ['bint', 'void']
to_repr = {
'longlong': 'long long',
'longdouble': 'long double',
'longdoublecomplex': 'long double complex',
'doublecomplex': 'double complex',
'floatcomplex': 'float complex',
}.get
gs = globals() gs = globals()
for name in int_types: for name in int_types:
gs[name] = typedef(py_int) reprname = to_repr(name, name)
gs[name] = typedef(py_int, reprname)
if name != 'Py_UNICODE' and not name.endswith('size_t'): if name != 'Py_UNICODE' and not name.endswith('size_t'):
gs['u'+name] = typedef(py_int) gs['u'+name] = typedef(py_int, "unsigned " + reprname)
gs['s'+name] = typedef(py_int) gs['s'+name] = typedef(py_int, "signed " + reprname)
for name in float_types: for name in float_types:
gs[name] = typedef(py_float) gs[name] = typedef(py_float, to_repr(name, name))
for name in complex_types: for name in complex_types:
gs[name] = typedef(py_complex) gs[name] = typedef(py_complex, to_repr(name, name))
bint = typedef(bool) bint = typedef(bool, "bint")
void = typedef(int) void = typedef(int, "void")
for t in int_types + float_types + complex_types + other_types: for t in int_types + float_types + complex_types + other_types:
for i in range(1, 4): for i in range(1, 4):
gs["%s_%s" % ('p'*i, t)] = globals()[t]._pointer(i) gs["%s_%s" % ('p'*i, t)] = globals()[t]._pointer(i)
void = typedef(None) void = typedef(None, "void")
NULL = p_void(0) NULL = p_void(0)
type_ordering = [py_int, py_long, py_float, py_complex] integral = floating = numeric = _FusedType()
\ No newline at end of file
type_ordering = [py_int, py_long, py_float, py_complex]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment