Commit 724f5756 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Only define PyObject_GetBuffer etc. if really needed

parent 9f930fcf
...@@ -8,6 +8,87 @@ from Cython.Compiler.Errors import CompileError ...@@ -8,6 +8,87 @@ from Cython.Compiler.Errors import CompileError
import PyrexTypes import PyrexTypes
from sets import Set as set from sets import Set as set
class IntroduceBufferAuxiliaryVars(CythonTransform):
#
# Entry point
#
buffers_exists = False
def __call__(self, node):
assert isinstance(node, ModuleNode)
result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
if self.buffers_exists:
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
use_py2_buffer_functions(node.scope)
node.scope.use_utility_code(buffer_boundsfail_error_utility_code)
return result
#
# Basic operations for transforms
#
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [entry for name, entry
in scope.entries.iteritems()
if entry.type.is_buffer]
if len(bufvars) > 0:
self.buffers_exists = True
if isinstance(node, ModuleNode) and len(bufvars) > 0:
# for now...note that pos is wrong
raise CompileError(node.pos, "Buffer vars not allowed in module scope")
for entry in bufvars:
name = entry.name
buftype = entry.type
# Get or make a type string checker
tschecker = buffer_type_checker(buftype.dtype, scope)
# Declare auxiliary vars
cname = scope.mangle(Naming.bufstruct_prefix, name)
bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
type=PyrexTypes.c_py_buffer_type, pos=node.pos)
bufinfo.used = True
def var(prefix, idx):
cname = scope.mangle(prefix, "%d_%s" % (idx, name))
result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
node.pos, cname=cname, is_cdef=True)
result.init = "0"
if entry.is_arg:
result.used = True
return result
stridevars = [var(Naming.bufstride_prefix, i) for i in range(entry.type.ndim)]
shapevars = [var(Naming.bufshape_prefix, i) for i in range(entry.type.ndim)]
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
scope.buffer_entries = bufvars
self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
def get_flags(buffer_aux, buffer_type): def get_flags(buffer_aux, buffer_type):
flags = 'PyBUF_FORMAT | PyBUF_INDIRECT' flags = 'PyBUF_FORMAT | PyBUF_INDIRECT'
if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE" if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
...@@ -229,129 +310,51 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) { ...@@ -229,129 +310,51 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
"""] """]
class IntroduceBufferAuxiliaryVars(CythonTransform): #
# Utils for creating type string checkers
# #
# Entry point def mangle_dtype_name(dtype):
# # Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
def __call__(self, node): if dtype.typestring is None:
assert isinstance(node, ModuleNode) prefix = "nn_"
self.tscheckers = {} else:
self.tsfuncs = set() prefix = ""
self.ts_funcs = [] return prefix + dtype.declaration_code("").replace(" ", "_")
self.ts_item_checkers = {}
self.module_scope = node.scope def get_ts_check_item(dtype, env):
self.module_pos = node.pos # See if we can consume one (unnamed) dtype as next item
result = super(IntroduceBufferAuxiliaryVars, self).__call__(node) # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
# Register ts stuff name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype)
if "endian.h" not in node.scope.include_files: if not env.has_utility_code(name):
node.scope.include_files.append("endian.h") char = dtype.typestring
result.body.stats += self.ts_funcs if char is not None:
return result
#
# Basic operations for transforms
#
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [entry for name, entry
in scope.entries.iteritems()
if entry.type.is_buffer]
if isinstance(node, ModuleNode) and len(bufvars) > 0:
# for now...note that pos is wrong
raise CompileError(node.pos, "Buffer vars not allowed in module scope")
for entry in bufvars:
name = entry.name
buftype = entry.type
# Get or make a type string checker
tschecker = self.buffer_type_checker(buftype.dtype, scope)
# Declare auxiliary vars
cname = scope.mangle(Naming.bufstruct_prefix, name)
bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
type=PyrexTypes.c_py_buffer_type, pos=node.pos)
bufinfo.used = True
def var(prefix, idx):
cname = scope.mangle(prefix, "%d_%s" % (idx, name))
result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
node.pos, cname=cname, is_cdef=True)
result.init = "0"
if entry.is_arg:
result.used = True
return result
stridevars = [var(Naming.bufstride_prefix, i) for i in range(entry.type.ndim)]
shapevars = [var(Naming.bufshape_prefix, i) for i in range(entry.type.ndim)]
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
scope.buffer_entries = bufvars
self.scope = scope
def visit_ModuleNode(self, node):
node.scope.use_utility_code(buffer_boundsfail_error_utility_code)
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
#
# Utils for creating type string checkers
#
def mangle_dtype_name(self, dtype):
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
if dtype.typestring is None:
prefix = "nn_"
else:
prefix = ""
return prefix + dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(self, dtype, env):
# See if we can consume one (unnamed) dtype as next item
# Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
name = "__Pyx_BufferTypestringCheck_item_%s" % self.mangle_dtype_name(dtype)
funcnode = self.ts_item_checkers.get(dtype)
if not name in self.tsfuncs:
char = dtype.typestring
if char is not None:
# Can use direct comparison # Can use direct comparison
code = """\ code = """\
if (*ts != '%s') { if (*ts != '%s') {
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts); PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL; return NULL;
} else return ts + 1; } else return ts + 1;
""" % char """ % char
else: else:
# Cannot trust declared size; but rely on int vs float and # Cannot trust declared size; but rely on int vs float and
# signed/unsigned to be correctly declared # signed/unsigned to be correctly declared
ctype = dtype.declaration_code("") ctype = dtype.declaration_code("")
code = """\ code = """\
int ok; int ok;
switch (*ts) {""" switch (*ts) {"""
if dtype.is_int: if dtype.is_int:
types = [ types = [
('b', 'char'), ('h', 'short'), ('i', 'int'), ('b', 'char'), ('h', 'short'), ('i', 'int'),
('l', 'long'), ('q', 'long long') ('l', 'long'), ('q', 'long long')
] ]
if dtype.signed == 0: if dtype.signed == 0:
code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" % code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
(char.upper(), ctype, against, ctype) for char, against in types]) (char.upper(), ctype, against, ctype) for char, against in types])
else: else:
code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" % code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
(char, ctype, against, ctype) for char, against in types]) (char, ctype, against, ctype) for char, against in types])
code += """\ code += """\
default: ok = 0; default: ok = 0;
} }
if (!ok) { if (!ok) {
...@@ -359,23 +362,22 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -359,23 +362,22 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
return NULL; return NULL;
} else return ts + 1; } else return ts + 1;
""" """
env.use_utility_code(["""\ env.use_utility_code(["""\
static const char* %s(const char* ts); /*proto*/ static const char* %s(const char* ts); /*proto*/
""" % name, """ """ % name, """
static const char* %s(const char* ts) { static const char* %s(const char* ts) {
%s %s
} }
""" % (name, code)]) """ % (name, code)], name=name)
self.tsfuncs.add(name)
return name return name
def get_ts_check_simple(self, dtype, env): def get_ts_check_simple(dtype, env):
# Check whole string for single unnamed item # Check whole string for single unnamed item
name = "__Pyx_BufferTypestringCheck_simple_%s" % self.mangle_dtype_name(dtype) name = "__Pyx_BufferTypestringCheck_simple_%s" % mangle_dtype_name(dtype)
if not name in self.tsfuncs: if not env.has_utility_code(name):
itemchecker = self.get_ts_check_item(dtype, env) itemchecker = get_ts_check_item(dtype, env)
utilcode = [""" utilcode = ["""
static int %s(Py_buffer* buf, int e_nd); /*proto*/ static int %s(Py_buffer* buf, int e_nd); /*proto*/
""" % name,""" """ % name,"""
static int %(name)s(Py_buffer* buf, int e_nd) { static int %(name)s(Py_buffer* buf, int e_nd) {
...@@ -398,200 +400,133 @@ static int %(name)s(Py_buffer* buf, int e_nd) { ...@@ -398,200 +400,133 @@ static int %(name)s(Py_buffer* buf, int e_nd) {
} }
return 0; return 0;
}""" % locals()] }""" % locals()]
env.use_utility_code(buffer_check_utility_code) env.use_utility_code(buffer_check_utility_code)
env.use_utility_code(utilcode) env.use_utility_code(utilcode, name)
self.tsfuncs.add(name) return name
return name
def buffer_type_checker(dtype, env):
def buffer_type_checker(self, dtype, env): # Creates a type checker function for the given type.
# Creates a type checker function for the given type. if dtype.is_struct_or_union:
# Each checker is created as utility code. However, as each function assert False
# is dynamically constructed we also keep a set self.tsfuncs containing elif dtype.is_int or dtype.is_float:
# the right functions for the types that are already created. # This includes simple typedef-ed types
if dtype.is_struct_or_union: funcname = get_ts_check_simple(dtype, env)
assert False else:
elif dtype.is_int or dtype.is_float: assert False
# This includes simple typedef-ed types return funcname
funcname = self.get_ts_check_simple(dtype, env)
else: def use_py2_buffer_functions(env):
assert False # will be refactored
return funcname try:
env.entries[u'numpy']
env.use_utility_code(["","""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
class BufferTransform(CythonTransform): /* This function is always called after a type-check; safe to cast */
""" PyArrayObject *arr = (PyArrayObject*)obj;
Run after type analysis. Takes care of the buffer functionality. PyArray_Descr *type = (PyArray_Descr*)arr->descr;
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform. int typenum = PyArray_TYPE(obj);
""" if (!PyTypeNum_ISNUMBER(typenum)) {
# Abbreviations: PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
# "ts" means typestring and/or typestring checking stuff return -1;
}
scope = None
#
# Entry point
#
def __call__(self, node):
assert isinstance(node, ModuleNode)
try:
cymod = self.context.modules[u'__cython__']
except KeyError:
# No buffer fun for this module
return node
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result
acquire_buffer_fragment = TreeFragment(u"""
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def acquire_buffer_stats(self, entry, buffer_aux, pos):
# Just the stats for acquiring and unpacking the buffer auxiliaries
auxass = []
for idx, strideentry in enumerate(buffer_aux.stridevars):
strideentry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(pos, name=strideentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx)),
})
auxass += ass.stats
for idx, shapeentry in enumerate(buffer_aux.shapevars):
shapeentry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(pos, name=shapeentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx))
})
auxass += ass.stats
buffer_aux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"SUBJECT" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
}, pos=pos)
return acq.stats + auxass
def acquire_argument_buffer_stats(self, entry, pos):
# On function entry, not getting a buffer is an uncatchable
# exception, so we don't need to worry about what happens if
# we don't get a buffer.
stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
for s in stats:
s.analyse_declarations(self.scope)
#s.analyse_expressions(self.scope)
return stats
# Notes: The cast to <char*> gets around Cython not supporting const types
reacquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
ACQUIRE
LHS = TMP
""")
def reacquire_buffer(self, node):
buffer_aux = node.lhs.entry.buffer_aux
acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
acq = self.reacquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
}, pos=node.pos)
# Preserve first assignment info on LHS
if node.first:
# TODO: Prettier code
acq.stats[4].first = True
del acq.stats[0]
del acq.stats[0]
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
buffer_cleanup_fragment = TreeFragment(u""" /*
if BUF is not None: NumPy format codes doesn't completely match buffer codes;
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO) seems safest to retranslate.
""") 01234567890123456789012345*/
def funcdef_buffer_cleanup(self, node, pos): const char* base_codes = "?bBhHiIlLqQfdgfdgO";
env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({ char* format = (char*)malloc(4);
u"BUF" : NameNode(pos, name=entry.name), char* fp = format;
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name) *fp++ = type->byteorder;
}, pos=pos) if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
for entry in node.local_scope.buffer_entries] *fp++ = base_codes[typenum];
cleanup_stats = [] *fp = 0;
for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats) view->buf = arr->data;
cleanup.analyse_expressions(env) view->readonly = !PyArray_ISWRITEABLE(obj);
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup) view->ndim = PyArray_NDIM(arr);
node.body = StatListNode.create_analysed(pos, env, stats=[result]) view->strides = PyArray_STRIDES(arr);
return node view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
# view->format = format;
# Transforms view->itemsize = type->elsize;
#
view->internal = 0;
def visit_ModuleNode(self, node): return 0;
self.handle_scope(node, node.scope) }
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node): static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
self.handle_scope(node, node.local_scope) free((char*)view->format);
self.visitchildren(node) view->format = NULL;
node = self.funcdef_buffer_cleanup(node, node.pos) }
stats = []
for arg in node.local_scope.arg_entries:
if arg.type.is_buffer:
stats += self.acquire_argument_buffer_stats(arg, node.pos)
node.body.stats = stats + node.body.stats
return node
"""])
except KeyError:
pass
codename = "PyObject_GetBuffer" # just a representative unique key
# Search all types for __getbuffer__ overloads
types = []
def find_buffer_types(scope):
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(env)
# For now, hard-code numpy imported as "numpy"
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code = """
#if PY_VERSION_HEX < 0x02060000
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
"""
if len(types) > 0:
clause = "if"
for t, get, release in types:
code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
clause = "else if"
code += " else {\n"
code += """\
PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
return -1;
"""
if len(types) > 0: code += " }"
code += """
}
# TODO: static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {
# - buf must be NULL before getting new buffer """
if len(types) > 0:
clause = "if"
for t, get, release in types:
if release:
code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
clause = "else if"
code += """
}
#endif
"""
env.use_utility_code(["""\
#if PY_VERSION_HEX < 0x02060000
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);
#endif
""" ,code], codename)
...@@ -260,7 +260,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -260,7 +260,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_module_cleanup_func(env, code) self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code) self.generate_filename_table(code)
self.generate_utility_functions(env, code) self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h) self.generate_declarations_for_modules(env, modules, code.h)
...@@ -441,8 +440,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -441,8 +440,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("") code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif") code.putln("#endif")
code.put(builtin_module_name_utility_code[0]) code.put(builtin_module_name_utility_code[0])
...@@ -1956,106 +1953,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1956,106 +1953,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("") code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# Search all types for __getbuffer__ overloads
types = []
def find_buffer_types(scope):
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(self.scope)
# For now, hard-code numpy imported as "numpy"
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
if release:
code.putln("%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
...@@ -15,6 +15,7 @@ from TypeSlots import \ ...@@ -15,6 +15,7 @@ from TypeSlots import \
get_special_method_signature, get_property_accessor_signature get_special_method_signature, get_property_accessor_signature
import ControlFlow import ControlFlow
import __builtin__ import __builtin__
from sets import Set as set
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
...@@ -626,9 +627,12 @@ class Scope: ...@@ -626,9 +627,12 @@ class Scope:
return [entry for entry in self.temp_entries return [entry for entry in self.temp_entries
if entry not in self.free_temp_entries] if entry not in self.free_temp_entries]
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code) self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code): def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used. # Generate extern decls for C library funcs used.
#if self.pow_function_used: #if self.pow_function_used:
...@@ -748,6 +752,7 @@ class ModuleScope(Scope): ...@@ -748,6 +752,7 @@ class ModuleScope(Scope):
# doc_cname string C name of module doc string # doc_cname string C name of module doc string
# const_counter integer Counter for naming constants # const_counter integer Counter for naming constants
# utility_code_used [string] Utility code to be included # utility_code_used [string] Utility code to be included
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries # default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included # python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included # include_files [string] Other C headers to be included
...@@ -782,6 +787,7 @@ class ModuleScope(Scope): ...@@ -782,6 +787,7 @@ class ModuleScope(Scope):
self.doc_cname = Naming.moddoc_cname self.doc_cname = Naming.moddoc_cname
self.const_counter = 1 self.const_counter = 1
self.utility_code_used = [] self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = [] self.default_entries = []
self.module_entries = {} self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"] self.python_include_files = ["Python.h", "structmember.h"]
...@@ -940,13 +946,25 @@ class ModuleScope(Scope): ...@@ -940,13 +946,25 @@ class ModuleScope(Scope):
self.const_counter = n + 1 self.const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n) return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included, # Add string to list of utility code to be included,
# if not already there (tested using 'is'). # if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used: for old_code in self.utility_code_used:
if old_code is new_code: if old_code is new_code:
return return
self.utility_code_used.append(new_code) self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0, def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None, module_name = None, base_type = None, objstruct_cname = None,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment