Commit 69b2e1a5 authored by Mark Florisson's avatar Mark Florisson

Support fused memoryviews & fused base types, disable fused generators,...

Support fused memoryviews & fused base types, disable fused generators, disable memoryview pointer dtypes
parent 5575fabb
...@@ -168,16 +168,18 @@ def valid_memslice_dtype(dtype): ...@@ -168,16 +168,18 @@ def valid_memslice_dtype(dtype):
return ( return (
dtype.is_error or dtype.is_error or
dtype.is_ptr or # Pointers are not valid (yet)
# (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
dtype.is_numeric or dtype.is_numeric or
dtype.is_struct or dtype.is_struct or
dtype.is_pyobject or dtype.is_pyobject or
dtype.is_fused or # accept this as it will be replaced by specializations later
(dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type)) (dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
) )
def validate_memslice_dtype(pos, dtype): def validate_memslice_dtype(pos, dtype):
if not valid_memslice_dtype(dtype): if not valid_memslice_dtype(dtype):
error(pos, "Invalid base type for memoryview slice") error(pos, "Invalid base type for memoryview slice: %s" % dtype)
class MemoryViewSliceBufferEntry(Buffer.BufferEntry): class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
......
...@@ -843,6 +843,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -843,6 +843,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
class MemoryViewSliceTypeNode(CBaseTypeNode): class MemoryViewSliceTypeNode(CBaseTypeNode):
name = 'memoryview'
child_attrs = ['base_type_node', 'axes'] child_attrs = ['base_type_node', 'axes']
def analyse(self, env, could_be_name = False): def analyse(self, env, could_be_name = False):
...@@ -2222,10 +2223,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2222,10 +2223,7 @@ class FusedCFuncDefNode(StatListNode):
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
for arg in copied_node.args: self._specialize_function_args(copied_node.args, fused_to_specific)
if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific)
copied_node.return_type = self.node.return_type.specialize( copied_node.return_type = self.node.return_type.specialize(
fused_to_specific) fused_to_specific)
...@@ -2287,8 +2285,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2287,8 +2285,8 @@ class FusedCFuncDefNode(StatListNode):
self.create_new_local_scope(copied_node, env, fused_to_specific) self.create_new_local_scope(copied_node, env, fused_to_specific)
# Make the argument types in the CFuncDeclarator specific # Make the argument types in the CFuncDeclarator specific
for arg in copied_node.cfunc_declarator.args: self._specialize_function_args(copied_node.cfunc_declarator.args,
arg.type = arg.type.specialize(fused_to_specific) fused_to_specific)
type.specialize_entry(entry, cname) type.specialize_entry(entry, cname)
env.cfunc_entries.append(entry) env.cfunc_entries.append(entry)
...@@ -2312,6 +2310,15 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2312,6 +2310,15 @@ class FusedCFuncDefNode(StatListNode):
else: else:
self.py_func = orig_py_func self.py_func = orig_py_func
def _specialize_function_args(self, args, fused_to_specific):
import MemoryView
for arg in args:
if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice:
MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype)
def create_new_local_scope(self, node, env, f2s): def create_new_local_scope(self, node, env, f2s):
""" """
Create a new local scope for the copied node and append it to Create a new local scope for the copied node and append it to
...@@ -2332,8 +2339,12 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2332,8 +2339,12 @@ class FusedCFuncDefNode(StatListNode):
def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types): def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types):
"""Specialize the copy of a DefNode given the copied node, """Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry""" the specialization cname and the original DefNode entry"""
type_strings = [f2s[fused_type].typeof_name() type_strings = [
for fused_type in fused_types] fused_type.specialize(f2s).typeof_name()
for fused_type in fused_types
]
#type_strings = [f2s[fused_type].typeof_name()
# for fused_type in fused_types]
node.specialized_signature_string = ', '.join(type_strings) node.specialized_signature_string = ', '.join(type_strings)
......
...@@ -1446,6 +1446,13 @@ if VALUE is not None: ...@@ -1446,6 +1446,13 @@ if VALUE is not None:
else: else:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
if node.is_generator and node.has_fused_arguments:
node.has_fused_arguments = False
error(node.pos, "Fused generators not supported")
node.gbody = Nodes.StatListNode(node.pos,
stats=[],
body=Nodes.PassStatNode(node.pos))
if node.has_fused_arguments: if node.has_fused_arguments:
if self.fused_function: if self.fused_function:
if self.fused_function not in self.fused_error_funcs: if self.fused_function not in self.fused_error_funcs:
......
...@@ -415,6 +415,8 @@ class MemoryViewSliceType(PyrexType): ...@@ -415,6 +415,8 @@ class MemoryViewSliceType(PyrexType):
exception_value = None exception_value = None
exception_check = True exception_check = True
subtypes = ['dtype']
def __init__(self, base_dtype, axes): def __init__(self, base_dtype, axes):
''' '''
MemoryViewSliceType(base, axes) MemoryViewSliceType(base, axes)
...@@ -462,7 +464,8 @@ class MemoryViewSliceType(PyrexType): ...@@ -462,7 +464,8 @@ class MemoryViewSliceType(PyrexType):
self.mode = MemoryView.get_mode(axes) self.mode = MemoryView.get_mode(axes)
self.writable_needed = False self.writable_needed = False
self.dtype_name = MemoryView.mangle_dtype_name(self.dtype) if not self.dtype.is_fused:
self.dtype_name = MemoryView.mangle_dtype_name(self.dtype)
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
return ((other_type.is_memoryviewslice and return ((other_type.is_memoryviewslice and
...@@ -711,11 +714,17 @@ class MemoryViewSliceType(PyrexType): ...@@ -711,11 +714,17 @@ class MemoryViewSliceType(PyrexType):
import MemoryView import MemoryView
axes_code_list = [] axes_code_list = []
for access, packing in self.axes: for idx, (access, packing) in enumerate(self.axes):
flag = MemoryView.get_memoryview_flag(access, packing) flag = MemoryView.get_memoryview_flag(access, packing)
if flag == "strided": if flag == "strided":
axes_code_list.append(":") axes_code_list.append(":")
else: else:
if flag == 'contiguous':
have_follow = [p for a, p in self.axes[idx - 1:idx + 1]
if p == 'follow']
if have_follow or self.ndim == 1:
flag = '1'
axes_code_list.append("::" + flag) axes_code_list.append("::" + flag)
if self.dtype.is_pyobject: if self.dtype.is_pyobject:
...@@ -725,6 +734,13 @@ class MemoryViewSliceType(PyrexType): ...@@ -725,6 +734,13 @@ class MemoryViewSliceType(PyrexType):
return "%s[%s]" % (dtype_name, ", ".join(axes_code_list)) return "%s[%s]" % (dtype_name, ", ".join(axes_code_list))
def specialize(self, values):
"This does not validate the base type!!"
dtype = self.dtype.specialize(values)
if dtype is not self.dtype:
return MemoryViewSliceType(dtype, self.axes)
class BufferType(BaseType): class BufferType(BaseType):
# #
...@@ -2579,10 +2595,10 @@ def get_fused_cname(fused_cname, orig_cname): ...@@ -2579,10 +2595,10 @@ def get_fused_cname(fused_cname, orig_cname):
fused_cname, orig_cname)) fused_cname, orig_cname))
def get_all_specific_permutations(fused_types, id="", f2s=()): def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0] fused_type, = fused_types[0].get_fused_types()
result = [] result = []
for newid, specific_type in enumerate(sorted(fused_type.types)): for newid, specific_type in enumerate(fused_type.types):
# f2s = dict(f2s, **{ fused_type: specific_type }) # f2s = dict(f2s, **{ fused_type: specific_type })
f2s = dict(f2s) f2s = dict(f2s)
f2s.update({ fused_type: specific_type }) f2s.update({ fused_type: specific_type })
......
...@@ -62,6 +62,7 @@ Indeed, one may write:: ...@@ -62,6 +62,7 @@ Indeed, one may write::
cdef otherfunc(A *x): cdef otherfunc(A *x):
... ...
Selecting Specializations Selecting Specializations
========================= =========================
You can select a specialization (an instance of the function with specific or specialized (i.e., You can select a specialization (an instance of the function with specific or specialized (i.e.,
...@@ -79,6 +80,15 @@ You can index functions with types to get certain specializations, i.e.:: ...@@ -79,6 +80,15 @@ You can index functions with types to get certain specializations, i.e.::
# From Python space # From Python space
func[cython.float, cython.double](myfloat, mydouble) func[cython.float, cython.double](myfloat, mydouble)
If a fused type is used as a base type, this will mean that the base type is the fused type, so the
base type is what needs to be specialized::
cdef myfunc(A *x):
...
# Specialize using int, not int *
myfunc[int](myint)
Calling Calling
------- -------
A fused function can also be called with arguments, where the dispatch is figured out automatically:: A fused function can also be called with arguments, where the dispatch is figured out automatically::
......
...@@ -15,6 +15,7 @@ Contents: ...@@ -15,6 +15,7 @@ Contents:
external_C_code external_C_code
source_files_and_compilation source_files_and_compilation
wrapping_CPlusPlus wrapping_CPlusPlus
fusedtypes
limitations limitations
pyrex_differences pyrex_differences
early_binding_for_speed early_binding_for_speed
......
...@@ -13,9 +13,12 @@ def closure3(cython.integral i): ...@@ -13,9 +13,12 @@ def closure3(cython.integral i):
def inner(): def inner():
return lambda cython.floating f: f return lambda cython.floating f: f
def generator(cython.integral i):
yield i
_ERRORS = u""" _ERRORS = u"""
e_fused_closure.pyx:6:4: Cannot nest fused functions e_fused_closure.pyx:6:4: Cannot nest fused functions
e_fused_closure.pyx:10:11: Cannot nest fused functions e_fused_closure.pyx:10:11: Cannot nest fused functions
e_fused_closure.pyx:14:15: Cannot nest fused functions e_fused_closure.pyx:14:15: Cannot nest fused functions
e_fused_closure.pyx:16:0: Fused generators not supported
""" """
...@@ -29,6 +29,12 @@ func[float, int](x) ...@@ -29,6 +29,12 @@ func[float, int](x)
func[float, int](x, y, y) func[float, int](x, y, y)
func(x, y=y) func(x, y=y)
ctypedef fused memslice_dtype_t:
cython.p_int # invalid dtype
cython.long
def f(memslice_dtype_t[:, :] a):
pass
# This is all valid # This is all valid
dtype5 = fused_type(int, long, float) dtype5 = fused_type(int, long, float)
...@@ -54,4 +60,5 @@ fused_types.pyx:27:4: Not enough types specified to specialize the function, int ...@@ -54,4 +60,5 @@ fused_types.pyx:27:4: Not enough types specified to specialize the function, int
fused_types.pyx:28:16: Call with wrong number of arguments (expected 2, got 1) fused_types.pyx:28:16: Call with wrong number of arguments (expected 2, got 1)
fused_types.pyx:29:16: Call with wrong number of arguments (expected 2, got 3) fused_types.pyx:29:16: Call with wrong number of arguments (expected 2, got 3)
fused_types.pyx:30:4: Keyword and starred arguments not allowed in cdef functions. fused_types.pyx:30:4: Keyword and starred arguments not allowed in cdef functions.
fused_types.pyx:36:6: Invalid base type for memoryview slice: int *
""" """
...@@ -117,6 +117,28 @@ def test_fused_with_pointer(): ...@@ -117,6 +117,28 @@ def test_fused_with_pointer():
print print
print fused_with_pointer(string_array).decode('ascii') print fused_with_pointer(string_array).decode('ascii')
include "cythonarrayutil.pxi"
cpdef cython.integral test_fused_memoryviews(cython.integral[:, ::1] a):
"""
>>> import cython
>>> a = create_array((3, 5), mode="c")
>>> test_fused_memoryviews[cython.int](a)
7
"""
return a[1, 2]
ctypedef int[:, ::1] memview_int
ctypedef long[:, ::1] memview_long
memview_t = cython.fused_type(memview_int, memview_long)
def test_fused_memoryview_def(memview_t a):
"""
>>> a = create_array((3, 5), mode="c")
>>> test_fused_memoryview_def["memview_int"](a)
7
"""
return a[1, 2]
cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a): cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a):
cdef fused_type1 result cdef fused_type1 result
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment