Commit 07d40c12 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Non-buffer code working again, typedefs working with buffers

parent 1ab79216
...@@ -69,7 +69,11 @@ class BufferTransform(CythonTransform): ...@@ -69,7 +69,11 @@ class BufferTransform(CythonTransform):
def __call__(self, node): def __call__(self, node):
assert isinstance(node, ModuleNode) assert isinstance(node, ModuleNode)
cymod = self.context.modules[u'__cython__'] try:
cymod = self.context.modules[u'__cython__']
except KeyError:
# No buffer fun for this module
return node
self.bufstruct_type = cymod.entries[u'Py_buffer'].type self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {} self.tscheckers = {}
self.ts_funcs = [] self.ts_funcs = []
...@@ -194,9 +198,6 @@ class BufferTransform(CythonTransform): ...@@ -194,9 +198,6 @@ class BufferTransform(CythonTransform):
return result return result
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def buffer_index(self, node): def buffer_index(self, node):
pos = node.pos pos = node.pos
bufaux = node.base.entry.buffer_aux bufaux = node.base.entry.buffer_aux
...@@ -262,8 +263,6 @@ class BufferTransform(CythonTransform): ...@@ -262,8 +263,6 @@ class BufferTransform(CythonTransform):
else: else:
return node return node
# #
# Utils for creating type string checkers # Utils for creating type string checkers
# #
...@@ -285,12 +284,42 @@ class BufferTransform(CythonTransform): ...@@ -285,12 +284,42 @@ class BufferTransform(CythonTransform):
funcnode = self.ts_item_checkers.get(dtype) funcnode = self.ts_item_checkers.get(dtype)
if funcnode is None: if funcnode is None:
char = dtype.typestring char = dtype.typestring
funcnode = self.new_ts_func("item_%s" % self.mangle_dtype_name(dtype), """\ if char is not None and len(char) > 1:
if (*ts != '%s') { # Can use direct comparison
funcnode = self.new_ts_func("natitem_%s" % self.mangle_dtype_name(dtype), """\
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts); PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL; return NULL;
} else return ts + 1; } else return ts + 1;
""" % char) """ % char)
else:
# Must deduce sign and length; rely on int vs. float to be correctly declared
ctype = dtype.declaration_code("")
code = """\
int ok;
switch (*ts) {"""
if dtype.is_int:
types = [
('b', 'char'), ('h', 'short'), ('i', 'int'),
('l', 'long'), ('q', 'long long')
]
code += "".join(["""\
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;
""" % (char, ctype, against, ctype, char.upper(), ctype, "unsigned " + against, ctype) for
char, against in types])
code += """\
default: ok = 0;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%s')", ts);
return NULL;
} else return ts + 1;
"""
funcnode = self.new_ts_func("tdefitem_%s" % self.mangle_dtype_name(dtype), code)
self.ts_item_checkers[dtype] = funcnode self.ts_item_checkers[dtype] = funcnode
return funcnode.entry.cname return funcnode.entry.cname
...@@ -368,13 +397,13 @@ if (*ts != '%s') { ...@@ -368,13 +397,13 @@ if (*ts != '%s') {
self.ensure_ts_utils() self.ensure_ts_utils()
funcnode = self.tscheckers.get(dtype) funcnode = self.tscheckers.get(dtype)
if funcnode is None: if funcnode is None:
assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
if dtype.is_struct_or_union: if dtype.is_struct_or_union:
assert False assert False
elif dtype.is_typedef: elif dtype.is_int or dtype.is_float:
assert False # This includes simple typedef-ed types
else:
funcnode = self.create_ts_check_simple(dtype) funcnode = self.create_ts_check_simple(dtype)
else:
assert False
self.tscheckers[dtype] = funcnode self.tscheckers[dtype] = funcnode
return funcnode.entry return funcnode.entry
...@@ -383,80 +412,3 @@ if (*ts != '%s') { ...@@ -383,80 +412,3 @@ if (*ts != '%s') {
# TODO: # TODO:
# - buf must be NULL before getting new buffer # - buf must be NULL before getting new buffer
## get_buffer_func_type = PyrexTypes.CFuncType(
## PyrexTypes.c_int_type,
## [PyrexTypes.CFuncTypeArg(EncodedString("obj"), PyrexTypes.py_object_type, (0, 0, None), cname="obj"),
## PyrexTypes.CFuncTypeArg(EncodedString("view"), PyrexTypes.c_py_buffer_ptr_type, (0, 0, None), cname="view"),
## PyrexTypes.CFuncTypeArg(EncodedString("flags"), PyrexTypes.c_int_type, (0, 0, None), cname="flags"),
## ],
## exception_value = "-1"
## )
## numpy_get_buffer_body = """
## PyArrayObject *arr = (PyArrayObject*)obj;
## PyArray_Descr *type = (PyArray_Descr*)arr->descr;
## view->buf = arr->data;
## view->readonly = 0; /*fixme*/
## view->format = "B"; /*fixme*/
## view->ndim = arr->nd;
## view->strides = arr->strides;
## view->shape = arr->dimensions;
## view->suboffsets = 0;
## view->itemsize = type->elsize;
## view->internal = 0;
## return 0;
## """
# will be refactored
## code.put("""
## static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
## /* This function is always called after a type-check */
## PyArrayObject *arr = (PyArrayObject*)obj;
## PyArray_Descr *type = (PyArray_Descr*)arr->descr;
## view->buf = arr->data;
## view->readonly = 0; /*fixme*/
## view->format = "B"; /*fixme*/
## view->ndim = arr->nd;
## view->strides = arr->strides;
## view->shape = arr->dimensions;
## view->suboffsets = 0;
## view->itemsize = type->elsize;
## view->internal = 0;
## return 0;
## }
## static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
## }
## """)
## # For now, hard-code numpy imported as "numpy"
## ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
## types = [
## (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
## ]
## # typeptr_cname = ndarrtype.typeptr_cname
## code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
## clause = "if"
## for t, get, release in types:
## code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
## clause = "else if"
## code.putln("else {")
## code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
## code.putln("return -1;")
## code.putln("}")
## code.putln("}")
## code.putln("")
## code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
## clause = "if"
## for t, get, release in types:
## code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
## clause = "else if"
## code.putln("}")
## code.putln("")
...@@ -1953,7 +1953,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1953,7 +1953,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_buffer_compatability_functions(self, env, code): def generate_buffer_compatability_functions(self, env, code):
# will be refactored # will be refactored
code.put(""" try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) { static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */ /* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj; PyArrayObject *arr = (PyArrayObject*)obj;
...@@ -1972,24 +1974,6 @@ static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) { ...@@ -1972,24 +1974,6 @@ static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
01234567890123456789012345*/ 01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO"; const char* base_codes = "?bBhHiIlLqQfdgfdgO";
/*
enum NPY_TYPES { NPY_BOOL=0,
NPY_BYTE, NPY_UBYTE,
NPY_SHORT, NPY_USHORT,
NPY_INT, NPY_UINT,
NPY_LONG, NPY_ULONG,
NPY_LONGLONG, NPY_ULONGLONG,
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
NPY_OBJECT=17,
NPY_STRING, NPY_UNICODE,
NPY_VOID,
NPY_NTYPES,
NPY_NOTYPE,
NPY_CHAR, special flag
NPY_USERDEF=256 leave room for characters
*/
char* format = (char*)malloc(4); char* format = (char*)malloc(4);
char* fp = format; char* fp = format;
*fp++ = type->byteorder; *fp++ = type->byteorder;
...@@ -2016,30 +2000,34 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -2016,30 +2000,34 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
} }
""") """)
except KeyError:
pass
# For now, hard-code numpy imported as "numpy" # For now, hard-code numpy imported as "numpy"
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type types = []
types = [ try:
(ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer") ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
] types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
# typeptr_cname = ndarrtype.typeptr_cname pass
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {") code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
clause = "if" if len(types) > 0:
for t, get, release in types: clause = "if"
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get)) for t, get, release in types:
clause = "else if" code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
code.putln("else {") clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);") code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;") code.putln("return -1;")
code.putln("}") if len(types) > 0: code.putln("}")
code.putln("}") code.putln("}")
code.putln("") code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {") code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
clause = "if" if len(types) > 0:
for t, get, release in types: clause = "if"
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release)) for t, get, release in types:
clause = "else if" code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}") code.putln("}")
code.putln("") code.putln("")
......
...@@ -1093,7 +1093,8 @@ c_returncode_type = CIntType(2, 1, "T_INT", is_returncode = 1) ...@@ -1093,7 +1093,8 @@ c_returncode_type = CIntType(2, 1, "T_INT", is_returncode = 1)
c_anon_enum_type = CAnonEnumType(-1, 1) c_anon_enum_type = CAnonEnumType(-1, 1)
# the Py_buffer type is defined in Builtin.py # the Py_buffer type is defined in Builtin.py
c_py_buffer_ptr_type = CPtrType(CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")) c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
error_type = ErrorType() error_type = ErrorType()
......
cdef extern from "Python.h":
ctypedef int Py_intptr_t
cdef extern from "numpy/arrayobject.h":
ctypedef class numpy.ndarray [object PyArrayObject]:
cdef char *data
cdef int nd
cdef Py_intptr_t *dimensions
cdef Py_intptr_t *strides
cdef object base
# descr not implemented yet here...
cdef int flags
cdef int itemsize
cdef object weakreflist
ctypedef unsigned int npy_uint8
ctypedef unsigned int npy_uint16
ctypedef unsigned int npy_uint32
ctypedef unsigned int npy_uint64
ctypedef unsigned int npy_uint96
ctypedef unsigned int npy_uint128
ctypedef signed int npy_int64
ctypedef float npy_float32
ctypedef float npy_float64
ctypedef float npy_float80
ctypedef float npy_float96
ctypedef float npy_float128
ctypedef npy_int64 Tint64
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment