Commit b8215f4e authored by scoder's avatar scoder Committed by GitHub

Merge pull request #1994 from scoder/optimised_memslice_indexing

Rewrite memoryview[i][j] and similar obvious cases into the faster memoryview[i, j]
parents f68639da 7601ba9a
...@@ -36,6 +36,9 @@ Features added ...@@ -36,6 +36,9 @@ Features added
* Some PEP-484/526 container type declarations are now considered for * Some PEP-484/526 container type declarations are now considered for
loop optimisations. loop optimisations.
* Indexing into memoryview slices with ``view[i][j]`` is now optimised into
``view[i, j]``.
* Python compatible ``cython.*`` types can now be mixed with type declarations * Python compatible ``cython.*`` types can now be mixed with type declarations
in Cython syntax. in Cython syntax.
......
...@@ -3690,23 +3690,33 @@ class IndexNode(_IndexingBaseNode): ...@@ -3690,23 +3690,33 @@ class IndexNode(_IndexingBaseNode):
else: else:
indices = [self.index] indices = [self.index]
base_type = self.base.type base = self.base
base_type = base.type
replacement_node = None replacement_node = None
if base_type.is_memoryviewslice: if base_type.is_memoryviewslice:
# memoryviewslice indexing or slicing # memoryviewslice indexing or slicing
from . import MemoryView from . import MemoryView
if base.is_memview_slice:
# For memory views, "view[i][j]" is the same as "view[i, j]" => use the latter for speed.
merged_indices = base.merged_indices(indices)
if merged_indices is not None:
base = base.base
base_type = base.type
indices = merged_indices
have_slices, indices, newaxes = MemoryView.unellipsify(indices, base_type.ndim) have_slices, indices, newaxes = MemoryView.unellipsify(indices, base_type.ndim)
if have_slices: if have_slices:
replacement_node = MemoryViewSliceNode(self.pos, indices=indices, base=self.base) replacement_node = MemoryViewSliceNode(self.pos, indices=indices, base=base)
else: else:
replacement_node = MemoryViewIndexNode(self.pos, indices=indices, base=self.base) replacement_node = MemoryViewIndexNode(self.pos, indices=indices, base=base)
elif base_type.is_buffer or base_type.is_pythran_expr: elif base_type.is_buffer or base_type.is_pythran_expr:
if base_type.is_pythran_expr or len(indices) == base_type.ndim: if base_type.is_pythran_expr or len(indices) == base_type.ndim:
# Buffer indexing # Buffer indexing
is_buffer_access = True is_buffer_access = True
indices = [index.analyse_types(env) for index in indices] indices = [index.analyse_types(env) for index in indices]
if base_type.is_pythran_expr: if base_type.is_pythran_expr:
do_replacement = all(index.type.is_int or index.is_slice or index.type.is_pythran_expr for index in indices) do_replacement = all(
index.type.is_int or index.is_slice or index.type.is_pythran_expr
for index in indices)
if do_replacement: if do_replacement:
for i,index in enumerate(indices): for i,index in enumerate(indices):
if index.is_slice: if index.is_slice:
...@@ -3716,7 +3726,7 @@ class IndexNode(_IndexingBaseNode): ...@@ -3716,7 +3726,7 @@ class IndexNode(_IndexingBaseNode):
else: else:
do_replacement = all(index.type.is_int for index in indices) do_replacement = all(index.type.is_int for index in indices)
if do_replacement: if do_replacement:
replacement_node = BufferIndexNode(self.pos, indices=indices, base=self.base) replacement_node = BufferIndexNode(self.pos, indices=indices, base=base)
# On cloning, indices is cloned. Otherwise, unpack index into indices. # On cloning, indices is cloned. Otherwise, unpack index into indices.
assert not isinstance(self.index, CloneNode) assert not isinstance(self.index, CloneNode)
...@@ -4425,6 +4435,37 @@ class MemoryViewSliceNode(MemoryViewIndexNode): ...@@ -4425,6 +4435,37 @@ class MemoryViewSliceNode(MemoryViewIndexNode):
else: else:
return MemoryCopySlice(self.pos, self) return MemoryCopySlice(self.pos, self)
def merged_indices(self, indices):
"""Return a new list of indices/slices with 'indices' merged into the current ones
according to slicing rules.
Is used to implement "view[i][j]" => "view[i, j]".
Return None if the indices cannot (easily) be merged at compile time.
"""
if not indices:
return None
# NOTE: Need to evaluate "self.original_indices" here as they might differ from "self.indices".
new_indices = self.original_indices[:]
indices = indices[:]
for i, s in enumerate(self.original_indices):
if s.is_slice:
if s.start.is_none and s.stop.is_none and s.step.is_none:
# Full slice found, replace by index.
new_indices[i] = indices[0]
indices.pop(0)
if not indices:
return new_indices
else:
# Found something non-trivial, e.g. a partial slice.
return None
elif not s.type.is_int:
# Not a slice, not an integer index => could be anything...
return None
if indices:
if len(new_indices) + len(indices) > self.base.type.ndim:
return None
new_indices += indices
return new_indices
def is_simple(self): def is_simple(self):
if self.is_ellipsis_noop: if self.is_ellipsis_noop:
# TODO: fix SimpleCallNode.is_simple() # TODO: fix SimpleCallNode.is_simple()
......
...@@ -1039,3 +1039,47 @@ def min_max_tree_restructuring(): ...@@ -1039,3 +1039,47 @@ def min_max_tree_restructuring():
cdef char[:] aview = a cdef char[:] aview = a
return max(<char>1, aview[0]), min(<char>5, aview[2]) return max(<char>1, aview[0]), min(<char>5, aview[2])
@cython.test_fail_if_path_exists(
'//MemoryViewSliceNode',
)
@cython.test_assert_path_exists(
'//MemoryViewIndexNode',
)
#@cython.boundscheck(False) # reduce C code clutter
def optimised_index_of_slice(int[:,:,:] arr, int x, int y, int z):
"""
>>> arr = IntMockBuffer("A", list(range(10*10*10)), shape=(10,10,10))
>>> optimised_index_of_slice(arr, 2, 3, 4)
acquired A
(123, 123)
(223, 223)
(133, 133)
(124, 124)
(234, 234)
(123, 123)
(123, 123)
(123, 123)
(134, 134)
(134, 134)
(234, 234)
(234, 234)
(234, 234)
released A
"""
print(arr[1, 2, 3], arr[1][2][3])
print(arr[x, 2, 3], arr[x][2][3])
print(arr[1, y, 3], arr[1][y][3])
print(arr[1, 2, z], arr[1][2][z])
print(arr[x, y, z], arr[x][y][z])
print(arr[1, 2, 3], arr[:, 2][1][3])
print(arr[1, 2, 3], arr[:, 2, :][1, 3])
print(arr[1, 2, 3], arr[:, 2, 3][1])
print(arr[1, y, z], arr[1, :][y][z])
print(arr[1, y, z], arr[1, :][y, z])
print(arr[x, y, z], arr[x][:][:][y][:][:][z])
print(arr[x, y, z], arr[:][x][:][y][:][:][z])
print(arr[x, y, z], arr[:, :][x][:, :][y][:][z])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment