Commit 7bd68089 authored by Gerald Dalley's avatar Gerald Dalley

Added PyDataType_SHAPE

parent 972aeb02
...@@ -176,7 +176,8 @@ cdef extern from "numpy/arrayobject.h": ...@@ -176,7 +176,8 @@ cdef extern from "numpy/arrayobject.h":
cdef object fields cdef object fields
cdef tuple names cdef tuple names
# Use PyDataType_HASSUBARRAY to test whether this field is # Use PyDataType_HASSUBARRAY to test whether this field is
# valid (the pointer can be NULL). # valid (the pointer can be NULL). Most users should access
# this field via the inline helper method PyDataType_SHAPE.
cdef PyArray_ArrayDescr* subarray cdef PyArray_ArrayDescr* subarray
ctypedef extern class numpy.flatiter [object PyArrayIterObject]: ctypedef extern class numpy.flatiter [object PyArrayIterObject]:
...@@ -798,6 +799,12 @@ cdef inline object PyArray_MultiIterNew4(a, b, c, d): ...@@ -798,6 +799,12 @@ cdef inline object PyArray_MultiIterNew4(a, b, c, d):
cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): cdef inline object PyArray_MultiIterNew5(a, b, c, d, e):
return PyArray_MultiIterNew(5, <void*>a, <void*>b, <void*>c, <void*> d, <void*> e) return PyArray_MultiIterNew(5, <void*>a, <void*>b, <void*>c, <void*> d, <void*> e)
cdef inline tuple PyDataType_SHAPE(dtype d):
if PyDataType_HASSUBARRAY(d):
return <tuple>d.subarray.shape
else:
return None
cdef inline char* _util_dtypestring(dtype descr, char* f, char* end, int* offset) except NULL: cdef inline char* _util_dtypestring(dtype descr, char* f, char* end, int* offset) except NULL:
# Recursive utility function used in __getbuffer__ to get format # Recursive utility function used in __getbuffer__ to get format
# string. The new location in the format string is returned. # string. The new location in the format string is returned.
......
...@@ -19,10 +19,19 @@ def test_record_subarray(): ...@@ -19,10 +19,19 @@ def test_record_subarray():
cdef np.dtype a_descr = descr.fields['a'][0] cdef np.dtype a_descr = descr.fields['a'][0]
cdef np.dtype b_descr = descr.fields['b'][0] cdef np.dtype b_descr = descr.fields['b'][0]
# Make sure the dtype looks like we expect
assert descr.fields == {'a': (py_numpy.dtype('int32'), 0), assert descr.fields == {'a': (py_numpy.dtype('int32'), 0),
'b': (py_numpy.dtype(('<f8', (3, 3))), 4)} 'b': (py_numpy.dtype(('<f8', (3, 3))), 4)}
# Make sure that HASSUBARRAY is working
assert not np.PyDataType_HASSUBARRAY(descr) assert not np.PyDataType_HASSUBARRAY(descr)
assert not np.PyDataType_HASSUBARRAY(a_descr) assert not np.PyDataType_HASSUBARRAY(a_descr)
assert np.PyDataType_HASSUBARRAY(b_descr) assert np.PyDataType_HASSUBARRAY(b_descr)
# Make sure the direct field access works
assert <tuple>b_descr.subarray.shape == (3, 3) assert <tuple>b_descr.subarray.shape == (3, 3)
# Make sure the safe high-level helper function works
assert np.PyDataType_SHAPE(descr) is None
assert np.PyDataType_SHAPE(a_descr) is None
assert np.PyDataType_SHAPE(b_descr) == (3, 3)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment