Commit 68e1429c authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Introduced BufferType, start of numpy-independent testcase, GetBuffer improvements

parent 1139758c
...@@ -9,6 +9,8 @@ import PyrexTypes ...@@ -9,6 +9,8 @@ import PyrexTypes
from sets import Set as set from sets import Set as set
class PureCFuncNode(Node): class PureCFuncNode(Node):
child_attrs = []
def __init__(self, pos, cname, type, c_code, visibility='private'): def __init__(self, pos, cname, type, c_code, visibility='private'):
self.pos = pos self.pos = pos
self.cname = cname self.cname = cname
...@@ -97,14 +99,14 @@ class BufferTransform(CythonTransform): ...@@ -97,14 +99,14 @@ class BufferTransform(CythonTransform):
# on the buffer entry # on the buffer entry
bufvars = [(name, entry) for name, entry bufvars = [(name, entry) for name, entry
in scope.entries.iteritems() in scope.entries.iteritems()
if entry.type.buffer_options is not None] if entry.type.is_buffer]
for name, entry in bufvars: for name, entry in bufvars:
bufopts = entry.type.buffer_options buftype = entry.type
# Get or make a type string checker # Get or make a type string checker
tschecker = self.tschecker(bufopts.dtype) tschecker = self.tschecker(buftype.dtype)
# Declare auxiliary vars # Declare auxiliary vars
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name), bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
...@@ -116,7 +118,7 @@ class BufferTransform(CythonTransform): ...@@ -116,7 +118,7 @@ class BufferTransform(CythonTransform):
stridevars = [] stridevars = []
shapevars = [] shapevars = []
for idx in range(bufopts.ndim): for idx in range(buftype.ndim):
# stride # stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx)) varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True) var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
...@@ -216,7 +218,7 @@ class BufferTransform(CythonTransform): ...@@ -216,7 +218,7 @@ class BufferTransform(CythonTransform):
expr = AddNode(pos, operator='+', operand1=expr, operand2=next) expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
casted = TypecastNode(pos, operand=expr, casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype)) type=PyrexTypes.c_ptr_type(node.base.entry.type.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0')) result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result return result
...@@ -412,3 +414,4 @@ class BufferTransform(CythonTransform): ...@@ -412,3 +414,4 @@ class BufferTransform(CythonTransform):
# TODO: # TODO:
# - buf must be NULL before getting new buffer # - buf must be NULL before getting new buffer
...@@ -1302,12 +1302,12 @@ class IndexNode(ExprNode): ...@@ -1302,12 +1302,12 @@ class IndexNode(ExprNode):
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
if self.base.type.buffer_options is not None: if self.base.type.is_buffer:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
indices = self.index.args indices = self.index.args
else: else:
indices = [self.index] indices = [self.index]
if len(indices) == self.base.type.buffer_options.ndim: if len(indices) == self.base.type.ndim:
buffer_access = True buffer_access = True
skip_child_analysis = True skip_child_analysis = True
for x in indices: for x in indices:
...@@ -1320,7 +1320,7 @@ class IndexNode(ExprNode): ...@@ -1320,7 +1320,7 @@ class IndexNode(ExprNode):
# for x in indices] # for x in indices]
self.indices = indices self.indices = indices
self.index = None self.index = None
self.type = self.base.type.buffer_options.dtype self.type = self.base.type.dtype
self.is_temp = 1 self.is_temp = 1
self.is_buffer_access = True self.is_buffer_access = True
......
...@@ -2003,8 +2003,24 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -2003,8 +2003,24 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
except KeyError: except KeyError:
pass pass
# For now, hard-code numpy imported as "numpy" # Search all types for __getbuffer__ overloads
types = [] types = []
def find_buffer_types(scope):
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(self.scope)
# For now, hard-code numpy imported as "numpy"
try: try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")) types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
...@@ -2015,7 +2031,7 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -2015,7 +2031,7 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
if len(types) > 0: if len(types) > 0:
clause = "if" clause = "if"
for t, get, release in types: for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get)) code.putln("%s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if" clause = "else if"
code.putln("else {") code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);") code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
...@@ -2027,7 +2043,8 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -2027,7 +2043,8 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
if len(types) > 0: if len(types) > 0:
clause = "if" clause = "if"
for t, get, release in types: for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release)) if release:
code.putln("%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if" clause = "else if"
code.putln("}") code.putln("}")
code.putln("") code.putln("")
......
...@@ -627,8 +627,7 @@ class CBufferAccessTypeNode(Node): ...@@ -627,8 +627,7 @@ class CBufferAccessTypeNode(Node):
def analyse(self, env): def analyse(self, env):
base_type = self.base_type_node.analyse(env) base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env) dtype = self.dtype_node.analyse(env)
options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim) self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim)
self.type = PyrexTypes.create_buffer_type(base_type, options)
return self.type return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
......
...@@ -6,21 +6,6 @@ from Cython import Utils ...@@ -6,21 +6,6 @@ from Cython import Utils
import Naming import Naming
import copy import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType: class BaseType:
# #
...@@ -57,6 +42,7 @@ class PyrexType(BaseType): ...@@ -57,6 +42,7 @@ class PyrexType(BaseType):
# is_unicode boolean Is a UTF-8 encoded C char * type # is_unicode boolean Is a UTF-8 encoded C char * type
# is_returncode boolean Is used only to signal exceptions # is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type # is_error boolean Is the dummy error type
# is_buffer boolean Is buffer access type
# has_attributes boolean Has C dot-selectable attributes # has_attributes boolean Has C dot-selectable attributes
# default_value string Initial value # default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple # parsetuple_format string Format char for PyArg_ParseTuple
...@@ -106,11 +92,11 @@ class PyrexType(BaseType): ...@@ -106,11 +92,11 @@ class PyrexType(BaseType):
is_unicode = 0 is_unicode = 0
is_returncode = 0 is_returncode = 0
is_error = 0 is_error = 0
is_buffer = 0
has_attributes = 0 has_attributes = 0
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -202,6 +188,26 @@ class CTypedefType(BaseType): ...@@ -202,6 +188,26 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class BufferType(BaseType):
#
# Delegates most attribute
# lookups to the base type. ANYTHING NOT DEFINED
# HERE IS DELEGATED!
# dtype PyrexType
# ndim int
is_buffer = 1
def __init__(self, base, dtype, ndim):
self.base = base
self.dtype = dtype
self.ndim = ndim
def __getattr__(self, name):
return getattr(self.base, name)
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
......
cimport __cython__
__doc__ = u"""
>>> fb = MockBuffer("=f", "f", [1.0, 1.25, 0.75, 1.0], (2,2))
>>> printbuf_float(fb, (2,2))
1.0 1.25
0.75 1.0
"""
def printbuf_float(o, shape):
# should make shape builtin
cdef object[float, 2] buf
buf = o
cdef int i, j
for i in range(shape[0]):
for j in range(shape[1]):
print buf[i, j],
print
sizes = {
'f': sizeof(float)
}
cimport stdlib
cdef class MockBuffer:
cdef object format
cdef char* buffer
cdef int len, itemsize, ndim
cdef Py_ssize_t* strides
cdef Py_ssize_t* shape
def __init__(self, format, typechar, data, shape=None, strides=None):
self.itemsize = sizes[typechar]
if shape is None: shape = (len(data),)
if strides is None:
strides = []
cumprod = 1
for s in shape:
strides.append(cumprod)
cumprod *= s
strides.reverse()
strides = [x * self.itemsize for x in strides]
self.format = format
self.len = len(data) * self.itemsize
self.buffer = <char*>stdlib.malloc(self.len)
self.fill_buffer(typechar, data)
self.ndim = len(shape)
self.strides = <Py_ssize_t*>stdlib.malloc(self.ndim * sizeof(Py_ssize_t))
for i, x in enumerate(strides):
self.strides[i] = x
self.shape = <Py_ssize_t*>stdlib.malloc(self.ndim * sizeof(Py_ssize_t))
def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
if buffer is NULL:
print u"locking!"
return
buffer.buf = self.buffer
buffer.len = self.len
buffer.readonly = 0
buffer.format = <char*>self.format
buffer.ndim = self.ndim
buffer.shape = self.shape
buffer.strides = self.strides
buffer.suboffsets = NULL
buffer.itemsize = self.itemsize
buffer.internal = NULL
cdef fill_buffer(self, typechar, object data):
cdef int idx = 0
for value in data:
(<float*>(self.buffer + idx))[0] = <float>value
idx += sizeof(float)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment