Commit 07a58cc3 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn Committed by Mark Florisson

use_utility_code vs. memoryviews fixup

parent 4fce6db0
......@@ -527,87 +527,111 @@ def buf_lookup_fortran_code(proto, defin, name, nd):
def use_py2_buffer_functions(env):
env.use_utility_code(GetAndReleaseBufferUtilityCode())
class GetAndReleaseBufferUtilityCode(object):
# Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
# For >= 2.6 we do double mode -- use the new buffer interface on objects
# which has the right tp_flags set, but emulation otherwise.
# Search all types for __getbuffer__ overloads
types = []
visited_scopes = set()
def find_buffer_types(scope):
if scope in visited_scopes:
return
visited_scopes.add(scope)
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
if e.name == 'array' and not e.used:
continue
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(env)
requires = None
code = dedent("""
#if PY_MAJOR_VERSION < 3
static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
#if PY_VERSION_HEX >= 0x02060000
if (PyObject_CheckBuffer(obj)) return PyObject_GetBuffer(obj, view, flags);
#endif
""")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
clause = "else if"
code += " else {\n"
code += dedent("""\
PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
return -1;
""", 2)
if len(types) > 0: code += " }"
code += dedent("""
}
def __init__(self):
pass
def __eq__(self, other):
return isinstance(other, GetAndReleaseBufferUtilityCode)
def __hash__(self):
return 24342342
def get_tree(self): pass
static void __Pyx_ReleaseBuffer(Py_buffer *view) {
PyObject* obj = view->obj;
if (obj) {
#if PY_VERSION_HEX >= 0x02060000
if (PyObject_CheckBuffer(obj)) {PyBuffer_Release(view); return;}
def put_code(self, output):
code = output['utility_code_def']
proto = output['utility_code_proto']
env = output.module_node.scope
cython_scope = env.context.cython_scope
proto.put(dedent("""\
#if PY_MAJOR_VERSION < 3
static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
static void __Pyx_ReleaseBuffer(Py_buffer *view);
#else
#define __Pyx_GetBuffer PyObject_GetBuffer
#define __Pyx_ReleaseBuffer PyBuffer_Release
#endif
""")
if len(types) > 0:
clause = "if"
for t, get, release in types:
if release:
code += " "
code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
"""))
# Search all types for __getbuffer__ overloads
types = []
visited_scopes = set()
def find_buffer_types(scope):
if scope in visited_scopes:
return
visited_scopes.add(scope)
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
if scope is cython_scope and not e.used:
continue
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(env)
code.put(dedent("""
#if PY_MAJOR_VERSION < 3
static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
#if PY_VERSION_HEX >= 0x02060000
if (PyObject_CheckBuffer(obj)) return PyObject_GetBuffer(obj, view, flags);
#endif
"""))
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln(" %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code += dedent("""
Py_DECREF(obj);
view->obj = NULL;
}
}
code.putln(" else {")
code.put(dedent("""\
PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
return -1;
""", 2))
if len(types) > 0:
code.putln(" }")
code.put(dedent("""\
}
static void __Pyx_ReleaseBuffer(Py_buffer *view) {
PyObject* obj = view->obj;
if (obj) {
#if PY_VERSION_HEX >= 0x02060000
if (PyObject_CheckBuffer(obj)) {PyBuffer_Release(view); return;}
#endif
"""))
if len(types) > 0:
clause = "if"
for t, get, release in types:
if release:
code.putln("%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.put(dedent("""
Py_DECREF(obj);
view->obj = NULL;
}
}
#endif
""")
#endif
"""))
env.use_utility_code(UtilityCode(
proto = dedent("""\
#if PY_MAJOR_VERSION < 3
static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
static void __Pyx_ReleaseBuffer(Py_buffer *view);
#else
#define __Pyx_GetBuffer PyObject_GetBuffer
#define __Pyx_ReleaseBuffer PyBuffer_Release
#endif
"""), impl = code))
def mangle_dtype_name(dtype):
......@@ -783,15 +807,6 @@ typedef struct {
size_t parent_offset;
} __Pyx_BufFmt_StackElem;
static CYTHON_INLINE int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack);
static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
""", impl="""
static CYTHON_INLINE int __Pyx_IsLittleEndian(void) {
unsigned int n = 1;
return *(unsigned char*)(&n) != 0;
}
typedef struct {
__Pyx_StructField root;
__Pyx_BufFmt_StackElem* head;
......@@ -804,10 +819,11 @@ typedef struct {
} __Pyx_BufFmt_Context;
static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
static int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack);
static CYTHON_INLINE int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack);
static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
""", impl="""
static INLINE int __Pyx_IsLittleEndian(void) {
static CYTHON_INLINE int __Pyx_IsLittleEndian(void) {
unsigned int n = 1;
return *(unsigned char*)(&n) != 0;
}
......
......@@ -500,7 +500,7 @@ class GlobalState(object):
]
def __init__(self, writer, emit_linenums=False):
def __init__(self, writer, module_node, emit_linenums=False):
self.filename_table = {}
self.filename_list = []
self.input_file_contents = {}
......@@ -509,6 +509,8 @@ class GlobalState(object):
self.in_utility_code_generation = False
self.emit_linenums = emit_linenums
self.parts = {}
self.module_node = module_node # because some utility code generation needs it
# (generating backwards-compatible Get/ReleaseBuffer
self.const_cname_counter = 1
self.string_const_index = {}
......
......@@ -4,6 +4,7 @@ from UtilityCode import CythonUtilityCode
from Errors import error
from Scanning import StringSourceDescriptor
import Options
import Buffer
class CythonScope(ModuleScope):
is_cython_builtin = 1
......@@ -239,7 +240,8 @@ memview_name = u'memoryview'
memview_typeptr_cname = Naming.typeptr_prefix+memview_name
memview_typeobj_cname = '__pyx_tobj_'+memview_name
memview_objstruct_cname = '__pyx_obj_'+memview_name
view_utility_code = CythonUtilityCode(u"""
view_utility_code = CythonUtilityCode(
u"""
cdef class Enum(object):
cdef object name
def __init__(self, name):
......@@ -274,7 +276,9 @@ cdef class memoryview(object):
cdef memoryview memoryview_cwrapper(object o, int flags):
return memoryview(o, flags)
""", name="view_code", prefix="__pyx_viewaxis_")
""", name="view_code",
prefix="__pyx_viewaxis_",
requires=(Buffer.GetAndReleaseBufferUtilityCode(),))
cyarray_prefix = u'__pyx_cythonarray_'
cython_array_utility_code = CythonUtilityCode(u'''
......
......@@ -117,31 +117,18 @@ def get_buf_flag(specs):
else:
return memview_strided_access
def use_cython_view_util_code(env, lu_name):
import CythonScope
cythonscope = env.global_scope().context.cython_scope
viewscope = cythonscope.viewscope
entry = viewscope.lookup_here(lu_name)
entry.used = 1
return entry
def use_cython_util_code(env, lu_name):
import CythonScope
cythonscope = env.global_scope().context.cython_scope
entry = cythonscope.lookup_here(lu_name)
entry.used = 1
return entry
def use_memview_util_code(env):
import CythonScope
return use_cython_view_util_code(env, CythonScope.memview_name)
env.use_utility_code(CythonScope.view_utility_code)
env.use_utility_code(memviewslice_declare_code)
def use_memview_cwrap(env):
import CythonScope
return use_cython_view_util_code(env, CythonScope.memview_cwrap_name)
env.use_utility_code(CythonScope.view_utility_code)
def use_cython_array(env):
return use_cython_util_code(env, 'array')
import CythonScope
env.use_utility_code(CythonScope.cython_array_utility_code)
def src_conforms_to_dst(src, dst):
'''
......@@ -318,7 +305,7 @@ def memoryviewslice_get_copy_func(from_memview, to_memview, mode, scope):
copy_contents_name = get_copy_contents_name(from_memview, to_memview)
scope.declare_cfunction(cython_name,
entry = scope.declare_cfunction(cython_name,
CFuncType(from_memview,
[CFuncTypeArg("memviewslice", from_memview, None)]),
pos = None,
......@@ -335,7 +322,7 @@ def memoryviewslice_get_copy_func(from_memview, to_memview, mode, scope):
copy_decl = ("static __Pyx_memviewslice "
"%s(const __Pyx_memviewslice); /* proto */\n" % (copy_name,))
return (copy_decl, copy_impl)
return (copy_decl, copy_impl, entry)
def get_copy_contents_func(from_mvs, to_mvs, cfunc_name):
assert from_mvs.dtype == to_mvs.dtype
......
......@@ -126,7 +126,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if (h_types or h_vars or h_funcs or h_extension_types):
result.h_file = replace_suffix(result.c_file, ".h")
h_code = Code.CCodeWriter()
Code.GlobalState(h_code)
Code.GlobalState(h_code, self)
if options.generate_pxi:
result.i_file = replace_suffix(result.c_file, ".pxi")
i_code = Code.PyrexCodeWriter(result.i_file)
......@@ -195,7 +195,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if api_vars or api_funcs or api_extension_types:
result.api_file = replace_suffix(result.c_file, "_api.h")
h_code = Code.CCodeWriter()
Code.GlobalState(h_code)
Code.GlobalState(h_code, self)
api_guard = Naming.api_guard_prefix + self.api_name(env)
h_code.put_h_guard(api_guard)
h_code.putln('#include "Python.h"')
......@@ -293,7 +293,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else:
emit_linenums = options.emit_linenums
rootwriter = Code.CCodeWriter(emit_linenums=emit_linenums, c_line_in_traceback=options.c_line_in_traceback)
globalstate = Code.GlobalState(rootwriter, emit_linenums)
globalstate = Code.GlobalState(rootwriter, self, emit_linenums)
globalstate.initialize_main_c_code()
h_code = globalstate['h_code']
......
......@@ -822,8 +822,11 @@ class MemoryViewSliceTypeNode(CBaseTypeNode):
self.type = PyrexTypes.ErrorType()
return self.type
self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs, env)
self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs)
MemoryView.use_memview_util_code(env)
MemoryView.use_cython_array(env)
MemoryView.use_memview_util_code(env)
env.use_utility_code(MemoryView.memviewslice_declare_code)
return self.type
class CNestedBaseTypeNode(CBaseTypeNode):
......
......@@ -4,6 +4,7 @@ from time import time
import Errors
import DebugFlags
import Options
from Visitor import CythonTransform
from Errors import PyrexError, CompileError, InternalError, AbortError, error
#
......@@ -78,16 +79,39 @@ def inject_utility_code_stage_factory(context):
added = []
# Note: the list might be extended inside the loop (if some utility code
# pulls in other utility code)
# pulls in other utility code, explicitly or implicitly)
for utilcode in module_node.scope.utility_code_list:
if utilcode in added: continue
added.append(utilcode)
if utilcode.requires:
for dep in utilcode.requires:
if not dep in added and not dep in module_node.scope.utility_code_list:
module_node.scope.utility_code_list.append(dep)
tree = utilcode.get_tree()
if tree:
module_node.merge_in(tree.body, tree.scope, merge_scope=True)
return module_node
return inject_utility_code_stage
#class UseUtilityCodeDefinitions(CythonTransform):
# # Temporary hack to use any utility code in nodes' "utility_code_definitions".
# # This should be moved to the code generation phase of the relevant nodes once
# # it is safe to generate CythonUtilityCode at code generation time.
# def __call__(self, node):
# self.scope = node.scope
# return super(UseUtilityCodeDefinitions, self).__call__(node)
#
# def visit_AttributeNode(self, node):
# if node.entry and node.entry.utility_code_definition:
# self.scope.use_utility_code(node.entry.utility_code_definition)
# return node
#
# def visit_NameNode(self, node):
# for e in (node.entry, node.type_entry):
# if e and e.utility_code_definition:
# self.scope.use_utility_code(e.utility_code_definition)
# return node
#
# Pipeline factories
#
......@@ -167,6 +191,7 @@ def create_pipeline(context, mode, exclude_classes=()):
DropRefcountingTransform(),
FinalOptimizePhase(context),
GilCheck(),
# UseUtilityCodeDefinitions(context),
]
filtered_stages = []
for s in stages:
......
......@@ -320,7 +320,7 @@ class MemoryViewSliceType(PyrexType):
has_attributes = 1
scope = None
def __init__(self, base_dtype, axes, env):
def __init__(self, base_dtype, axes):
'''
MemoryViewSliceType(base, axes)
......@@ -357,7 +357,6 @@ class MemoryViewSliceType(PyrexType):
self.dtype = base_dtype
self.axes = axes
self.env = env
import MemoryView
self.is_c_contig, self.is_f_contig = MemoryView.is_cf_contig(self.axes)
......@@ -373,8 +372,6 @@ class MemoryViewSliceType(PyrexType):
assert not pyrex
assert not dll_linkage
import MemoryView
if not for_display:
self.env.use_utility_code(MemoryView.memviewslice_declare_code)
return self.base_declaration_code(
MemoryView.memviewslice_cname,
entity_code)
......@@ -384,11 +381,10 @@ class MemoryViewSliceType(PyrexType):
import Symtab, MemoryView
from MemoryView import axes_to_str
self.scope = scope = Symtab.CClassScope(
'mvs_class_'+self.specialization_suffix(),
self.env.global_scope(),
visibility='private')
None,
visibility='extern')
scope.parent_type = self
......@@ -403,17 +399,17 @@ class MemoryViewSliceType(PyrexType):
to_axes_c = [('direct', 'follow')]*(ndim-1) + to_axes_c
to_axes_f = to_axes_f + [('direct', 'follow')]*(ndim-1)
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c, self.env)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f, self.env)
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f)
copy_contents_name_c =\
MemoryView.get_copy_contents_name(self, to_memview_c)
copy_contents_name_f =\
MemoryView.get_copy_contents_name(self, to_memview_f)
c_copy_decl, c_copy_impl = \
c_copy_decl, c_copy_impl, c_entry = \
MemoryView.memoryviewslice_get_copy_func(self, to_memview_c, 'c', self.scope)
f_copy_decl, f_copy_impl = \
f_copy_decl, f_copy_impl, f_entry = \
MemoryView.memoryviewslice_get_copy_func(self, to_memview_f, 'fortran', self.scope)
c_copy_contents_decl, c_copy_contents_impl = \
......@@ -430,19 +426,13 @@ class MemoryViewSliceType(PyrexType):
proto = f_copy_decl,
impl = f_copy_impl)
c_entry.utility_code_definition = c_util_code
f_entry.utility_code_definition = f_util_code
if copy_contents_name_c != copy_contents_name_f:
f_util_code.proto += f_copy_contents_decl
f_util_code.impl += f_copy_contents_impl
c_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if c_util_code.proto == util_code.proto]
f_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if f_util_code.proto == util_code.proto]
if not c_copy_used:
self.env.use_utility_code(c_util_code)
if not f_copy_used:
self.env.use_utility_code(f_util_code)
# is_c_contiguous and is_f_contiguous functions
for c_or_f, cython_name in (('c', 'is_c_contig'), ('fortran', 'is_f_contig')):
......@@ -460,17 +450,6 @@ class MemoryViewSliceType(PyrexType):
contig_util_code = UtilityCode(
proto = contig_decl, impl = contig_impl)
contig_used = [1 for util_code in \
self.env.global_scope().utility_code_list \
if contig_decl == util_code.proto]
if not contig_used:
self.env.use_utility_code(contig_util_code)
# use the supporting util code
MemoryView.use_cython_array(self.env)
MemoryView.use_memview_util_code(self.env)
return True
def specialization_suffix(self):
......
......@@ -124,6 +124,8 @@ class Entry(object):
# used uninitialized
# cf_used boolean Entry is used
# TODO: utility_code and utility_code_definition serves the same purpose...
inline_func_in_pxd = False
borrowed = 0
init = ""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment