Commit 72d54fb4 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

PS: non-working state. Buffer access able to run fully in some very restricted cases

parent 6f0bc35a
......@@ -1275,36 +1275,59 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
self.is_buffer_access = False
self.base.analyse_types(env)
self.index.analyse_types(env)
if self.base.type.is_pyobject:
if self.index.type.is_int:
self.original_index_type = self.index.type
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
if getting:
env.use_utility_code(getitem_int_utility_code)
if setting:
env.use_utility_code(setitem_int_utility_code)
if self.base.type.buffer_options is not None:
if isinstance(self.index, TupleNode):
indices = self.index.args
# is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int])
else:
self.index = self.index.coerce_to_pyobject(env)
self.type = py_object_type
self.gil_check(env)
self.is_temp = 1
else:
if self.base.type.is_ptr or self.base.type.is_array:
self.type = self.base.type.base_type
# is_int_indices = self.index.type.is_int
indices = [self.index]
all_ints = True
for index in indices:
index.analyse_types(env)
if not index.type.is_int:
all_ints = False
if all_ints:
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.is_buffer_access = True
if not self.is_buffer_access:
self.index.analyse_types(env) # ok to analyse as tuple
if self.base.type.is_pyobject:
if self.index.type.is_int:
self.original_index_type = self.index.type
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
if getting:
env.use_utility_code(getitem_int_utility_code)
if setting:
env.use_utility_code(setitem_int_utility_code)
else:
self.index = self.index.coerce_to_pyobject(env)
self.type = py_object_type
self.gil_check(env)
self.is_temp = 1
else:
error(self.pos,
"Attempting to index non-array type '%s'" %
self.base.type)
self.type = PyrexTypes.error_type
if self.index.type.is_pyobject:
self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
if not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
if self.base.type.is_ptr or self.base.type.is_array:
self.type = self.base.type.base_type
else:
error(self.pos,
"Attempting to index non-array type '%s'" %
self.base.type)
self.type = PyrexTypes.error_type
if self.index.type.is_pyobject:
self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
if not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
gil_message = "Indexing Python object"
......@@ -1330,11 +1353,17 @@ class IndexNode(ExprNode):
def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code)
self.index.generate_evaluation_code(code)
if self.index is not None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices: i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code)
self.index.generate_disposal_code(code)
if self.index is not None:
self.index.generate_disposal_code(code)
else:
for i in self.indices: i.generate_disposal_code(code)
def generate_result_code(self, code):
if self.type.is_pyobject:
......
......@@ -354,7 +354,7 @@ def create_generate_code(context, options, result):
return generate_code
def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, BufferTransform
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from ModuleNode import check_c_classes
......@@ -367,6 +367,7 @@ def create_default_pipeline(context, options, result):
AnalyseDeclarationsTransform(context),
check_c_classes,
AnalyseExpressionsTransform(context),
BufferTransform(context),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
......
......@@ -259,6 +259,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code)
self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h)
......@@ -438,6 +439,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif")
code.put(builtin_module_name_utility_code[0])
......@@ -1945,6 +1949,60 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.h.put(utility_code[0])
code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions)
code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
view->buf = arr->data;
view->readonly = 0; /*fixme*/
view->format = "B"; /*fixme*/
view->ndim = arr->nd;
view->strides = arr->strides;
view->shape = arr->dimensions;
view->suboffsets = 0;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
}
""")
# For now, hard-code numpy imported as "numpy"
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types = [
(ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
]
# typeptr_cname = ndarrtype.typeptr_cname
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
#------------------------------------------------------------------------------------
#
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
......@@ -137,12 +138,177 @@ class PostParse(CythonTransform):
if ndim_value < 0:
raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
node.ndim = int(ndimnode.value)
else:
node.ndim = 1
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
return node
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
"""
scope = None
def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
for name, entry in bufvars:
# Variable has buffer opts, declare auxiliary vars
bufopts = entry.type.buffer_options
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.buffer_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
# ass = SingleAssignmentNode(pos=node.pos,
# lhs=NameNode(node.pos, name=entry.name),
# rhs=IndexNode(node.pos,
# base=AttributeNode(node.pos,
# obj=NameNode(node.pos, name=bufaux.buffer_info_var.name),
# attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump()
def visit_SingleAssignmentNode(self, node):
self.visitchildren(node)
bufaux = node.lhs.entry.buffer_aux
if bufaux is not None:
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
# stats += [node] # Do assignment after successful buffer acquisition
# print acq.dump()
return stats
else:
return node
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index,
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
print to_sum
indices = node.indices
# reduce * on indices
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
})
tmp.analyse_expressions(self.scope)
return tmp.stats[0].expr
else:
return node
def visit_CallNode(self, node):
### print node.dump()
return node
# def visit_FuncDefNode(self, node):
# print node.dump()
class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains
......
......@@ -6,6 +6,22 @@ from Cython import Utils
import Naming
import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType:
#
# Base class for all Pyrex types including pseudo-types.
......@@ -93,6 +109,7 @@ class PyrexType(BaseType):
default_value = ""
parsetuple_format = ""
pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self):
# If a typedef, returns the base type.
......@@ -184,21 +201,6 @@ class CTypedefType(BaseType):
def __getattr__(self, name):
return getattr(self.typedef_base_type, name)
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class PyObjectType(PyrexType):
#
# Base class for all Python object types (reference-counted).
......@@ -208,7 +210,6 @@ class PyObjectType(PyrexType):
default_value = "0"
parsetuple_format = "O"
pymemberdef_typecode = "T_OBJECT"
buffer_options = None # can contain a BufferOptions instance
def __str__(self):
return "Python object"
......
......@@ -19,6 +19,14 @@ import __builtin__
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux:
def __init__(self, buffer_info_var, stridevars, shapevars):
self.buffer_info_var = buffer_info_var
self.stridevars = stridevars
self.shapevars = shapevars
def __repr__(self):
return "<BufferAux %r>" % self.__dict__
class Entry:
# A symbol table entry in a Scope or ModuleNamespace.
#
......@@ -76,6 +84,8 @@ class Entry:
# defined_in_pxd boolean Is defined in a .pxd file (not just declared)
# api boolean Generate C API for C class or function
# utility_code string Utility code needed when this entry is used
#
# buffer_aux BufferAux or None Extra information needed for buffer variables
borrowed = 0
init = ""
......@@ -117,6 +127,7 @@ class Entry:
api = 0
utility_code = None
is_overridable = 0
buffer_aux = None
def __init__(self, name, cname, type, pos = None, init = None):
self.name = name
......
cdef extern from "Python.h":
ctypedef struct PyObject
ctypedef struct Py_buffer:
void *buf
Py_ssize_t len
int readonly
char *format
int ndim
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
Py_ssize_t itemsize
void *internal
int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
# int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
# int flags)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment