Commit 724f5756 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Only define PyObject_GetBuffer etc. if really needed

parent 9f930fcf
This diff is collapsed.
......@@ -260,7 +260,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code)
self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h)
......@@ -441,8 +440,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif")
code.put(builtin_module_name_utility_code[0])
......@@ -1956,106 +1953,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.put(PyrexTypes.type_conversion_functions)
code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# Search all types for __getbuffer__ overloads
types = []
def find_buffer_types(scope):
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(self.scope)
# For now, hard-code numpy imported as "numpy"
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
if release:
code.putln("%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------
#
# Runtime support code
......
......@@ -15,6 +15,7 @@ from TypeSlots import \
get_special_method_signature, get_property_accessor_signature
import ControlFlow
import __builtin__
from sets import Set as set
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
......@@ -626,8 +627,11 @@ class Scope:
return [entry for entry in self.temp_entries
if entry not in self.free_temp_entries]
def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code)
def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used.
......@@ -748,6 +752,7 @@ class ModuleScope(Scope):
# doc_cname string C name of module doc string
# const_counter integer Counter for naming constants
# utility_code_used [string] Utility code to be included
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included
......@@ -782,6 +787,7 @@ class ModuleScope(Scope):
self.doc_cname = Naming.moddoc_cname
self.const_counter = 1
self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = []
self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"]
......@@ -940,13 +946,25 @@ class ModuleScope(Scope):
self.const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code):
def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included,
# if not already there (tested using 'is').
# if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used:
if old_code is new_code:
return
self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment