Commit e2bd21ab authored by Mark Florisson's avatar Mark Florisson

Runtime dispatch to specialized cpdef

parent 2f20794c
This diff is collapsed.
...@@ -1852,6 +1852,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1852,6 +1852,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Function import code ---*/") code.putln("/*--- Function import code ---*/")
for module in imported_modules: for module in imported_modules:
self.specialize_fused_types(module, env)
self.generate_c_function_import_code_for_module(module, env, code) self.generate_c_function_import_code_for_module(module, env, code)
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
...@@ -2059,11 +2060,23 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2059,11 +2060,23 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if entry.defined_in_pxd: if entry.defined_in_pxd:
self.generate_type_import_code(env, entry.type, entry.pos, code) self.generate_type_import_code(env, entry.type, entry.pos, code)
def specialize_fused_types(self, pxd_env, impl_env):
"""
If fused c(p)def functions are defined in an imported pxd, but not
used in this implementation file, we still have fused entries and
not specialized ones. This method replaces any fused entries with their
specialized ones.
"""
for entry in pxd_env.cfunc_entries[:]:
if entry.type.is_fused:
# This call modifies the cfunc_entries in-place
entry.type.get_all_specific_function_types()
def generate_c_function_import_code_for_module(self, module, env, code): def generate_c_function_import_code_for_module(self, module, env, code):
# Generate import code for all exported C functions in a cimported module. # Generate import code for all exported C functions in a cimported module.
entries = [] entries = []
for entry in module.cfunc_entries: for entry in module.cfunc_entries:
if entry.defined_in_pxd: if entry.defined_in_pxd and entry.used:
entries.append(entry) entries.append(entry)
if entries: if entries:
env.use_utility_code(import_module_utility_code) env.use_utility_code(import_module_utility_code)
......
This diff is collapsed.
...@@ -1144,6 +1144,8 @@ if VALUE is not None: ...@@ -1144,6 +1144,8 @@ if VALUE is not None:
if node.has_fused_arguments: if node.has_fused_arguments:
node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1]) node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1])
self.visitchildren(node) self.visitchildren(node)
if node.py_func:
node.stats.append(node.py_func)
else: else:
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
self.env_stack.append(lenv) self.env_stack.append(lenv)
...@@ -1354,18 +1356,10 @@ class AnalyseExpressionsTransform(EnvTransform): ...@@ -1354,18 +1356,10 @@ class AnalyseExpressionsTransform(EnvTransform):
re-analyse the types. re-analyse the types.
""" """
self.visit_Node(node) self.visit_Node(node)
type = node.type
if node.is_fused_index: if node.is_fused_index and node.type is not PyrexTypes.error_type:
if node.type is PyrexTypes.error_type:
node.type = PyrexTypes.error_type
else:
node.base.type = node.type
node.base.entry = getattr(node, 'entry', None) or node.type.entry
node = node.base node = node.base
node.analyse_types(self.env_stack[-1])
return node return node
......
...@@ -62,6 +62,29 @@ class BaseType(object): ...@@ -62,6 +62,29 @@ class BaseType(object):
is_fused = property(get_fused_types, doc="Whether this type or any of its " is_fused = property(get_fused_types, doc="Whether this type or any of its "
"subtypes is a fused type") "subtypes is a fused type")
def __lt__(self, other):
"""
For sorting. The sorting order should correspond to the preference of
conversion from Python types.
"""
return NotImplemented
def py_type_name(self):
"""
Return the name of the Python type that can coerce to this type.
"""
def typeof_name(self):
"""
Return the string with which fused python functions can be indexed.
"""
if self.is_builtin_type or self.py_type_name() == 'object':
index_name = self.py_type_name()
else:
index_name = str(self)
return index_name
class PyrexType(BaseType): class PyrexType(BaseType):
# #
# Base class for all Pyrex types. # Base class for all Pyrex types.
...@@ -334,6 +357,8 @@ class CTypedefType(BaseType): ...@@ -334,6 +357,8 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
def py_type_name(self):
return self.typedef_base_type.py_type_name()
class BufferType(BaseType): class BufferType(BaseType):
# #
...@@ -418,6 +443,17 @@ class PyObjectType(PyrexType): ...@@ -418,6 +443,17 @@ class PyObjectType(PyrexType):
else: else:
return cname return cname
def py_type_name(self):
return "object"
def __lt__(self, other):
"""
Make sure we sort highest, as instance checking on py_type_name
('object') is always true
"""
return False
class BuiltinObjectType(PyObjectType): class BuiltinObjectType(PyObjectType):
# objstruct_cname string Name of PyObject struct # objstruct_cname string Name of PyObject struct
...@@ -514,6 +550,10 @@ class BuiltinObjectType(PyObjectType): ...@@ -514,6 +550,10 @@ class BuiltinObjectType(PyObjectType):
to_object_struct and self.objstruct_cname or "PyObject", # self.objstruct_cname may be None to_object_struct and self.objstruct_cname or "PyObject", # self.objstruct_cname may be None
expr_code) expr_code)
def py_type_name(self):
return self.name
class PyExtensionType(PyObjectType): class PyExtensionType(PyObjectType):
# #
...@@ -621,6 +661,12 @@ class PyExtensionType(PyObjectType): ...@@ -621,6 +661,12 @@ class PyExtensionType(PyObjectType):
return "<PyExtensionType %s%s>" % (self.scope.class_name, return "<PyExtensionType %s%s>" % (self.scope.class_name,
("", " typedef")[self.typedef_flag]) ("", " typedef")[self.typedef_flag])
def py_type_name(self):
if not self.module_name:
return self.name
return "__import__(%r, None, None, ['']).%s" % (self.module_name,
self.name)
class CType(PyrexType): class CType(PyrexType):
# #
...@@ -773,6 +819,17 @@ class CNumericType(CType): ...@@ -773,6 +819,17 @@ class CNumericType(CType):
cname=" ") cname=" ")
return True return True
def __lt__(self, other):
"Sort based on rank, preferring signed over unsigned"
if other.is_numeric:
return self.rank > other.rank and self.signed >= other.signed
return NotImplemented
def py_type_name(self):
if self.rank <= 4:
return "(int, long)"
return "float"
type_conversion_predeclarations = "" type_conversion_predeclarations = ""
type_conversion_functions = "" type_conversion_functions = ""
...@@ -1010,6 +1067,9 @@ class CBIntType(CIntType): ...@@ -1010,6 +1067,9 @@ class CBIntType(CIntType):
def __str__(self): def __str__(self):
return 'bint' return 'bint'
def py_type_name(self):
return "bool"
class CPyUCS4IntType(CIntType): class CPyUCS4IntType(CIntType):
# Py_UCS4 # Py_UCS4
...@@ -1339,6 +1399,9 @@ class CComplexType(CNumericType): ...@@ -1339,6 +1399,9 @@ class CComplexType(CNumericType):
def binary_op(self, op): def binary_op(self, op):
return self.lookup_op(2, op) return self.lookup_op(2, op)
def py_type_name(self):
return "complex"
complex_ops = { complex_ops = {
(1, '-'): 'neg', (1, '-'): 'neg',
(1, 'zero'): 'is_zero', (1, 'zero'): 'is_zero',
...@@ -2040,7 +2103,6 @@ class CFuncType(CType): ...@@ -2040,7 +2103,6 @@ class CFuncType(CType):
elif self.cached_specialized_types is not None: elif self.cached_specialized_types is not None:
return self.cached_specialized_types return self.cached_specialized_types
cfunc_entries = self.entry.scope.cfunc_entries cfunc_entries = self.entry.scope.cfunc_entries
cfunc_entries.remove(self.entry) cfunc_entries.remove(self.entry)
...@@ -2482,6 +2544,10 @@ class CStringType(object): ...@@ -2482,6 +2544,10 @@ class CStringType(object):
assert isinstance(value, str) assert isinstance(value, str)
return '"%s"' % StringEncoding.escape_byte_string(value) return '"%s"' % StringEncoding.escape_byte_string(value)
def py_type_name(self):
if self.is_unicode:
return "unicode"
return "bytes"
class CUTF8CharArrayType(CStringType, CArrayType): class CUTF8CharArrayType(CStringType, CArrayType):
# C 'char []' type. # C 'char []' type.
...@@ -2726,6 +2792,8 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2726,6 +2792,8 @@ def best_match(args, functions, pos=None, env=None):
the same weight, we return None (as there is no best match). If pos the same weight, we return None (as there is no best match). If pos
is not None, we also generate an error. is not None, we also generate an error.
""" """
from Cython import Utils
# TODO: args should be a list of types, not a list of Nodes. # TODO: args should be a list of types, not a list of Nodes.
actual_nargs = len(args) actual_nargs = len(args)
...@@ -2764,8 +2832,9 @@ def best_match(args, functions, pos=None, env=None): ...@@ -2764,8 +2832,9 @@ def best_match(args, functions, pos=None, env=None):
return candidates[0][0] return candidates[0][0]
elif len(candidates) == 0: elif len(candidates) == 0:
if pos is not None: if pos is not None:
if len(errors) == 1: func, errmsg = errors[0]
error(pos, errors[0][1]) if len(errors) == 1 or [1 for func, e in errors if e == errmsg]:
error(pos, errmsg)
else: else:
error(pos, "no suitable method found") error(pos, "no suitable method found")
return None return None
......
...@@ -64,7 +64,8 @@ def sizeof(arg): ...@@ -64,7 +64,8 @@ def sizeof(arg):
return 1 return 1
def typeof(arg): def typeof(arg):
return type(arg) return arg.__class__.__name__
# return type(arg)
def address(arg): def address(arg):
return pointer(type(arg))([arg]) return pointer(type(arg))([arg])
...@@ -233,9 +234,7 @@ class typedef(CythonType): ...@@ -233,9 +234,7 @@ class typedef(CythonType):
return self.name or str(self._basetype) return self.name or str(self._basetype)
class _FusedType(CythonType): class _FusedType(CythonType):
pass
def __call__(self, type, value):
return value
def fused_type(*args): def fused_type(*args):
......
...@@ -613,6 +613,7 @@ def run_forked_test(result, run_func, test_name, fork=True): ...@@ -613,6 +613,7 @@ def run_forked_test(result, run_func, test_name, fork=True):
gc.collect() gc.collect()
return return
module_name = test_name.split()[-1]
# fork to make sure we do not keep the tested module loaded # fork to make sure we do not keep the tested module loaded
result_handle, result_file = tempfile.mkstemp() result_handle, result_file = tempfile.mkstemp()
os.close(result_handle) os.close(result_handle)
......
cimport cython
cy = __import__("cython")
cpdef func1(self, cython.integral x):
print "%s," % (self,),
if cython.integral is int:
print 'x is int', x, cython.typeof(x)
else:
print 'x is long', x, cython.typeof(x)
class A(object):
meth = func1
def __str__(self):
return "A"
pyfunc = func1
def test_fused_cpdef():
"""
>>> test_fused_cpdef()
None, x is int 2 int
None, x is long 2 long
None, x is long 2 long
<BLANKLINE>
None, x is int 2 int
None, x is long 2 long
<BLANKLINE>
A, x is int 2 int
A, x is long 2 long
A, x is long 2 long
A, x is long 2 long
"""
func1[int](None, 2)
func1[long](None, 2)
func1(None, 2)
print
pyfunc[cy.int](None, 2)
pyfunc(None, 2)
print
A.meth[cy.int](A(), 2)
A.meth(A(), 2)
A().meth[cy.long](2)
A().meth(2)
def assert_raise(func, *args):
try:
func(*args)
except TypeError:
pass
else:
assert False, "Function call did not raise TypeError"
def test_badcall():
"""
>>> test_badcall()
"""
assert_raise(pyfunc)
assert_raise(pyfunc, 1, 2, 3)
assert_raise(pyfunc[cy.int], 10, 11, 12)
assert_raise(pyfunc, None, object())
assert_raise(A().meth)
assert_raise(A.meth)
assert_raise(A().meth[cy.int])
assert_raise(A.meth[cy.int])
ctypedef long double long_double
cpdef multiarg(cython.integral x, cython.floating y):
if cython.integral is int:
print "x is an int,",
else:
print "x is a long,",
if cython.floating is long_double:
print "y is a long double:",
elif float is cython.floating:
print "y is a float:",
else:
print "y is a double:",
print x, y
def test_multiarg():
"""
>>> test_multiarg()
x is an int, y is a float: 1 2.0
x is an int, y is a float: 1 2.0
x is a long, y is a double: 4 5.0
"""
multiarg[int, float](1, 2.0)
multiarg[cy.int, cy.float](1, 2.0)
multiarg(4, 5.0)
...@@ -21,7 +21,7 @@ ctypedef int *p_int ...@@ -21,7 +21,7 @@ ctypedef int *p_int
def test_pure(): def test_pure():
""" """
>>> test_pure() >>> test_pure()
(10+0j) 10
""" """
mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex)) mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex))
print mytype(10) print mytype(10)
......
...@@ -32,18 +32,20 @@ cdef public class MyExt [ type MyExtType, object MyExtObject ]: ...@@ -32,18 +32,20 @@ cdef public class MyExt [ type MyExtType, object MyExtObject ]:
ctypedef char *string_t ctypedef char *string_t
ctypedef cython.fused_type(int, float) simple_t ctypedef cython.fused_type(int, float) simple_t
ctypedef cython.fused_type(int, float, string_t) less_simple_t ctypedef cython.fused_type(int, float, string_t) less_simple_t
ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) object_t ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) struct_t
ctypedef cython.fused_type(str, unicode, bytes) builtin_t ctypedef cython.fused_type(str, unicode, bytes) builtin_t
cdef object_t add_simple(object_t obj, simple_t simple) cdef struct_t add_simple(struct_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple) cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple)
cdef public_optional_args(object_t obj, simple_t simple = *) cdef public_optional_args(struct_t obj, simple_t simple = *)
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
cdef class TestFusedExtMethods(object): cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y) cdef cython.floating method(self, cython.integral x, cython.floating y)
cpdef cpdef_method(self, cython.integral x, cython.floating y)
ctypedef cython.fused_type(TestFusedExtMethods, object, list) object_t
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z)
######## header.h ######## ######## header.h ########
...@@ -54,28 +56,37 @@ typedef long extern_long; ...@@ -54,28 +56,37 @@ typedef long extern_long;
cimport cython cimport cython
cdef object_t add_simple(object_t obj, simple_t simple): cdef struct_t add_simple(struct_t obj, simple_t simple):
obj.a = <int> (obj.a + simple) obj.a = <int> (obj.a + simple)
return obj return obj
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple): cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple):
return obj.a + simple return obj.a + simple
cdef public_optional_args(object_t obj, simple_t simple = 6): cdef public_optional_args(struct_t obj, simple_t simple = 6):
return obj.a, simple return obj.a, simple
cdef class TestFusedExtMethods(object): cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y): cdef cython.floating method(self, cython.integral x, cython.floating y):
if integral is int: if cython.integral is int:
x += 1 x += 1
if floating is double: if cython.floating is double:
y += 2.0 y += 2.0
return x + y return x + y
cpdef cpdef_method(self, cython.integral x, cython.floating y):
return cython.typeof(x), cython.typeof(y)
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z):
return cython.typeof(x), cython.typeof(y), cython.typeof(z)
######## b.pyx ######## ######## b.pyx ########
cimport cython
cimport a as a_cmod
from a cimport * from a cimport *
cdef mystruct_t mystruct cdef mystruct_t mystruct
...@@ -134,9 +145,12 @@ assert obj.method[int, double](x, b) == 14.0 ...@@ -134,9 +145,12 @@ assert obj.method[int, double](x, b) == 14.0
# Test inheritance # Test inheritance
cdef class Subclass(TestFusedExtMethods): cdef class Subclass(TestFusedExtMethods):
cdef floating method(self, integral x, floating y): cdef cython.floating method(self, cython.integral x, cython.floating y):
return -x -y return -x -y
cpdef cpdef_method(self, cython.integral x, cython.floating y):
return x, y
cdef Subclass myobj = Subclass() cdef Subclass myobj = Subclass()
assert myobj.method[int, float](5, 5.0) == -10 assert myobj.method[int, float](5, 5.0) == -10
...@@ -147,3 +161,44 @@ assert meth(myobj, 5, 5.0) == -10 ...@@ -147,3 +161,44 @@ assert meth(myobj, 5, 5.0) == -10
meth = myobj.method[int, float] meth = myobj.method[int, float]
assert meth(myobj, 5, 5.0) == -10 assert meth(myobj, 5, 5.0) == -10
# Test cpdef functions and methods
cy = __import__("cython")
import a as a_mod
def ae(result, expected):
"assert equals"
if result != expected:
print 'result :', result
print 'expected:', expected
assert result == expected
ae(a_mod.public_cpdef["int, float, list"](5, 6, [7]), ("int", "float", "list"))
ae(a_mod.public_cpdef[int, float, list](5, 6, [7]), ("int", "float", "list"))
idx = cy.typeof(0), cy.typeof(0.0), cy.typeof([])
ae(a_mod.public_cpdef[idx](5, 6, [7]), ("int", "float", "list"))
ae(a_mod.public_cpdef[cy.int, cy.double, cython.typeof(obj)](5, 6, obj), ("int", "double", "TestFusedExtMethods"))
ae(a_mod.public_cpdef[cy.int, cy.double, cython.typeof(obj)](5, 6, myobj), ("int", "double", "TestFusedExtMethods"))
ae(public_cpdef[int, float, list](5, 6, [7]), ("int", "float", "list"))
ae(public_cpdef[int, double, TestFusedExtMethods](5, 6, obj), ("int", "double", "TestFusedExtMethods"))
ae(public_cpdef[int, double, TestFusedExtMethods](5, 6, myobj), ("int", "double", "TestFusedExtMethods"))
ae(obj.cpdef_method(10, 10.0), ("long", "double"))
ae(myobj.cpdef_method(10, 10.0), (10, 10.0))
ae(obj.cpdef_method[int, float](10, 10.0), ("int", "float"))
ae(myobj.cpdef_method[int, float](10, 10.0), (10, 10.0))
s = """\
import cython as cy
ae(obj.cpdef_method[cy.int, cy.float](10, 10.0), ("int", "float"))
ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0))
"""
d = {'obj': obj, 'myobj': myobj, 'ae': ae}
exec s in d, d
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment