Commit f952904c authored by Stefan Behnel's avatar Stefan Behnel

merged in quick fix by Dag

parents 01a5a332 07d40c12
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
import PyrexTypes
from sets import Set as set
class PureCFuncNode(Node):
def __init__(self, pos, cname, type, c_code, visibility='private'):
self.pos = pos
self.cname = cname
self.type = type
self.c_code = c_code
self.visibility = visibility
def analyse_types(self, env):
self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname,
defining=True, visibility=self.visibility)
def generate_function_definitions(self, env, code, transforms):
assert self.type.optional_arg_count == 0
visibility = self.entry.visibility
if visibility != 'private':
storage_class = "%s " % Naming.extern_c_macro
else:
storage_class = "static "
arg_decls = [arg.declaration_code() for arg in self.type.args]
sig = self.type.return_type.declaration_code(
self.type.function_header_code(self.cname, ", ".join(arg_decls)))
code.putln("")
code.putln("%s%s {" % (storage_class, sig))
code.put(self.c_code)
code.putln("}")
def generate_execution_code(self, code):
pass
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
#
# Entry point
#
def __call__(self, node):
assert isinstance(node, ModuleNode)
try:
cymod = self.context.modules[u'__cython__']
except KeyError:
# No buffer fun for this module
return node
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result
#
# Basic operations for transforms
#
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
for name, entry in bufvars:
bufopts = entry.type.buffer_options
# Get or make a type string checker
tschecker = self.tschecker(bufopts.dtype)
# Declare auxiliary vars
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.bufstruct_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars, tschecker)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
# Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx)),
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
pos = node.pos
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them with the buffer pointer
expr = AttributeNode(pos,
obj=NameNode(pos, name=bufaux.buffer_info_var.name),
attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
else:
return node
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else:
return node
#
# Utils for creating type string checkers
#
def new_ts_func(self, name, code):
cname = "%s_%s" % (tsprefix, name)
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.ts_funcs.append(funcnode)
return funcnode
def mangle_dtype_name(self, dtype):
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
return dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(self, dtype):
# See if we can consume one (unnamed) dtype as next item
funcnode = self.ts_item_checkers.get(dtype)
if funcnode is None:
char = dtype.typestring
if char is not None and len(char) > 1:
# Can use direct comparison
funcnode = self.new_ts_func("natitem_%s" % self.mangle_dtype_name(dtype), """\
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL;
} else return ts + 1;
""" % char)
else:
# Must deduce sign and length; rely on int vs. float to be correctly declared
ctype = dtype.declaration_code("")
code = """\
int ok;
switch (*ts) {"""
if dtype.is_int:
types = [
('b', 'char'), ('h', 'short'), ('i', 'int'),
('l', 'long'), ('q', 'long long')
]
code += "".join(["""\
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;
""" % (char, ctype, against, ctype, char.upper(), ctype, "unsigned " + against, ctype) for
char, against in types])
code += """\
default: ok = 0;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%s')", ts);
return NULL;
} else return ts + 1;
"""
funcnode = self.new_ts_func("tdefitem_%s" % self.mangle_dtype_name(dtype), code)
self.ts_item_checkers[dtype] = funcnode
return funcnode.entry.cname
ts_consume_whitespace_cname = None
ts_check_endian_cname = None
def ensure_ts_utils(self):
# Makes sure that the typechecker utils are in scope
# (and constructs them if not)
if self.ts_consume_whitespace_cname is None:
self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\
while (1) {
switch (*ts) {
case 10:
case 13:
case ' ':
++ts;
default:
return ts;
}
}
""").entry.cname
if self.ts_check_endian_cname is None:
self.ts_check_endian_cname = self.new_ts_func("check_endian", """\
int ok = 1;
switch (*ts) {
case '@':
case '=':
++ts; break;
case '<':
if (__BYTE_ORDER == __LITTLE_ENDIAN) ++ts;
else ok = 0;
break;
case '>':
case '!':
if (__BYTE_ORDER == __BIG_ENDIAN) ++ts;
else ok = 0;
break;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts);
return NULL;
}
return ts;
""").entry.cname
def create_ts_check_simple(self, dtype):
# Check whole string for single unnamed item
consume_whitespace = self.ts_consume_whitespace_cname
check_endian = self.ts_check_endian_cname
check_item = self.get_ts_check_item(dtype)
return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\
ts = %(consume_whitespace)s(ts);
ts = %(check_endian)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
ts = %(check_item)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
if (*ts != 0) {
PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts);
return NULL;
}
return ts;
""" % locals())
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
self.ensure_ts_utils()
funcnode = self.tscheckers.get(dtype)
if funcnode is None:
if dtype.is_struct_or_union:
assert False
elif dtype.is_int or dtype.is_float:
# This includes simple typedef-ed types
funcnode = self.create_ts_check_simple(dtype)
else:
assert False
self.tscheckers[dtype] = funcnode
return funcnode.entry
# TODO:
# - buf must be NULL before getting new buffer
...@@ -1251,8 +1251,19 @@ class IndexNode(ExprNode): ...@@ -1251,8 +1251,19 @@ class IndexNode(ExprNode):
# #
# base ExprNode # base ExprNode
# index ExprNode # index ExprNode
# indices [ExprNode]
# is_buffer_access boolean Whether this is a buffer access.
#
# indices is used on buffer access, index on non-buffer access.
# The former contains a clean list of index parameters, the
# latter whatever Python object is needed for index access.
subexprs = ['base', 'index'] subexprs = ['base', 'index', 'indices']
indices = None
def __init__(self, pos, index, *args, **kw):
ExprNode.__init__(self, pos, index=index, *args, **kw)
self._index = index
def compile_time_value(self, denv): def compile_time_value(self, denv):
base = self.base.compile_time_value(denv) base = self.base.compile_time_value(denv)
...@@ -1273,33 +1284,45 @@ class IndexNode(ExprNode): ...@@ -1273,33 +1284,45 @@ class IndexNode(ExprNode):
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0): def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
self.is_buffer_access = False self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
skip_child_analysis = False
buffer_access = False
if self.base.type.buffer_options is not None: if self.base.type.buffer_options is not None:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
indices = self.index.args indices = self.index.args
# is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int])
else: else:
# is_int_indices = self.index.type.is_int
indices = [self.index] indices = [self.index]
all_ints = True if len(indices) == self.base.type.buffer_options.ndim:
for index in indices: buffer_access = True
index.analyse_types(env) skip_child_analysis = True
if not index.type.is_int: for x in indices:
all_ints = False x.analyse_types(env)
if all_ints: if not x.type.is_int:
self.indices = indices buffer_access = False
self.index = None if buffer_access:
self.type = self.base.type.buffer_options.dtype # self.indices = [
self.is_temp = 1 # x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
self.is_buffer_access = True # for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
if not self.is_buffer_access: if not buffer_access:
self.index.analyse_types(env) # ok to analyse as tuple if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
self.index.analyse_types(env)
if self.base.type.is_pyobject: if self.base.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
self.original_index_type = self.index.type self.original_index_type = self.index.type
...@@ -1339,8 +1362,11 @@ class IndexNode(ExprNode): ...@@ -1339,8 +1362,11 @@ class IndexNode(ExprNode):
return 1 return 1
def calculate_result_code(self): def calculate_result_code(self):
return "(%s[%s])" % ( if self.is_buffer_access:
self.base.result_code, self.index.result_code) return "<not needed>"
else:
return "(%s[%s])" % (
self.base.result_code, self.index.result_code)
def index_unsigned_parameter(self): def index_unsigned_parameter(self):
if self.index.type.is_int: if self.index.type.is_int:
...@@ -2197,10 +2223,10 @@ class SequenceNode(ExprNode): ...@@ -2197,10 +2223,10 @@ class SequenceNode(ExprNode):
for arg in self.args: for arg in self.args:
arg.analyse_target_declaration(env) arg.analyse_target_declaration(env)
def analyse_types(self, env): def analyse_types(self, env, skip_children=False):
for i in range(len(self.args)): for i in range(len(self.args)):
arg = self.args[i] arg = self.args[i]
arg.analyse_types(env) if not skip_children: arg.analyse_types(env)
self.args[i] = arg.coerce_to_pyobject(env) self.args[i] = arg.coerce_to_pyobject(env)
self.type = py_object_type self.type = py_object_type
self.gil_check(env) self.gil_check(env)
...@@ -2305,12 +2331,12 @@ class TupleNode(SequenceNode): ...@@ -2305,12 +2331,12 @@ class TupleNode(SequenceNode):
gil_message = "Constructing Python tuple" gil_message = "Constructing Python tuple"
def analyse_types(self, env): def analyse_types(self, env, skip_children=False):
if len(self.args) == 0: if len(self.args) == 0:
self.is_temp = 0 self.is_temp = 0
self.is_literal = 1 self.is_literal = 1
else: else:
SequenceNode.analyse_types(self, env) SequenceNode.analyse_types(self, env, skip_children)
self.type = tuple_type self.type = tuple_type
def calculate_result_code(self): def calculate_result_code(self):
...@@ -2813,15 +2839,20 @@ def unop_node(pos, operator, operand): ...@@ -2813,15 +2839,20 @@ def unop_node(pos, operator, operand):
class TypecastNode(ExprNode): class TypecastNode(ExprNode):
# C type cast # C type cast
# #
# operand ExprNode
# base_type CBaseTypeNode # base_type CBaseTypeNode
# declarator CDeclaratorNode # declarator CDeclaratorNode
# operand ExprNode #
# If used from a transform, one can if wanted specify the attribute
# "type" directly and leave base_type and declarator to None
subexprs = ['operand'] subexprs = ['operand']
base_type = declarator = type = None
def analyse_types(self, env): def analyse_types(self, env):
base_type = self.base_type.analyse(env) if self.type is None:
_, self.type = self.declarator.analyse(base_type, env) base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env)
if self.type.is_cfunction: if self.type.is_cfunction:
error(self.pos, error(self.pos,
"Cannot cast to a function type") "Cannot cast to a function type")
...@@ -3842,6 +3873,10 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -3842,6 +3873,10 @@ class CoerceToPyTypeNode(CoercionNode):
gil_message = "Converting to Python object" gil_message = "Converting to Python object"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.arg.type.to_py_function function = self.arg.type.to_py_function
code.putln('%s = %s(%s); %s' % ( code.putln('%s = %s(%s); %s' % (
...@@ -3866,6 +3901,10 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -3866,6 +3901,10 @@ class CoerceFromPyTypeNode(CoercionNode):
error(arg.pos, error(arg.pos,
"Obtaining char * from temporary Python value") "Obtaining char * from temporary Python value")
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.type.from_py_function function = self.type.from_py_function
operand = self.arg.py_result() operand = self.arg.py_result()
...@@ -3924,6 +3963,10 @@ class CoerceToTempNode(CoercionNode): ...@@ -3924,6 +3963,10 @@ class CoerceToTempNode(CoercionNode):
gil_message = "Creating temporary Python reference" gil_message = "Creating temporary Python reference"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
#self.arg.generate_evaluation_code(code) # Already done #self.arg.generate_evaluation_code(code) # Already done
# by generic generate_subexpr_evaluation_code! # by generic generate_subexpr_evaluation_code!
......
...@@ -354,9 +354,10 @@ def create_generate_code(context, options, result): ...@@ -354,9 +354,10 @@ def create_generate_code(context, options, result):
return generate_code return generate_code
def create_default_pipeline(context, options, result): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, BufferTransform from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from Buffer import BufferTransform
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
return [ return [
......
...@@ -1953,53 +1953,81 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1953,53 +1953,81 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_buffer_compatability_functions(self, env, code): def generate_buffer_compatability_functions(self, env, code):
# will be refactored # will be refactored
code.put(""" try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) { static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check */ /* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj; PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr; PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data; view->buf = arr->data;
view->readonly = 0; /*fixme*/ view->readonly = !PyArray_ISWRITEABLE(obj);
view->format = "B"; /*fixme*/ view->ndim = PyArray_NDIM(arr);
view->ndim = arr->nd; view->strides = PyArray_STRIDES(arr);
view->strides = arr->strides; view->shape = PyArray_DIMS(arr);
view->shape = arr->dimensions; view->suboffsets = NULL;
view->suboffsets = 0; view->format = format;
view->itemsize = type->elsize; view->itemsize = type->elsize;
view->internal = 0; view->internal = 0;
return 0; return 0;
} }
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
} }
""") """)
except KeyError:
pass
# For now, hard-code numpy imported as "numpy" # For now, hard-code numpy imported as "numpy"
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type types = []
types = [ try:
(ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer") ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
] types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
# typeptr_cname = ndarrtype.typeptr_cname pass
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {") code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
clause = "if" if len(types) > 0:
for t, get, release in types: clause = "if"
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get)) for t, get, release in types:
clause = "else if" code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
code.putln("else {") clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);") code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;") code.putln("return -1;")
code.putln("}") if len(types) > 0: code.putln("}")
code.putln("}") code.putln("}")
code.putln("") code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {") code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
clause = "if" if len(types) > 0:
for t, get, release in types: clause = "if"
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release)) for t, get, release in types:
clause = "else if" code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}") code.putln("}")
code.putln("") code.putln("")
......
...@@ -184,10 +184,10 @@ class Node(object): ...@@ -184,10 +184,10 @@ class Node(object):
attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out] attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out]
if len(attrs) == 0: if len(attrs) == 0:
return "<%s>" % self.__class__.__name__ return "<%s (%d)>" % (self.__class__.__name__, id(self))
else: else:
indent = " " * level indent = " " * level
res = "<%s\n" % (self.__class__.__name__) res = "<%s (%d)\n" % (self.__class__.__name__, id(self))
for key, value in attrs: for key, value in attrs:
res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1)) res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1))
res += "%s>" % indent res += "%s>" % indent
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
...@@ -51,12 +50,6 @@ class NormalizeTree(CythonTransform): ...@@ -51,12 +50,6 @@ class NormalizeTree(CythonTransform):
else: else:
return node return node
def visit_PassStatNode(self, node):
if not self.is_in_statlist:
return StatListNode(pos=node.pos, stats=[])
else:
return []
def visit_StatListNode(self, node): def visit_StatListNode(self, node):
self.is_in_statlist = True self.is_in_statlist = True
self.visitchildren(node) self.visitchildren(node)
...@@ -72,6 +65,18 @@ class NormalizeTree(CythonTransform): ...@@ -72,6 +65,18 @@ class NormalizeTree(CythonTransform):
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
return self.visit_StatNode(node, True) return self.visit_StatNode(node, True)
# Eliminate PassStatNode
def visit_PassStatNode(self, node):
if not self.is_in_statlist:
return StatListNode(pos=node.pos, stats=[])
else:
return []
# Eliminate CascadedAssignmentNode
def visit_CascadedAssignmentNode(self, node):
tmpname = temp_name_handle()
class PostParseError(CompileError): pass class PostParseError(CompileError): pass
...@@ -146,169 +151,6 @@ class PostParse(CythonTransform): ...@@ -146,169 +151,6 @@ class PostParse(CythonTransform):
node.keyword_args = None node.keyword_args = None
return node return node
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
"""
scope = None
def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
for name, entry in bufvars:
# Variable has buffer opts, declare auxiliary vars
bufopts = entry.type.buffer_options
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.buffer_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
# ass = SingleAssignmentNode(pos=node.pos,
# lhs=NameNode(node.pos, name=entry.name),
# rhs=IndexNode(node.pos,
# base=AttributeNode(node.pos,
# obj=NameNode(node.pos, name=bufaux.buffer_info_var.name),
# attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump()
def visit_SingleAssignmentNode(self, node):
self.visitchildren(node)
bufaux = node.lhs.entry.buffer_aux
if bufaux is not None:
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
# stats += [node] # Do assignment after successful buffer acquisition
# print acq.dump()
return stats
else:
return node
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index,
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
print to_sum
indices = node.indices
# reduce * on indices
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
})
tmp.analyse_expressions(self.scope)
return tmp.stats[0].expr
else:
return node
def visit_CallNode(self, node):
### print node.dump()
return node
# def visit_FuncDefNode(self, node):
# print node.dump()
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
......
...@@ -61,6 +61,7 @@ class PyrexType(BaseType): ...@@ -61,6 +61,7 @@ class PyrexType(BaseType):
# default_value string Initial value # default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple # parsetuple_format string Format char for PyArg_ParseTuple
# pymemberdef_typecode string Type code for PyMemberDef struct # pymemberdef_typecode string Type code for PyMemberDef struct
# typestring string String char defining the type (see Python struct module)
# #
# declaration_code(entity_code, # declaration_code(entity_code,
# for_display = 0, dll_linkage = None, pyrex = 0) # for_display = 0, dll_linkage = None, pyrex = 0)
...@@ -416,9 +417,10 @@ class CNumericType(CType): ...@@ -416,9 +417,10 @@ class CNumericType(CType):
sign_words = ("unsigned ", "", "signed ") sign_words = ("unsigned ", "", "signed ")
def __init__(self, rank, signed = 1, pymemberdef_typecode = None): def __init__(self, rank, signed = 1, pymemberdef_typecode = None, typestring = None):
self.rank = rank self.rank = rank
self.signed = signed self.signed = signed
self.typestring = typestring
ptf = self.parsetuple_formats[signed][rank] ptf = self.parsetuple_formats[signed][rank]
if ptf == '?': if ptf == '?':
ptf = None ptf = None
...@@ -451,8 +453,9 @@ class CIntType(CNumericType): ...@@ -451,8 +453,9 @@ class CIntType(CNumericType):
from_py_function = "__pyx_PyInt_AsLong" from_py_function = "__pyx_PyInt_AsLong"
exception_value = -1 exception_value = -1
def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0): def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0,
CNumericType.__init__(self, rank, signed, pymemberdef_typecode) typestring=None):
CNumericType.__init__(self, rank, signed, pymemberdef_typecode, typestring=typestring)
self.is_returncode = is_returncode self.is_returncode = is_returncode
if self.from_py_function == '__pyx_PyInt_AsLong': if self.from_py_function == '__pyx_PyInt_AsLong':
self.from_py_function = self.get_type_conversion() self.from_py_function = self.get_type_conversion()
...@@ -543,8 +546,8 @@ class CFloatType(CNumericType): ...@@ -543,8 +546,8 @@ class CFloatType(CNumericType):
to_py_function = "PyFloat_FromDouble" to_py_function = "PyFloat_FromDouble"
from_py_function = "__pyx_PyFloat_AsDouble" from_py_function = "__pyx_PyFloat_AsDouble"
def __init__(self, rank, pymemberdef_typecode = None): def __init__(self, rank, pymemberdef_typecode = None, typestring=None):
CNumericType.__init__(self, rank, 1, pymemberdef_typecode) CNumericType.__init__(self, rank, 1, pymemberdef_typecode, typestring = typestring)
def assignable_from_resolved_type(self, src_type): def assignable_from_resolved_type(self, src_type):
return src_type.is_numeric or src_type is error_type return src_type.is_numeric or src_type is error_type
...@@ -852,9 +855,12 @@ class CFuncTypeArg: ...@@ -852,9 +855,12 @@ class CFuncTypeArg:
# type PyrexType # type PyrexType
# pos source file position # pos source file position
def __init__(self, name, type, pos): def __init__(self, name, type, pos, cname=None):
self.name = name self.name = name
self.cname = Naming.var_prefix + name if cname is not None:
self.cname = cname
else:
self.cname = Naming.var_prefix + name
self.type = type self.type = type
self.pos = pos self.pos = pos
self.not_none = False self.not_none = False
...@@ -1050,29 +1056,29 @@ c_void_type = CVoidType() ...@@ -1050,29 +1056,29 @@ c_void_type = CVoidType()
c_void_ptr_type = CPtrType(c_void_type) c_void_ptr_type = CPtrType(c_void_type)
c_void_ptr_ptr_type = CPtrType(c_void_ptr_type) c_void_ptr_ptr_type = CPtrType(c_void_ptr_type)
c_uchar_type = CIntType(0, 0, "T_UBYTE") c_uchar_type = CIntType(0, 0, "T_UBYTE", typestring="B")
c_ushort_type = CIntType(1, 0, "T_USHORT") c_ushort_type = CIntType(1, 0, "T_USHORT", typestring="H")
c_uint_type = CUIntType(2, 0, "T_UINT") c_uint_type = CUIntType(2, 0, "T_UINT", typestring="I")
c_ulong_type = CULongType(3, 0, "T_ULONG") c_ulong_type = CULongType(3, 0, "T_ULONG", typestring="L")
c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG") c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG", typestring="Q")
c_char_type = CIntType(0, 1, "T_CHAR") c_char_type = CIntType(0, 1, "T_CHAR", typestring="b")
c_short_type = CIntType(1, 1, "T_SHORT") c_short_type = CIntType(1, 1, "T_SHORT", typestring="h")
c_int_type = CIntType(2, 1, "T_INT") c_int_type = CIntType(2, 1, "T_INT", typestring="i")
c_long_type = CIntType(3, 1, "T_LONG") c_long_type = CIntType(3, 1, "T_LONG", typestring="l")
c_longlong_type = CLongLongType(4, 1, "T_LONGLONG") c_longlong_type = CLongLongType(4, 1, "T_LONGLONG", typestring="q")
c_py_ssize_t_type = CPySSizeTType(5, 1) c_py_ssize_t_type = CPySSizeTType(5, 1)
c_bint_type = CBIntType(2, 1, "T_INT") c_bint_type = CBIntType(2, 1, "T_INT", typestring="i")
c_schar_type = CIntType(0, 2, "T_CHAR") c_schar_type = CIntType(0, 2, "T_CHAR", typestring="b")
c_sshort_type = CIntType(1, 2, "T_SHORT") c_sshort_type = CIntType(1, 2, "T_SHORT", typestring="h")
c_sint_type = CIntType(2, 2, "T_INT") c_sint_type = CIntType(2, 2, "T_INT", typestring="i")
c_slong_type = CIntType(3, 2, "T_LONG") c_slong_type = CIntType(3, 2, "T_LONG", typestring="l")
c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG") c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG", typestring="q")
c_float_type = CFloatType(6, "T_FLOAT") c_float_type = CFloatType(6, "T_FLOAT", typestring="f")
c_double_type = CFloatType(7, "T_DOUBLE") c_double_type = CFloatType(7, "T_DOUBLE", typestring="d")
c_longdouble_type = CFloatType(8) c_longdouble_type = CFloatType(8, typestring="g")
c_null_ptr_type = CNullPtrType(c_void_type) c_null_ptr_type = CNullPtrType(c_void_type)
c_char_array_type = CCharArrayType(None) c_char_array_type = CCharArrayType(None)
...@@ -1087,7 +1093,8 @@ c_returncode_type = CIntType(2, 1, "T_INT", is_returncode = 1) ...@@ -1087,7 +1093,8 @@ c_returncode_type = CIntType(2, 1, "T_INT", is_returncode = 1)
c_anon_enum_type = CAnonEnumType(-1, 1) c_anon_enum_type = CAnonEnumType(-1, 1)
# the Py_buffer type is defined in Builtin.py # the Py_buffer type is defined in Builtin.py
c_py_buffer_ptr_type = CPtrType(CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")) c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
error_type = ErrorType() error_type = ErrorType()
......
...@@ -20,10 +20,12 @@ possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match ...@@ -20,10 +20,12 @@ possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux: class BufferAux:
def __init__(self, buffer_info_var, stridevars, shapevars): def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var self.buffer_info_var = buffer_info_var
self.stridevars = stridevars self.stridevars = stridevars
self.shapevars = shapevars self.shapevars = shapevars
self.tschecker = tschecker
def __repr__(self): def __repr__(self):
return "<BufferAux %r>" % self.__dict__ return "<BufferAux %r>" % self.__dict__
......
...@@ -181,10 +181,14 @@ def replace_node(ptr, value): ...@@ -181,10 +181,14 @@ def replace_node(ptr, value):
getattr(parent, attrname)[listidx] = value getattr(parent, attrname)[listidx] = value
tmpnamectr = 0 tmpnamectr = 0
def temp_name_handle(description): def temp_name_handle(description=None):
global tmpnamectr global tmpnamectr
tmpnamectr += 1 tmpnamectr += 1
return EncodedString(Naming.temp_prefix + u"%d_%s" % (tmpnamectr, description)) if description is not None:
name = u"%d_%s" % (tmpnamectr, description)
else:
name = u"%d" % tmpnamectr
return EncodedString(Naming.temp_prefix + name)
def get_temp_name_handle_desc(handle): def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"): if not handle.startswith(u"__cyt_"):
...@@ -198,7 +202,7 @@ class PrintTree(TreeVisitor): ...@@ -198,7 +202,7 @@ class PrintTree(TreeVisitor):
Subclass and override repr_of to provide more information Subclass and override repr_of to provide more information
about nodes. """ about nodes. """
def __init__(self): def __init__(self):
Transform.__init__(self) TreeVisitor.__init__(self)
self._indent = "" self._indent = ""
def indent(self): def indent(self):
...@@ -208,6 +212,7 @@ class PrintTree(TreeVisitor): ...@@ -208,6 +212,7 @@ class PrintTree(TreeVisitor):
def __call__(self, tree, phase=None): def __call__(self, tree, phase=None):
print("Parse tree dump at phase '%s'" % phase) print("Parse tree dump at phase '%s'" % phase)
self.visit(tree)
# Don't do anything about process_list, the defaults gives # Don't do anything about process_list, the defaults gives
# nice-looking name[idx] nodes which will visually appear # nice-looking name[idx] nodes which will visually appear
......
...@@ -19,5 +19,10 @@ cdef extern from "Python.h": ...@@ -19,5 +19,10 @@ cdef extern from "Python.h":
int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1 int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view) void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
void PyErr_Format(int, char*, ...)
enum:
PyExc_TypeError
# int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, # int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
# int flags) # int flags)
cdef extern from "Python.h":
ctypedef int Py_intptr_t
cdef extern from "numpy/arrayobject.h":
ctypedef class numpy.ndarray [object PyArrayObject]:
cdef char *data
cdef int nd
cdef Py_intptr_t *dimensions
cdef Py_intptr_t *strides
cdef object base
# descr not implemented yet here...
cdef int flags
cdef int itemsize
cdef object weakreflist
ctypedef unsigned int npy_uint8
ctypedef unsigned int npy_uint16
ctypedef unsigned int npy_uint32
ctypedef unsigned int npy_uint64
ctypedef unsigned int npy_uint96
ctypedef unsigned int npy_uint128
ctypedef signed int npy_int64
ctypedef float npy_float32
ctypedef float npy_float64
ctypedef float npy_float80
ctypedef float npy_float96
ctypedef float npy_float128
ctypedef npy_int64 Tint64
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment