Commit acdacfe9 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffers: Support for dtype=object

parent 71433dd7
...@@ -117,6 +117,7 @@ ERR_BUF_DUP = '"%s" buffer option already supplied' ...@@ -117,6 +117,7 @@ ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing' ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_MODE = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)' ERR_BUF_MODE = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
ERR_BUF_NDIM = 'ndim must be a non-negative integer' ERR_BUF_NDIM = 'ndim must be a non-negative integer'
ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True): def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
""" """
...@@ -159,11 +160,16 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee ...@@ -159,11 +160,16 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee
if need_complete: if need_complete:
raise CompileError(globalpos, ERR_BUF_MISSING % name) raise CompileError(globalpos, ERR_BUF_MISSING % name)
ndim = options["ndim"] dtype = options.get("dtype")
if not isinstance(ndim, int) or ndim < 0: if dtype and dtype.is_extension_type:
raise CompileError(globalpos, ERR_BUF_DTYPE)
ndim = options.get("ndim")
if ndim and (not isinstance(ndim, int) or ndim < 0):
raise CompileError(globalpos, ERR_BUF_NDIM) raise CompileError(globalpos, ERR_BUF_NDIM)
if not options["mode"] in ('full', 'strided'): mode = options.get("mode")
if mode and not (mode in ('full', 'strided')):
raise CompileError(globalpos, ERR_BUF_MODE) raise CompileError(globalpos, ERR_BUF_MODE)
return options return options
...@@ -307,14 +313,18 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, ...@@ -307,14 +313,18 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
code.putln('}') code.putln('}')
def put_access(entry, index_signeds, index_cnames, options, pos, code): def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
"""Returns a c string which can be used to access the buffer """
for reading or writing. Generates code to process indices and calculate an offset into
a buffer. Returns a C string which gives a pointer which can be
read from or written to at will (it is an expression so caller should
store it in a temporary if it is used more than once).
As the bounds checking can have any number of combinations of unsigned As the bounds checking can have any number of combinations of unsigned
arguments, smart optimizations etc. we insert it directly in the function arguments, smart optimizations etc. we insert it directly in the function
body. The lookup however is delegated to a inline function that is instantiated body. The lookup however is delegated to a inline function that is instantiated
once per ndim (lookup with suboffsets tend to get quite complicated). once per ndim (lookup with suboffsets tend to get quite complicated).
""" """
bufaux = entry.buffer_aux bufaux = entry.buffer_aux
bufstruct = bufaux.buffer_info_var.cname bufstruct = bufaux.buffer_info_var.cname
...@@ -371,12 +381,11 @@ def put_access(entry, index_signeds, index_cnames, options, pos, code): ...@@ -371,12 +381,11 @@ def put_access(entry, index_signeds, index_cnames, options, pos, code):
funcname = "__Pyx_BufPtrStrided%dd" % nd funcname = "__Pyx_BufPtrStrided%dd" % nd
funcgen = buf_lookup_strided_code funcgen = buf_lookup_strided_code
# Make sure the utility code is available
code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd) code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd)
ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params)) ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params))
valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode) return entry.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode
def use_empty_bufstruct_code(env, max_ndim): def use_empty_bufstruct_code(env, max_ndim):
...@@ -421,11 +430,16 @@ def buf_lookup_full_code(proto, defin, name, nd): ...@@ -421,11 +430,16 @@ def buf_lookup_full_code(proto, defin, name, nd):
def mangle_dtype_name(dtype): def mangle_dtype_name(dtype):
# Use prefixes to seperate user defined types from builtins # Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int") # (consider "typedef float unsigned_int")
if dtype.typestring is None: if dtype.is_pyobject:
prefix = "nn_" return "object"
elif dtype.is_ptr:
return "ptr"
else: else:
prefix = "" if dtype.typestring is None:
return prefix + dtype.declaration_code("").replace(" ", "_") prefix = "nn_"
else:
prefix = ""
return prefix + dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(dtype, writer): def get_ts_check_item(dtype, writer):
# See if we can consume one (unnamed) dtype as next item # See if we can consume one (unnamed) dtype as next item
......
...@@ -1370,6 +1370,7 @@ class IndexNode(ExprNode): ...@@ -1370,6 +1370,7 @@ class IndexNode(ExprNode):
self.index = None self.index = None
self.type = self.base.type.dtype self.type = self.base.type.dtype
self.is_buffer_access = True self.is_buffer_access = True
self.buffer_type = self.base.entry.type
if getting: if getting:
# we only need a temp because result_code isn't refactored to # we only need a temp because result_code isn't refactored to
...@@ -1457,8 +1458,13 @@ class IndexNode(ExprNode): ...@@ -1457,8 +1458,13 @@ class IndexNode(ExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
if self.is_buffer_access: if self.is_buffer_access:
valuecode = self.buffer_access_code(code) ptrcode = self.buffer_lookup_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode)) code.putln("%s = *%s;" % (
self.result_code,
self.buffer_type.buffer_ptr_type.cast_code(ptrcode)))
# Must incref the value we pulled out.
if self.buffer_type.dtype.is_pyobject:
code.putln("Py_INCREF((PyObject*)%s);" % self.result_code)
elif self.type.is_pyobject: elif self.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
function = "__Pyx_GetItemInt" function = "__Pyx_GetItemInt"
...@@ -1496,8 +1502,26 @@ class IndexNode(ExprNode): ...@@ -1496,8 +1502,26 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.is_buffer_access: if self.is_buffer_access:
valuecode = self.buffer_access_code(code) ptrexpr = self.buffer_lookup_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code)) if self.buffer_type.dtype.is_pyobject:
# Must manage refcounts. Decref what is already there
# and incref what we put in.
ptr = code.funcstate.allocate_temp(self.buffer_type.buffer_ptr_type)
if rhs.is_temp:
rhs_code = code.funcstate.allocate_temp(rhs.type)
else:
rhs_code = rhs.result_code
code.putln("%s = %s;" % (ptr, ptrexpr))
code.putln("Py_DECREF(*%s); Py_INCREF(%s);" % (
ptr, rhs_code
))
code.putln("*%s = %s;" % (ptr, rhs_code))
if rhs.is_temp:
code.funcstate.release_temp(rhs_code)
code.funcstate.release_temp(ptr)
else:
# Simple case
code.putln("*%s = %s;" % (ptrexpr, rhs.result_code))
elif self.type.is_pyobject: elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
else: else:
...@@ -1524,21 +1548,18 @@ class IndexNode(ExprNode): ...@@ -1524,21 +1548,18 @@ class IndexNode(ExprNode):
code.error_goto(self.pos))) code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code): def buffer_lookup_code(self, code):
# Assign indices to temps # Assign indices to temps
index_temps = [code.funcstate.allocate_temp(i.type) for i in self.indices] index_temps = [code.funcstate.allocate_temp(i.type) for i in self.indices]
for temp, index in zip(index_temps, self.indices): for temp, index in zip(index_temps, self.indices):
code.putln("%s = %s;" % (temp, index.result_code)) code.putln("%s = %s;" % (temp, index.result_code))
# Generate buffer access code using these temps # Generate buffer access code using these temps
import Buffer import Buffer
valuecode = Buffer.put_access(entry=self.base.entry, return Buffer.put_buffer_lookup_code(entry=self.base.entry,
index_signeds=[i.type.signed for i in self.indices], index_signeds=[i.type.signed for i in self.indices],
index_cnames=index_temps, index_cnames=index_temps,
options=self.options, options=self.options,
pos=self.pos, code=code) pos=self.pos, code=code)
return valuecode
class SliceIndexNode(ExprNode): class SliceIndexNode(ExprNode):
# 2-element slice indexing # 2-element slice indexing
......
...@@ -231,6 +231,7 @@ class PyObjectType(PyrexType): ...@@ -231,6 +231,7 @@ class PyObjectType(PyrexType):
parsetuple_format = "O" parsetuple_format = "O"
pymemberdef_typecode = "T_OBJECT" pymemberdef_typecode = "T_OBJECT"
buffer_defaults = None buffer_defaults = None
typestring = "O"
def __str__(self): def __str__(self):
return "Python object" return "Python object"
......
...@@ -14,6 +14,7 @@ cimport python_buffer ...@@ -14,6 +14,7 @@ cimport python_buffer
cimport stdio cimport stdio
cimport cython cimport cython
cimport refcount
__test__ = {} __test__ = {}
setup_string = """ setup_string = """
...@@ -708,6 +709,62 @@ def printbuf_cytypedef2(object[cytypedef2] buf, shape): ...@@ -708,6 +709,62 @@ def printbuf_cytypedef2(object[cytypedef2] buf, shape):
print buf[i], print buf[i],
print print
#
# Object access
#
from python_ref cimport Py_INCREF, Py_DECREF
def addref(*args):
for item in args: Py_INCREF(item)
def decref(*args):
for item in args: Py_DECREF(item)
def get_refcount(x):
return refcount.CyTest_GetRefcount(x)
@testcase
def printbuf_object(object[object] buf, shape):
"""
Only play with unique objects, interned numbers etc. will have
unpredictable refcounts.
ObjectMockBuffer doesn't do anything about increfing/decrefing,
we to the "buffer implementor" refcounting directly in the
testcase.
>>> a, b, c = "globally_unique_string_23234123", {4:23}, [34,3]
>>> get_refcount(a), get_refcount(b), get_refcount(c)
(2, 2, 2)
>>> A = ObjectMockBuffer(None, [a, b, c])
>>> printbuf_object(A, (3,))
'globally_unique_string_23234123' 2
{4: 23} 2
[34, 3] 2
"""
cdef int i
for i in range(shape[0]):
print repr(buf[i]), refcount.CyTest_GetRefcount(buf[i])
@testcase
def assign_to_object(object[object] buf, int idx, obj):
"""
See comments on printbuf_object above.
>>> a, b = [1, 2, 3], [4, 5, 6]
>>> get_refcount(a), get_refcount(b)
(2, 2)
>>> addref(a)
>>> A = ObjectMockBuffer(None, [1, a]) # 1, ...,otherwise it thinks nested lists...
>>> get_refcount(a), get_refcount(b)
(3, 2)
>>> assign_to_object(A, 1, b)
>>> get_refcount(a), get_refcount(b)
(2, 3)
>>> decref(b)
"""
buf[idx] = obj
# #
# Testcase support code (more tests below!, because of scope rules) # Testcase support code (more tests below!, because of scope rules)
...@@ -735,6 +792,8 @@ cdef class MockBuffer: ...@@ -735,6 +792,8 @@ cdef class MockBuffer:
cdef public object fail cdef public object fail
def __init__(self, label, data, shape=None, strides=None, format=None, offset=0): def __init__(self, label, data, shape=None, strides=None, format=None, offset=0):
# It is important not to store references to data after the constructor
# as refcounting is checked on object buffers.
self.label = label self.label = label
self.release_ok = True self.release_ok = True
self.log = "" self.log = ""
...@@ -894,6 +953,18 @@ cdef class UnsignedShortMockBuffer(MockBuffer): ...@@ -894,6 +953,18 @@ cdef class UnsignedShortMockBuffer(MockBuffer):
cdef get_itemsize(self): return sizeof(unsigned short) cdef get_itemsize(self): return sizeof(unsigned short)
cdef get_default_format(self): return "=H" cdef get_default_format(self): return "=H"
cdef extern from *:
void* addr_of_pyobject "(void*)"(object)
cdef class ObjectMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<void**>buf)[0] = addr_of_pyobject(value)
return 0
cdef get_itemsize(self): return sizeof(void*)
cdef get_default_format(self): return "=O"
cdef class IntStridedMockBuffer(IntMockBuffer): cdef class IntStridedMockBuffer(IntMockBuffer):
cdef __cythonbufferdefaults__ = {"mode" : "strided"} cdef __cythonbufferdefaults__ = {"mode" : "strided"}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment