Commit 1ab79216 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffer access working for builtin numeric types.

parent 79d4bd3d
...@@ -9,94 +9,84 @@ import PyrexTypes ...@@ -9,94 +9,84 @@ import PyrexTypes
from sets import Set as set from sets import Set as set
class PureCFuncNode(Node): class PureCFuncNode(Node):
def __init__(self, pos, cname, type, c_code): def __init__(self, pos, cname, type, c_code, visibility='private'):
self.pos = pos self.pos = pos
self.cname = cname self.cname = cname
self.type = type self.type = type
self.c_code = c_code self.c_code = c_code
self.visibility = visibility
def analyse_types(self, env): def analyse_types(self, env):
self.entry = env.declare_cfunction( self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname, "<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname, self.type, self.pos, cname=self.cname,
defining=True) defining=True, visibility=self.visibility)
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code, transforms):
# TODO: Fix constness, don't hack it
assert self.type.optional_arg_count == 0 assert self.type.optional_arg_count == 0
visibility = self.entry.visibility
if visibility != 'private':
storage_class = "%s " % Naming.extern_c_macro
else:
storage_class = "static "
arg_decls = [arg.declaration_code() for arg in self.type.args] arg_decls = [arg.declaration_code() for arg in self.type.args]
sig = self.type.return_type.declaration_code( sig = self.type.return_type.declaration_code(
self.type.function_header_code(self.cname, ", ".join(arg_decls))) self.type.function_header_code(self.cname, ", ".join(arg_decls)))
code.putln("") code.putln("")
code.putln("%s {" % sig) code.putln("%s%s {" % (storage_class, sig))
code.put(self.c_code) code.put(self.c_code)
code.putln("}") code.putln("}")
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform): class BufferTransform(CythonTransform):
""" """
Run after type analysis. Takes care of the buffer functionality. Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
""" """
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None scope = None
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type, #
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type, # Entry point
(0, 0, None), cname="ts")], #
exception_value = "NULL"
)
def __call__(self, node): def __call__(self, node):
assert isinstance(node, ModuleNode)
cymod = self.context.modules[u'__cython__'] cymod = self.context.modules[u'__cython__']
self.bufstruct_type = cymod.entries[u'Py_buffer'].type self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {} self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope self.module_scope = node.scope
self.module_pos = node.pos self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node) result = super(BufferTransform, self).__call__(node)
result.body.stats += [node for node in self.tscheckers.values()] # Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result return result
def tschecker_simple(self, dtype):
char = dtype.typestring
return """
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch");
return NULL;
} else return ts + 1;
""" % char
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
# #
# TODO: __eq__ and __hash__ for types # Basic operations for transforms
funcnode = self.tscheckers.get(dtype, None) #
if funcnode is None:
assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
builtin = not (dtype.is_struct_or_union or dtype.is_typedef)
if not builtin:
prefix = "user"
else:
prefix = "builtin"
cname = "check_typestring_%s_%s" % (prefix,
dtype.declaration_code("").replace(" ", "_"))
if dtype.typestring is not None and len(dtype.typestring) == 1:
code = self.tschecker_simple(dtype)
else:
assert False
funcnode = PureCFuncNode(self.module_pos, cname,
self.tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.tscheckers[dtype] = funcnode
return funcnode.entry
def handle_scope(self, node, scope): def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope. # For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info # The variables are also accessible from the buffer_info
...@@ -136,17 +126,6 @@ class BufferTransform(CythonTransform): ...@@ -136,17 +126,6 @@ class BufferTransform(CythonTransform):
entry.buffer_aux.temp_var = temp_var entry.buffer_aux.temp_var = temp_var
self.scope = scope self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
# Notes: The cast to <char*> gets around Cython not supporting const types # Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u""" acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS TMP = LHS
...@@ -215,27 +194,45 @@ class BufferTransform(CythonTransform): ...@@ -215,27 +194,45 @@ class BufferTransform(CythonTransform):
return result return result
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def buffer_index(self, node): def buffer_index(self, node):
pos = node.pos
bufaux = node.base.entry.buffer_aux bufaux = node.base.entry.buffer_aux
assert bufaux is not None assert bufaux is not None
# indices * strides... # indices * strides...
to_sum = [ IntBinopNode(node.pos, operator='*', to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index), operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name)) operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)] for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them # then sum them with the buffer pointer
expr = to_sum[0] expr = AttributeNode(pos,
for next in to_sum[1:]: obj=NameNode(pos, name=bufaux.buffer_info_var.name),
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next) attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({ casted = TypecastNode(pos, operand=expr,
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name), type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
'OFFSET': expr result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
}, pos=node.pos)
return result
return tmp.stats[0].expr
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen: # On assignments, two buffer-related things can happen:
...@@ -254,9 +251,6 @@ class BufferTransform(CythonTransform): ...@@ -254,9 +251,6 @@ class BufferTransform(CythonTransform):
else: else:
return node return node
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node): def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue # Only occurs when the IndexNode is an rvalue
if node.is_buffer_access: if node.is_buffer_access:
...@@ -268,3 +262,201 @@ class BufferTransform(CythonTransform): ...@@ -268,3 +262,201 @@ class BufferTransform(CythonTransform):
else: else:
return node return node
#
# Utils for creating type string checkers
#
def new_ts_func(self, name, code):
cname = "%s_%s" % (tsprefix, name)
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.ts_funcs.append(funcnode)
return funcnode
def mangle_dtype_name(self, dtype):
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
return dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(self, dtype):
# See if we can consume one (unnamed) dtype as next item
funcnode = self.ts_item_checkers.get(dtype)
if funcnode is None:
char = dtype.typestring
funcnode = self.new_ts_func("item_%s" % self.mangle_dtype_name(dtype), """\
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL;
} else return ts + 1;
""" % char)
self.ts_item_checkers[dtype] = funcnode
return funcnode.entry.cname
ts_consume_whitespace_cname = None
ts_check_endian_cname = None
def ensure_ts_utils(self):
# Makes sure that the typechecker utils are in scope
# (and constructs them if not)
if self.ts_consume_whitespace_cname is None:
self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\
while (1) {
switch (*ts) {
case 10:
case 13:
case ' ':
++ts;
default:
return ts;
}
}
""").entry.cname
if self.ts_check_endian_cname is None:
self.ts_check_endian_cname = self.new_ts_func("check_endian", """\
int ok = 1;
switch (*ts) {
case '@':
case '=':
++ts; break;
case '<':
if (__BYTE_ORDER == __LITTLE_ENDIAN) ++ts;
else ok = 0;
break;
case '>':
case '!':
if (__BYTE_ORDER == __BIG_ENDIAN) ++ts;
else ok = 0;
break;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts);
return NULL;
}
return ts;
""").entry.cname
def create_ts_check_simple(self, dtype):
# Check whole string for single unnamed item
consume_whitespace = self.ts_consume_whitespace_cname
check_endian = self.ts_check_endian_cname
check_item = self.get_ts_check_item(dtype)
return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\
ts = %(consume_whitespace)s(ts);
ts = %(check_endian)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
ts = %(check_item)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
if (*ts != 0) {
PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts);
return NULL;
}
return ts;
""" % locals())
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
self.ensure_ts_utils()
funcnode = self.tscheckers.get(dtype)
if funcnode is None:
assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
if dtype.is_struct_or_union:
assert False
elif dtype.is_typedef:
assert False
else:
funcnode = self.create_ts_check_simple(dtype)
self.tscheckers[dtype] = funcnode
return funcnode.entry
# TODO:
# - buf must be NULL before getting new buffer
## get_buffer_func_type = PyrexTypes.CFuncType(
## PyrexTypes.c_int_type,
## [PyrexTypes.CFuncTypeArg(EncodedString("obj"), PyrexTypes.py_object_type, (0, 0, None), cname="obj"),
## PyrexTypes.CFuncTypeArg(EncodedString("view"), PyrexTypes.c_py_buffer_ptr_type, (0, 0, None), cname="view"),
## PyrexTypes.CFuncTypeArg(EncodedString("flags"), PyrexTypes.c_int_type, (0, 0, None), cname="flags"),
## ],
## exception_value = "-1"
## )
## numpy_get_buffer_body = """
## PyArrayObject *arr = (PyArrayObject*)obj;
## PyArray_Descr *type = (PyArray_Descr*)arr->descr;
## view->buf = arr->data;
## view->readonly = 0; /*fixme*/
## view->format = "B"; /*fixme*/
## view->ndim = arr->nd;
## view->strides = arr->strides;
## view->shape = arr->dimensions;
## view->suboffsets = 0;
## view->itemsize = type->elsize;
## view->internal = 0;
## return 0;
## """
# will be refactored
## code.put("""
## static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
## /* This function is always called after a type-check */
## PyArrayObject *arr = (PyArrayObject*)obj;
## PyArray_Descr *type = (PyArray_Descr*)arr->descr;
## view->buf = arr->data;
## view->readonly = 0; /*fixme*/
## view->format = "B"; /*fixme*/
## view->ndim = arr->nd;
## view->strides = arr->strides;
## view->shape = arr->dimensions;
## view->suboffsets = 0;
## view->itemsize = type->elsize;
## view->internal = 0;
## return 0;
## }
## static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
## }
## """)
## # For now, hard-code numpy imported as "numpy"
## ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
## types = [
## (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
## ]
## # typeptr_cname = ndarrtype.typeptr_cname
## code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
## clause = "if"
## for t, get, release in types:
## code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
## clause = "else if"
## code.putln("else {")
## code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
## code.putln("return -1;")
## code.putln("}")
## code.putln("}")
## code.putln("")
## code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
## clause = "if"
## for t, get, release in types:
## code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
## clause = "else if"
## code.putln("}")
## code.putln("")
...@@ -2839,13 +2839,18 @@ def unop_node(pos, operator, operand): ...@@ -2839,13 +2839,18 @@ def unop_node(pos, operator, operand):
class TypecastNode(ExprNode): class TypecastNode(ExprNode):
# C type cast # C type cast
# #
# operand ExprNode
# base_type CBaseTypeNode # base_type CBaseTypeNode
# declarator CDeclaratorNode # declarator CDeclaratorNode
# operand ExprNode #
# If used from a transform, one can if wanted specify the attribute
# "type" directly and leave base_type and declarator to None
subexprs = ['operand'] subexprs = ['operand']
base_type = declarator = type = None
def analyse_types(self, env): def analyse_types(self, env):
if self.type is None:
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env) _, self.type = self.declarator.analyse(base_type, env)
if self.type.is_cfunction: if self.type.is_cfunction:
......
...@@ -1955,24 +1955,64 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1955,24 +1955,64 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# will be refactored # will be refactored
code.put(""" code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) { static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check */ /* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj; PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr; PyArray_Descr *type = (PyArray_Descr*)arr->descr;
view->buf = arr->data;
view->readonly = 0; /*fixme*/
view->format = "B"; /*fixme*/
view->ndim = arr->nd;
view->strides = arr->strides;
view->shape = arr->dimensions;
view->suboffsets = 0;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
/*
enum NPY_TYPES { NPY_BOOL=0,
NPY_BYTE, NPY_UBYTE,
NPY_SHORT, NPY_USHORT,
NPY_INT, NPY_UINT,
NPY_LONG, NPY_ULONG,
NPY_LONGLONG, NPY_ULONGLONG,
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
NPY_OBJECT=17,
NPY_STRING, NPY_UNICODE,
NPY_VOID,
NPY_NTYPES,
NPY_NOTYPE,
NPY_CHAR, special flag
NPY_USERDEF=256 leave room for characters
*/
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize; view->itemsize = type->elsize;
view->internal = 0; view->internal = 0;
return 0; return 0;
} }
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
} }
""") """)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment