Commit 4228dfd1 authored by scoder's avatar scoder

Merge pull request #284 from pv/fused-fixes

fused types: specialize each base type only once, also for compound type args
parents bc04ebef 95d76de0
...@@ -82,7 +82,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -82,7 +82,8 @@ class FusedCFuncDefNode(StatListNode):
""" """
fused_compound_types = PyrexTypes.unique( fused_compound_types = PyrexTypes.unique(
[arg.type for arg in self.node.args if arg.type.is_fused]) [arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types) fused_types = self._get_fused_base_types(fused_compound_types)
permutations = PyrexTypes.get_all_specialized_permutations(fused_types)
self.fused_compound_types = fused_compound_types self.fused_compound_types = fused_compound_types
...@@ -183,6 +184,17 @@ class FusedCFuncDefNode(StatListNode): ...@@ -183,6 +184,17 @@ class FusedCFuncDefNode(StatListNode):
else: else:
self.py_func = orig_py_func self.py_func = orig_py_func
def _get_fused_base_types(self, fused_compound_types):
"""
Get a list of unique basic fused types, from a list of
(possibly) compound fused types.
"""
base_types = []
seen = set()
for fused_type in fused_compound_types:
fused_type.get_fused_types(result=base_types, seen=seen)
return base_types
def _specialize_function_args(self, args, fused_to_specific): def _specialize_function_args(self, args, fused_to_specific):
for arg in args: for arg in args:
if arg.type.is_fused: if arg.type.is_fused:
...@@ -207,9 +219,10 @@ class FusedCFuncDefNode(StatListNode): ...@@ -207,9 +219,10 @@ class FusedCFuncDefNode(StatListNode):
node.has_fused_arguments = False node.has_fused_arguments = False
self.nodes.append(node) self.nodes.append(node)
def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types): def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types):
"""Specialize the copy of a DefNode given the copied node, """Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry""" the specialization cname and the original DefNode entry"""
fused_types = self._get_fused_base_types(fused_compound_types)
type_strings = [ type_strings = [
PyrexTypes.specialization_signature_string(fused_type, f2s) PyrexTypes.specialization_signature_string(fused_type, f2s)
for fused_type in fused_types for fused_type in fused_types
...@@ -522,13 +535,13 @@ class FusedCFuncDefNode(StatListNode): ...@@ -522,13 +535,13 @@ class FusedCFuncDefNode(StatListNode):
""" """
from . import TreeFragment, Code, UtilityCode from . import TreeFragment, Code, UtilityCode
# { (arg_pos, FusedType) : specialized_type } fused_types = self._get_fused_base_types([
seen_fused_types = set() arg.type for arg in self.node.args if arg.type.is_fused])
context = { context = {
'memviewslice_cname': MemoryView.memviewslice_cname, 'memviewslice_cname': MemoryView.memviewslice_cname,
'func_args': self.node.args, 'func_args': self.node.args,
'n_fused': len([arg for arg in self.node.args]), 'n_fused': len(fused_types),
'name': orig_py_func.entry.name, 'name': orig_py_func.entry.name,
} }
...@@ -560,9 +573,17 @@ class FusedCFuncDefNode(StatListNode): ...@@ -560,9 +573,17 @@ class FusedCFuncDefNode(StatListNode):
fused_index = 0 fused_index = 0
default_idx = 0 default_idx = 0
all_buffer_types = set() all_buffer_types = set()
seen_fused_types = set()
for i, arg in enumerate(self.node.args): for i, arg in enumerate(self.node.args):
if arg.type.is_fused and arg.type not in seen_fused_types: if arg.type.is_fused:
seen_fused_types.add(arg.type) arg_fused_types = arg.type.get_fused_types()
if len(arg_fused_types) > 1:
raise NotImplementedError("Determination of more than one fused base "
"type per argument is not implemented.")
fused_type = arg_fused_types[0]
if arg.type.is_fused and fused_type not in seen_fused_types:
seen_fused_types.add(fused_type)
context.update( context.update(
arg_tuple_idx=i, arg_tuple_idx=i,
......
...@@ -18,7 +18,7 @@ ctypedef fused_type1 *composed_t ...@@ -18,7 +18,7 @@ ctypedef fused_type1 *composed_t
other_t = cython.fused_type(int, double) other_t = cython.fused_type(int, double)
ctypedef double *p_double ctypedef double *p_double
ctypedef int *p_int ctypedef int *p_int
fused_type3 = cython.fused_type(int, double)
def test_pure(): def test_pure():
""" """
...@@ -268,6 +268,12 @@ def get_array(itemsize, format): ...@@ -268,6 +268,12 @@ def get_array(itemsize, format):
result[6] = 6.0 result[6] = 6.0
return result return result
def get_intc_array():
result = array((10,), sizeof(int), 'i')
result[5] = 5.0
result[6] = 6.0
return result
def test_fused_memslice_dtype(cython.floating[:] array): def test_fused_memslice_dtype(cython.floating[:] array):
""" """
Note: the np.ndarray dtype test is in numpy_test Note: the np.ndarray dtype test is in numpy_test
...@@ -285,6 +291,42 @@ def test_fused_memslice_dtype(cython.floating[:] array): ...@@ -285,6 +291,42 @@ def test_fused_memslice_dtype(cython.floating[:] array):
print cython.typeof(array), cython.typeof(otherarray), \ print cython.typeof(array), cython.typeof(otherarray), \
array[5], otherarray[6] array[5], otherarray[6]
def test_fused_memslice_dtype_repeated(cython.floating[:] array1, cython.floating[:] array2):
"""
Note: the np.ndarray dtype test is in numpy_test
>>> import cython
>>> sorted(test_fused_memslice_dtype_repeated.__signatures__)
['double', 'float']
>>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(8, 'd'))
double[:] double[:]
>>> test_fused_memslice_dtype_repeated(get_array(4, 'f'), get_array(4, 'f'))
float[:] float[:]
>>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(4, 'f'))
Traceback (most recent call last):
ValueError: Buffer dtype mismatch, expected 'double' but got 'float'
"""
print cython.typeof(array1), cython.typeof(array2)
def test_fused_memslice_dtype_repeated_2(cython.floating[:] array1, cython.floating[:] array2,
fused_type3[:] array3):
"""
Note: the np.ndarray dtype test is in numpy_test
>>> import cython
>>> sorted(test_fused_memslice_dtype_repeated_2.__signatures__)
['double|double', 'double|int', 'float|double', 'float|int']
>>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_array(8, 'd'))
double[:] double[:] double[:]
>>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_intc_array())
double[:] double[:] int[:]
>>> test_fused_memslice_dtype_repeated_2(get_array(4, 'f'), get_array(4, 'f'), get_intc_array())
float[:] float[:] int[:]
"""
print cython.typeof(array1), cython.typeof(array2), cython.typeof(array3)
def test_cython_numeric(cython.numeric arg): def test_cython_numeric(cython.numeric arg):
""" """
Test to see whether complex numbers have their utility code declared Test to see whether complex numbers have their utility code declared
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment