Commit f12d22b6 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffers: NumPy record array support, format string parsing improvements

parent b3004f5d
......@@ -562,14 +562,31 @@ def get_ts_check_item(dtype, writer):
return name
def get_typestringchecker(code, dtype):
"""
Returns the name of a typestring checker with the given type; emitting
it to code if needed.
"""
name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
code.globalstate.use_code_from(create_typestringchecker,
name,
dtype=dtype)
return name
def create_typestringchecker(protocode, defcode, name, dtype):
def put_assert(cond, msg):
defcode.putln("if (!(%s)) {" % cond)
msg += ", got '%s'"
defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg)
defcode.putln("return NULL;")
defcode.putln("}")
if dtype.is_error: return
simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr
simple = dtype.is_simple_buffer_dtype()
complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
# Cannot add utility code recursively...
if simple:
itemchecker = get_ts_check_item(dtype, protocode)
else:
if not simple:
dtype_t = dtype.declaration_code("")
protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
fields = dtype.scope.var_entries
......@@ -580,18 +597,58 @@ def create_typestringchecker(protocode, defcode, name, dtype):
prevtype = None
for f in fields:
if n and f.type != prevtype:
field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode)))
field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
n = 0
prevtype = f.type
n += 1
field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode)))
field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
defcode.putln("static const char* %s(const char* ts) {" % name)
if simple:
defcode.putln("int ok;")
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
defcode.putln("if (*ts == '1') ++ts;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
if dtype.typestring is not None:
assert len(dtype.typestring) == 1
# Can use direct comparison
defcode.putln("ok = (*ts == '%s');" % dtype.typestring)
else:
# Cannot trust declared size; but rely on int vs float and
# signed/unsigned to be correctly declared. Use a switch statement
# on all possible format codes to validate that the size is ok.
# (Note that many codes may map to same size, e.g. 'i' and 'l'
# may both be four bytes).
ctype = dtype.declaration_code("")
defcode.putln("switch (*ts) {")
if dtype.is_int:
types = [
('b', 'char'), ('h', 'short'), ('i', 'int'),
('l', 'long'), ('q', 'long long')
]
elif dtype.is_float:
types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
else:
assert False
if dtype.signed == 0:
for char, against in types:
defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
(char.upper(), ctype, against, ctype))
else:
for char, against in types:
defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
(char, ctype, against, ctype))
defcode.putln("default: ok = 0;")
defcode.putln("}")
defcode.putln("if (!ok) {")
if dtype.typestring is not None:
errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring
else:
errmsg = "Buffer datatype mismatch (rejecting on '%s')"
defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg)
defcode.putln("return NULL;");
defcode.putln("}")
defcode.putln("++ts;")
elif complex_possible:
# Could be a struct representing a complex number, so allow
# for parsing a "Zf" spec.
......@@ -623,15 +680,25 @@ def create_typestringchecker(protocode, defcode, name, dtype):
else:
defcode.putln("int n, count;")
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
for n, type, checker in field_blocks:
if n == 1:
defcode.putln("if (*ts == '1') ++ts;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
else:
defcode.putln("n = %d;" % n);
defcode.putln("do {")
defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
simple = type.is_simple_buffer_dtype()
if not simple:
put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True))
defcode.putln("ts += 2;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
if not simple:
put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True))
defcode.putln("++ts;")
if n > 1:
defcode.putln("} while (n > 0);");
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
......@@ -651,11 +718,7 @@ def get_getbuffer_code(dtype, code):
name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
if not code.globalstate.has_code(name):
code.globalstate.use_utility_code(acquire_utility_code)
typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
code.globalstate.use_code_from(create_typestringchecker,
typestringchecker,
dtype=dtype)
typestringchecker = get_typestringchecker(code, dtype)
dtype_name = str(dtype)
dtype_cname = dtype.declaration_code("")
utilcode = [dedent("""
......
......@@ -140,6 +140,10 @@ class PyrexType(BaseType):
# a struct whose attributes are not defined, etc.
return 1
def is_simple_buffer_dtype(self):
return (self.is_int or self.is_float or self.is_pyobject or
self.is_extension_type or self.is_ptr)
class CTypedefType(BaseType):
#
# Pseudo-type defined with a ctypedef statement in a
......
cimport python_buffer as pybuf
cimport stdlib
cdef extern from "Python.h":
ctypedef int Py_intptr_t
......@@ -26,6 +27,11 @@ cdef extern from "numpy/arrayobject.h":
NPY_C_CONTIGUOUS,
NPY_F_CONTIGUOUS
ctypedef class numpy.dtype [object PyArray_Descr]:
cdef int type_num
cdef object fields
cdef object names
ctypedef class numpy.ndarray [object PyArrayObject]:
cdef __cythonbufferdefaults__ = {"mode": "strided"}
......@@ -36,6 +42,7 @@ cdef extern from "numpy/arrayobject.h":
npy_intp *shape "dimensions"
npy_intp *strides
int flags
dtype descr
# Note: This syntax (function definition in pxd files) is an
# experimental exception made for __getbuffer__ and __releasebuffer__
......@@ -57,7 +64,6 @@ cdef extern from "numpy/arrayobject.h":
raise ValueError("ndarray is not Fortran contiguous")
info.buf = PyArray_DATA(self)
# info.obj = None # this is automatic
info.ndim = PyArray_NDIM(self)
info.strides = <Py_ssize_t*>PyArray_STRIDES(self)
info.shape = <Py_ssize_t*>PyArray_DIMS(self)
......@@ -65,31 +71,104 @@ cdef extern from "numpy/arrayobject.h":
info.itemsize = PyArray_ITEMSIZE(self)
info.readonly = not PyArray_ISWRITEABLE(self)
# Formats that are not tested and working in Cython are not
# made available from this pxd file yet.
cdef int t = PyArray_TYPE(self)
cdef char* f = NULL
if t == NPY_BYTE: f = "b"
elif t == NPY_UBYTE: f = "B"
elif t == NPY_SHORT: f = "h"
elif t == NPY_USHORT: f = "H"
elif t == NPY_INT: f = "i"
elif t == NPY_UINT: f = "I"
elif t == NPY_LONG: f = "l"
elif t == NPY_ULONG: f = "L"
elif t == NPY_LONGLONG: f = "q"
elif t == NPY_ULONGLONG: f = "Q"
elif t == NPY_FLOAT: f = "f"
elif t == NPY_DOUBLE: f = "d"
elif t == NPY_LONGDOUBLE: f = "g"
elif t == NPY_CFLOAT: f = "Zf"
elif t == NPY_CDOUBLE: f = "Zd"
elif t == NPY_CLONGDOUBLE: f = "Zg"
elif t == NPY_OBJECT: f = "O"
if f == NULL:
raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t)
info.format = f
cdef int t
cdef char* f = NULL
cdef dtype descr = self.descr
cdef list stack
cdef bint hasfields = PyDataType_HASFIELDS(descr)
# Ugly hack warning:
# Cython currently will not support helper functions in
# pxd files -- so we must keep our own, manual stack!
# In addition, avoid allocation of the stack in the common
# case that we are dealing with a single non-nested datatype...
# (this would look much prettier if we could use utility
# functions).
if not hasfields:
info.obj = None # do not call releasebuffer
t = descr.type_num
if t == NPY_BYTE: f = "b"
elif t == NPY_UBYTE: f = "B"
elif t == NPY_SHORT: f = "h"
elif t == NPY_USHORT: f = "H"
elif t == NPY_INT: f = "i"
elif t == NPY_UINT: f = "I"
elif t == NPY_LONG: f = "l"
elif t == NPY_ULONG: f = "L"
elif t == NPY_LONGLONG: f = "q"
elif t == NPY_ULONGLONG: f = "Q"
elif t == NPY_FLOAT: f = "f"
elif t == NPY_DOUBLE: f = "d"
elif t == NPY_LONGDOUBLE: f = "g"
elif t == NPY_CFLOAT: f = "Zf"
elif t == NPY_CDOUBLE: f = "Zd"
elif t == NPY_CLONGDOUBLE: f = "Zg"
elif t == NPY_OBJECT: f = "O"
else:
raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
info.format = f
return
else:
info.obj = self # need to call releasebuffer
info.format = <char*>stdlib.malloc(255) # static size
f = info.format
stack = [iter(descr.fields.iteritems())]
while True:
iterator = stack[-1]
descr = None
while descr is None:
try:
descr = iterator.next()[1][0]
except StopIteration:
stack.pop()
if len(stack) > 0:
f[0] = "}"
f += 1
iterator = stack[-1]
else:
f[0] = 0 # Terminate string!
return
hasfields = PyDataType_HASFIELDS(descr)
if not hasfields:
t = descr.type_num
if f - info.format > 240: # this should leave room for "T{" and "}" as well
raise RuntimeError("Format string allocated too short.")
if t == NPY_BYTE: f[0] = "b"
elif t == NPY_UBYTE: f[0] = "B"
elif t == NPY_SHORT: f[0] = "h"
elif t == NPY_USHORT: f[0] = "H"
elif t == NPY_INT: f[0] = "i"
elif t == NPY_UINT: f[0] = "I"
elif t == NPY_LONG: f[0] = "l"
elif t == NPY_ULONG: f[0] = "L"
elif t == NPY_LONGLONG: f[0] = "q"
elif t == NPY_ULONGLONG: f[0] = "Q"
elif t == NPY_FLOAT: f[0] = "f"
elif t == NPY_DOUBLE: f[0] = "d"
elif t == NPY_LONGDOUBLE: f[0] = "g"
elif t == NPY_CFLOAT: f[0] = "Z"; f[1] = "f"; f += 1
elif t == NPY_CDOUBLE: f[0] = "Z"; f[1] = "d"; f += 1
elif t == NPY_CLONGDOUBLE: f[0] = "Z"; f[1] = "g"; f += 1
elif t == NPY_OBJECT: f[0] = "O"
else:
raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
f += 1
else:
f[0] = "T"
f[1] = "{"
f += 2
stack.append(iter(descr.fields.iteritems()))
def __releasebuffer__(ndarray self, Py_buffer* info):
# This can not be called unless format needs to be freed (as
# obj is set to NULL in those case)
stdlib.free(info.format)
cdef void* PyArray_DATA(ndarray arr)
......@@ -100,6 +179,9 @@ cdef extern from "numpy/arrayobject.h":
cdef npy_intp PyArray_DIMS(ndarray arr)
cdef Py_ssize_t PyArray_ITEMSIZE(ndarray arr)
cdef int PyArray_CHKFLAGS(ndarray arr, int flags)
cdef int PyArray_HASFIELDS(ndarray arr, int flags)
cdef int PyDataType_HASFIELDS(dtype obj)
ctypedef signed int npy_byte
ctypedef signed int npy_short
......
......@@ -1292,6 +1292,15 @@ cdef struct MyStruct:
int d
int e
cdef struct SmallStruct:
int a
int b
cdef struct NestedStruct:
SmallStruct x
SmallStruct y
int z
cdef class MyStructMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
cdef MyStruct* s
......@@ -1302,6 +1311,16 @@ cdef class MyStructMockBuffer(MockBuffer):
cdef get_itemsize(self): return sizeof(MyStruct)
cdef get_default_format(self): return b"2bq2i"
cdef class NestedStructMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
cdef NestedStruct* s
s = <NestedStruct*>buf;
s.x.a, s.x.b, s.y.a, s.y.b, s.z = value
return 0
cdef get_itemsize(self): return sizeof(NestedStruct)
cdef get_default_format(self): return b"2T{ii}i"
@testcase
def basic_struct(object[MyStruct] buf):
"""
......@@ -1316,6 +1335,21 @@ def basic_struct(object[MyStruct] buf):
"""
print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
@testcase
def nested_struct(object[NestedStruct] buf):
"""
>>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
1 2 3 4 5
>>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i"))
1 2 3 4 5
>>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii"))
Traceback (most recent call last):
...
ValueError: Expected start of SmallStruct, got 'iiiii'
"""
print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z
cdef struct LongComplex:
long double real
long double imag
......
......@@ -129,12 +129,22 @@ try:
>>> test_dtype(np.int32, inc1_int32_t)
>>> test_dtype(np.float64, inc1_float64_t)
Unsupported types:
>>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
>>> inc1_byte(a)
>>> test_recordarray()
>>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
('a', np.dtype('i,i')),\
('b', np.dtype('i,i'))\
])))
array([((0, 0), (0, 0)), ((1, 2), (1, 4)), ((1, 2), (1, 4))],
dtype=[('a', [('f0', '<i4'), ('f1', '<i4')]), ('b', [('f0', '<i4'), ('f1', '<i4')])])
>>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
('a', np.dtype('i,f')),\
('b', np.dtype('i,i'))\
])))
Traceback (most recent call last):
...
ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20)
...
ValueError: Buffer datatype mismatch (expected 'i', got 'f}T{ii}')
>>> test_good_cast()
True
......@@ -261,6 +271,49 @@ def test_dtype(dtype, inc1):
inc1(a)
if a[1] != 11: print "failed!"
cdef struct DoubleInt:
int x, y
def test_recordarray():
cdef object[DoubleInt] arr
arr = np.array([(5,5), (4, 6)], dtype=np.dtype('i,i'))
cdef DoubleInt rec
rec = arr[0]
if rec.x != 5: print "failed"
if rec.y != 5: print "failed"
rec.y += 5
arr[1] = rec
arr[0].x -= 2
arr[0].y += 3
if arr[0].x != 3: print "failed"
if arr[0].y != 8: print "failed"
if arr[1].x != 5: print "failed"
if arr[1].y != 10: print "failed"
cdef struct NestedStruct:
DoubleInt a
DoubleInt b
cdef struct BadDoubleInt:
float x
int y
cdef struct BadNestedStruct:
DoubleInt a
BadDoubleInt b
def test_nested_dtypes(obj):
cdef object[NestedStruct] arr = obj
arr[1].a.x = 1
arr[1].a.y = 2
arr[1].b.x = arr[0].a.y + 1
arr[1].b.y = 4
arr[2] = arr[1]
return arr
def test_bad_nested_dtypes():
cdef object[BadNestedStruct] arr
def test_good_cast():
# Check that a signed int can round-trip through casted unsigned int access
cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment