Commit 11f1bc8f authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Initial working support for buffers as function arguments

parent 4fc5eee9
...@@ -133,19 +133,11 @@ class BufferTransform(CythonTransform): ...@@ -133,19 +133,11 @@ class BufferTransform(CythonTransform):
scope.buffer_entries = bufvars scope.buffer_entries = bufvars
self.scope = scope self.scope = scope
# Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u""" acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS __cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format) TSCHECKER(<char*>BUFINFO.format)
ASSIGN_AUX
LHS = TMP
""") """)
fetch_strides = TreeFragment(u""" fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX] TARGET = BUFINFO.strides[IDX]
""") """)
...@@ -154,35 +146,64 @@ class BufferTransform(CythonTransform): ...@@ -154,35 +146,64 @@ class BufferTransform(CythonTransform):
TARGET = BUFINFO.shape[IDX] TARGET = BUFINFO.shape[IDX]
""") """)
def reacquire_buffer(self, node): def acquire_buffer_stats(self, entry, buffer_aux, pos):
bufaux = node.lhs.entry.buffer_aux # Just the stats for acquiring and unpacking the buffer auxiliaries
auxass = [] auxass = []
for idx, entry in enumerate(bufaux.stridevars): for idx, strideentry in enumerate(buffer_aux.stridevars):
entry.used = True strideentry.used = True
ass = self.fetch_strides.substitute({ ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name), u"TARGET": NameNode(pos, name=strideentry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx)), u"IDX": IntNode(pos, value=EncodedString(idx)),
}) })
auxass.append(ass) auxass += ass.stats
for idx, entry in enumerate(bufaux.shapevars): for idx, shapeentry in enumerate(buffer_aux.shapevars):
entry.used = True shapeentry.used = True
ass = self.fetch_shape.substitute({ ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name), u"TARGET": NameNode(pos, name=shapeentry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx)) u"IDX": IntNode(pos, value=EncodedString(idx))
}) })
auxass.append(ass) auxass += ass.stats
buffer_aux.buffer_info_var.used = True
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({ acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name), u"SUBJECT" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
}, pos=pos)
return acq.stats + auxass
def acquire_argument_buffer_stats(self, entry, pos):
# On function entry, not getting a buffer is an uncatchable
# exception, so we don't need to worry about what happens if
# we don't get a buffer.
stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
for s in stats:
s.analyse_declarations(self.scope)
s.analyse_expressions(self.scope)
return stats
# Notes: The cast to <char*> gets around Cython not supporting const types
reacquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
ACQUIRE
LHS = TMP
""")
def reacquire_buffer(self, node):
buffer_aux = node.lhs.entry.buffer_aux
acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
acq = self.reacquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
u"LHS" : node.lhs, u"LHS" : node.lhs,
u"RHS": node.rhs, u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name), u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
}, pos=node.pos) }, pos=node.pos)
# Note: The below should probably be refactored into something # Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with # like fragment.substitute(..., context=self.context), with
...@@ -228,21 +249,19 @@ class BufferTransform(CythonTransform): ...@@ -228,21 +249,19 @@ class BufferTransform(CythonTransform):
if BUF is not None: if BUF is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO) __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
""") """)
def funcdef_buffer_cleanup(self, node): def funcdef_buffer_cleanup(self, node, pos):
pos = node.pos
env = node.local_scope env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({ cleanups = [self.buffer_cleanup_fragment.substitute({
u"BUF" : NameNode(pos, name=entry.name), u"BUF" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name) u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
}) }, pos=pos)
for entry in node.local_scope.buffer_entries] for entry in node.local_scope.buffer_entries]
cleanup_stats = [] cleanup_stats = []
for c in cleanups: cleanup_stats += c.stats for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats) cleanup = StatListNode(pos, stats=cleanup_stats)
cleanup.analyse_expressions(env) cleanup.analyse_expressions(env)
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup) result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
node.body = result node.body = StatListNode.create_analysed(pos, env, stats=[result])
return node return node
# #
...@@ -257,7 +276,13 @@ class BufferTransform(CythonTransform): ...@@ -257,7 +276,13 @@ class BufferTransform(CythonTransform):
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope) self.handle_scope(node, node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return self.funcdef_buffer_cleanup(node) node = self.funcdef_buffer_cleanup(node, node.pos)
stats = []
for arg in node.local_scope.arg_entries:
if arg.type.is_buffer:
stats += self.acquire_argument_buffer_stats(arg, node.pos)
node.body.stats = stats + node.body.stats
return node
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen: # On assignments, two buffer-related things can happen:
......
...@@ -204,9 +204,15 @@ class BufferType(BaseType): ...@@ -204,9 +204,15 @@ class BufferType(BaseType):
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
def as_argument_type(self):
return self
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.base, name) return getattr(self.base, name)
def __repr__(self):
return "<BufferType %r>" % self.base
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
......
...@@ -22,6 +22,10 @@ __doc__ = u""" ...@@ -22,6 +22,10 @@ __doc__ = u"""
acquired A acquired A
released A released A
>>> print_buffer_as_argument(MockBuffer("i", range(6)), 6)
acquired
0 1 2 3 4 5
released
>>> printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,)) >>> printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,))
acquired acquired
1.0 1.25 0.75 1.0 1.0 1.25 0.75 1.0
...@@ -43,9 +47,17 @@ def acquire_release(o1, o2): ...@@ -43,9 +47,17 @@ def acquire_release(o1, o2):
def acquire_raise(o): def acquire_raise(o):
cdef object[int] buf cdef object[int] buf
buf = o buf = o
print "a"
raise Exception("on purpose") raise Exception("on purpose")
def print_buffer_as_argument(object[int] bufarg, int n):
cdef int i
for i in range(n):
print bufarg[i],
print
# default values
#
def printbuf_float(o, shape): def printbuf_float(o, shape):
# should make shape builtin # should make shape builtin
cdef object[float] buf cdef object[float] buf
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment