Commit eaa6d477 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Implemented mode flag and strided mode for buffers

parent e175a590
...@@ -80,11 +80,18 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -80,11 +80,18 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
result.used = True result.used = True
return result return result
stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)] stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)] shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker) entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim) mode = entry.type.mode
if mode == 'full':
suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim)
elif mode == 'strided':
suboffsetvars = None
entry.buffer_aux.lookup = get_buf_lookup_strided(scope, entry.type.ndim)
entry.buffer_aux.suboffsetvars = suboffsetvars entry.buffer_aux.suboffsetvars = suboffsetvars
entry.buffer_aux.get_buffer_cname = tschecker entry.buffer_aux.get_buffer_cname = tschecker
...@@ -105,7 +112,13 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -105,7 +112,13 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
def get_flags(buffer_aux, buffer_type): def get_flags(buffer_aux, buffer_type):
flags = 'PyBUF_FORMAT | PyBUF_INDIRECT' flags = 'PyBUF_FORMAT'
if buffer_type.mode == 'full':
flags += '| PyBUF_INDIRECT'
elif buffer_type.mode == 'strided':
flags += '| PyBUF_STRIDES'
else:
assert False
if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE" if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
return flags return flags
...@@ -116,14 +129,17 @@ def used_buffer_aux_vars(entry): ...@@ -116,14 +129,17 @@ def used_buffer_aux_vars(entry):
for s in buffer_aux.stridevars: s.used = True for s in buffer_aux.stridevars: s.used = True
for s in buffer_aux.suboffsetvars: s.used = True for s in buffer_aux.suboffsetvars: s.used = True
def put_unpack_buffer_aux_into_scope(buffer_aux, code): def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
# Generate code to copy the needed struct info into local
# variables.
bufstruct = buffer_aux.buffer_info_var.cname bufstruct = buffer_aux.buffer_info_var.cname
# __pyx_bstride_0_buf = __pyx_bstruct_buf.strides[0] and so on varspec = [("strides", buffer_aux.stridevars),
("shape", buffer_aux.shapevars)]
if mode == 'full':
varspec.append(("suboffsets", buffer_aux.suboffsetvars))
for field, vars in (("strides", buffer_aux.stridevars), for field, vars in varspec:
("shape", buffer_aux.shapevars),
("suboffsets", buffer_aux.suboffsetvars)):
code.putln(" ".join(["%s = %s.%s[%d];" % code.putln(" ".join(["%s = %s.%s[%d];" %
(s.cname, bufstruct, field, idx) (s.cname, bufstruct, field, idx)
for idx, s in enumerate(vars)])) for idx, s in enumerate(vars)]))
...@@ -146,7 +162,7 @@ def put_acquire_arg_buffer(entry, code, pos): ...@@ -146,7 +162,7 @@ def put_acquire_arg_buffer(entry, code, pos):
pos)) pos))
# An exception raised in arg parsing cannot be catched, so no # An exception raised in arg parsing cannot be catched, so no
# need to do care about the buffer then. # need to do care about the buffer then.
put_unpack_buffer_aux_into_scope(buffer_aux, code) put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
#def put_release_buffer_normal(entry, code): #def put_release_buffer_normal(entry, code):
# code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % ( # code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
...@@ -215,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, ...@@ -215,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
code.end_block() code.end_block()
# Unpack indices # Unpack indices
code.end_block() code.end_block()
put_unpack_buffer_aux_into_scope(buffer_aux, code) put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln(code.error_goto_if_neg(retcode_cname, pos)) code.putln(code.error_goto_if_neg(retcode_cname, pos))
code.func.release_temp(retcode_cname) code.func.release_temp(retcode_cname)
else: else:
...@@ -227,7 +243,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, ...@@ -227,7 +243,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
code.putln(code.error_goto(pos)) code.putln(code.error_goto(pos))
code.put('} else {') code.put('} else {')
# Unpack indices # Unpack indices
put_unpack_buffer_aux_into_scope(buffer_aux, code) put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln('}') code.putln('}')
...@@ -266,8 +282,6 @@ def put_access(entry, index_signeds, index_cnames, pos, code): ...@@ -266,8 +282,6 @@ def put_access(entry, index_signeds, index_cnames, pos, code):
code.putln("if (%s) %s = %d;" % ( code.putln("if (%s) %s = %d;" % (
code.unlikely("%s >= %s" % (cname, shape.cname)), code.unlikely("%s >= %s" % (cname, shape.cname)),
tmp_cname, idx)) tmp_cname, idx))
# if boundscheck or not nonegs:
# code.putln("}")
if boundscheck: if boundscheck:
code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname)) code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
code.begin_block() code.begin_block()
...@@ -275,16 +289,20 @@ def put_access(entry, index_signeds, index_cnames, pos, code): ...@@ -275,16 +289,20 @@ def put_access(entry, index_signeds, index_cnames, pos, code):
code.putln(code.error_goto(pos)) code.putln(code.error_goto(pos))
code.end_block() code.end_block()
code.func.release_temp(tmp_cname) code.func.release_temp(tmp_cname)
# Create buffer lookup and return it
offset = " + ".join(["%s * %s" % (idx, stride.cname) # Create buffer lookup and return it
for idx, stride in params = []
zip(index_cnames, bufaux.stridevars)]) if entry.type.mode == 'full':
ptrcode = "(%s.buf + %s)" % (bufstruct, offset) for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
params.append(i)
params.append(s.cname)
params.append(o.cname)
else:
for i, s in zip(index_cnames, bufaux.stridevars):
params.append(i)
params.append(s.cname)
ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct, ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct,
", ".join([", ".join([i, s.cname, o.cname]) for i, s, o in ", ".join(params))
zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars)]))
valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode) valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode return valuecode
...@@ -297,6 +315,25 @@ def use_empty_bufstruct_code(env, max_ndim): ...@@ -297,6 +315,25 @@ def use_empty_bufstruct_code(env, max_ndim):
""") % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim)) """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
env.use_utility_code([code, ""]) env.use_utility_code([code, ""])
def get_buf_lookup_strided(env, nd):
"""
Generates and registers as utility a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location.
"""
name = "__Pyx_BufPtrStrided_%dd" % nd
if not env.has_utility_code(name):
# _i_ndex, _s_tride
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
proto = dedent("""\
#define %s(buf, %s) ((char*)buf + %s)
""") % (name, args, offset)
env.use_utility_code([proto, ""], name=name)
return name
def get_buf_lookup_full(env, nd): def get_buf_lookup_full(env, nd):
""" """
Generates and registers as utility a buffer lookup function for the right number Generates and registers as utility a buffer lookup function for the right number
......
...@@ -600,7 +600,8 @@ class CBufferAccessTypeNode(Node): ...@@ -600,7 +600,8 @@ class CBufferAccessTypeNode(Node):
def analyse(self, env): def analyse(self, env):
base_type = self.base_type_node.analyse(env) base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env) dtype = self.dtype_node.analyse(env)
self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim) self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
mode=self.mode)
return self.type return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
......
...@@ -84,6 +84,7 @@ ERR_BUF_INT = '"%s" must be an integer' ...@@ -84,6 +84,7 @@ ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative' ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes' ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables' ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
class PostParse(CythonTransform): class PostParse(CythonTransform):
""" """
Basic interpretation of the parse tree, as well as validity Basic interpretation of the parse tree, as well as validity
...@@ -155,7 +156,7 @@ class PostParse(CythonTransform): ...@@ -155,7 +156,7 @@ class PostParse(CythonTransform):
return stats return stats
# buffer access # buffer access
buffer_options = ("dtype", "ndim") # ordered! buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node): def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function': if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY) raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
...@@ -176,7 +177,6 @@ class PostParse(CythonTransform): ...@@ -176,7 +177,6 @@ class PostParse(CythonTransform):
raise PostParseError(item.key.pos, ERR_BUF_DUP % key) raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
options[name] = item.value options[name] = item.value
provided = options.keys()
# get dtype # get dtype
dtype = options.get("dtype") dtype = options.get("dtype")
if dtype is None: if dtype is None:
...@@ -184,7 +184,7 @@ class PostParse(CythonTransform): ...@@ -184,7 +184,7 @@ class PostParse(CythonTransform):
node.dtype_node = dtype node.dtype_node = dtype
# get ndim # get ndim
if "ndim" in provided: if "ndim" in options:
ndimnode = options["ndim"] ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode): if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser, # Compile-time values (DEF) are currently resolved by the parser,
...@@ -196,6 +196,17 @@ class PostParse(CythonTransform): ...@@ -196,6 +196,17 @@ class PostParse(CythonTransform):
node.ndim = int(ndimnode.value) node.ndim = int(ndimnode.value)
else: else:
node.ndim = 1 node.ndim = 1
if "mode" in options:
modenode = options["mode"]
if not isinstance(modenode, StringNode):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
mode = modenode.value
if not mode in ('full', 'strided'):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
node.mode = mode
else:
node.mode = 'full'
# We're done with the parse tree args # We're done with the parse tree args
node.positional_args = None node.positional_args = None
......
...@@ -196,14 +196,18 @@ class BufferType(BaseType): ...@@ -196,14 +196,18 @@ class BufferType(BaseType):
# dtype PyrexType # dtype PyrexType
# ndim int # ndim int
# mode str
# is_buffer boolean
# writable boolean
is_buffer = 1 is_buffer = 1
writable = True writable = True
def __init__(self, base, dtype, ndim): def __init__(self, base, dtype, ndim, mode):
self.base = base self.base = base
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype) self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode
def as_argument_type(self): def as_argument_type(self):
return self return self
......
...@@ -8,6 +8,8 @@ def f(): ...@@ -8,6 +8,8 @@ def f():
cdef object[ndim=-1] buf2 cdef object[ndim=-1] buf2
cdef object[int, 'a'] buf3 cdef object[int, 'a'] buf3
cdef object[int,2,3,4,5,6] buf4 cdef object[int,2,3,4,5,6] buf4
cdef object[int, 2, 'foo'] buf5
cdef object[int, 2, well] buf6
_ERRORS = u""" _ERRORS = u"""
1:11: Buffer types only allowed as function local variables 1:11: Buffer types only allowed as function local variables
...@@ -17,5 +19,7 @@ _ERRORS = u""" ...@@ -17,5 +19,7 @@ _ERRORS = u"""
8:15: "dtype" missing 8:15: "dtype" missing
9:21: "ndim" must be an integer 9:21: "ndim" must be an integer
10:15: Too many buffer options 10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
""" """
...@@ -477,6 +477,19 @@ def writable(obj): ...@@ -477,6 +477,19 @@ def writable(obj):
cdef object[unsigned short int, 3] buf = obj cdef object[unsigned short int, 3] buf = obj
buf[2, 2, 1] = 23 buf[2, 2, 1] = 23
@testcase
def strided(object[int, 1, 'strided'] buf):
"""
>>> A = IntMockBuffer("A", range(4))
>>> strided(A)
acquired A
released A
2
>>> A.recieved_flags
['FORMAT', 'ND', 'STRIDES']
"""
return buf[2]
# #
# Coercions # Coercions
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment