Commit 5f37f843 authored by Kurt Smith's avatar Kurt Smith Committed by Mark Florisson

coercion and validation for memoryview wrappers.

parent 307ad639
...@@ -3,6 +3,7 @@ from PyrexTypes import * ...@@ -3,6 +3,7 @@ from PyrexTypes import *
from UtilityCode import CythonUtilityCode from UtilityCode import CythonUtilityCode
from Errors import error from Errors import error
from Scanning import StringSourceDescriptor from Scanning import StringSourceDescriptor
import Options
class CythonScope(ModuleScope): class CythonScope(ModuleScope):
is_cython_builtin = 1 is_cython_builtin = 1
...@@ -239,7 +240,7 @@ memviewext_typeptr_cname = Naming.typeptr_prefix+memview_name ...@@ -239,7 +240,7 @@ memviewext_typeptr_cname = Naming.typeptr_prefix+memview_name
memviewext_typeobj_cname = '__pyx_tobj_'+memview_name memviewext_typeobj_cname = '__pyx_tobj_'+memview_name
memviewext_objstruct_cname = '__pyx_obj_'+memview_name memviewext_objstruct_cname = '__pyx_obj_'+memview_name
view_utility_code = CythonUtilityCode(u""" view_utility_code = CythonUtilityCode(u"""
cdef class Enum: cdef class Enum(object):
cdef object name cdef object name
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
...@@ -257,25 +258,141 @@ cdef extern from *: ...@@ -257,25 +258,141 @@ cdef extern from *:
int __Pyx_GetBuffer(object, Py_buffer *, int) int __Pyx_GetBuffer(object, Py_buffer *, int)
void __Pyx_ReleaseBuffer(Py_buffer *) void __Pyx_ReleaseBuffer(Py_buffer *)
cdef class memoryview:
cdef class memoryview(object):
cdef Py_buffer view cdef Py_buffer view
cdef int gotbuf_flag cdef int gotbuf_flag
def __cinit__(self): def __cinit__(memoryview self, object obj, int flags):
self.gotbuf_flag = 0
cdef memoryview from_obj(memoryview self, obj, int flags):
__Pyx_GetBuffer(obj, &self.view, flags) __Pyx_GetBuffer(obj, &self.view, flags)
self.gotbuf_flag = 1
def __dealloc__(memoryview self): def __dealloc__(memoryview self):
if self.gotbuf_flag:
__Pyx_ReleaseBuffer(&self.view) __Pyx_ReleaseBuffer(&self.view)
self.gotbuf_flag = 0
cdef memoryview memoryview_cwrapper(object o, int flags):
return memoryview(o, flags)
# XXX: put in #defines...
DEF BUF_MAX_NDIMS = %d
DEF __Pyx_MEMVIEW_DIRECT = 1
DEF __Pyx_MEMVIEW_PTR = 2
DEF __Pyx_MEMVIEW_FULL = 4
DEF __Pyx_MEMVIEW_CONTIG = 8
DEF __Pyx_MEMVIEW_STRIDED = 16
DEF __Pyx_MEMVIEW_FOLLOW = 32
cdef extern from *:
struct __pyx_obj_memoryview:
Py_buffer view
ctypedef struct __Pyx_mv_DimInfo:
Py_ssize_t shape, strides, suboffsets
ctypedef struct __Pyx_memviewstruct:
__pyx_obj_memoryview *memviewext
char *data
__Pyx_mv_DimInfo diminfo[BUF_MAX_NDIMS]
cdef is_cf_contig(int *specs, int ndim):
""", prefix="__pyx_viewaxis_") is_c_contig = is_f_contig = False
# c_contiguous: 'follow', 'follow', ..., 'follow', 'contig'
if specs[ndim-1] & __Pyx_MEMVIEW_CONTIG:
for i in range(0, ndim-1):
if not (specs[i] & __Pyx_MEMVIEW_FOLLOW):
break
else:
is_c_contig = True
# f_contiguous: 'contig', 'follow', 'follow', ..., 'follow'
elif ndim > 1 and (specs[0] & __Pyx_MEMVIEW_CONTIG):
for i in range(1, ndim):
if not (specs[i] & __Pyx_MEMVIEW_FOLLOW):
break
else:
is_f_contig = True
return is_c_contig, is_f_contig
cdef object pyxmemview_from_memview(
memoryview memview,
int *axes_specs,
int ndim,
Py_ssize_t itemsize,
char *format,
__Pyx_memviewstruct *pyx_memview):
cdef int i
if ndim > BUF_MAX_NDIMS:
raise ValueError("number of dimensions exceed maximum of" + str(BUF_MAX_NDIMS))
cdef Py_buffer pybuf = memview.view
if pybuf.ndim != ndim:
raise ValueError("incompatible number of dimensions.")
cdef str pyx_format = pybuf.format
cdef str view_format = format
if pyx_format != view_format:
raise ValueError("Buffer and memoryview datatype formats do not match.")
if itemsize != pybuf.itemsize:
raise ValueError("Buffer and memoryview itemsize do not match.")
if not pybuf.strides:
raise ValueError("no stride information provided.")
has_suboffsets = True
if not pybuf.suboffsets:
has_suboffsets = False
is_c_contig, is_f_contig = is_cf_contig(axes_specs, ndim)
cdef int spec = 0
for i in range(ndim):
istr = str(i)
spec = axes_specs[i]
if spec & __Pyx_MEMVIEW_CONTIG:
if pybuf.strides[i] != 1:
raise ValueError("Dimension "+istr+" in axes specification is incompatible with buffer.")
if spec & (__Pyx_MEMVIEW_STRIDED | __Pyx_MEMVIEW_FOLLOW):
if pybuf.strides[i] <= 1:
raise ValueError("Dimension "+istr+" in axes specification is incompatible with buffer.")
if spec & __Pyx_MEMVIEW_DIRECT:
if has_suboffsets and pybuf.suboffsets[i] >= 0:
raise ValueError("Dimension "+istr+" in axes specification is incompatible with buffer.")
if spec & (__Pyx_MEMVIEW_PTR | __Pyx_MEMVIEW_FULL):
if not has_suboffsets:
raise ValueError("Buffer object does not provide suboffsets.")
if spec & __Pyx_MEMVIEW_PTR:
if pybuf.suboffsets[i] < 0:
raise ValueError("Buffer object suboffset in dimension "+istr+"must be >= 0.")
if is_f_contig:
idx = 0; stride = 1
for i in range(ndim):
if stride != pybuf.strides[i]:
raise ValueError("Buffer object not fortran contiguous.")
stride = stride * pybuf.shape[i]
elif is_c_contig:
idx = ndim-1; stride = 1
for i in range(ndim-1,-1,-1):
if stride != pybuf.strides[i]:
raise ValueError("Buffer object not C contiguous.")
stride = stride * pybuf.shape[i]
for i in range(ndim):
pyx_memview.diminfo[i].strides = pybuf.strides[i]
pyx_memview.diminfo[i].shape = pybuf.shape[i]
if has_suboffsets:
pyx_memview.diminfo[i].suboffsets = pybuf.suboffsets[i]
pyx_memview.memviewext = <__pyx_obj_memoryview*>memview
pyx_memview.data = <char *>pybuf.buf
""" % Options.buffer_max_dims, name="foobar", prefix="__pyx_viewaxis_")
cyarray_prefix = u'__pyx_cythonarray_' cyarray_prefix = u'__pyx_cythonarray_'
cython_array_utility_code = CythonUtilityCode(u''' cython_array_utility_code = CythonUtilityCode(u'''
...@@ -300,7 +417,7 @@ cdef class array: ...@@ -300,7 +417,7 @@ cdef class array:
Py_ssize_t *shape Py_ssize_t *shape
Py_ssize_t *strides Py_ssize_t *strides
Py_ssize_t itemsize Py_ssize_t itemsize
char *mode str mode
def __cinit__(array self, tuple shape, Py_ssize_t itemsize, char *format, mode="c"): def __cinit__(array self, tuple shape, Py_ssize_t itemsize, char *format, mode="c"):
...@@ -360,12 +477,12 @@ cdef class array: ...@@ -360,12 +477,12 @@ cdef class array:
if not self.data: if not self.data:
raise MemoryError("unable to allocate array data.") raise MemoryError("unable to allocate array data.")
def __getbuffer__(array self, Py_buffer *info, int flags): def __getbuffer__(self, Py_buffer *info, int flags):
cdef int bufmode cdef int bufmode = -1
if self.mode == b"c": if self.mode == b"c":
bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS
if self.mode == b"fortran": elif self.mode == b"fortran":
bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS
if not (flags & bufmode): if not (flags & bufmode):
raise ValueError("Can only create a buffer that is contiguous in memory.") raise ValueError("Can only create a buffer that is contiguous in memory.")
......
...@@ -7578,20 +7578,26 @@ class CoerceToMemViewNode(CoercionNode): ...@@ -7578,20 +7578,26 @@ class CoerceToMemViewNode(CoercionNode):
self.type = dst_type self.type = dst_type
self.is_temp = 1 self.is_temp = 1
self.env = env self.env = env
import MemoryView
self.env.use_utility_code(MemoryView.obj_to_memview_code)
# MemoryView.use_memview_cwrap(env)
def generate_result_code(self, code): def generate_result_code(self, code):
# create a cython.memoryview object.
# declare a new temporary cython.memoryview variable.
import MemoryView import MemoryView
memviewobj = code.funcstate.allocate_temp(PyrexTypes.py_object_type, manage_ref=True)
# -) initialize cython.memview object with self.arg, it calls buf_flag = MemoryView.get_buf_flag(self.type.axes)
# __Pyx_GetBuffer on it. code.putln("%s = (PyObject *)__pyx_viewaxis_memoryview_cwrapper(%s, %s);" % (memviewobj, self.arg.result(), buf_flag))
# -) check the axes specifiers for the underlying memview's Py_buffer, ndim = len(self.type.axes)
# make sure they're compatible with the dst_type's axes specs. spec_int_arr = code.funcstate.allocate_temp(PyrexTypes.c_array_type(PyrexTypes.c_int_type, ndim),manage_ref=True)
# -) output the temp assignment code (see specs_code = MemoryView.specs_to_code(self.type.axes)
# CoerceFromPyTypeNode.generate_result_code for example) for idx, cspec in enumerate(specs_code):
pass code.putln("%s[%d] = %s;" % (spec_int_arr, idx, cspec))
itemsize = self.type.dtype.sign_and_name()
format = MemoryView.format_from_type(self.type.dtype)
code.putln("__pyx_viewaxis_pyxmemview_from_memview((struct __pyx_obj_memoryview *)%s, %s, %d, sizeof(%s), \"%s\", &%s);" % (memviewobj, spec_int_arr, ndim, itemsize, format, self.result()))
code.funcstate.release_temp(memviewobj)
code.funcstate.release_temp(spec_int_arr)
code.putln('/* @@@ */')
class CastNode(CoercionNode): class CastNode(CoercionNode):
# Wrap a node in a C type cast. # Wrap a node in a C type cast.
......
...@@ -4,6 +4,7 @@ from Visitor import CythonTransform ...@@ -4,6 +4,7 @@ from Visitor import CythonTransform
import Options import Options
import CythonScope import CythonScope
from Code import UtilityCode from Code import UtilityCode
from UtilityCode import CythonUtilityCode
START_ERR = "there must be nothing or the value 0 (zero) in the start slot." START_ERR = "there must be nothing or the value 0 (zero) in the start slot."
STOP_ERR = "Axis specification only allowed in the 'stop' slot." STOP_ERR = "Axis specification only allowed in the 'stop' slot."
...@@ -15,13 +16,87 @@ INVALID_ERR = "Invalid axis specification." ...@@ -15,13 +16,87 @@ INVALID_ERR = "Invalid axis specification."
EXPR_ERR = "no expressions allowed in axis spec, only names (e.g. cython.view.contig)." EXPR_ERR = "no expressions allowed in axis spec, only names (e.g. cython.view.contig)."
CF_ERR = "Invalid axis specification for a C/Fortran contiguous array." CF_ERR = "Invalid axis specification for a C/Fortran contiguous array."
def use_memview_util_code(env): memview_c_contiguous = "PyBUF_C_CONTIGUOUS"
memview_f_contiguous = "PyBUF_F_CONTIGUOUS"
memview_any_contiguous = "PyBUF_ANY_CONTIGUOUS"
memview_full_access = "PyBUF_FULL"
memview_strided_access = "PyBUF_STRIDED"
MEMVIEW_DIRECT = 1
MEMVIEW_PTR = 2
MEMVIEW_FULL = 4
MEMVIEW_CONTIG = 8
MEMVIEW_STRIDED= 16
MEMVIEW_FOLLOW = 32
_spec_to_const = {
'contig' : MEMVIEW_CONTIG,
'strided': MEMVIEW_STRIDED,
'follow' : MEMVIEW_FOLLOW,
'direct' : MEMVIEW_DIRECT,
'ptr' : MEMVIEW_PTR,
'full' : MEMVIEW_FULL
}
def specs_to_code(specs):
arr = []
for access, packing in specs:
arr.append("(%s | %s)" % (_spec_to_const[access], _spec_to_const[packing]))
return arr
# XXX: add complex support below...
_typename_to_format = {
'char' : 'c',
'signed char' : 'b',
'unsigned char' : 'B',
'short' : 'h',
'unsigned short' : 'H',
'int' : 'i',
'unsigned int' : 'I',
'long' : 'l',
'unsigned long' : 'L',
'long long' : 'q',
'unsigned long long' : 'Q',
'float' : 'f',
'double' : 'd',
}
def format_from_type(base_type):
return _typename_to_format[base_type.sign_and_name()]
def get_buf_flag(specs):
is_c_contig, is_f_contig = is_cf_contig(specs)
if is_c_contig:
return memview_c_contiguous
elif is_f_contig:
return memview_f_contiguous
access, packing = zip(*specs)
assert 'follow' not in packing
if 'full' in access or 'ptr' in access:
return memview_full_access
else:
return memview_strided_access
def use_cython_util_code(env, lu_name):
import CythonScope import CythonScope
cythonscope = env.global_scope().context.cython_scope cythonscope = env.global_scope().context.cython_scope
viewscope = cythonscope.viewscope viewscope = cythonscope.viewscope
memview_entry = viewscope.lookup_here(CythonScope.memview_name) entry = viewscope.lookup_here(lu_name)
assert memview_entry is cythonscope.memviewentry entry.used = 1
memview_entry.used = 1 return entry
def use_memview_util_code(env):
import CythonScope
memview_entry = use_cython_util_code(env, CythonScope.memview_name)
def use_memview_cwrap(env):
import CythonScope
mv_cwrap_entry = use_cython_util_code(env, CythonScope.memview_cwrap_name)
def get_axes_specs(env, axes): def get_axes_specs(env, axes):
''' '''
...@@ -123,25 +198,30 @@ def get_axes_specs(env, axes): ...@@ -123,25 +198,30 @@ def get_axes_specs(env, axes):
return axes_specs return axes_specs
def validate_axes_specs(pos, specs): def is_cf_contig(specs):
packing_specs = ('contig', 'strided', 'follow')
access_specs = ('direct', 'ptr', 'full')
is_c_contig = is_f_contig = False is_c_contig = is_f_contig = False
packing_idx = 1 packing_idx = 1
if (specs[0][packing_idx] == 'contig' and if (specs[-1][packing_idx] == 'contig' and
all(axis[packing_idx] == 'follow' for axis in specs[:-1])):
# c_contiguous: 'follow', 'follow', ..., 'follow', 'contig'
is_c_contig = True
elif (len(specs) > 1 and
specs[0][packing_idx] == 'contig' and
all(axis[packing_idx] == 'follow' for axis in specs[1:])): all(axis[packing_idx] == 'follow' for axis in specs[1:])):
# f_contiguous: 'contig', 'follow', 'follow', ..., 'follow' # f_contiguous: 'contig', 'follow', 'follow', ..., 'follow'
is_f_contig = True is_f_contig = True
elif (len(specs) > 1 and return is_c_contig, is_f_contig
specs[-1][packing_idx] == 'contig' and
all(axis[packing_idx] == 'follow' for axis in specs[:-1])): def validate_axes_specs(pos, specs):
# c_contiguous: 'follow', 'follow', ..., 'follow', 'contig'
is_c_contig = True packing_specs = ('contig', 'strided', 'follow')
access_specs = ('direct', 'ptr', 'full')
is_c_contig, is_f_contig = is_cf_contig(specs)
has_contig = has_follow = has_strided = False has_contig = has_follow = has_strided = False
...@@ -236,9 +316,18 @@ class MemoryViewTransform(CythonTransform): ...@@ -236,9 +316,18 @@ class MemoryViewTransform(CythonTransform):
return node return node
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
import pdb; pdb.set_trace()
return node return node
spec_constants_code = UtilityCode(proto="""
#define __Pyx_MEMVIEW_DIRECT 1
#define __Pyx_MEMVIEW_PTR 2
#define __Pyx_MEMVIEW_FULL 4
#define __Pyx_MEMVIEW_CONTIG 8
#define __Pyx_MEMVIEW_STRIDED 16
#define __Pyx_MEMVIEW_FOLLOW 32
"""
)
memviewstruct_cname = u'__Pyx_memviewstruct' memviewstruct_cname = u'__Pyx_memviewstruct'
memviewstruct_declare_code = UtilityCode(proto=""" memviewstruct_declare_code = UtilityCode(proto="""
......
...@@ -2547,6 +2547,9 @@ c_pyx_buffer_ptr_type = CPtrType(c_pyx_buffer_type) ...@@ -2547,6 +2547,9 @@ c_pyx_buffer_ptr_type = CPtrType(c_pyx_buffer_type)
c_pyx_buffer_nd_type = CStructOrUnionType("__Pyx_LocalBuf_ND", "struct", c_pyx_buffer_nd_type = CStructOrUnionType("__Pyx_LocalBuf_ND", "struct",
None, 1, "__Pyx_LocalBuf_ND") None, 1, "__Pyx_LocalBuf_ND")
cython_memoryview_type = CStructOrUnionType("__pyx_obj_memoryview", "struct",
None, 1, "__pyx_obj_memoryview")
error_type = ErrorType() error_type = ErrorType()
unspecified_type = UnspecifiedType() unspecified_type = UnspecifiedType()
......
u''' u'''
>>> f() >>> f()
>>> g() >>> g()
>>> call()
''' '''
# from cython.view cimport memoryview from cython.view cimport memoryview
from cython cimport array, PyBUF_C_CONTIGUOUS from cython cimport array, PyBUF_C_CONTIGUOUS
def f(): def f():
pass cdef array arr = array(shape=(10,10), itemsize=sizeof(int), format='i')
# cdef array arr = array(shape=(10,10), itemsize=sizeof(int), format='i') cdef memoryview mv = memoryview(arr, PyBUF_C_CONTIGUOUS)
# cdef memoryview mv = memoryview(arr, PyBUF_C_CONTIGUOUS)
def g(): def g():
# cdef int[::1] mview = array((10,), itemsize=sizeof(int), format='i') # cdef int[::1] mview = array((10,), itemsize=sizeof(int), format='i')
cdef int[::1] mview = array((10,), itemsize=sizeof(int), format='i') cdef int[::1] mview = array((10,), itemsize=sizeof(int), format='i')
mview = array((10,), itemsize=sizeof(int), format='i')
cdef class Foo: cdef class Foo:
cdef int[:] mview cdef int[::1] mview
def __init__(self): def __init__(self):
self.mview = array((10,), itemsize=sizeof(int), format='i') self.mview = array((10,), itemsize=sizeof(int), format='i')
self.mview = array((10,), itemsize=sizeof(int), format='i')
class pyfoo: class pyfoo:
def __init__(self): def __init__(self):
self.mview = array((10,), itemsize=sizeof(long), format='l') self.mview = array((10,), itemsize=sizeof(long), format='l')
# self.mview = arr
cdef cdg(): cdef cdg():
cdef double[:] dmv = array((10,), itemsize=sizeof(double), format='d') cdef double[::1] dmv = array((10,), itemsize=sizeof(double), format='d')
dmv = array((10,), itemsize=sizeof(double), format='d')
cdef float[:,:] global_mv = array((10,10), itemsize=sizeof(float), format='f') cdef float[:,::1] global_mv = array((10,10), itemsize=sizeof(float), format='f')
global_mv = array((10,10), itemsize=sizeof(float), format='f')
def call(): def call():
cdg() cdg()
f = Foo() f = Foo()
pf = pyfoo() pf = pyfoo()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment