Commit 90f19619 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffers released at function exit

parent 68e1429c
...@@ -97,12 +97,12 @@ class BufferTransform(CythonTransform): ...@@ -97,12 +97,12 @@ class BufferTransform(CythonTransform):
# For all buffers, insert extra variables in the scope. # For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info # The variables are also accessible from the buffer_info
# on the buffer entry # on the buffer entry
bufvars = [(name, entry) for name, entry bufvars = [entry for name, entry
in scope.entries.iteritems() in scope.entries.iteritems()
if entry.type.is_buffer] if entry.type.is_buffer]
for name, entry in bufvars: for entry in bufvars:
name = entry.name
buftype = entry.type buftype = entry.type
# Get or make a type string checker # Get or make a type string checker
...@@ -130,6 +130,7 @@ class BufferTransform(CythonTransform): ...@@ -130,6 +130,7 @@ class BufferTransform(CythonTransform):
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars, tschecker) shapevars, tschecker)
entry.buffer_aux.temp_var = temp_var entry.buffer_aux.temp_var = temp_var
scope.buffer_entries = bufvars
self.scope = scope self.scope = scope
# Notes: The cast to <char*> gets around Cython not supporting const types # Notes: The cast to <char*> gets around Cython not supporting const types
...@@ -223,10 +224,31 @@ class BufferTransform(CythonTransform): ...@@ -223,10 +224,31 @@ class BufferTransform(CythonTransform):
return result return result
buffer_cleanup_fragment = TreeFragment(u"""
if BUF is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
""")
def funcdef_buffer_cleanup(self, node):
pos = node.pos
env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({
u"BUF" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
})
for entry in node.local_scope.buffer_entries]
cleanup_stats = []
for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats)
cleanup.analyse_expressions(env)
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
node.body = result
return node
# #
# Transforms # Transforms
# #
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope) self.handle_scope(node, node.scope)
self.visitchildren(node) self.visitchildren(node)
...@@ -235,7 +257,7 @@ class BufferTransform(CythonTransform): ...@@ -235,7 +257,7 @@ class BufferTransform(CythonTransform):
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope) self.handle_scope(node, node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return node return self.funcdef_buffer_cleanup(node)
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen: # On assignments, two buffer-related things can happen:
......
...@@ -810,6 +810,13 @@ class NameNode(AtomicExprNode): ...@@ -810,6 +810,13 @@ class NameNode(AtomicExprNode):
is_name = 1 is_name = 1
def create_analysed_rvalue(pos, env, entry):
node = NameNode(pos)
node.analyse_types(env, entry=entry)
return node
create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def compile_time_value(self, denv): def compile_time_value(self, denv):
try: try:
return denv.lookup(self.name) return denv.lookup(self.name)
...@@ -862,8 +869,10 @@ class NameNode(AtomicExprNode): ...@@ -862,8 +869,10 @@ class NameNode(AtomicExprNode):
if self.entry.is_pyglobal and self.entry.is_member: if self.entry.is_pyglobal and self.entry.is_member:
env.use_utility_code(type_cache_invalidation_code) env.use_utility_code(type_cache_invalidation_code)
def analyse_types(self, env): def analyse_types(self, env, entry=None):
self.entry = env.lookup(self.name) if entry is None:
entry = env.lookup(self.name)
self.entry = entry
if not self.entry: if not self.entry:
self.entry = env.declare_builtin(self.name, self.pos) self.entry = env.declare_builtin(self.name, self.pos)
if not self.entry: if not self.entry:
......
...@@ -252,6 +252,11 @@ class StatListNode(Node): ...@@ -252,6 +252,11 @@ class StatListNode(Node):
child_attrs = ["stats"] child_attrs = ["stats"]
def create_analysed(pos, env, *args, **kw):
node = StatListNode(pos, *args, **kw)
return node # No node-specific analysis necesarry
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env): def analyse_control_flow(self, env):
for stat in self.stats: for stat in self.stats:
stat.analyse_control_flow(env) stat.analyse_control_flow(env)
...@@ -3523,6 +3528,12 @@ class TryFinallyStatNode(StatNode): ...@@ -3523,6 +3528,12 @@ class TryFinallyStatNode(StatNode):
# continue in the try block, since we have no problem # continue in the try block, since we have no problem
# handling it. # handling it.
def create_analysed(pos, env, body, finally_clause):
node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause)
node.cleanup_list = []
return node
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env): def analyse_control_flow(self, env):
env.start_branching(self.pos) env.start_branching(self.pos)
self.body.analyse_control_flow(env) self.body.analyse_control_flow(env)
......
cimport __cython__ cimport __cython__
__doc__ = u""" __doc__ = u"""
>>> fb = MockBuffer("=f", "f", [1.0, 1.25, 0.75, 1.0], (2,2)) >>> A = MockBuffer("i", range(10), label="A")
>>> printbuf_float(fb, (2,2)) >>> B = MockBuffer("i", range(10), label="B")
1.0 1.25 >>> E = ErrorBuffer("E")
0.75 1.0 >>> acquire_release(A, B)
acquired A
released A
acquired B
released B
>>> acquire_raise(A)
acquired A
released A
Traceback (most recent call last):
...
Exception: on purpose
>>> printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,))
acquired
1.0 1.25 0.75 1.0
released
>>> printbuf_int_2d(MockBuffer("i", range(6), (2,3)), (2,3))
acquired
0 1 2
3 4 5
released
""" """
def acquire_release(o1, o2):
cdef object[int] buf
buf = o1
buf = o2
def acquire_raise(o):
cdef object[int] buf
buf = o
raise Exception("on purpose")
def printbuf_float(o, shape): def printbuf_float(o, shape):
# should make shape builtin # should make shape builtin
cdef object[float, 2] buf cdef object[float] buf
buf = o
cdef int i, j
for i in range(shape[0]):
print buf[i],
print
def printbuf_int_2d(o, shape):
# should make shape builtin
cdef object[int, 2] buf
buf = o buf = o
cdef int i, j cdef int i, j
for i in range(shape[0]): for i in range(shape[0]):
...@@ -19,9 +57,20 @@ def printbuf_float(o, shape): ...@@ -19,9 +57,20 @@ def printbuf_float(o, shape):
print print
ctypedef char* (*write_func_ptr)(char*, object)
cdef char* write_float(char* buf, object value):
(<float*>buf)[0] = <float>value
return buf + sizeof(float)
cdef char* write_int(char* buf, object value):
(<int*>buf)[0] = <int>value
return buf + sizeof(int)
sizes = { # long can hold a pointer on all target platforms,
'f': sizeof(float) # though really we should have a seperate typedef for this..
# TODO: Should create subclasses of MockBuffer instead.
typemap = {
'f': (sizeof(float), <unsigned long>&write_float),
'i': (sizeof(int), <unsigned long>&write_int)
} }
cimport stdlib cimport stdlib
...@@ -32,14 +81,21 @@ cdef class MockBuffer: ...@@ -32,14 +81,21 @@ cdef class MockBuffer:
cdef int len, itemsize, ndim cdef int len, itemsize, ndim
cdef Py_ssize_t* strides cdef Py_ssize_t* strides
cdef Py_ssize_t* shape cdef Py_ssize_t* shape
cdef write_func_ptr wfunc
cdef object label
def __init__(self, format, typechar, data, shape=None, strides=None): def __init__(self, typechar, data, shape=None, strides=None, format=None, label=None):
self.itemsize = sizes[typechar] self.label = label
if format is None: format = "=%s" % typechar
self.itemsize, x = typemap[typechar]
self.wfunc = <write_func_ptr><unsigned long>x
if shape is None: shape = (len(data),) if shape is None: shape = (len(data),)
if strides is None: if strides is None:
strides = [] strides = []
cumprod = 1 cumprod = 1
for s in shape: rshape = list(shape)
rshape.reverse()
for s in rshape:
strides.append(cumprod) strides.append(cumprod)
cumprod *= s cumprod *= s
strides.reverse() strides.reverse()
...@@ -68,11 +124,32 @@ cdef class MockBuffer: ...@@ -68,11 +124,32 @@ cdef class MockBuffer:
buffer.suboffsets = NULL buffer.suboffsets = NULL
buffer.itemsize = self.itemsize buffer.itemsize = self.itemsize
buffer.internal = NULL buffer.internal = NULL
print "acquired",
if self.label:
print self.label
else:
print
def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
print "released",
if self.label:
print self.label
else:
print
cdef fill_buffer(self, typechar, object data): cdef fill_buffer(self, typechar, object data):
cdef int idx = 0 cdef char* it = self.buffer
for value in data: for value in data:
(<float*>(self.buffer + idx))[0] = <float>value it = self.wfunc(it, value)
idx += sizeof(float)
cdef class ErrorBuffer:
cdef object label
def __init__(self, label):
self.label = label
def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
raise Exception("acquiring %s" % self.label)
def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
raise Exception("releasing %s" % self.label)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment