Commit 7de1010c authored by Stefan Behnel's avatar Stefan Behnel

Avoid slow Pythran dtype checking code in fused function dispatch if pythran is not used.

parent 8581fb72
......@@ -424,7 +424,7 @@ class FusedCFuncDefNode(StatListNode):
if ndarray is not None:
if isinstance(arg, ndarray):
dtype = arg.dtype
arg_is_pythran_compatible = True
""" + ("arg_is_pythran_compatible = True" if pythran_types else "") + """
elif __pyx_memoryview_check(arg):
arg_base = arg.base
if isinstance(arg_base, ndarray):
......@@ -438,24 +438,30 @@ class FusedCFuncDefNode(StatListNode):
if dtype is not None:
itemsize = dtype.itemsize
kind = ord(dtype.kind)
dtype_signed = kind == 'i'
""")
pyx_code.indent(2)
if pythran_types:
pyx_code.put_chunk(
u"""
# We only support the endianness of the current compiler
byteorder = dtype.byteorder
if byteorder == "<" and not __Pyx_Is_Little_Endian():
arg_is_pythran_compatible = False
if byteorder == ">" and __Pyx_Is_Little_Endian():
elif byteorder == ">" and __Pyx_Is_Little_Endian():
arg_is_pythran_compatible = False
dtype_signed = kind == 'i'
if arg_is_pythran_compatible:
cur_stride = itemsize
for dim,stride in zip(reversed(arg.shape),reversed(arg.strides)):
if stride != cur_stride:
shape = arg.shape
strides = arg.strides
for i in range(arg.ndim):
if strides[i] != cur_stride:
arg_is_pythran_compatible = False
break
cur_stride *= dim
cur_stride *= shape[i]
else:
arg_is_pythran_compatible = not (arg.flags.f_contiguous and arg.ndim > 1)
""")
pyx_code.indent(2)
pyx_code.named_insertion_point("numpy_dtype_checks")
self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types)
pyx_code.dedent(2)
......@@ -464,7 +470,7 @@ class FusedCFuncDefNode(StatListNode):
self._buffer_parse_format_string_check(
pyx_code, decl_code, specialized_type, env)
def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types):
def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types):
"""
If we have any buffer specializations, write out some variable
declarations and imports.
......@@ -484,9 +490,13 @@ class FusedCFuncDefNode(StatListNode):
cdef Py_ssize_t itemsize
cdef bint dtype_signed
cdef char kind
cdef bint arg_is_pythran_compatible
itemsize = -1
""")
if pythran_types:
pyx_code.local_variable_declarations.put_chunk(u"""
cdef bint arg_is_pythran_compatible
arg_is_pythran_compatible = False
""")
......@@ -670,7 +680,7 @@ class FusedCFuncDefNode(StatListNode):
default_idx += 1
if all_buffer_types:
self._buffer_declarations(pyx_code, decl_code, all_buffer_types)
self._buffer_declarations(pyx_code, decl_code, all_buffer_types, pythran_types)
env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment