Commit e2bd21ab authored by Mark Florisson's avatar Mark Florisson

Runtime dispatch to specialized cpdef

parent 2f20794c
......@@ -1434,6 +1434,7 @@ class NameNode(AtomicExprNode):
if (not self.is_lvalue() and self.entry.is_cfunction and
self.entry.fused_cfunction and self.entry.as_variable):
# We need this for the fused 'def' TreeFragment
self.entry = self.entry.as_variable
self.type = self.entry.type
......@@ -2361,18 +2362,24 @@ class IndexNode(ExprNode):
specific_type = arg.analyse_as_type(env)
specific_types.append(specific_type)
else:
return error(self.pos, "Can only index fused functions with types")
specific_types = [False]
if not Utils.all(specific_types):
self.index.analyse_types(env)
if not self.base.entry.as_variable:
error(self.pos, "cdef function must be indexed with types")
error(self.pos, "Can only index fused functions with types")
else:
# A cpdef function indexed with Python objects
self.entry = self.base.entry.as_variable
self.type = self.entry.type
self.base.entry = self.entry = self.base.entry.as_variable
self.base.type = self.type = self.entry.type
self.base.is_temp = True
self.is_temp = True
self.entry.used = True
self.is_fused_index = False
return
fused_types = base_type.get_fused_types()
......@@ -2411,9 +2418,12 @@ class IndexNode(ExprNode):
# Pretend to be a normal attribute, for cdef extension
# methods
self.entry = signature.entry
self.is_attribute = self.base.is_attribute
self.is_attribute = True
self.obj = self.base.obj
self.entry.used = True
self.type.entry.used = True
self.base.type = signature
self.base.entry = signature.entry
break
else:
......@@ -3148,7 +3158,7 @@ class SimpleCallNode(CallNode):
elif hasattr(self.function, 'entry'):
overloaded_entry = self.function.entry
elif (isinstance(self.function, IndexNode) and
self.function.base.type.is_fused):
self.function.is_fused_index):
overloaded_entry = self.function.type.entry
else:
overloaded_entry = None
......@@ -5172,7 +5182,6 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
is_temp = 1
specialized_cpdefs = None
fused_args_positions = None
def analyse_types(self, env):
if self.specialized_cpdefs:
......@@ -5245,22 +5254,9 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
fmt_dict['sigdict'] = \
"((%(binding_cfunc)s_object *) %(result)s)->__signatures__" % fmt_dict
# Initialize __signatures__ and set __PYX_FUSED_ARGS_POSITIONS
# Initialize __signatures__
goto_err("%(sigdict)s = PyDict_New()")
assert self.fused_args_positions
pos_str = ','.join(map(str, self.fused_args_positions))
string_const = code.globalstate.new_string_const(pos_str, pos_str)
fmt_dict["pos_const_cname"] = string_const.cname
goto_err("%(signature)s = PyUnicode_FromString(%(pos_const_cname)s)")
code.put_error_if_neg(self.pos,
'PyDict_SetItemString(%(sigdict)s, '
'"__PYX_FUSED_ARGS_POSITIONS", '
'%(signature)s)' % fmt_dict)
code.putln("Py_DECREF(%(signature)s); %(signature)s = NULL;" % fmt_dict)
# Now put all specialized cpdefs in __signatures__
for cpdef in self.specialized_cpdefs:
fmt_dict['signature_string'] = cpdef.specialized_signature_string
......@@ -5938,8 +5934,8 @@ class TypeofNode(ExprNode):
def analyse_types(self, env):
self.operand.analyse_types(env)
self.literal = StringNode(
self.pos, value=StringEncoding.EncodedString(str(self.operand.type)))
value = StringEncoding.EncodedString(self.operand.type.typeof_name())
self.literal = StringNode(self.pos, value=value)
self.literal.analyse_types(env)
self.literal = self.literal.coerce_to_pyobject(env)
......@@ -8762,57 +8758,59 @@ proto="""
#define %(binding_cfunc)s_USED 1
#include <structmember.h>
#define __PYX_O %(binding_cfunc)s_object
#define __PYX_T %(binding_cfunc)s_type
#define __PYX_TP %(binding_cfunc)s
#define __PYX_M(MEMB) %(binding_cfunc)s_##MEMB
typedef struct {
PyCFunctionObject func;
PyObject *__signatures__;
PyObject *type;
PyObject *self;
PyObject *__dict__;
} __PYX_O;
} %(binding_cfunc)s_object;
/* Binding PyCFunction Prototypes */
static PyObject *__PYX_M(NewEx)(PyMethodDef *ml, PyObject *self, PyObject *module); /* proto */
static PyObject *%(binding_cfunc)s_NewEx(PyMethodDef *ml, PyObject *self, PyObject *module); /* proto */
#define %(binding_cfunc)s_New(ml, self) %(binding_cfunc)s_NewEx(ml, self, NULL)
static int __PYX_M(init)(void);
static void __PYX_M(dealloc)(__PYX_O *m);
static int __PYX_M(traverse)(__PYX_O *m, visitproc visit, void *arg);
static PyObject *__PYX_M(descr_get)(PyObject *func, PyObject *obj, PyObject *type);
static PyObject *__PYX_M(getitem)(__PYX_O *m, PyObject *idx);
static PyObject *__PYX_M(call)(PyObject *func, PyObject *args, PyObject *kw);
static PyObject *__PYX_M(get__name__)(__PYX_O *func, void *closure);
static int __PYX_M(set__name__)(__PYX_O *func, PyObject *value, void *closure);
static PyGetSetDef __PYX_M(getsets)[] = {
{"__name__", (getter) __PYX_M(get__name__), (setter) __PYX_M(set__name__), NULL},
static int %(binding_cfunc)s_init(void);
static void %(binding_cfunc)s_dealloc(%(binding_cfunc)s_object *m);
static int %(binding_cfunc)s_traverse(%(binding_cfunc)s_object *m, visitproc visit, void *arg);
static PyObject *%(binding_cfunc)s_descr_get(PyObject *func, PyObject *obj, PyObject *type);
static PyObject *%(binding_cfunc)s_getitem(%(binding_cfunc)s_object *m, PyObject *idx);
static PyObject *%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject *kw);
static PyObject *%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure);
static int %(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value, void *closure);
static PyGetSetDef %(binding_cfunc)s_getsets[] = {
{(char *)"__name__",
(getter) %(binding_cfunc)s_get__name__,
(setter) %(binding_cfunc)s_set__name__,
NULL},
{NULL},
};
static PyMemberDef __PYX_M(members)[] = {
{"__signatures__",
static PyMemberDef %(binding_cfunc)s_members[] = {
{(char *) "__signatures__",
T_OBJECT,
offsetof(%(binding_cfunc)s_object, __signatures__),
__Pyx_DOCSTR(0)},
{(char *) "__dict__",
T_OBJECT,
offsetof(__PYX_O, __signatures__),
PY_WRITE_RESTRICTED},
{"__dict__", T_OBJECT, offsetof(__PYX_O, __dict__), 0},
offsetof(%(binding_cfunc)s_object, __dict__),
__Pyx_DOCSTR(0)},
};
static PyMappingMethods __PYX_M(mapping_methods) = {
0, /*mp_length*/
(binaryfunc) __PYX_M(getitem), /*mp_subscript*/
0, /*mp_ass_subscript*/
static PyMappingMethods %(binding_cfunc)s_mapping_methods = {
0,
(binaryfunc) %(binding_cfunc)s_getitem,
0,
};
static PyTypeObject __PYX_T = {
static PyTypeObject __pyx_binding_PyCFunctionType_type = {
PyVarObject_HEAD_INIT(0, 0)
__Pyx_NAMESTR("cython_function_or_method"), /*tp_name*/
sizeof(__PYX_O), /*tp_basicsize*/
sizeof(%(binding_cfunc)s_object), /*tp_basicsize*/
0, /*tp_itemsize*/
(destructor) __PYX_M(dealloc), /*tp_dealloc*/
(destructor) %(binding_cfunc)s_dealloc, /*tp_dealloc*/
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
......@@ -8824,29 +8822,29 @@ static PyTypeObject __PYX_T = {
0, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
&__PYX_M(mapping_methods), /*tp_as_mapping*/
&%(binding_cfunc)s_mapping_methods, /*tp_as_mapping*/
0, /*tp_hash*/
__PYX_M(call), /*tp_call*/
%(binding_cfunc)s_call, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags*/
0, /*tp_doc*/
(traverseproc) __PYX_M(traverse), /*tp_traverse*/
(traverseproc) %(binding_cfunc)s_traverse, /*tp_traverse*/
0, /*tp_clear*/
0, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
0, /*tp_iter*/
0, /*tp_iternext*/
0, /*tp_methods*/
__PYX_M(members), /*tp_members*/
__PYX_M(getsets), /*tp_getset*/
%(binding_cfunc)s_members, /*tp_members*/
%(binding_cfunc)s_getsets, /*tp_getset*/
&PyCFunction_Type, /*tp_base*/
0, /*tp_dict*/
__PYX_M(descr_get), /*tp_descr_get*/
%(binding_cfunc)s_descr_get, /*tp_descr_get*/
0, /*tp_descr_set*/
offsetof(__PYX_O, __dict__), /*tp_dictoffset*/
offsetof(%(binding_cfunc)s_object, __dict__), /*tp_dictoffset*/
0, /*tp_init*/
0, /*tp_alloc*/
0, /*tp_new*/
......@@ -8863,11 +8861,15 @@ static PyTypeObject __PYX_T = {
#endif
};
static PyTypeObject *__PYX_TP = NULL;
static PyTypeObject *%(binding_cfunc)s = NULL;
""" % Naming.__dict__,
impl="""
static PyObject *__PYX_M(NewWithDict)(PyMethodDef *ml, PyObject *self, PyObject *module, PyObject *dict) {
__PYX_O *op = PyObject_GC_New(__PYX_O, __PYX_TP);
static PyObject *
%(binding_cfunc)s_NewWithDict(PyMethodDef *ml, PyObject *self,
PyObject *module, PyObject *dict)
{
%(binding_cfunc)s_object *op = PyObject_GC_New(%(binding_cfunc)s_object,
%(binding_cfunc)s);
if (op == NULL)
return NULL;
op->func.m_ml = ml;
......@@ -8881,44 +8883,55 @@ static PyObject *__PYX_M(NewWithDict)(PyMethodDef *ml, PyObject *self, PyObject
Py_XINCREF(dict);
op->__dict__ = dict;
op->type = NULL;
op->self = NULL;
op->__signatures__ = NULL;
op->type = NULL;
PyObject_GC_Track(op);
return (PyObject *)op;
}
static PyObject *__PYX_M(NewEx)(PyMethodDef *ml, PyObject *self, PyObject *module) {
static PyObject *
%(binding_cfunc)s_NewEx(PyMethodDef *ml, PyObject *self, PyObject *module)
{
PyObject *dict = PyDict_New();
PyObject *result;
if (!dict)
return NULL;
result = __PYX_M(NewWithDict)(ml, self, module, dict);
result = %(binding_cfunc)s_NewWithDict(ml, self, module, dict);
Py_DECREF(dict);
return result;
}
static void __PYX_M(dealloc)(__PYX_O *m) {
static void %(binding_cfunc)s_dealloc(%(binding_cfunc)s_object *m) {
PyObject_GC_UnTrack(m);
Py_XDECREF(m->func.m_self);
Py_XDECREF(m->func.m_module);
Py_XDECREF(m->__signatures__);
Py_XDECREF(m->__dict__);
Py_XDECREF(m->self);
Py_XDECREF(m->type);
PyObject_GC_Del(m);
}
static int __PYX_M(traverse)(__PYX_O *m, visitproc visit, void *arg) {
static int
%(binding_cfunc)s_traverse(%(binding_cfunc)s_object *m, visitproc visit,
void *arg)
{
Py_VISIT(m->func.m_self);
Py_VISIT(m->func.m_module);
Py_VISIT(m->self);
Py_VISIT(m->type);
Py_VISIT(m->__signatures__);
Py_VISIT(m->__dict__);
return 0;
}
static PyObject *__PYX_M(get__name__)(__PYX_O *func, void *closure) {
static PyObject *
%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure)
{
PyObject *result = PyDict_GetItemString(func->__dict__, "__name__");
if (result) {
/* Borrowed reference! */
......@@ -8929,20 +8942,19 @@ static PyObject *__PYX_M(get__name__)(__PYX_O *func, void *closure) {
return PyUnicode_FromString(func->func.m_ml->ml_name);
}
static int __PYX_M(set__name__)(__PYX_O *func, PyObject *value, void *closure) {
static int
%(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value,
void *closure)
{
return PyDict_SetItemString(func->__dict__, "__name__", value);
}
/*
Note: PyMethod_New() will create a bound or unbound method that does not take
PyCFunctionObject into account, it will not accept an additional
'self' argument in the unbound case, or will take one less argument in the
bound method case.
*/
static PyObject *__PYX_M(descr_get)(PyObject *op, PyObject *obj, PyObject *type) {
__PYX_O *func = (__PYX_O *) op;
static PyObject *
%(binding_cfunc)s_descr_get(PyObject *op, PyObject *obj, PyObject *type)
{
%(binding_cfunc)s_object *func = (%(binding_cfunc)s_object *) op;
if (func->func.m_self) {
if (func->self) {
/* Do not allow rebinding */
Py_INCREF(op);
return op;
......@@ -8951,128 +8963,167 @@ static PyObject *__PYX_M(descr_get)(PyObject *op, PyObject *obj, PyObject *type)
if (obj == Py_None)
obj = NULL;
if (1 || func->__signatures__) {
/* Fused bound or unbound method */
__PYX_O *meth = (__PYX_O *) __PYX_M(NewWithDict)(func->func.m_ml,
obj,
func->func.m_module,
func->__dict__);
meth->__signatures__ = func->__signatures__;
Py_XINCREF(meth->__signatures__);
%(binding_cfunc)s_object *meth = (%(binding_cfunc)s_object *) \
%(binding_cfunc)s_NewWithDict(func->func.m_ml,
func->func.m_self,
func->func.m_module,
func->__dict__);
meth->type = type;
Py_XINCREF(type);
Py_XINCREF(func->__signatures__);
meth->__signatures__ = func->__signatures__;
return (PyObject *) meth;
} else {
PyObject *meth = PyDescr_NewMethod((PyTypeObject *) type, func->func.m_ml);
PyObject *self = obj;
Py_XINCREF(type);
meth->type = type;
if (self == NULL)
self = Py_None;
Py_XINCREF(obj);
meth->self = obj;
if (meth == NULL)
return NULL;
return (PyObject *) meth;
return PyObject_CallMethod(meth, "__get__", "OO", self, type);
}
}
static PyObject *__PYX_M(getitem)(__PYX_O *m, PyObject *idx) {
static PyObject *
%(binding_cfunc)s_getitem(%(binding_cfunc)s_object *m, PyObject *idx)
{
PyObject *signature = NULL;
PyObject *unbound_result_func;
PyObject *result_func = NULL;
PyObject *type = NULL;
if (m->__signatures__ == NULL) {
PyErr_SetString(PyExc_TypeError, "Function is not fused");
return NULL;
}
if (!(signature = PyObject_Str(idx)))
return NULL;
if (PyTuple_Check(idx)) {
PyObject *list = PyList_New(0);
Py_ssize_t n = PyTuple_GET_SIZE(idx);
PyObject *string = NULL;
PyObject *sep = NULL;
int i;
unbound_result_func = PyObject_GetItem(m->__signatures__, signature);
if (!list)
return NULL;
for (i = 0; i < n; i++) {
PyObject *item = PyTuple_GET_ITEM(idx, i);
if (unbound_result_func) {
if (m->func.m_self)
type = (PyObject *) m->func.m_self->ob_type;
if (PyType_Check(item))
string = PyObject_GetAttrString(item, "__name__");
else
string = PyObject_Str(item);
result_func = __PYX_M(descr_get)(unbound_result_func, m->func.m_self, m->type);
if (!string || PyList_Append(list, string) < 0)
goto __pyx_err;
Py_DECREF(string);
}
sep = PyUnicode_FromString(", ");
if (sep)
signature = PyUnicode_Join(sep, list);
__pyx_err:
;
Py_DECREF(list);
Py_XDECREF(sep);
} else {
signature = PyObject_Str(idx);
}
if (!signature)
return NULL;
unbound_result_func = PyObject_GetItem(m->__signatures__, signature);
if (unbound_result_func)
result_func = %(binding_cfunc)s_descr_get(unbound_result_func,
m->self, m->type);
Py_DECREF(signature);
Py_XDECREF(unbound_result_func);
return result_func;
}
static PyObject *__PYX_M(call)(PyObject *func, PyObject *args, PyObject *kw) {
__PYX_O *binding_func = (__PYX_O *) func;
PyObject *dtype = binding_func->type;
Py_ssize_t argc;
/* Note: the 'self' from method binding is passed in in the args tuple,
whereas PyCFunctionObject's m_self is passed in as the first
argument to the C function. For extension methods we also need
to pass 'self' as 'm_self' and not as the first element of the
args tuple.
*/
static PyObject *
%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject *kw)
{
%(binding_cfunc)s_object *binding_func = (%(binding_cfunc)s_object *) func;
Py_ssize_t argc = PyTuple_GET_SIZE(args);
PyObject *new_args = NULL;
PyObject *new_func = NULL;
PyObject *result;
if (binding_func->__signatures__) {
PyObject *module = PyImport_ImportModule("cython");
if (!module)
return NULL;
new_func = PyObject_CallMethod(module, "_specialized_from_args", "OOO",
binding_func->__signatures__, args, kw);
Py_DECREF(module);
if (!new_func)
PyObject *result = NULL;
PyObject *self = NULL;
if (binding_func->self) {
/* Bound method call, put 'self' in the args tuple */
Py_ssize_t i;
new_args = PyTuple_New(argc + 1);
if (!new_args)
return NULL;
func = new_func;
}
self = binding_func->self;
Py_INCREF(self);
PyTuple_SET_ITEM(new_args, 0, self);
if (dtype && !binding_func->func.m_self) {
/* Unbound method call, make sure that the first argument is acceptable
as 'self' */
PyObject *self;
for (i = 0; i < argc; i++) {
PyObject *item = PyTuple_GET_ITEM(args, i);
Py_INCREF(item);
PyTuple_SET_ITEM(new_args, i + 1, item);
}
argc = PyTuple_GET_SIZE(args);
args = new_args;
} else if (binding_func->type) {
/* Unbound method call */
if (argc < 1) {
PyErr_Format(PyExc_TypeError, "Need at least one argument, 0 given.");
return NULL;
}
self = PyTuple_GET_ITEM(args, 0);
if (!PyObject_IsInstance(self, dtype)) {
PyErr_Format(PyExc_TypeError,
"First argument should be of type %%s, got %%s.",
((PyTypeObject *) dtype)->tp_name,
self->ob_type->tp_name);
return NULL;
}
}
args = PyTuple_GetSlice(args, 1, argc);
if (args == NULL) {
return NULL;
}
if (self && !PyObject_IsInstance(self, binding_func->type)) {
PyErr_Format(PyExc_TypeError,
"First argument should be of type %%s, got %%s.",
((PyTypeObject *) binding_func->type)->tp_name,
self->ob_type->tp_name);
goto __pyx_err;
}
if (binding_func->__signatures__) {
/*
binaryfunc meth = (binaryfunc) binding_func->func.m_ml->ml_meth;
func = new_func = meth(binding_func->__signatures__, args);
*/
PyObject *tup = PyTuple_Pack(2, binding_func->__signatures__, args);
if (!tup)
goto __pyx_err;
func = new_func = PyCFunction_NewEx(binding_func->func.m_ml, self, dtype);
func = new_func = PyCFunction_Call(func, tup, NULL);
if (!new_func)
goto __pyx_err;
}
result = PyCFunction_Call(func, args, kw);
__pyx_err:
Py_XDECREF(new_args);
Py_XDECREF(new_func);
return result;
}
static int __PYX_M(init)(void) {
if (PyType_Ready(&__PYX_T) < 0) {
static int %(binding_cfunc)s_init(void) {
if (PyType_Ready(&%(binding_cfunc)s_type) < 0) {
return -1;
}
__PYX_TP = &__PYX_T;
%(binding_cfunc)s = &%(binding_cfunc)s_type;
return 0;
}
#undef __PYX_O
#undef __PYX_T
#undef __PYX_TP
#undef __PYX_M
""" % Naming.__dict__)
......
......@@ -1852,6 +1852,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Function import code ---*/")
for module in imported_modules:
self.specialize_fused_types(module, env)
self.generate_c_function_import_code_for_module(module, env, code)
code.putln("/*--- Execution code ---*/")
......@@ -2059,11 +2060,23 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if entry.defined_in_pxd:
self.generate_type_import_code(env, entry.type, entry.pos, code)
def specialize_fused_types(self, pxd_env, impl_env):
"""
If fused c(p)def functions are defined in an imported pxd, but not
used in this implementation file, we still have fused entries and
not specialized ones. This method replaces any fused entries with their
specialized ones.
"""
for entry in pxd_env.cfunc_entries[:]:
if entry.type.is_fused:
# This call modifies the cfunc_entries in-place
entry.type.get_all_specific_function_types()
def generate_c_function_import_code_for_module(self, module, env, code):
# Generate import code for all exported C functions in a cimported module.
entries = []
for entry in module.cfunc_entries:
if entry.defined_in_pxd:
if entry.defined_in_pxd and entry.used:
entries.append(entry)
if entries:
env.use_utility_code(import_module_utility_code)
......
......@@ -991,6 +991,7 @@ class CVarDefNode(StatNode):
# declarators [CDeclaratorNode]
# in_pxd boolean
# api boolean
# overridable boolean whether it is a cpdef
# decorators [cython.locals(...)] or None
# directive_locals { string : NameNode } locals defined by cython.locals(...)
......@@ -1008,6 +1009,8 @@ class CVarDefNode(StatNode):
self.dest_scope = dest_scope
base_type = self.base_type.analyse(env)
self.entry = None
# If the field is an external typedef, we cannot be sure about the type,
# so do conversion ourself rather than rely on the CPython mechanism (through
# a property; made in AnalyseDeclarationsTransform).
......@@ -1036,20 +1039,21 @@ class CVarDefNode(StatNode):
error(declarator.pos, "Missing name in declaration.")
return
if type.is_cfunction:
entry = dest_scope.declare_cfunction(name, type, declarator.pos,
self.entry = dest_scope.declare_cfunction(name, type, declarator.pos,
cname = cname, visibility = self.visibility, in_pxd = self.in_pxd,
api = self.api)
if entry is not None:
entry.directive_locals = copy.copy(self.directive_locals)
if self.entry is not None:
self.entry.is_overridable = self.overridable
self.entry.directive_locals = copy.copy(self.directive_locals)
else:
if self.directive_locals:
error(self.pos, "Decorators can only be followed by functions")
if self.in_pxd and self.visibility != 'extern':
error(self.pos,
"Only 'extern' C variable declaration allowed in .pxd file")
entry = dest_scope.declare_var(name, type, declarator.pos,
self.entry = dest_scope.declare_var(name, type, declarator.pos,
cname=cname, visibility=visibility, api=self.api, is_cdef=1)
entry.needs_property = need_property
self.entry.needs_property = need_property
class CStructOrUnionDefNode(StatNode):
......@@ -1841,6 +1845,7 @@ class CFuncDefNode(FuncDefNode):
self.py_func.is_module_scope = env.is_module_scope
self.py_func.analyse_declarations(env)
self.entry.as_variable = self.py_func.entry
self.entry.used = self.entry.as_variable.used = True
# Reset scope entry the above cfunction
env.entries[name] = self.entry
if not env.is_module_scope or Options.lookup_module_cpdef:
......@@ -2057,9 +2062,16 @@ class FusedCFuncDefNode(StatListNode):
node.entry.fused_cfunction = self
self.stats = self.nodes[:]
if self.py_func:
self.stats.append(self.py_func)
self.py_func.entry.fused_cfunction = self
for node in self.nodes:
node.py_func.fused_py_func = self.py_func
node.entry.as_variable = self.py_func.entry
# Copy the nodes as AnalyseDeclarationsTransform will append
# self.py_func to self.stats, as we only want specialized
# CFuncDefNodes in self.nodes
self.stats = self.nodes[:]
def copy_cdefs(self, env):
"""
......@@ -2077,10 +2089,10 @@ class FusedCFuncDefNode(StatListNode):
env.cfunc_entries.remove(self.node.entry)
# Prevent copying of the python function
self.py_func = self.node.py_func
orig_py_func = self.node.py_func
self.node.py_func = None
if self.py_func:
env.pyfunc_entries.remove(self.py_func.entry)
if orig_py_func:
env.pyfunc_entries.remove(orig_py_func.entry)
fused_types = self.node.type.get_fused_types()
......@@ -2108,7 +2120,7 @@ class FusedCFuncDefNode(StatListNode):
copied_node.local_scope.fused_to_specific = fused_to_specific
# This is copied from the original function, set it to false to
# stop recursivon
# stop recursion
copied_node.has_fused_arguments = False
self.nodes.append(copied_node)
......@@ -2123,16 +2135,18 @@ class FusedCFuncDefNode(StatListNode):
copied_node.declare_cpdef_wrapper(env)
if copied_node.py_func:
env.pyfunc_entries.remove(copied_node.py_func.entry)
# copied_node.py_func.self_in_stararg = True
type_strings = [
fused_to_specific[fused_type].typeof_name()
for fused_type in fused_types]
type_strings = [str(fused_to_specific[fused_type])
for fused_type in fused_types]
if len(type_strings) == 1:
sigstring = type_strings[0]
else:
sigstring = '(%s)' % ', '.join(type_strings)
sigstring = ', '.join(type_strings)
copied_node.py_func.specialized_signature_string = sigstring
copied_node.py_func.fused_py_func = self.py_func
e = copied_node.py_func.entry
e.pymethdef_cname = PyrexTypes.get_fused_cname(
......@@ -2146,22 +2160,129 @@ class FusedCFuncDefNode(StatListNode):
if Errors.num_errors > num_errors:
break
if self.py_func:
self.py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
self.py_func.fused_args_positions = [
i for i, arg in enumerate(self.node.type.args)
if arg.is_fused]
if orig_py_func:
self.py_func = self.make_fused_cpdef(orig_py_func, env)
else:
self.py_func = orig_py_func
from Cython.Compiler import TreeFragment
fragment = TreeFragment.TreeFragment(u"""
raise ValueError("Index the function to get a specialized version")
""", level='function')
self.py_func.body = fragment.substitute()
def make_fused_cpdef(self, orig_py_func, env):
"""
This creates the function that is indexable from Python and does
runtime dispatch based on the argument types.
"""
from Cython.Compiler import TreeFragment
from Cython.Compiler import ParseTreeTransforms
# { (arg_pos, FusedType) : specialized_type }
seen_fused_types = cython.set()
# list of statements that do the instance checks
body_stmts = []
for i, arg_type in enumerate(self.node.type.args):
arg_type = arg_type.type
if arg_type.is_fused and arg_type not in seen_fused_types:
seen_fused_types.add(arg_type)
specialized_types = PyrexTypes.get_specific_types(arg_type)
# Prefer long over int, etc
specialized_types.sort()
seen_py_type_names = cython.set()
first_check = True
for specialized_type in specialized_types:
py_type_name = specialized_type.py_type_name()
if not py_type_name or py_type_name in seen_py_type_names:
continue
seen_py_type_names.add(py_type_name)
if first_check:
if_ = 'if'
first_check = False
else:
if_ = 'elif'
tup = (if_, i, py_type_name, len(seen_fused_types) - 1,
specialized_type.typeof_name())
body_stmts.append(
" %s isinstance(args[%d], %s): "
"dest_sig[%d] = '%s'" % tup)
fmt_dict = {
'body': '\n'.join(body_stmts),
'nargs': len(self.node.type.args),
'name': orig_py_func.entry.name,
}
fragment = TreeFragment.TreeFragment(u"""
def __pyx_fused_cpdef(signatures, args):
if len(args) < %(nargs)d:
raise TypeError("Invalid number of arguments, expected %(nargs)d, "
"got %%d" %% len(args))
import sys
if sys.version_info >= (3, 0):
long = int
unicode = str
else:
bytes = str
dest_sig = [None] * len(args)
# instance check body
%(body)s
candidates = []
for sig in signatures:
match_found = True
for src_type, dst_type in zip(sig.strip('()').split(', '), dest_sig):
if dst_type is not None and match_found:
match_found = src_type == dst_type
if match_found:
candidates.append(sig)
if not candidates:
raise TypeError("No matching signature found")
elif len(candidates) > 1:
raise TypeError("Function call with ambiguous argument types")
else:
return signatures[candidates[0]]
""" % fmt_dict, level='module')
# analyse the declarations of our fragment ...
py_func, = fragment.substitute(pos=self.node.pos).stats
# Analyse the function object ...
py_func.analyse_declarations(env)
# ... and its body
py_func.scope = env
ParseTreeTransforms.AnalyseDeclarationsTransform(None)(py_func)
e, orig_e = py_func.entry, orig_py_func.entry
# Update the new entry ...
py_func.name = e.name = orig_e.name
e.cname, e.func_cname = orig_e.cname, orig_e.func_cname
e.pymethdef_cname = orig_e.pymethdef_cname
# e.signature = TypeSlots.binaryfunc
# ... and the symbol table
del env.entries['__pyx_fused_cpdef']
env.entries[e.name].as_variable = e
env.pyfunc_entries.append(e)
py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
return py_func
def generate_function_definitions(self, env, code):
for stat in self.stats:
# print stat.entry, stat.entry.used
if stat.entry.used:
code.mark_pos(stat.pos)
stat.generate_function_definitions(env, code)
def generate_execution_code(self, code):
......@@ -2217,8 +2338,6 @@ class DefNode(FuncDefNode):
# fused_py_func DefNode The original fused cpdef DefNode
# (in case this is a specialization)
# specialized_cpdefs [DefNode] list of specialized cpdef DefNodes
# fused_args_positions [int] list of the positions of the
# arguments with fused types
child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
......@@ -2240,7 +2359,6 @@ class DefNode(FuncDefNode):
fused_py_func = False
specialized_cpdefs = None
fused_args_positions = None
def __init__(self, pos, **kwds):
FuncDefNode.__init__(self, pos, **kwds)
......@@ -2563,14 +2681,6 @@ class DefNode(FuncDefNode):
self.local_scope.directives = env.directives
self.analyse_default_values(env)
if self.specialized_cpdefs:
for arg in self.args + self.local_scope.arg_entries:
arg.needs_conversion = False
arg.type = py_object_type
self.local_scope.entries.clear()
del self.local_scope.var_entries[:]
if self.needs_assignment_synthesis(env):
# Shouldn't we be doing this at the module level too?
self.synthesize_assignment_node(env)
......@@ -2604,8 +2714,7 @@ class DefNode(FuncDefNode):
self.pos,
pymethdef_cname = self.entry.pymethdef_cname,
binding = env.directives['binding'],
specialized_cpdefs = self.specialized_cpdefs,
fused_args_positions = self.fused_args_positions)
specialized_cpdefs = self.specialized_cpdefs)
if env.is_py_class_scope:
if not self.is_staticmethod and not self.is_classmethod:
......
......@@ -1144,6 +1144,8 @@ if VALUE is not None:
if node.has_fused_arguments:
node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1])
self.visitchildren(node)
if node.py_func:
node.stats.append(node.py_func)
else:
node.body.analyse_declarations(lenv)
self.env_stack.append(lenv)
......@@ -1354,17 +1356,9 @@ class AnalyseExpressionsTransform(EnvTransform):
re-analyse the types.
"""
self.visit_Node(node)
type = node.type
if node.is_fused_index:
if node.type is PyrexTypes.error_type:
node.type = PyrexTypes.error_type
else:
node.base.type = node.type
node.base.entry = getattr(node, 'entry', None) or node.type.entry
node = node.base
node.analyse_types(self.env_stack[-1])
if node.is_fused_index and node.type is not PyrexTypes.error_type:
node = node.base
return node
......
......@@ -62,6 +62,29 @@ class BaseType(object):
is_fused = property(get_fused_types, doc="Whether this type or any of its "
"subtypes is a fused type")
def __lt__(self, other):
"""
For sorting. The sorting order should correspond to the preference of
conversion from Python types.
"""
return NotImplemented
def py_type_name(self):
"""
Return the name of the Python type that can coerce to this type.
"""
def typeof_name(self):
"""
Return the string with which fused python functions can be indexed.
"""
if self.is_builtin_type or self.py_type_name() == 'object':
index_name = self.py_type_name()
else:
index_name = str(self)
return index_name
class PyrexType(BaseType):
#
# Base class for all Pyrex types.
......@@ -334,6 +357,8 @@ class CTypedefType(BaseType):
def __getattr__(self, name):
return getattr(self.typedef_base_type, name)
def py_type_name(self):
return self.typedef_base_type.py_type_name()
class BufferType(BaseType):
#
......@@ -418,6 +443,17 @@ class PyObjectType(PyrexType):
else:
return cname
def py_type_name(self):
return "object"
def __lt__(self, other):
"""
Make sure we sort highest, as instance checking on py_type_name
('object') is always true
"""
return False
class BuiltinObjectType(PyObjectType):
# objstruct_cname string Name of PyObject struct
......@@ -514,6 +550,10 @@ class BuiltinObjectType(PyObjectType):
to_object_struct and self.objstruct_cname or "PyObject", # self.objstruct_cname may be None
expr_code)
def py_type_name(self):
return self.name
class PyExtensionType(PyObjectType):
#
......@@ -621,6 +661,12 @@ class PyExtensionType(PyObjectType):
return "<PyExtensionType %s%s>" % (self.scope.class_name,
("", " typedef")[self.typedef_flag])
def py_type_name(self):
if not self.module_name:
return self.name
return "__import__(%r, None, None, ['']).%s" % (self.module_name,
self.name)
class CType(PyrexType):
#
......@@ -773,6 +819,17 @@ class CNumericType(CType):
cname=" ")
return True
def __lt__(self, other):
"Sort based on rank, preferring signed over unsigned"
if other.is_numeric:
return self.rank > other.rank and self.signed >= other.signed
return NotImplemented
def py_type_name(self):
if self.rank <= 4:
return "(int, long)"
return "float"
type_conversion_predeclarations = ""
type_conversion_functions = ""
......@@ -1010,6 +1067,9 @@ class CBIntType(CIntType):
def __str__(self):
return 'bint'
def py_type_name(self):
return "bool"
class CPyUCS4IntType(CIntType):
# Py_UCS4
......@@ -1339,6 +1399,9 @@ class CComplexType(CNumericType):
def binary_op(self, op):
return self.lookup_op(2, op)
def py_type_name(self):
return "complex"
complex_ops = {
(1, '-'): 'neg',
(1, 'zero'): 'is_zero',
......@@ -2040,7 +2103,6 @@ class CFuncType(CType):
elif self.cached_specialized_types is not None:
return self.cached_specialized_types
cfunc_entries = self.entry.scope.cfunc_entries
cfunc_entries.remove(self.entry)
......@@ -2482,6 +2544,10 @@ class CStringType(object):
assert isinstance(value, str)
return '"%s"' % StringEncoding.escape_byte_string(value)
def py_type_name(self):
if self.is_unicode:
return "unicode"
return "bytes"
class CUTF8CharArrayType(CStringType, CArrayType):
# C 'char []' type.
......@@ -2726,6 +2792,8 @@ def best_match(args, functions, pos=None, env=None):
the same weight, we return None (as there is no best match). If pos
is not None, we also generate an error.
"""
from Cython import Utils
# TODO: args should be a list of types, not a list of Nodes.
actual_nargs = len(args)
......@@ -2764,8 +2832,9 @@ def best_match(args, functions, pos=None, env=None):
return candidates[0][0]
elif len(candidates) == 0:
if pos is not None:
if len(errors) == 1:
error(pos, errors[0][1])
func, errmsg = errors[0]
if len(errors) == 1 or [1 for func, e in errors if e == errmsg]:
error(pos, errmsg)
else:
error(pos, "no suitable method found")
return None
......
......@@ -64,7 +64,8 @@ def sizeof(arg):
return 1
def typeof(arg):
return type(arg)
return arg.__class__.__name__
# return type(arg)
def address(arg):
return pointer(type(arg))([arg])
......@@ -233,9 +234,7 @@ class typedef(CythonType):
return self.name or str(self._basetype)
class _FusedType(CythonType):
def __call__(self, type, value):
return value
pass
def fused_type(*args):
......
......@@ -613,6 +613,7 @@ def run_forked_test(result, run_func, test_name, fork=True):
gc.collect()
return
module_name = test_name.split()[-1]
# fork to make sure we do not keep the tested module loaded
result_handle, result_file = tempfile.mkstemp()
os.close(result_handle)
......
cimport cython
cy = __import__("cython")
cpdef func1(self, cython.integral x):
print "%s," % (self,),
if cython.integral is int:
print 'x is int', x, cython.typeof(x)
else:
print 'x is long', x, cython.typeof(x)
class A(object):
meth = func1
def __str__(self):
return "A"
pyfunc = func1
def test_fused_cpdef():
"""
>>> test_fused_cpdef()
None, x is int 2 int
None, x is long 2 long
None, x is long 2 long
<BLANKLINE>
None, x is int 2 int
None, x is long 2 long
<BLANKLINE>
A, x is int 2 int
A, x is long 2 long
A, x is long 2 long
A, x is long 2 long
"""
func1[int](None, 2)
func1[long](None, 2)
func1(None, 2)
print
pyfunc[cy.int](None, 2)
pyfunc(None, 2)
print
A.meth[cy.int](A(), 2)
A.meth(A(), 2)
A().meth[cy.long](2)
A().meth(2)
def assert_raise(func, *args):
try:
func(*args)
except TypeError:
pass
else:
assert False, "Function call did not raise TypeError"
def test_badcall():
"""
>>> test_badcall()
"""
assert_raise(pyfunc)
assert_raise(pyfunc, 1, 2, 3)
assert_raise(pyfunc[cy.int], 10, 11, 12)
assert_raise(pyfunc, None, object())
assert_raise(A().meth)
assert_raise(A.meth)
assert_raise(A().meth[cy.int])
assert_raise(A.meth[cy.int])
ctypedef long double long_double
cpdef multiarg(cython.integral x, cython.floating y):
if cython.integral is int:
print "x is an int,",
else:
print "x is a long,",
if cython.floating is long_double:
print "y is a long double:",
elif float is cython.floating:
print "y is a float:",
else:
print "y is a double:",
print x, y
def test_multiarg():
"""
>>> test_multiarg()
x is an int, y is a float: 1 2.0
x is an int, y is a float: 1 2.0
x is a long, y is a double: 4 5.0
"""
multiarg[int, float](1, 2.0)
multiarg[cy.int, cy.float](1, 2.0)
multiarg(4, 5.0)
......@@ -21,7 +21,7 @@ ctypedef int *p_int
def test_pure():
"""
>>> test_pure()
(10+0j)
10
"""
mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex))
print mytype(10)
......
......@@ -32,18 +32,20 @@ cdef public class MyExt [ type MyExtType, object MyExtObject ]:
ctypedef char *string_t
ctypedef cython.fused_type(int, float) simple_t
ctypedef cython.fused_type(int, float, string_t) less_simple_t
ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) object_t
ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) struct_t
ctypedef cython.fused_type(str, unicode, bytes) builtin_t
cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
cdef public_optional_args(object_t obj, simple_t simple = *)
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
cdef struct_t add_simple(struct_t obj, simple_t simple)
cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple)
cdef public_optional_args(struct_t obj, simple_t simple = *)
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y)
cdef cython.floating method(self, cython.integral x, cython.floating y)
cpdef cpdef_method(self, cython.integral x, cython.floating y)
ctypedef cython.fused_type(TestFusedExtMethods, object, list) object_t
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z)
######## header.h ########
......@@ -54,28 +56,37 @@ typedef long extern_long;
cimport cython
cdef object_t add_simple(object_t obj, simple_t simple):
cdef struct_t add_simple(struct_t obj, simple_t simple):
obj.a = <int> (obj.a + simple)
return obj
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple):
return obj.a + simple
cdef public_optional_args(object_t obj, simple_t simple = 6):
cdef public_optional_args(struct_t obj, simple_t simple = 6):
return obj.a, simple
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y):
if integral is int:
cdef cython.floating method(self, cython.integral x, cython.floating y):
if cython.integral is int:
x += 1
if floating is double:
if cython.floating is double:
y += 2.0
return x + y
cpdef cpdef_method(self, cython.integral x, cython.floating y):
return cython.typeof(x), cython.typeof(y)
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z):
return cython.typeof(x), cython.typeof(y), cython.typeof(z)
######## b.pyx ########
cimport cython
cimport a as a_cmod
from a cimport *
cdef mystruct_t mystruct
......@@ -134,9 +145,12 @@ assert obj.method[int, double](x, b) == 14.0
# Test inheritance
cdef class Subclass(TestFusedExtMethods):
cdef floating method(self, integral x, floating y):
cdef cython.floating method(self, cython.integral x, cython.floating y):
return -x -y
cpdef cpdef_method(self, cython.integral x, cython.floating y):
return x, y
cdef Subclass myobj = Subclass()
assert myobj.method[int, float](5, 5.0) == -10
......@@ -147,3 +161,44 @@ assert meth(myobj, 5, 5.0) == -10
meth = myobj.method[int, float]
assert meth(myobj, 5, 5.0) == -10
# Test cpdef functions and methods
cy = __import__("cython")
import a as a_mod
def ae(result, expected):
"assert equals"
if result != expected:
print 'result :', result
print 'expected:', expected
assert result == expected
ae(a_mod.public_cpdef["int, float, list"](5, 6, [7]), ("int", "float", "list"))
ae(a_mod.public_cpdef[int, float, list](5, 6, [7]), ("int", "float", "list"))
idx = cy.typeof(0), cy.typeof(0.0), cy.typeof([])
ae(a_mod.public_cpdef[idx](5, 6, [7]), ("int", "float", "list"))
ae(a_mod.public_cpdef[cy.int, cy.double, cython.typeof(obj)](5, 6, obj), ("int", "double", "TestFusedExtMethods"))
ae(a_mod.public_cpdef[cy.int, cy.double, cython.typeof(obj)](5, 6, myobj), ("int", "double", "TestFusedExtMethods"))
ae(public_cpdef[int, float, list](5, 6, [7]), ("int", "float", "list"))
ae(public_cpdef[int, double, TestFusedExtMethods](5, 6, obj), ("int", "double", "TestFusedExtMethods"))
ae(public_cpdef[int, double, TestFusedExtMethods](5, 6, myobj), ("int", "double", "TestFusedExtMethods"))
ae(obj.cpdef_method(10, 10.0), ("long", "double"))
ae(myobj.cpdef_method(10, 10.0), (10, 10.0))
ae(obj.cpdef_method[int, float](10, 10.0), ("int", "float"))
ae(myobj.cpdef_method[int, float](10, 10.0), (10, 10.0))
s = """\
import cython as cy
ae(obj.cpdef_method[cy.int, cy.float](10, 10.0), ("int", "float"))
ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0))
"""
d = {'obj': obj, 'myobj': myobj, 'ae': ae}
exec s in d, d
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment