Commit 3da50c6d authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffer assignment appears to be working

parent c5ba9581
......@@ -5,18 +5,97 @@ from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
import PyrexTypes
from sets import Set as set
class PureCFuncNode(Node):
def __init__(self, pos, cname, type, c_code):
self.pos = pos
self.cname = cname
self.type = type
self.c_code = c_code
def analyse_types(self, env):
self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname,
defining=True)
def generate_function_definitions(self, env, code, transforms):
# TODO: Fix constness, don't hack it
assert self.type.optional_arg_count == 0
arg_decls = [arg.declaration_code() for arg in self.type.args]
sig = self.type.return_type.declaration_code(
self.type.function_header_code(self.cname, ", ".join(arg_decls)))
code.putln("")
code.putln("%s {" % sig)
code.put(self.c_code)
code.putln("}")
def generate_execution_code(self, code):
pass
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
"""
scope = None
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
return super(BufferTransform, self).__call__(node)
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
result.body.stats += [node for node in self.tscheckers.values()]
return result
def tschecker_simple(self, dtype):
char = dtype.typestring
return """
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch");
return NULL;
} else return ts + 1;
""" % char
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
funcnode = self.tscheckers.get(dtype, None)
if funcnode is None:
assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
builtin = not (dtype.is_struct_or_union or dtype.is_typedef)
if not builtin:
prefix = "user"
else:
prefix = "builtin"
cname = "check_typestring_%s_%s" % (prefix,
dtype.declaration_code("").replace(" ", "_"))
if dtype.typestring is not None and len(dtype.typestring) == 1:
code = self.tschecker_simple(dtype)
else:
assert False
funcnode = PureCFuncNode(self.module_pos, cname,
self.tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.tscheckers[dtype] = funcnode
return funcnode.entry
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
......@@ -27,11 +106,15 @@ class BufferTransform(CythonTransform):
if entry.type.buffer_options is not None]
for name, entry in bufvars:
# Variable has buffer opts, declare auxiliary vars
bufopts = entry.type.buffer_options
# Get or make a type string checker
tschecker = self.tschecker(bufopts.dtype)
# Declare auxiliary vars
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.buffer_type, node.pos)
self.bufstruct_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
......@@ -49,7 +132,7 @@ class BufferTransform(CythonTransform):
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars)
shapevars, tschecker)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
......@@ -64,13 +147,16 @@ class BufferTransform(CythonTransform):
self.visitchildren(node)
return node
# Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
ASSIGN_AUX
if TMP is not None:
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
ASSIGN_AUX
LHS = TMP
""")
......@@ -82,22 +168,6 @@ class BufferTransform(CythonTransform):
TARGET = BUFINFO.shape[IDX]
""")
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
auxass = []
......@@ -106,7 +176,7 @@ class BufferTransform(CythonTransform):
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
u"IDX": IntNode(node.pos, value=EncodedString(idx)),
})
auxass.append(ass)
......@@ -125,7 +195,8 @@ class BufferTransform(CythonTransform):
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
......@@ -165,6 +236,24 @@ class BufferTransform(CythonTransform):
return tmp.stats[0].expr
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
else:
return node
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
......@@ -179,10 +268,3 @@ class BufferTransform(CythonTransform):
else:
return node
def visit_CallNode(self, node):
### print node.dump()
return node
# def visit_FuncDefNode(self, node):
# print node.dump()
......@@ -1251,8 +1251,19 @@ class IndexNode(ExprNode):
#
# base ExprNode
# index ExprNode
# indices [ExprNode]
# is_buffer_access boolean Whether this is a buffer access.
#
# indices is used on buffer access, index on non-buffer access.
# The former contains a clean list of index parameters, the
# latter whatever Python object is needed for index access.
subexprs = ['base', 'index']
subexprs = ['base', 'index', 'indices']
indices = None
def __init__(self, pos, index, *args, **kw):
ExprNode.__init__(self, pos, index=index, *args, **kw)
self._index = index
def compile_time_value(self, denv):
base = self.base.compile_time_value(denv)
......@@ -1273,7 +1284,7 @@ class IndexNode(ExprNode):
def analyse_target_types(self, env):
self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
self.is_buffer_access = False
......@@ -1282,19 +1293,20 @@ class IndexNode(ExprNode):
if self.base.type.buffer_options is not None:
if isinstance(self.index, TupleNode):
indices = self.index.args
# is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int])
else:
# is_int_indices = self.index.type.is_int
indices = [self.index]
all_ints = True
for index in indices:
index.analyse_types(env)
if not index.type.is_int:
for x in indices:
x.analyse_types(env)
if not x.type.is_int:
all_ints = False
if all_ints:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.is_buffer_access = True
......@@ -3935,6 +3947,10 @@ class CoerceToTempNode(CoercionNode):
gil_message = "Creating temporary Python reference"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code):
#self.arg.generate_evaluation_code(code) # Already done
# by generic generate_subexpr_evaluation_code!
......
......@@ -184,10 +184,10 @@ class Node(object):
attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out]
if len(attrs) == 0:
return "<%s>" % self.__class__.__name__
return "<%s (%d)>" % (self.__class__.__name__, id(self))
else:
indent = " " * level
res = "<%s\n" % (self.__class__.__name__)
res = "<%s (%d)\n" % (self.__class__.__name__, id(self))
for key, value in attrs:
res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1))
res += "%s>" % indent
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
......@@ -51,12 +50,6 @@ class NormalizeTree(CythonTransform):
else:
return node
def visit_PassStatNode(self, node):
if not self.is_in_statlist:
return StatListNode(pos=node.pos, stats=[])
else:
return []
def visit_StatListNode(self, node):
self.is_in_statlist = True
self.visitchildren(node)
......@@ -72,6 +65,18 @@ class NormalizeTree(CythonTransform):
def visit_CStructOrUnionDefNode(self, node):
return self.visit_StatNode(node, True)
# Eliminate PassStatNode
def visit_PassStatNode(self, node):
if not self.is_in_statlist:
return StatListNode(pos=node.pos, stats=[])
else:
return []
# Eliminate CascadedAssignmentNode
def visit_CascadedAssignmentNode(self, node):
tmpname = temp_name_handle()
class PostParseError(CompileError): pass
......
......@@ -61,6 +61,7 @@ class PyrexType(BaseType):
# default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple
# pymemberdef_typecode string Type code for PyMemberDef struct
# typestring string String char defining the type (see Python struct module)
#
# declaration_code(entity_code,
# for_display = 0, dll_linkage = None, pyrex = 0)
......@@ -416,9 +417,10 @@ class CNumericType(CType):
sign_words = ("unsigned ", "", "signed ")
def __init__(self, rank, signed = 1, pymemberdef_typecode = None):
def __init__(self, rank, signed = 1, pymemberdef_typecode = None, typestring = None):
self.rank = rank
self.signed = signed
self.typestring = typestring
ptf = self.parsetuple_formats[signed][rank]
if ptf == '?':
ptf = None
......@@ -451,8 +453,9 @@ class CIntType(CNumericType):
from_py_function = "__pyx_PyInt_AsLong"
exception_value = -1
def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0):
CNumericType.__init__(self, rank, signed, pymemberdef_typecode)
def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0,
typestring=None):
CNumericType.__init__(self, rank, signed, pymemberdef_typecode, typestring=typestring)
self.is_returncode = is_returncode
if self.from_py_function == '__pyx_PyInt_AsLong':
self.from_py_function = self.get_type_conversion()
......@@ -543,8 +546,8 @@ class CFloatType(CNumericType):
to_py_function = "PyFloat_FromDouble"
from_py_function = "__pyx_PyFloat_AsDouble"
def __init__(self, rank, pymemberdef_typecode = None):
CNumericType.__init__(self, rank, 1, pymemberdef_typecode)
def __init__(self, rank, pymemberdef_typecode = None, typestring=None):
CNumericType.__init__(self, rank, 1, pymemberdef_typecode, typestring = typestring)
def assignable_from_resolved_type(self, src_type):
return src_type.is_numeric or src_type is error_type
......@@ -852,9 +855,12 @@ class CFuncTypeArg:
# type PyrexType
# pos source file position
def __init__(self, name, type, pos):
def __init__(self, name, type, pos, cname=None):
self.name = name
self.cname = Naming.var_prefix + name
if cname is not None:
self.cname = cname
else:
self.cname = Naming.var_prefix + name
self.type = type
self.pos = pos
self.not_none = False
......@@ -1050,29 +1056,29 @@ c_void_type = CVoidType()
c_void_ptr_type = CPtrType(c_void_type)
c_void_ptr_ptr_type = CPtrType(c_void_ptr_type)
c_uchar_type = CIntType(0, 0, "T_UBYTE")
c_ushort_type = CIntType(1, 0, "T_USHORT")
c_uint_type = CUIntType(2, 0, "T_UINT")
c_ulong_type = CULongType(3, 0, "T_ULONG")
c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG")
c_char_type = CIntType(0, 1, "T_CHAR")
c_short_type = CIntType(1, 1, "T_SHORT")
c_int_type = CIntType(2, 1, "T_INT")
c_long_type = CIntType(3, 1, "T_LONG")
c_longlong_type = CLongLongType(4, 1, "T_LONGLONG")
c_uchar_type = CIntType(0, 0, "T_UBYTE", typestring="B")
c_ushort_type = CIntType(1, 0, "T_USHORT", typestring="H")
c_uint_type = CUIntType(2, 0, "T_UINT", typestring="I")
c_ulong_type = CULongType(3, 0, "T_ULONG", typestring="L")
c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG", typestring="Q")
c_char_type = CIntType(0, 1, "T_CHAR", typestring="b")
c_short_type = CIntType(1, 1, "T_SHORT", typestring="h")
c_int_type = CIntType(2, 1, "T_INT", typestring="i")
c_long_type = CIntType(3, 1, "T_LONG", typestring="l")
c_longlong_type = CLongLongType(4, 1, "T_LONGLONG", typestring="q")
c_py_ssize_t_type = CPySSizeTType(5, 1)
c_bint_type = CBIntType(2, 1, "T_INT")
c_bint_type = CBIntType(2, 1, "T_INT", typestring="i")
c_schar_type = CIntType(0, 2, "T_CHAR")
c_sshort_type = CIntType(1, 2, "T_SHORT")
c_sint_type = CIntType(2, 2, "T_INT")
c_slong_type = CIntType(3, 2, "T_LONG")
c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG")
c_schar_type = CIntType(0, 2, "T_CHAR", typestring="b")
c_sshort_type = CIntType(1, 2, "T_SHORT", typestring="h")
c_sint_type = CIntType(2, 2, "T_INT", typestring="i")
c_slong_type = CIntType(3, 2, "T_LONG", typestring="l")
c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG", typestring="q")
c_float_type = CFloatType(6, "T_FLOAT")
c_double_type = CFloatType(7, "T_DOUBLE")
c_longdouble_type = CFloatType(8)
c_float_type = CFloatType(6, "T_FLOAT", typestring="f")
c_double_type = CFloatType(7, "T_DOUBLE", typestring="d")
c_longdouble_type = CFloatType(8, typestring="g")
c_null_ptr_type = CNullPtrType(c_void_type)
c_char_array_type = CCharArrayType(None)
......
......@@ -20,10 +20,12 @@ possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux:
def __init__(self, buffer_info_var, stridevars, shapevars):
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var
self.stridevars = stridevars
self.shapevars = shapevars
self.tschecker = tschecker
def __repr__(self):
return "<BufferAux %r>" % self.__dict__
......
......@@ -181,10 +181,14 @@ def replace_node(ptr, value):
getattr(parent, attrname)[listidx] = value
tmpnamectr = 0
def temp_name_handle(description):
def temp_name_handle(description=None):
global tmpnamectr
tmpnamectr += 1
return EncodedString(Naming.temp_prefix + u"%d_%s" % (tmpnamectr, description))
if description is not None:
name = u"%d_%s" % (tmpnamectr, description)
else:
name = u"%d" % tmpnamectr
return EncodedString(Naming.temp_prefix + name)
def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"):
......@@ -198,7 +202,7 @@ class PrintTree(TreeVisitor):
Subclass and override repr_of to provide more information
about nodes. """
def __init__(self):
Transform.__init__(self)
TreeVisitor.__init__(self)
self._indent = ""
def indent(self):
......@@ -208,6 +212,7 @@ class PrintTree(TreeVisitor):
def __call__(self, tree, phase=None):
print("Parse tree dump at phase '%s'" % phase)
self.visit(tree)
# Don't do anything about process_list, the defaults gives
# nice-looking name[idx] nodes which will visually appear
......
......@@ -19,5 +19,10 @@ cdef extern from "Python.h":
int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
void PyErr_Format(int, char*, ...)
enum:
PyExc_TypeError
# int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
# int flags)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment