Commit 75c761cf authored by Mark Florisson's avatar Mark Florisson

Support newaxis indexing for memoryview slices

    todo: support memoryview object newaxis indexing
parent 0edf8d2a
...@@ -2523,26 +2523,32 @@ class IndexNode(ExprNode): ...@@ -2523,26 +2523,32 @@ class IndexNode(ExprNode):
import MemoryView import MemoryView
skip_child_analysis = True skip_child_analysis = True
newaxes = [newaxis for newaxis in indices if newaxis.is_none]
have_slices, indices = MemoryView.unellipsify(indices, have_slices, indices = MemoryView.unellipsify(indices,
newaxes,
self.base.type.ndim) self.base.type.ndim)
self.memslice_index = len(indices) == self.base.type.ndim
self.memslice_index = (not newaxes and
len(indices) == self.base.type.ndim)
axes = [] axes = []
index_type = PyrexTypes.c_py_ssize_t_type index_type = PyrexTypes.c_py_ssize_t_type
new_indices = [] new_indices = []
if len(indices) > self.base.type.ndim: if len(indices) - len(newaxes) > self.base.type.ndim:
self.type = error_type self.type = error_type
return error(indices[self.base.type.ndim].pos, return error(indices[self.base.type.ndim].pos,
"Too many indices specified for type %s" % "Too many indices specified for type %s" %
self.base.type) self.base.type)
suboffsets_dim = -1 axis_idx = 0
for i, index in enumerate(indices[:]): for i, index in enumerate(indices[:]):
index.analyse_types(env) index.analyse_types(env)
access, packing = self.base.type.axes[i] if not index.is_none:
access, packing = self.base.type.axes[axis_idx]
axis_idx += 1
if isinstance(index, SliceNode): if isinstance(index, SliceNode):
suboffsets_dim = i
self.memslice_slice = True self.memslice_slice = True
if index.step.is_none: if index.step.is_none:
axes.append((access, packing)) axes.append((access, packing))
...@@ -2558,6 +2564,11 @@ class IndexNode(ExprNode): ...@@ -2558,6 +2564,11 @@ class IndexNode(ExprNode):
setattr(index, attr, value) setattr(index, attr, value)
new_indices.append(value) new_indices.append(value)
elif index.is_none:
self.memslice_slice = True
new_indices.append(index)
axes.append(('direct', 'strided'))
elif index.type.is_int or index.type.is_pyobject: elif index.type.is_int or index.type.is_pyobject:
if index.type.is_pyobject and not self.warned_untyped_idx: if index.type.is_pyobject and not self.warned_untyped_idx:
warning(index.pos, "Index should be typed for more " warning(index.pos, "Index should be typed for more "
......
...@@ -277,8 +277,8 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): ...@@ -277,8 +277,8 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
""" """
Slice a memoryviewslice. Slice a memoryviewslice.
indices - list of index nodes. If not a SliceNode, then it must be indices - list of index nodes. If not a SliceNode, or NoneNode,
coercible to Py_ssize_t then it must be coercible to Py_ssize_t
Simply call __pyx_memoryview_slice_memviewslice with the right Simply call __pyx_memoryview_slice_memviewslice with the right
arguments. arguments.
...@@ -307,28 +307,14 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): ...@@ -307,28 +307,14 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
code.putln("%(dst)s.memview = %(src)s.memview;" % locals()) code.putln("%(dst)s.memview = %(src)s.memview;" % locals())
code.put_incref_memoryviewslice(dst) code.put_incref_memoryviewslice(dst)
for dim, index in enumerate(indices): dim = -1
for index in indices:
error_goto = code.error_goto(index.pos) error_goto = code.error_goto(index.pos)
if not index.is_none:
if not isinstance(index, ExprNodes.SliceNode): dim += 1
# normal index
idx = index.result()
access, packing = self.type.axes[dim] access, packing = self.type.axes[dim]
if access == 'direct':
indirect = False
else:
indirect = True
generic = (access == 'full')
if new_ndim != 0:
return error(index.pos,
"All preceding dimensions must be "
"indexed and not sliced")
d = locals()
code.put(load_slice_util("SliceIndex", d))
else:
if isinstance(index, ExprNodes.SliceNode):
# slice, unspecified dimension, or part of ellipsis # slice, unspecified dimension, or part of ellipsis
d = locals() d = locals()
for s in "start stop step".split(): for s in "start stop step".split():
...@@ -344,7 +330,6 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): ...@@ -344,7 +330,6 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
not d['have_step']): not d['have_step']):
# full slice (:), simply copy over the extent, stride # full slice (:), simply copy over the extent, stride
# and suboffset. Also update suboffset_dim if needed # and suboffset. Also update suboffset_dim if needed
access, packing = self.type.axes[dim]
d['access'] = access d['access'] = access
code.put(load_slice_util("SimpleSlice", d)) code.put(load_slice_util("SimpleSlice", d))
else: else:
...@@ -352,6 +337,31 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): ...@@ -352,6 +337,31 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
new_ndim += 1 new_ndim += 1
elif index.is_none:
# newaxis
attribs = [('shape', 1), ('strides', 0), ('suboffsets', -1)]
for attrib, value in attribs:
code.putln("%s.%s[%d] = %d;" % (dst, attrib, new_ndim, value))
new_ndim += 1
else:
# normal index
idx = index.result()
if access == 'direct':
indirect = False
else:
indirect = True
generic = (access == 'full')
if new_ndim != 0:
return error(index.pos,
"All preceding dimensions must be "
"indexed and not sliced")
d = locals()
code.put(load_slice_util("SliceIndex", d))
if not no_suboffset_dim: if not no_suboffset_dim:
code.funcstate.release_temp(suboffset_dim) code.funcstate.release_temp(suboffset_dim)
...@@ -361,11 +371,13 @@ def empty_slice(pos): ...@@ -361,11 +371,13 @@ def empty_slice(pos):
return ExprNodes.SliceNode(pos, start=none, return ExprNodes.SliceNode(pos, start=none,
stop=none, step=none) stop=none, step=none)
def unellipsify(indices, ndim): def unellipsify(indices, newaxes, ndim):
result = [] result = []
seen_ellipsis = False seen_ellipsis = False
have_slices = False have_slices = False
n_indices = len(indices) - len(newaxes)
for index in indices: for index in indices:
if isinstance(index, ExprNodes.EllipsisNode): if isinstance(index, ExprNodes.EllipsisNode):
have_slices = True have_slices = True
...@@ -374,16 +386,19 @@ def unellipsify(indices, ndim): ...@@ -374,16 +386,19 @@ def unellipsify(indices, ndim):
if seen_ellipsis: if seen_ellipsis:
result.append(full_slice) result.append(full_slice)
else: else:
nslices = ndim - len(indices) + 1 nslices = ndim - n_indices + 1
result.extend([full_slice] * nslices) result.extend([full_slice] * nslices)
seen_ellipsis = True seen_ellipsis = True
else: else:
have_slices = have_slices or isinstance(index, ExprNodes.SliceNode) have_slices = (have_slices or
isinstance(index, ExprNodes.SliceNode) or
index.is_none)
result.append(index) result.append(index)
if len(result) < ndim: result_length = len(result) - len(newaxes)
if result_length < ndim:
have_slices = True have_slices = True
nslices = ndim - len(result) nslices = ndim - result_length
result.extend([empty_slice(indices[-1].pos)] * nslices) result.extend([empty_slice(indices[-1].pos)] * nslices)
return have_slices, result return have_slices, result
......
...@@ -713,6 +713,11 @@ cdef memoryview memview_slice(memoryview memview, object indices): ...@@ -713,6 +713,11 @@ cdef memoryview memview_slice(memoryview memview, object indices):
index, 0, 0, # start, stop, step index, 0, 0, # start, stop, step
0, 0, 0, # have_{start,stop,step} 0, 0, 0, # have_{start,stop,step}
False) False)
elif index is None:
p_dst.shape[new_ndim] = 1
p_dst.strides[new_ndim] = 0
p_dst.suboffsets[new_ndim] = -1
new_ndim += 1
else: else:
start = index.start or 0 start = index.start or 0
stop = index.stop or 0 stop = index.stop or 0
......
...@@ -49,6 +49,14 @@ def testcase(func): ...@@ -49,6 +49,14 @@ def testcase(func):
include "mockbuffers.pxi" include "mockbuffers.pxi"
include "cythonarrayutil.pxi" include "cythonarrayutil.pxi"
def _print_attributes(memview):
print "shape: " + " ".join(map(str, memview.shape))
print "strides: " + " ".join([str(stride // memview.itemsize)
for stride in memview.strides])
print "suboffsets: " + " ".join(
[str(suboffset if suboffset < 0 else suboffset // memview.itemsize)
for suboffset in memview.suboffsets])
# #
# Buffer acquire and release tests # Buffer acquire and release tests
# #
...@@ -2217,3 +2225,63 @@ def test_inplace_assignment(): ...@@ -2217,3 +2225,63 @@ def test_inplace_assignment():
m[0] = get_int() m[0] = get_int()
print m[0] print m[0]
@testcase
def test_newaxis(int[:] one_D):
"""
>>> A = IntMockBuffer("A", range(6))
>>> test_newaxis(A)
acquired A
3
3
3
3
released A
"""
cdef int[:, :] two_D_1 = one_D[None]
cdef int[:, :] two_D_2 = one_D[None, :]
cdef int[:, :] two_D_3 = one_D[:, None]
cdef int[:, :] two_D_4 = one_D[..., None]
print two_D_1[0, 3]
print two_D_2[0, 3]
print two_D_3[3, 0]
print two_D_4[3, 0]
@testcase
def test_newaxis2(int[:, :] two_D):
"""
>>> A = IntMockBuffer("A", range(6), shape=(3, 2))
>>> test_newaxis2(A)
acquired A
shape: 3 1 1
strides: 2 0 0
suboffsets: -1 -1 -1
<BLANKLINE>
shape: 1 2 1
strides: 0 1 0
suboffsets: -1 -1 -1
<BLANKLINE>
shape: 3 1 1 1
strides: 2 0 1 0
suboffsets: -1 -1 -1 -1
<BLANKLINE>
shape: 1 2 2 1
strides: 0 2 1 0
suboffsets: -1 -1 -1 -1
released A
"""
cdef int[:, :, :] a = two_D[..., None, 1, None]
cdef int[:, :, :] b = two_D[None, 1, ..., None]
cdef int[:, :, :, :] c = two_D[..., None, 1:, None]
cdef int[:, :, :, :] d = two_D[None, 1:, ..., None]
_print_attributes(a)
print
_print_attributes(b)
print
_print_attributes(c)
print
_print_attributes(d)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment