Commit b4a3a0bd authored by Kirill Smelkov's avatar Kirill Smelkov

.

parent c4b66f7d
...@@ -18,19 +18,22 @@ ...@@ -18,19 +18,22 @@
# See COPYING file for full licensing terms. # See COPYING file for full licensing terms.
# See https://www.nexedi.com/licensing for rationale and options. # See https://www.nexedi.com/licensing for rationale and options.
from numpy import arange, dtype, int32 from numpy import ndarray, arange, dtype, int32
from numpy.lib.stride_tricks import DummyArray
from wendelin.lib.xnumpy import restructure from wendelin.lib.xnumpy import restructure
from pytest import raises from pytest import raises
# XXX text # xbase returns original object from which arr was _as_strided viewed.
def test_restructure(): def xbase(arr):
# xbase returns original object from which arr was restructured. b = arr.base # arr -> typed view | DummyArray
def xbase(arr): if type(b) is not DummyArray:
b = arr.base # arr -> typed view b = b.base # it was typed view -> DummyArray
b = b.base # -> DummyArray assert type(b) is DummyArray
b = b.base # -> origin b = b.base # -> origin
return b return b
# XXX text
def test_restructure():
dtxy = dtype([('x', int32), ('y', int32)]) dtxy = dtype([('x', int32), ('y', int32)])
# C order # C order
...@@ -119,3 +122,22 @@ def test_restructure(): ...@@ -119,3 +122,22 @@ def test_restructure():
assert bxy[2]['y'] == 200 assert bxy[2]['y'] == 200
assert b[1,2] == 200 assert b[1,2] == 200
assert a[1,2] == 200 assert a[1,2] == 200
# custom class
class MyArray(ndarray):
pass
a = arange(4*3, dtype=int32).reshape((4,3))
# 0 1 2
# 3 4 5
# 6 7 8
# 9 10 11
a = a.view(type=MyArray)
b = a[:3,:2]
bxy = restructure(b, dtxy)
assert xbase(bxy) is b
assert bxy.dtype == dtxy
assert bxy.shape == (3,)
assert type(bxy) is MyArray
...@@ -29,10 +29,10 @@ from numpy.lib import stride_tricks as npst ...@@ -29,10 +29,10 @@ from numpy.lib import stride_tricks as npst
# It must be used with extreme care, because if there is math error in the # It must be used with extreme care, because if there is math error in the
# arguments, the resulting array can cover wrong memory. Bugs here thus can # arguments, the resulting array can cover wrong memory. Bugs here thus can
# lead to mysterious crashes. # lead to mysterious crashes.
def _as_strided(a, shape, stridev, dtype): def _as_strided(arr, shape, stridev, dtype):
# the code below is very close to # the code below is very close to
# #
# a = stride_tricks.as_strided(a, shape=shape, strides=stridev) # a = stride_tricks.as_strided(arr, shape=shape, strides=stridev)
# #
# but we don't use as_strided() because we also have to change dtype # but we don't use as_strided() because we also have to change dtype
# with shape and strides in one go - else changing dtype after either # with shape and strides in one go - else changing dtype after either
...@@ -41,19 +41,21 @@ def _as_strided(a, shape, stridev, dtype): ...@@ -41,19 +41,21 @@ def _as_strided(a, shape, stridev, dtype):
# "When changing to a larger dtype, its size must be a # "When changing to a larger dtype, its size must be a
# divisor of the total size in bytes of the last axis # divisor of the total size in bytes of the last axis
# of the array." # of the array."
aiface = dict(a.__array_interface__) aiface = dict(arr.__array_interface__)
aiface['shape'] = shape aiface['shape'] = shape
aiface['strides'] = stridev aiface['strides'] = stridev
# type: for now we only care that itemsize is the same # type: for now we only care that itemsize is the same
aiface['typestr'] = '|V%d' % dtype.itemsize aiface['typestr'] = '|V%d' % dtype.itemsize
aiface['descr'] = [('', aiface['typestr'])] aiface['descr'] = [('', aiface['typestr'])]
a = np.asarray(npst.DummyArray(aiface, base=a)) a = np.asarray(npst.DummyArray(aiface, base=arr))
# restore full dtype - it should not raise here, since itemsize is the same # restore full dtype - it should not raise here, since itemsize is the same
a.dtype = dtype a.dtype = dtype
# XXX restore full array type? # restore full array type (mimics subok=True)
if type(a) is not type(arr):
a = a.view(type=type(arr))
# we are done # we are done
return a return a
...@@ -122,10 +124,4 @@ def restructure(arr, dtype): ...@@ -122,10 +124,4 @@ def restructure(arr, dtype):
# NOTE we cannot use just np.ndarray because if arr is a slice it can give: # NOTE we cannot use just np.ndarray because if arr is a slice it can give:
# TypeError: expected a single-segment buffer object # TypeError: expected a single-segment buffer object
#return np.ndarray.__new__(type(arr), shape, dtype, buffer(arr), 0, stridev) #return np.ndarray.__new__(type(arr), shape, dtype, buffer(arr), 0, stridev)
a = _as_strided(arr, shape, stridev, dtype) return _as_strided(arr, shape, stridev, dtype)
# restore full array type
a = a.view(type=type(arr))
# we are done
return a
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment