Commit 1ab79216 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffer access working for builtin numeric types.

parent 79d4bd3d
......@@ -9,94 +9,84 @@ import PyrexTypes
from sets import Set as set
class PureCFuncNode(Node):
def __init__(self, pos, cname, type, c_code):
def __init__(self, pos, cname, type, c_code, visibility='private'):
self.pos = pos
self.cname = cname
self.type = type
self.c_code = c_code
self.visibility = visibility
def analyse_types(self, env):
self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname,
defining=True)
defining=True, visibility=self.visibility)
def generate_function_definitions(self, env, code, transforms):
# TODO: Fix constness, don't hack it
assert self.type.optional_arg_count == 0
visibility = self.entry.visibility
if visibility != 'private':
storage_class = "%s " % Naming.extern_c_macro
else:
storage_class = "static "
arg_decls = [arg.declaration_code() for arg in self.type.args]
sig = self.type.return_type.declaration_code(
self.type.function_header_code(self.cname, ", ".join(arg_decls)))
code.putln("")
code.putln("%s {" % sig)
code.putln("%s%s {" % (storage_class, sig))
code.put(self.c_code)
code.putln("}")
def generate_execution_code(self, code):
pass
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
#
# Entry point
#
def __call__(self, node):
assert isinstance(node, ModuleNode)
cymod = self.context.modules[u'__cython__']
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
result.body.stats += [node for node in self.tscheckers.values()]
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result
def tschecker_simple(self, dtype):
char = dtype.typestring
return """
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch");
return NULL;
} else return ts + 1;
""" % char
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
funcnode = self.tscheckers.get(dtype, None)
if funcnode is None:
assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
builtin = not (dtype.is_struct_or_union or dtype.is_typedef)
if not builtin:
prefix = "user"
else:
prefix = "builtin"
cname = "check_typestring_%s_%s" % (prefix,
dtype.declaration_code("").replace(" ", "_"))
if dtype.typestring is not None and len(dtype.typestring) == 1:
code = self.tschecker_simple(dtype)
else:
assert False
funcnode = PureCFuncNode(self.module_pos, cname,
self.tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.tscheckers[dtype] = funcnode
return funcnode.entry
# Basic operations for transforms
#
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
......@@ -136,17 +126,6 @@ class BufferTransform(CythonTransform):
entry.buffer_aux.temp_var = temp_var
self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
# Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
......@@ -215,27 +194,45 @@ class BufferTransform(CythonTransform):
return result
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def buffer_index(self, node):
pos = node.pos
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(node.pos, operator='*',
to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
# then sum them with the buffer pointer
expr = AttributeNode(pos,
obj=NameNode(pos, name=bufaux.buffer_info_var.name),
attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
}, pos=node.pos)
casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result
return tmp.stats[0].expr
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
......@@ -254,9 +251,6 @@ class BufferTransform(CythonTransform):
else:
return node
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
......@@ -268,3 +262,201 @@ class BufferTransform(CythonTransform):
else:
return node
#
# Utils for creating type string checkers
#
def new_ts_func(self, name, code):
cname = "%s_%s" % (tsprefix, name)
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.ts_funcs.append(funcnode)
return funcnode
def mangle_dtype_name(self, dtype):
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
return dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(self, dtype):
# See if we can consume one (unnamed) dtype as next item
funcnode = self.ts_item_checkers.get(dtype)
if funcnode is None:
char = dtype.typestring
funcnode = self.new_ts_func("item_%s" % self.mangle_dtype_name(dtype), """\
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL;
} else return ts + 1;
""" % char)
self.ts_item_checkers[dtype] = funcnode
return funcnode.entry.cname
ts_consume_whitespace_cname = None
ts_check_endian_cname = None
def ensure_ts_utils(self):
# Makes sure that the typechecker utils are in scope
# (and constructs them if not)
if self.ts_consume_whitespace_cname is None:
self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\
while (1) {
switch (*ts) {
case 10:
case 13:
case ' ':
++ts;
default:
return ts;
}
}
""").entry.cname
if self.ts_check_endian_cname is None:
self.ts_check_endian_cname = self.new_ts_func("check_endian", """\
int ok = 1;
switch (*ts) {
case '@':
case '=':
++ts; break;
case '<':
if (__BYTE_ORDER == __LITTLE_ENDIAN) ++ts;
else ok = 0;
break;
case '>':
case '!':
if (__BYTE_ORDER == __BIG_ENDIAN) ++ts;
else ok = 0;
break;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts);
return NULL;
}
return ts;
""").entry.cname
def create_ts_check_simple(self, dtype):
# Check whole string for single unnamed item
consume_whitespace = self.ts_consume_whitespace_cname
check_endian = self.ts_check_endian_cname
check_item = self.get_ts_check_item(dtype)
return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\
ts = %(consume_whitespace)s(ts);
ts = %(check_endian)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
ts = %(check_item)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
if (*ts != 0) {
PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts);
return NULL;
}
return ts;
""" % locals())
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
self.ensure_ts_utils()
funcnode = self.tscheckers.get(dtype)
if funcnode is None:
assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
if dtype.is_struct_or_union:
assert False
elif dtype.is_typedef:
assert False
else:
funcnode = self.create_ts_check_simple(dtype)
self.tscheckers[dtype] = funcnode
return funcnode.entry
# TODO:
# - buf must be NULL before getting new buffer
## get_buffer_func_type = PyrexTypes.CFuncType(
## PyrexTypes.c_int_type,
## [PyrexTypes.CFuncTypeArg(EncodedString("obj"), PyrexTypes.py_object_type, (0, 0, None), cname="obj"),
## PyrexTypes.CFuncTypeArg(EncodedString("view"), PyrexTypes.c_py_buffer_ptr_type, (0, 0, None), cname="view"),
## PyrexTypes.CFuncTypeArg(EncodedString("flags"), PyrexTypes.c_int_type, (0, 0, None), cname="flags"),
## ],
## exception_value = "-1"
## )
## numpy_get_buffer_body = """
## PyArrayObject *arr = (PyArrayObject*)obj;
## PyArray_Descr *type = (PyArray_Descr*)arr->descr;
## view->buf = arr->data;
## view->readonly = 0; /*fixme*/
## view->format = "B"; /*fixme*/
## view->ndim = arr->nd;
## view->strides = arr->strides;
## view->shape = arr->dimensions;
## view->suboffsets = 0;
## view->itemsize = type->elsize;
## view->internal = 0;
## return 0;
## """
# will be refactored
## code.put("""
## static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
## /* This function is always called after a type-check */
## PyArrayObject *arr = (PyArrayObject*)obj;
## PyArray_Descr *type = (PyArray_Descr*)arr->descr;
## view->buf = arr->data;
## view->readonly = 0; /*fixme*/
## view->format = "B"; /*fixme*/
## view->ndim = arr->nd;
## view->strides = arr->strides;
## view->shape = arr->dimensions;
## view->suboffsets = 0;
## view->itemsize = type->elsize;
## view->internal = 0;
## return 0;
## }
## static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
## }
## """)
## # For now, hard-code numpy imported as "numpy"
## ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
## types = [
## (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
## ]
## # typeptr_cname = ndarrtype.typeptr_cname
## code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
## clause = "if"
## for t, get, release in types:
## code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
## clause = "else if"
## code.putln("else {")
## code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
## code.putln("return -1;")
## code.putln("}")
## code.putln("}")
## code.putln("")
## code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
## clause = "if"
## for t, get, release in types:
## code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
## clause = "else if"
## code.putln("}")
## code.putln("")
......@@ -2839,13 +2839,18 @@ def unop_node(pos, operator, operand):
class TypecastNode(ExprNode):
# C type cast
#
# operand ExprNode
# base_type CBaseTypeNode
# declarator CDeclaratorNode
# operand ExprNode
#
# If used from a transform, one can if wanted specify the attribute
# "type" directly and leave base_type and declarator to None
subexprs = ['operand']
base_type = declarator = type = None
def analyse_types(self, env):
if self.type is None:
base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env)
if self.type.is_cfunction:
......
......@@ -1955,24 +1955,64 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# will be refactored
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check */
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
view->buf = arr->data;
view->readonly = 0; /*fixme*/
view->format = "B"; /*fixme*/
view->ndim = arr->nd;
view->strides = arr->strides;
view->shape = arr->dimensions;
view->suboffsets = 0;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
/*
enum NPY_TYPES { NPY_BOOL=0,
NPY_BYTE, NPY_UBYTE,
NPY_SHORT, NPY_USHORT,
NPY_INT, NPY_UINT,
NPY_LONG, NPY_ULONG,
NPY_LONGLONG, NPY_ULONGLONG,
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
NPY_OBJECT=17,
NPY_STRING, NPY_UNICODE,
NPY_VOID,
NPY_NTYPES,
NPY_NOTYPE,
NPY_CHAR, special flag
NPY_USERDEF=256 leave room for characters
*/
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment