Commit 11f1bc8f authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Initial working support for buffers as function arguments

parent 4fc5eee9
......@@ -100,7 +100,7 @@ class BufferTransform(CythonTransform):
bufvars = [entry for name, entry
in scope.entries.iteritems()
if entry.type.is_buffer]
for entry in bufvars:
name = entry.name
buftype = entry.type
......@@ -133,19 +133,11 @@ class BufferTransform(CythonTransform):
scope.buffer_entries = bufvars
self.scope = scope
# Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
ASSIGN_AUX
LHS = TMP
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
......@@ -154,35 +146,64 @@ class BufferTransform(CythonTransform):
TARGET = BUFINFO.shape[IDX]
""")
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
def acquire_buffer_stats(self, entry, buffer_aux, pos):
# Just the stats for acquiring and unpacking the buffer auxiliaries
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
for idx, strideentry in enumerate(buffer_aux.stridevars):
strideentry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx)),
u"TARGET": NameNode(pos, name=strideentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx)),
})
auxass.append(ass)
auxass += ass.stats
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
for idx, shapeentry in enumerate(buffer_aux.shapevars):
shapeentry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
u"TARGET": NameNode(pos, name=shapeentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
auxass += ass.stats
buffer_aux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"SUBJECT" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
}, pos=pos)
return acq.stats + auxass
def acquire_argument_buffer_stats(self, entry, pos):
# On function entry, not getting a buffer is an uncatchable
# exception, so we don't need to worry about what happens if
# we don't get a buffer.
stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
for s in stats:
s.analyse_declarations(self.scope)
s.analyse_expressions(self.scope)
return stats
# Notes: The cast to <char*> gets around Cython not supporting const types
reacquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
ACQUIRE
LHS = TMP
""")
def reacquire_buffer(self, node):
buffer_aux = node.lhs.entry.buffer_aux
acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
acq = self.reacquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
......@@ -228,21 +249,19 @@ class BufferTransform(CythonTransform):
if BUF is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
""")
def funcdef_buffer_cleanup(self, node):
pos = node.pos
def funcdef_buffer_cleanup(self, node, pos):
env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({
u"BUF" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
})
}, pos=pos)
for entry in node.local_scope.buffer_entries]
cleanup_stats = []
for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats)
cleanup.analyse_expressions(env)
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
node.body = result
node.body = StatListNode.create_analysed(pos, env, stats=[result])
return node
#
......@@ -257,7 +276,13 @@ class BufferTransform(CythonTransform):
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return self.funcdef_buffer_cleanup(node)
node = self.funcdef_buffer_cleanup(node, node.pos)
stats = []
for arg in node.local_scope.arg_entries:
if arg.type.is_buffer:
stats += self.acquire_argument_buffer_stats(arg, node.pos)
node.body.stats = stats + node.body.stats
return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
......
......@@ -204,9 +204,15 @@ class BufferType(BaseType):
self.dtype = dtype
self.ndim = ndim
def as_argument_type(self):
return self
def __getattr__(self, name):
return getattr(self.base, name)
def __repr__(self):
return "<BufferType %r>" % self.base
class PyObjectType(PyrexType):
#
......
......@@ -21,7 +21,11 @@ __doc__ = u"""
>>> A.printlog()
acquired A
released A
>>> print_buffer_as_argument(MockBuffer("i", range(6)), 6)
acquired
0 1 2 3 4 5
released
>>> printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,))
acquired
1.0 1.25 0.75 1.0
......@@ -43,8 +47,16 @@ def acquire_release(o1, o2):
def acquire_raise(o):
cdef object[int] buf
buf = o
print "a"
raise Exception("on purpose")
def print_buffer_as_argument(object[int] bufarg, int n):
cdef int i
for i in range(n):
print bufarg[i],
print
# default values
#
def printbuf_float(o, shape):
# should make shape builtin
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment