Commit eaa6d477 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Implemented mode flag and strided mode for buffers

parent e175a590
......@@ -80,11 +80,18 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
result.used = True
return result
stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
mode = entry.type.mode
if mode == 'full':
suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim)
elif mode == 'strided':
suboffsetvars = None
entry.buffer_aux.lookup = get_buf_lookup_strided(scope, entry.type.ndim)
entry.buffer_aux.suboffsetvars = suboffsetvars
entry.buffer_aux.get_buffer_cname = tschecker
......@@ -105,7 +112,13 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
def get_flags(buffer_aux, buffer_type):
flags = 'PyBUF_FORMAT | PyBUF_INDIRECT'
flags = 'PyBUF_FORMAT'
if buffer_type.mode == 'full':
flags += '| PyBUF_INDIRECT'
elif buffer_type.mode == 'strided':
flags += '| PyBUF_STRIDES'
else:
assert False
if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
return flags
......@@ -116,14 +129,17 @@ def used_buffer_aux_vars(entry):
for s in buffer_aux.stridevars: s.used = True
for s in buffer_aux.suboffsetvars: s.used = True
def put_unpack_buffer_aux_into_scope(buffer_aux, code):
def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
# Generate code to copy the needed struct info into local
# variables.
bufstruct = buffer_aux.buffer_info_var.cname
# __pyx_bstride_0_buf = __pyx_bstruct_buf.strides[0] and so on
varspec = [("strides", buffer_aux.stridevars),
("shape", buffer_aux.shapevars)]
if mode == 'full':
varspec.append(("suboffsets", buffer_aux.suboffsetvars))
for field, vars in (("strides", buffer_aux.stridevars),
("shape", buffer_aux.shapevars),
("suboffsets", buffer_aux.suboffsetvars)):
for field, vars in varspec:
code.putln(" ".join(["%s = %s.%s[%d];" %
(s.cname, bufstruct, field, idx)
for idx, s in enumerate(vars)]))
......@@ -146,7 +162,7 @@ def put_acquire_arg_buffer(entry, code, pos):
pos))
# An exception raised in arg parsing cannot be catched, so no
# need to do care about the buffer then.
put_unpack_buffer_aux_into_scope(buffer_aux, code)
put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
#def put_release_buffer_normal(entry, code):
# code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
......@@ -215,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
code.end_block()
# Unpack indices
code.end_block()
put_unpack_buffer_aux_into_scope(buffer_aux, code)
put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln(code.error_goto_if_neg(retcode_cname, pos))
code.func.release_temp(retcode_cname)
else:
......@@ -227,7 +243,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
code.putln(code.error_goto(pos))
code.put('} else {')
# Unpack indices
put_unpack_buffer_aux_into_scope(buffer_aux, code)
put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln('}')
......@@ -266,8 +282,6 @@ def put_access(entry, index_signeds, index_cnames, pos, code):
code.putln("if (%s) %s = %d;" % (
code.unlikely("%s >= %s" % (cname, shape.cname)),
tmp_cname, idx))
# if boundscheck or not nonegs:
# code.putln("}")
if boundscheck:
code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
code.begin_block()
......@@ -277,14 +291,18 @@ def put_access(entry, index_signeds, index_cnames, pos, code):
code.func.release_temp(tmp_cname)
# Create buffer lookup and return it
offset = " + ".join(["%s * %s" % (idx, stride.cname)
for idx, stride in
zip(index_cnames, bufaux.stridevars)])
ptrcode = "(%s.buf + %s)" % (bufstruct, offset)
params = []
if entry.type.mode == 'full':
for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
params.append(i)
params.append(s.cname)
params.append(o.cname)
else:
for i, s in zip(index_cnames, bufaux.stridevars):
params.append(i)
params.append(s.cname)
ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct,
", ".join([", ".join([i, s.cname, o.cname]) for i, s, o in
zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars)]))
", ".join(params))
valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode
......@@ -297,6 +315,25 @@ def use_empty_bufstruct_code(env, max_ndim):
""") % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
env.use_utility_code([code, ""])
def get_buf_lookup_strided(env, nd):
"""
Generates and registers as utility a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location.
"""
name = "__Pyx_BufPtrStrided_%dd" % nd
if not env.has_utility_code(name):
# _i_ndex, _s_tride
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
proto = dedent("""\
#define %s(buf, %s) ((char*)buf + %s)
""") % (name, args, offset)
env.use_utility_code([proto, ""], name=name)
return name
def get_buf_lookup_full(env, nd):
"""
Generates and registers as utility a buffer lookup function for the right number
......
......@@ -600,7 +600,8 @@ class CBufferAccessTypeNode(Node):
def analyse(self, env):
base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env)
self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim)
self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
mode=self.mode)
return self.type
class CComplexBaseTypeNode(CBaseTypeNode):
......
......@@ -84,6 +84,7 @@ ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
class PostParse(CythonTransform):
"""
Basic interpretation of the parse tree, as well as validity
......@@ -155,7 +156,7 @@ class PostParse(CythonTransform):
return stats
# buffer access
buffer_options = ("dtype", "ndim") # ordered!
buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
......@@ -176,7 +177,6 @@ class PostParse(CythonTransform):
raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
options[name] = item.value
provided = options.keys()
# get dtype
dtype = options.get("dtype")
if dtype is None:
......@@ -184,7 +184,7 @@ class PostParse(CythonTransform):
node.dtype_node = dtype
# get ndim
if "ndim" in provided:
if "ndim" in options:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
......@@ -197,6 +197,17 @@ class PostParse(CythonTransform):
else:
node.ndim = 1
if "mode" in options:
modenode = options["mode"]
if not isinstance(modenode, StringNode):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
mode = modenode.value
if not mode in ('full', 'strided'):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
node.mode = mode
else:
node.mode = 'full'
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
......
......@@ -196,14 +196,18 @@ class BufferType(BaseType):
# dtype PyrexType
# ndim int
# mode str
# is_buffer boolean
# writable boolean
is_buffer = 1
writable = True
def __init__(self, base, dtype, ndim):
def __init__(self, base, dtype, ndim, mode):
self.base = base
self.dtype = dtype
self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode
def as_argument_type(self):
return self
......
......@@ -8,6 +8,8 @@ def f():
cdef object[ndim=-1] buf2
cdef object[int, 'a'] buf3
cdef object[int,2,3,4,5,6] buf4
cdef object[int, 2, 'foo'] buf5
cdef object[int, 2, well] buf6
_ERRORS = u"""
1:11: Buffer types only allowed as function local variables
......@@ -17,5 +19,7 @@ _ERRORS = u"""
8:15: "dtype" missing
9:21: "ndim" must be an integer
10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
"""
......@@ -477,6 +477,19 @@ def writable(obj):
cdef object[unsigned short int, 3] buf = obj
buf[2, 2, 1] = 23
@testcase
def strided(object[int, 1, 'strided'] buf):
"""
>>> A = IntMockBuffer("A", range(4))
>>> strided(A)
acquired A
released A
2
>>> A.recieved_flags
['FORMAT', 'ND', 'STRIDES']
"""
return buf[2]
#
# Coercions
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment