Commit b35bb6c0 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge in Dag's work

parents bbb832bc eaa6d477
......@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'),
class AnnotationCCodeWriter(CCodeWriter):
def __init__(self, f):
CCodeWriter.__init__(self, self)
self.buffer = StringIO()
self.real_f = f
def __init__(self, create_from=None, buffer=None):
CCodeWriter.__init__(self, create_from, buffer)
self.annotation_buffer = StringIO()
if create_from is None:
self.annotations = []
self.last_pos = None
self.code = {}
else:
# When creating an insertion point, keep references to the same database
self.annotation_buffer = create_from.annotation_buffer
self.annotations = create_from.annotations
self.code = create_from.code
def getvalue(self):
return self.real_f.getvalue()
def create_new(self, create_from, buffer):
return AnnotationCCodeWriter(create_from, buffer)
def write(self, s):
self.real_f.write(s)
self.buffer.write(s)
CCodeWriter.write(self, s)
self.annotation_buffer.write(s)
def mark_pos(self, pos):
# if pos is not None:
# CCodeWriter.mark_pos(self, pos)
# return
if self.last_pos:
try:
code = self.code[self.last_pos[1]]
except KeyError:
code = ""
self.code[self.last_pos[1]] = code + self.buffer.getvalue()
self.buffer = StringIO()
code = self.code.get(self.last_pos[1], "")
self.code[self.last_pos[1]] = code + self.annotation_buffer.getvalue()
self.annotation_buffer = StringIO()
self.last_pos = pos
def annotate(self, pos, item):
......
......@@ -7,84 +7,33 @@ from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
import PyrexTypes
from sets import Set as set
import textwrap
class PureCFuncNode(Node):
def __init__(self, pos, cname, type, c_code, visibility='private'):
self.pos = pos
self.cname = cname
self.type = type
self.c_code = c_code
self.visibility = visibility
def analyse_types(self, env):
self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname,
defining=True, visibility=self.visibility)
def generate_function_definitions(self, env, code, transforms):
assert self.type.optional_arg_count == 0
visibility = self.entry.visibility
if visibility != 'private':
storage_class = "%s " % Naming.extern_c_macro
else:
storage_class = "static "
arg_decls = [arg.declaration_code() for arg in self.type.args]
sig = self.type.return_type.declaration_code(
self.type.function_header_code(self.cname, ", ".join(arg_decls)))
code.putln("")
code.putln("%s%s {" % (storage_class, sig))
code.put(self.c_code)
code.putln("}")
def generate_execution_code(self, code):
pass
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
def dedent(text, reindent=0):
text = textwrap.dedent(text)
if reindent > 0:
indent = " " * reindent
text = '\n'.join([indent + x for x in text.split('\n')])
return text
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
class IntroduceBufferAuxiliaryVars(CythonTransform):
#
# Entry point
#
buffers_exists = False
def __call__(self, node):
assert isinstance(node, ModuleNode)
try:
cymod = self.context.modules[u'__cython__']
except KeyError:
# No buffer fun for this module
return node
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
# Register ts stuff
self.max_ndim = 0
result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
if self.buffers_exists:
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
use_py2_buffer_functions(node.scope)
use_empty_bufstruct_code(node.scope, self.max_ndim)
node.scope.use_utility_code(access_utility_code)
return result
......@@ -95,136 +44,60 @@ class BufferTransform(CythonTransform):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
bufvars = [entry for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
if entry.type.is_buffer]
if len(bufvars) > 0:
self.buffers_exists = True
for name, entry in bufvars:
bufopts = entry.type.buffer_options
if isinstance(node, ModuleNode) and len(bufvars) > 0:
# for now...note that pos is wrong
raise CompileError(node.pos, "Buffer vars not allowed in module scope")
for entry in bufvars:
name = entry.name
buftype = entry.type
if buftype.ndim > self.max_ndim:
self.max_ndim = buftype.ndim
# Get or make a type string checker
tschecker = self.tschecker(bufopts.dtype)
tschecker = buffer_type_checker(buftype.dtype, scope)
# Declare auxiliary vars
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.bufstruct_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars, tschecker)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
cname = scope.mangle(Naming.bufstruct_prefix, name)
bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
type=PyrexTypes.c_py_buffer_type, pos=node.pos)
# Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
ASSIGN_AUX
LHS = TMP
""")
bufinfo.used = True
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
def var(prefix, idx, initval):
cname = scope.mangle(prefix, "%d_%s" % (idx, name))
result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
node.pos, cname=cname, is_cdef=True)
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx)),
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
result.init = initval
if entry.is_arg:
result.used = True
return result
def buffer_index(self, node):
pos = node.pos
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them with the buffer pointer
expr = AttributeNode(pos,
obj=NameNode(pos, name=bufaux.buffer_info_var.name),
attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
mode = entry.type.mode
if mode == 'full':
suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim)
elif mode == 'strided':
suboffsetvars = None
entry.buffer_aux.lookup = get_buf_lookup_strided(scope, entry.type.ndim)
casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result
entry.buffer_aux.suboffsetvars = suboffsetvars
entry.buffer_aux.get_buffer_cname = tschecker
scope.buffer_entries = bufvars
self.scope = scope
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
......@@ -235,102 +108,541 @@ class BufferTransform(CythonTransform):
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
def get_flags(buffer_aux, buffer_type):
flags = 'PyBUF_FORMAT'
if buffer_type.mode == 'full':
flags += '| PyBUF_INDIRECT'
elif buffer_type.mode == 'strided':
flags += '| PyBUF_STRIDES'
else:
return node
assert False
if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
return flags
def used_buffer_aux_vars(entry):
buffer_aux = entry.buffer_aux
buffer_aux.buffer_info_var.used = True
for s in buffer_aux.shapevars: s.used = True
for s in buffer_aux.stridevars: s.used = True
for s in buffer_aux.suboffsetvars: s.used = True
def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
# Generate code to copy the needed struct info into local
# variables.
bufstruct = buffer_aux.buffer_info_var.cname
varspec = [("strides", buffer_aux.stridevars),
("shape", buffer_aux.shapevars)]
if mode == 'full':
varspec.append(("suboffsets", buffer_aux.suboffsetvars))
for field, vars in varspec:
code.putln(" ".join(["%s = %s.%s[%d];" %
(s.cname, bufstruct, field, idx)
for idx, s in enumerate(vars)]))
def getbuffer_cond_code(obj_cname, buffer_aux, flags, ndim):
bufstruct = buffer_aux.buffer_info_var.cname
return "%s(%s, &%s, %s, %d) == -1" % (
buffer_aux.get_buffer_cname, obj_cname, bufstruct, flags, ndim)
def put_acquire_arg_buffer(entry, code, pos):
buffer_aux = entry.buffer_aux
cname = entry.cname
bufstruct = buffer_aux.buffer_info_var.cname
flags = get_flags(buffer_aux, entry.type)
# Acquire any new buffer
code.putln(code.error_goto_if(getbuffer_cond_code(cname,
buffer_aux,
flags,
entry.type.ndim),
pos))
# An exception raised in arg parsing cannot be catched, so no
# need to do care about the buffer then.
put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
#def put_release_buffer_normal(entry, code):
# code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
# entry.cname,
# entry.cname,
# entry.buffer_aux.buffer_info_var.cname))
def get_release_buffer_code(entry):
return "if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s)" % (
entry.cname,
entry.cname,
entry.buffer_aux.buffer_info_var.cname)
def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
is_initialized, pos, code):
"""
Generate code for reassigning a buffer variables. This only deals with getting
the buffer auxiliary structure and variables set up correctly, the assignment
itself and refcounting is the responsibility of the caller.
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
However, the assignment operation may throw an exception so that the reassignment
never happens.
Depending on the circumstances there are two possible outcomes:
- Old buffer released, new acquired, rhs assigned to lhs
- Old buffer released, new acquired which fails, reaqcuire old lhs buffer
(which may or may not succeed).
"""
bufstruct = buffer_aux.buffer_info_var.cname
flags = get_flags(buffer_aux, buffer_type)
getbuffer = "%s(%%s, &%s, %s, %d)" % (buffer_aux.get_buffer_cname,
# note: object is filled in later
bufstruct,
flags,
buffer_type.ndim)
if is_initialized:
# Release any existing buffer
code.put('if (%s != Py_None) ' % lhs_cname)
code.begin_block();
code.putln('PyObject_ReleaseBuffer(%s, &%s);' % (
lhs_cname, bufstruct))
code.end_block()
# Acquire
retcode_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
# If acquisition failed, attempt to reacquire the old buffer
# before raising the exception. A failure of reacquisition
# will cause the reacquisition exception to be reported, one
# can consider working around this later.
code.begin_block()
type, value, tb = [code.func.allocate_temp(PyrexTypes.py_object_type)
for i in range(3)]
code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
code.begin_block()
code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
code.putln('__Pyx_RaiseBufferFallbackError();')
code.putln('} else {')
code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
for t in (type, value, tb):
code.func.release_temp(t)
code.end_block()
# Unpack indices
code.end_block()
put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln(code.error_goto_if_neg(retcode_cname, pos))
code.func.release_temp(retcode_cname)
else:
return node
# Our entry had no previous value, so set to None when acquisition fails.
# In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
# so it suffices to set the buf field to NULL.
code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
code.putln('%s = Py_None; Py_INCREF(Py_None); %s.buf = NULL;' % (lhs_cname, bufstruct))
code.putln(code.error_goto(pos))
code.put('} else {')
# Unpack indices
put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln('}')
def put_access(entry, index_signeds, index_cnames, pos, code):
"""Returns a c string which can be used to access the buffer
for reading or writing.
As the bounds checking can have any number of combinations of unsigned
arguments, smart optimizations etc. we insert it directly in the function
body. The lookup however is delegated to a inline function that is instantiated
once per ndim (lookup with suboffsets tend to get quite complicated).
"""
bufaux = entry.buffer_aux
bufstruct = bufaux.buffer_info_var.cname
# Check bounds and fix negative indices
boundscheck = True
nonegs = True
tmp_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
if boundscheck:
code.putln("%s = -1;" % tmp_cname)
for idx, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
bufaux.shapevars)):
if signed != 0:
nonegs = False
# not unsigned, deal with negative index
code.putln("if (%s < 0) {" % cname)
code.putln("%s += %s;" % (cname, shape.cname))
if boundscheck:
code.putln("if (%s) %s = %d;" % (
code.unlikely("%s < 0" % cname), tmp_cname, idx))
code.put("} else ")
else:
if idx > 0: code.put("else ")
if boundscheck:
# check bounds in positive direction
code.putln("if (%s) %s = %d;" % (
code.unlikely("%s >= %s" % (cname, shape.cname)),
tmp_cname, idx))
if boundscheck:
code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
code.begin_block()
code.putln('__Pyx_BufferIndexError(%s);' % tmp_cname)
code.putln(code.error_goto(pos))
code.end_block()
code.func.release_temp(tmp_cname)
# Create buffer lookup and return it
params = []
if entry.type.mode == 'full':
for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
params.append(i)
params.append(s.cname)
params.append(o.cname)
else:
for i, s in zip(index_cnames, bufaux.stridevars):
params.append(i)
params.append(s.cname)
ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct,
", ".join(params))
valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode
#
# Utils for creating type string checkers
#
def new_ts_func(self, name, code):
cname = "%s_%s" % (tsprefix, name)
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.ts_funcs.append(funcnode)
return funcnode
def mangle_dtype_name(self, dtype):
def use_empty_bufstruct_code(env, max_ndim):
code = dedent("""
Py_ssize_t __Pyx_zeros[] = {%s};
Py_ssize_t __Pyx_minusones[] = {%s};
""") % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
env.use_utility_code([code, ""])
def get_buf_lookup_strided(env, nd):
"""
Generates and registers as utility a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location.
"""
name = "__Pyx_BufPtrStrided_%dd" % nd
if not env.has_utility_code(name):
# _i_ndex, _s_tride
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
proto = dedent("""\
#define %s(buf, %s) ((char*)buf + %s)
""") % (name, args, offset)
env.use_utility_code([proto, ""], name=name)
return name
def get_buf_lookup_full(env, nd):
"""
Generates and registers as utility a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location.
"""
name = "__Pyx_BufPtrFull_%dd" % nd
if not env.has_utility_code(name):
# _i_ndex, _s_tride, sub_o_ffset
args = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
proto = dedent("""\
static INLINE void* %s(void* buf, %s);
""") % (name, args)
func = dedent("""
static INLINE void* %s(void* buf, %s) {
char* ptr = (char*)buf;
""") % (name, args) + "".join([dedent("""\
ptr += s%d * i%d;
if (o%d >= 0) ptr = *((char**)ptr) + o%d;
""") % (i, i, i, i) for i in range(nd)]
) + "\nreturn ptr;\n}"
env.use_utility_code([proto, func], name=name)
return name
#
# Utils for creating type string checkers
#
def mangle_dtype_name(dtype):
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
return dtype.declaration_code("").replace(" ", "_")
if dtype.typestring is None:
prefix = "nn_"
else:
prefix = ""
return prefix + dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(self, dtype):
def get_ts_check_item(dtype, env):
# See if we can consume one (unnamed) dtype as next item
funcnode = self.ts_item_checkers.get(dtype)
if funcnode is None:
# Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype)
if not env.has_utility_code(name):
char = dtype.typestring
if char is not None and len(char) > 1:
if char is not None:
# Can use direct comparison
funcnode = self.new_ts_func("natitem_%s" % self.mangle_dtype_name(dtype), """\
code = dedent("""\
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL;
} else return ts + 1;
""" % char)
""", 2) % char
else:
# Must deduce sign and length; rely on int vs. float to be correctly declared
# Cannot trust declared size; but rely on int vs float and
# signed/unsigned to be correctly declared
ctype = dtype.declaration_code("")
code = """\
code = dedent("""\
int ok;
switch (*ts) {"""
switch (*ts) {""", 2)
if dtype.is_int:
types = [
('b', 'char'), ('h', 'short'), ('i', 'int'),
('l', 'long'), ('q', 'long long')
]
code += "".join(["""\
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;
""" % (char, ctype, against, ctype, char.upper(), ctype, "unsigned " + against, ctype) for
char, against in types])
code += """\
if dtype.signed == 0:
code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
(char.upper(), ctype, against, ctype) for char, against in types])
else:
code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
(char, ctype, against, ctype) for char, against in types])
code += dedent("""\
default: ok = 0;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%s')", ts);
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%s')", ts);
return NULL;
} else return ts + 1;
"""
""", 2)
env.use_utility_code([dedent("""\
static const char* %s(const char* ts); /*proto*/
""") % name, dedent("""
static const char* %s(const char* ts) {
%s
}
""") % (name, code)], name=name)
funcnode = self.new_ts_func("tdefitem_%s" % self.mangle_dtype_name(dtype), code)
return name
self.ts_item_checkers[dtype] = funcnode
return funcnode.entry.cname
def get_getbuffer_code(dtype, env):
"""
Generate a utility function for getting a buffer for the given dtype.
The function will:
- Call PyObject_GetBuffer
- Check that ndim matched the expected value
- Check that the format string is right
- Set suboffsets to all -1 if it is returned as NULL.
"""
ts_consume_whitespace_cname = None
ts_check_endian_cname = None
name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
if not env.has_utility_code(name):
env.use_utility_code(acquire_utility_code)
itemchecker = get_ts_check_item(dtype, env)
utilcode = [dedent("""
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/
""") % name, dedent("""
static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd) {
const char* ts;
if (obj == Py_None) {
__Pyx_ZeroBuffer(buf);
return 0;
}
buf->buf = NULL;
if (PyObject_GetBuffer(obj, buf, flags) == -1) goto fail;
if (buf->ndim != nd) {
__Pyx_BufferNdimError(buf, nd);
goto fail;
}
ts = buf->format;
ts = __Pyx_ConsumeWhitespace(ts);
ts = __Pyx_BufferTypestringCheckEndian(ts);
if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts);
ts = %(itemchecker)s(ts);
if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts);
if (*ts != 0) {
PyErr_Format(PyExc_ValueError,
"Expected non-struct buffer data type (rejecting on '%%s')", ts);
goto fail;
}
if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
return 0;
fail:;
__Pyx_ZeroBuffer(buf);
return -1;
}""") % locals()]
env.use_utility_code(utilcode, name)
return name
def buffer_type_checker(dtype, env):
# Creates a type checker function for the given type.
if dtype.is_struct_or_union:
assert False
elif dtype.is_int or dtype.is_float:
# This includes simple typedef-ed types
funcname = get_getbuffer_code(dtype, env)
else:
assert False
return funcname
def ensure_ts_utils(self):
# Makes sure that the typechecker utils are in scope
# (and constructs them if not)
if self.ts_consume_whitespace_cname is None:
self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\
def use_py2_buffer_functions(env):
# will be refactored
try:
env.entries[u'numpy']
env.use_utility_code(["","""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
"""])
except KeyError:
pass
codename = "PyObject_GetBuffer" # just a representative unique key
# Search all types for __getbuffer__ overloads
types = []
def find_buffer_types(scope):
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(env)
# For now, hard-code numpy imported as "numpy"
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code = dedent("""
#if PY_VERSION_HEX < 0x02060000
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
""")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
clause = "else if"
code += " else {\n"
code += dedent("""\
PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
return -1;
""", 2)
if len(types) > 0: code += " }"
code += dedent("""
}
static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {
""")
if len(types) > 0:
clause = "if"
for t, get, release in types:
if release:
code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
clause = "else if"
code += dedent("""
}
#endif
""")
env.use_utility_code([dedent("""\
#if PY_VERSION_HEX < 0x02060000
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);
#endif
""") ,code], codename)
#
# Static utility code
#
# Utility function to set the right exception
# The caller should immediately goto_error
access_utility_code = [
"""\
static void __Pyx_BufferIndexError(int axis); /*proto*/
""","""\
static void __Pyx_BufferIndexError(int axis) {
PyErr_Format(PyExc_IndexError,
"Out of bounds on buffer access (axis %d)", axis);
}
"""]
#
# Buffer type checking. Utility code for checking that acquired
# buffers match our assumptions. We only need to check ndim and
# the format string; the access mode/flags is checked by the
# exporter.
#
acquire_utility_code = ["""\
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
static INLINE const char* __Pyx_BufferTypestringCheckEndian(const char* ts); /*proto*/
static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
""", """
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
buf->buf = NULL;
buf->strides = __Pyx_zeros;
buf->shape = __Pyx_zeros;
buf->suboffsets = __Pyx_minusones;
}
static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
while (1) {
switch (*ts) {
case 10:
......@@ -341,9 +653,9 @@ class BufferTransform(CythonTransform):
return ts;
}
}
""").entry.cname
if self.ts_check_endian_cname is None:
self.ts_check_endian_cname = self.new_ts_func("check_endian", """\
}
static INLINE const char* __Pyx_BufferTypestringCheckEndian(const char* ts) {
int ok = 1;
switch (*ts) {
case '@':
......@@ -360,55 +672,21 @@ class BufferTransform(CythonTransform):
break;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts);
return NULL;
}
return ts;
""").entry.cname
def create_ts_check_simple(self, dtype):
# Check whole string for single unnamed item
consume_whitespace = self.ts_consume_whitespace_cname
check_endian = self.ts_check_endian_cname
check_item = self.get_ts_check_item(dtype)
return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\
ts = %(consume_whitespace)s(ts);
ts = %(check_endian)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
ts = %(check_item)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
if (*ts != 0) {
PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts);
PyErr_Format(PyExc_ValueError, "Buffer has wrong endianness (rejecting on '%s')", ts);
return NULL;
}
return ts;
""" % locals())
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
self.ensure_ts_utils()
funcnode = self.tscheckers.get(dtype)
if funcnode is None:
if dtype.is_struct_or_union:
assert False
elif dtype.is_int or dtype.is_float:
# This includes simple typedef-ed types
funcnode = self.create_ts_check_simple(dtype)
else:
assert False
self.tscheckers[dtype] = funcnode
return funcnode.entry
}
static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
PyErr_Format(PyExc_ValueError,
"Buffer has wrong number of dimensions (expected %d, got %d)",
expected_ndim, buffer->ndim);
}
# TODO:
# - buf must be NULL before getting new buffer
static void __Pyx_RaiseBufferFallbackError(void) {
PyErr_Format(PyExc_ValueError,
"Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
}
"""]
......@@ -9,9 +9,142 @@ from Cython.Utils import open_new_file, open_source_file
from PyrexTypes import py_object_type, typecast
from TypeSlots import method_coexist
from Scanning import SourceDescriptor
from Cython.StringIOTree import StringIOTree
from sets import Set as set
class FunctionContext(object):
# Not used for now, perhaps later
def __init__(self, names_taken=set()):
self.names_taken = names_taken
self.error_label = None
self.label_counter = 0
self.labels_used = {}
self.return_label = self.new_label()
self.new_error_label()
self.continue_label = None
self.break_label = None
self.temps_allocated = [] # of (name, type)
self.temps_free = {} # type -> list of free vars
self.temps_used_type = {} # name -> type
self.temp_counter = 0
def new_label(self):
n = self.label_counter
self.label_counter = n + 1
return "%s%d" % (Naming.label_prefix, n)
def new_error_label(self):
old_err_lbl = self.error_label
self.error_label = self.new_label()
return old_err_lbl
def get_loop_labels(self):
return (
self.continue_label,
self.break_label)
def set_loop_labels(self, labels):
(self.continue_label,
self.break_label) = labels
def new_loop_labels(self):
old_labels = self.get_loop_labels()
self.set_loop_labels(
(self.new_label(),
self.new_label()))
return old_labels
def get_all_labels(self):
return (
self.continue_label,
self.break_label,
self.return_label,
self.error_label)
def set_all_labels(self, labels):
(self.continue_label,
self.break_label,
self.return_label,
self.error_label) = labels
def all_new_labels(self):
old_labels = self.get_all_labels()
new_labels = []
for old_label in old_labels:
if old_label:
new_labels.append(self.new_label())
else:
new_labels.append(old_label)
self.set_all_labels(new_labels)
return old_labels
def use_label(self, lbl):
self.labels_used[lbl] = 1
def label_used(self, lbl):
return lbl in self.labels_used
def allocate_temp(self, type):
"""
Allocates a temporary (which may create a new one or get a previously
allocated and released one of the same type). Type is simply registered
and handed back, but will usually be a PyrexType.
A C string referring to the variable is returned.
"""
freelist = self.temps_free.get(type)
if freelist is not None and len(freelist) > 0:
result = freelist.pop()
else:
while True:
self.temp_counter += 1
result = "%s%d" % (Naming.codewriter_temp_prefix, self.temp_counter)
if not result in self.names_taken: break
self.temps_allocated.append((result, type))
self.temps_used_type[result] = type
return result
def release_temp(self, name):
"""
Releases a temporary so that it can be reused by other code needing
a temp of the same type.
"""
type = self.temps_used_type[name]
freelist = self.temps_free.get(type)
if freelist is None:
freelist = []
self.temps_free[type] = freelist
freelist.append(name)
def funccontext_property(name):
def get(self):
return getattr(self.func, name)
def set(self, value):
setattr(self.func, name, value)
return property(get, set)
class CCodeWriter(object):
"""
Utility class to output C code.
When creating an insertion point one must care about the state that is
kept:
- formatting state (level, bol) is cloned and used in insertion points
as well
- labels, temps, exc_vars: One must construct a scope in which these can
exist by calling enter_cfunc_scope/exit_cfunc_scope (these are for
sanity checking and forward compatabilty). Created insertion points
looses this scope and cannot access it.
- marker: Not copied to insertion point
- filename_table, filename_list, input_file_contents: All codewriters
coming from the same root share the same instances simultaneously.
"""
class CCodeWriter:
# f file output file
# buffer StringIOTree
# level int indentation level
# bol bool beginning of line?
# marker string comment to emit before next line
......@@ -19,6 +152,7 @@ class CCodeWriter:
# error_label string error catch point label
# continue_label string loop continue point label
# break_label string loop break point label
# return_from_error_cleanup_label string
# label_counter integer counter for naming labels
# in_try_finally boolean inside try of try...finally
# filename_table {string : int} for finding filename table indexes
......@@ -27,36 +161,94 @@ class CCodeWriter:
# input_file_contents dict contents (=list of lines) of any file that was used as input
# to create this output C code. This is
# used to annotate the comments.
# func FunctionContext contains labels and temps context info
in_try_finally = 0
def __init__(self, f):
#self.f = open_new_file(outfile_name)
self.f = f
self._write = f.write
self.level = 0
self.bol = 1
def __init__(self, create_from=None, buffer=None):
if buffer is None: buffer = StringIOTree()
self.buffer = buffer
self.marker = None
self.last_marker_line = 0
self.label_counter = 1
self.error_label = None
self.func = None
if create_from is None:
# Root CCodeWriter
self.level = 0
self.bol = 1
self.filename_table = {}
self.filename_list = []
self.exc_vars = None
self.input_file_contents = {}
else:
# Clone formatting state
c = create_from
self.level = c.level
self.bol = c.bol
# Note: NOT copying but sharing instance
self.filename_table = c.filename_table
self.filename_list = []
self.input_file_contents = c.input_file_contents
# Leave other state alone
def create_new(self, create_from, buffer):
# polymorphic constructor -- very slightly more versatile
# than using __class__
return CCodeWriter(create_from, buffer)
def copyto(self, f):
self.buffer.copyto(f)
def getvalue(self):
return self.buffer.getvalue()
def write(self, s):
self.buffer.write(s)
def insertion_point(self):
other = self.create_new(create_from=self, buffer=self.buffer.insertion_point())
return other
# Properties delegated to function scope
label_counter = funccontext_property("label_counter")
return_label = funccontext_property("return_label")
error_label = funccontext_property("error_label")
labels_used = funccontext_property("labels_used")
continue_label = funccontext_property("continue_label")
break_label = funccontext_property("break_label")
# Functions delegated to function scope
def new_label(self): return self.func.new_label()
def new_error_label(self): return self.func.new_error_label()
def get_loop_labels(self): return self.func.get_loop_labels()
def set_loop_labels(self, labels): return self.func.set_loop_labels(labels)
def new_loop_labels(self): return self.func.new_loop_labels()
def get_all_labels(self): return self.func.get_all_labels()
def set_all_labels(self, labels): return self.func.set_all_labels(labels)
def all_new_labels(self): return self.func.all_new_labels()
def use_label(self, lbl): return self.func.use_label(lbl)
def label_used(self, lbl): return self.func.label_used(lbl)
def enter_cfunc_scope(self):
self.func = FunctionContext()
def exit_cfunc_scope(self):
self.func = None
def putln(self, code = ""):
if self.marker and self.bol:
self.emit_marker()
if code:
self.put(code)
self._write("\n");
self.write("\n");
self.bol = 1
def emit_marker(self):
self._write("\n");
self.write("\n");
self.indent()
self._write("/* %s */\n" % self.marker[1])
self.write("/* %s */\n" % self.marker[1])
self.last_marker_line = self.marker[0]
self.marker = None
......@@ -68,7 +260,7 @@ class CCodeWriter:
self.level -= 1
if self.bol:
self.indent()
self._write(code)
self.write(code)
self.bol = 0
if dl > 0:
self.level += dl
......@@ -90,7 +282,7 @@ class CCodeWriter:
self.putln("}")
def indent(self):
self._write(" " * self.level)
self.write(" " * self.level)
def get_py_version_hex(self, pyversion):
return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4]
......@@ -123,76 +315,13 @@ class CCodeWriter:
source_desc.get_escaped_description(), line, u'\n'.join(lines))
self.marker = (line, marker)
def init_labels(self):
self.label_counter = 0
self.labels_used = {}
self.return_label = self.new_label()
self.new_error_label()
self.continue_label = None
self.break_label = None
def new_label(self):
n = self.label_counter
self.label_counter = n + 1
return "%s%d" % (Naming.label_prefix, n)
def new_error_label(self):
old_err_lbl = self.error_label
self.error_label = self.new_label()
return old_err_lbl
def get_loop_labels(self):
return (
self.continue_label,
self.break_label)
def set_loop_labels(self, labels):
(self.continue_label,
self.break_label) = labels
def new_loop_labels(self):
old_labels = self.get_loop_labels()
self.set_loop_labels(
(self.new_label(),
self.new_label()))
return old_labels
def get_all_labels(self):
return (
self.continue_label,
self.break_label,
self.return_label,
self.error_label)
def set_all_labels(self, labels):
(self.continue_label,
self.break_label,
self.return_label,
self.error_label) = labels
def all_new_labels(self):
old_labels = self.get_all_labels()
new_labels = []
for old_label in old_labels:
if old_label:
new_labels.append(self.new_label())
else:
new_labels.append(old_label)
self.set_all_labels(new_labels)
return old_labels
def use_label(self, lbl):
self.labels_used[lbl] = 1
def label_used(self, lbl):
return lbl in self.labels_used
def put_label(self, lbl):
if lbl in self.labels_used:
if lbl in self.func.labels_used:
self.putln("%s:;" % lbl)
def put_goto(self, lbl):
self.use_label(lbl)
self.func.use_label(lbl)
self.putln("goto %s;" % lbl)
def put_var_declarations(self, entries, static = 0, dll_linkage = None,
......@@ -211,7 +340,7 @@ class CCodeWriter:
#print "...private and not definition, skipping" ###
return
if not entry.used and visibility == "private":
#print "not used and private, skipping" ###
#print "not used and private, skipping", entry.cname ###
return
storage_class = ""
if visibility == 'extern':
......@@ -232,6 +361,14 @@ class CCodeWriter:
self.put(" = %s" % entry.type.literal_code(entry.init))
self.putln(";")
def put_temp_declarations(self, func_context):
for name, type in func_context.temps_allocated:
decl = type.declaration_code(name)
if type.is_pyobject:
self.putln("%s = NULL;" % decl)
else:
self.putln("%s;" % decl)
def entry_as_pyobject(self, entry):
type = entry.type
if (not entry.is_self_arg and not entry.type.is_complete()) \
......@@ -337,9 +474,15 @@ class CCodeWriter:
self.putln("#ifndef %s" % guard)
self.putln("#define %s" % guard)
def unlikely(self, cond):
if Options.gcc_branch_hints:
return 'unlikely(%s)' % cond
else:
return cond
def error_goto(self, pos):
lbl = self.error_label
self.use_label(lbl)
lbl = self.func.error_label
self.func.use_label(lbl)
if Options.c_line_in_traceback:
cinfo = " %s = %s;" % (Naming.clineno_cname, Naming.line_c_macro)
else:
......@@ -354,10 +497,7 @@ class CCodeWriter:
lbl)
def error_goto_if(self, cond, pos):
if Options.gcc_branch_hints:
return "if (unlikely(%s)) %s" % (cond, self.error_goto(pos))
else:
return "if (%s) %s" % (cond, self.error_goto(pos))
return "if (%s) %s" % (self.unlikely(cond), self.error_goto(pos))
def error_goto_if_null(self, cname, pos):
return self.error_goto_if("!%s" % cname, pos)
......
......@@ -33,6 +33,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""):
self.position = position
self.message_only = message
self.reported = False
# Deprecated and withdrawn in 2.6:
# self.message = message
if position:
......@@ -88,17 +89,23 @@ def close_listing_file():
listing_file.close()
listing_file = None
def error(position, message):
#print "Errors.error:", repr(position), repr(message) ###
def report_error(err):
global num_errors
err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
# See Main.py for why dual reporting occurs. Quick fix for now.
if err.reported: return
err.reported = True
line = "%s\n" % err
if listing_file:
listing_file.write(line)
if echo_file:
echo_file.write(line)
num_errors = num_errors + 1
def error(position, message):
#print "Errors.error:", repr(position), repr(message) ###
err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
report_error(err)
return err
LEVEL=1 # warn about all errors level 1 or higher
......
......@@ -806,6 +806,15 @@ class NameNode(AtomicExprNode):
# interned_cname string
is_name = 1
skip_assignment_decref = False
entry = None
def create_analysed_rvalue(pos, env, entry):
node = NameNode(pos)
node.analyse_types(env, entry=entry)
return node
create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def compile_time_value(self, denv):
try:
......@@ -834,6 +843,8 @@ class NameNode(AtomicExprNode):
def analyse_as_module(self, env):
# Try to interpret this as a reference to a cimported module.
# Returns the module scope, or None.
entry = self.entry
if not entry:
entry = env.lookup(self.name)
if entry and entry.as_module:
return entry.as_module
......@@ -842,6 +853,8 @@ class NameNode(AtomicExprNode):
def analyse_as_extension_type(self, env):
# Try to interpret this as a reference to an extension type.
# Returns the extension type, or None.
entry = self.entry
if not entry:
entry = env.lookup(self.name)
if entry and entry.is_type and entry.type.is_extension_type:
return entry.type
......@@ -849,6 +862,7 @@ class NameNode(AtomicExprNode):
return None
def analyse_target_declaration(self, env):
if not self.entry:
self.entry = env.lookup_here(self.name)
if not self.entry:
self.entry = env.declare_var(self.name, py_object_type, self.pos)
......@@ -860,6 +874,7 @@ class NameNode(AtomicExprNode):
env.use_utility_code(type_cache_invalidation_code)
def analyse_types(self, env):
if self.entry is None:
self.entry = env.lookup(self.name)
if not self.entry:
self.entry = env.declare_builtin(self.name, self.pos)
......@@ -875,6 +890,12 @@ class NameNode(AtomicExprNode):
% self.name)
self.type = PyrexTypes.error_type
self.entry.used = 1
if self.entry.type.is_buffer:
# Have an rhs temp just in case. All rhs I could
# think of had a single symbol result_code but better
# safe than sorry. Feel free to change this.
import Buffer
Buffer.used_buffer_aux_vars(self.entry)
def analyse_rvalue_entry(self, env):
#print "NameNode.analyse_rvalue_entry:", self.name ###
......@@ -1018,12 +1039,23 @@ class NameNode(AtomicExprNode):
rhs.generate_disposal_code(code)
else:
if self.type.is_buffer:
# Generate code for doing the buffer release/acquisition.
# This might raise an exception in which case the assignment (done
# below) will not happen.
#
# The reason this is not in a typetest-like node is because the
# variables that the acquired buffer info is stored to is allocated
# per entry and coupled with it.
self.generate_acquire_buffer(rhs, code)
if self.type.is_pyobject:
rhs.make_owned_reference(code)
#print "NameNode.generate_assignment_code: to", self.name ###
#print "...from", rhs ###
#print "...LHS type", self.type, "ctype", self.ctype() ###
#print "...RHS type", rhs.type, "ctype", rhs.ctype() ###
rhs.make_owned_reference(code)
if not self.skip_assignment_decref:
if entry.is_local and not Options.init_local_none:
initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos)
if initalized is True:
......@@ -1038,6 +1070,19 @@ class NameNode(AtomicExprNode):
print("...generating post-assignment code for %s" % rhs)
rhs.generate_post_assignment_code(code)
def generate_acquire_buffer(self, rhs, code):
rhstmp = code.func.allocate_temp(self.entry.type)
buffer_aux = self.entry.buffer_aux
bufstruct = buffer_aux.buffer_info_var.cname
code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype())))
import Buffer
Buffer.put_assign_to_buffer(self.result_code, rhstmp, buffer_aux, self.entry.type,
is_initialized=not self.skip_assignment_decref,
pos=self.pos, code=code)
code.putln("%s = 0;" % rhstmp)
code.func.release_temp(rhstmp)
def generate_deletion_code(self, code):
if self.entry is None:
return # There was an error earlier
......@@ -1297,39 +1342,45 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
self.is_buffer_access = False
self.base.analyse_types(env)
skip_child_analysis = False
buffer_access = False
if self.base.type.buffer_options is not None:
if self.base.type.is_buffer:
assert isinstance(self.base, NameNode)
if isinstance(self.index, TupleNode):
indices = self.index.args
else:
indices = [self.index]
if len(indices) == self.base.type.buffer_options.ndim:
if len(indices) == self.base.type.ndim:
buffer_access = True
skip_child_analysis = True
for x in indices:
x.analyse_types(env)
if not x.type.is_int:
buffer_access = False
if buffer_access:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.type = self.base.type.dtype
self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
if not buffer_access:
if getting:
# we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take
self.is_temp = True
if setting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.base.entry.buffer_aux.writable_needed = True
else:
if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
......@@ -1374,7 +1425,7 @@ class IndexNode(ExprNode):
def calculate_result_code(self):
if self.is_buffer_access:
return "<not needed>"
return "<not used>"
else:
return "(%s[%s])" % (
self.base.result_code, self.index.result_code)
......@@ -1393,17 +1444,22 @@ class IndexNode(ExprNode):
if self.index is not None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices: i.generate_evaluation_code(code)
for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code)
if self.index is not None:
self.index.generate_disposal_code(code)
else:
for i in self.indices: i.generate_disposal_code(code)
for i in self.indices:
i.generate_disposal_code(code)
def generate_result_code(self, code):
if self.type.is_pyobject:
if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode))
elif self.type.is_pyobject:
if self.index.type.is_int:
function = "__Pyx_GetItemInt"
index_code = self.index.result_code
......@@ -1439,7 +1495,10 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject:
if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code))
elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code)
else:
code.putln(
......@@ -1465,6 +1524,19 @@ class IndexNode(ExprNode):
code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code):
# Assign indices to temps
index_temps = [code.func.allocate_temp(i.type) for i in self.indices]
for temp, index in zip(index_temps, self.indices):
code.putln("%s = %s;" % (temp, index.result_code))
# Generate buffer access code using these temps
import Buffer
valuecode = Buffer.put_access(entry=self.base.entry,
index_signeds=[i.type.signed for i in self.indices],
index_cnames=index_temps,
pos=self.pos, code=code)
return valuecode
class SliceIndexNode(ExprNode):
# 2-element slice indexing
......
......@@ -47,6 +47,13 @@ class Context:
self.include_directories = include_directories
self.future_directives = set()
import os.path
standard_include_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), '..', 'Includes'))
self.include_directories = include_directories + [standard_include_path]
def find_module(self, module_name,
relative_to = None, pos = None, need_pxd = 1):
# Finds and returns the module scope corresponding to
......@@ -323,14 +330,18 @@ class Context:
verbose_flag = options.show_version,
cplus = options.cplus)
def nonfatal_error(self, exc):
return Errors.report_error(exc)
def run_pipeline(self, pipeline, source):
errors_occurred = False
data = source
try:
for phase in pipeline:
data = phase(data)
except CompileError:
except CompileError, err:
errors_occurred = True
Errors.report_error(err)
return (errors_occurred, data)
def create_parse(context):
......@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform
from Buffer import BufferTransform
from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [
create_parse(context),
# printit,
NormalizeTree(context),
PostParse(context),
FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes,
AnalyseExpressionsTransform(context),
BufferTransform(context),
# BufferTransform(context),
SwitchTransform(),
OptimizeRefcounting(context),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
......@@ -607,6 +621,7 @@ def main(command_line = 0):
else:
options = CompilationOptions(default_options)
sources = args
if options.show_version:
sys.stderr.write("Cython version %s\n" % Version.version)
if options.working_path!="":
......
......@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_extension_types = h_entries(env.c_class_entries)
if h_types or h_vars or h_funcs or h_extension_types:
result.h_file = replace_suffix(result.c_file, ".h")
h_code = Code.CCodeWriter(open_new_file(result.h_file))
h_code = Code.CCodeWriter()
if options.generate_pxi:
result.i_file = replace_suffix(result.c_file, ".pxi")
i_code = Code.PyrexCodeWriter(result.i_file)
......@@ -130,6 +130,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("")
h_code.putln("#endif")
h_code.copyto(open_new_file(result.h_file))
def generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % (
Naming.extern_c_macro,
......@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_api_extension_types = 1
if api_funcs or has_api_extension_types:
result.api_file = replace_suffix(result.c_file, "_api.h")
h_code = Code.CCodeWriter(open_new_file(result.api_file))
h_code = Code.CCodeWriter()
name = self.api_name(env)
guard = Naming.api_guard_prefix + name
h_code.put_h_guard(guard)
......@@ -210,6 +212,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("")
h_code.putln("#endif")
h_code.copyto(open_new_file(result.api_file))
def generate_cclass_header_code(self, type, h_code):
h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % (
Naming.extern_c_macro,
......@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_code(self, env, options, result):
modules = self.referenced_modules
if Options.annotate or options.annotate:
code = Annotate.AnnotationCCodeWriter(StringIO())
code = Annotate.AnnotationCCodeWriter()
else:
code = Code.CCodeWriter(StringIO())
code.h = Code.CCodeWriter(StringIO())
code.init_labels()
self.generate_module_preamble(env, modules, code.h)
code = Code.CCodeWriter()
h_code = code.insertion_point()
self.generate_module_preamble(env, modules, h_code)
code.putln("")
code.putln("/* Implementation of %s */" % env.qualified_name)
......@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.mark_pos(None)
self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code)
self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_utility_functions(env, code, h_code)
self.generate_declarations_for_modules(env, modules, code.h)
self.generate_declarations_for_modules(env, modules, h_code)
h_code.write('\n')
f = open_new_file(result.c_file)
f.write(code.h.f.getvalue())
f.write("\n")
f.write(code.f.getvalue())
code.copyto(f)
f.close()
result.c_file_generated = 1
if Options.annotate or options.annotate:
......@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif")
code.put(builtin_module_name_utility_code[0])
......@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("0")
code.putln("};")
code.putln()
code.enter_cfunc_scope() # as we need labels
code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set)
code.putln("char** type_name = %s_type_names;" % Naming.import_star)
code.putln("while (*type_name) {")
......@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("return -1;")
code.putln("}")
code.putln(import_star_utility_code)
code.exit_cfunc_scope() # done with labels
def generate_module_init_func(self, imported_modules, env, code):
code.enter_cfunc_scope()
code.putln("")
header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name
header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name
......@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(header3)
code.putln("#endif")
code.putln("{")
code.put_var_declarations(env.temp_entries)
tempdecl_code = code.insertion_point()
code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos)));
code.putln("/*--- Libary function declarations ---*/")
......@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Execution code ---*/")
code.mark_pos(None)
self.body.generate_execution_code(code)
if Options.generate_cleanup_code:
......@@ -1610,6 +1612,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("#endif")
code.putln('}')
tempdecl_code.put_var_declarations(env.temp_entries)
tempdecl_code.put_temp_declarations(code.func)
code.exit_cfunc_scope()
def generate_module_cleanup_func(self, env, code):
if not Options.generate_cleanup_code:
return
......@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"%s = &%s;" % (
type.typeptr_cname, type.typeobj_cname))
def generate_utility_functions(self, env, code):
def generate_utility_functions(self, env, code, h_code):
code.putln("")
code.putln("/* Runtime support code */")
code.putln("")
......@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
(Naming.filetable_cname, Naming.filenames_cname))
code.putln("}")
for utility_code in env.utility_code_used:
code.h.put(utility_code[0])
h_code.put(utility_code[0])
code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions)
code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# For now, hard-code numpy imported as "numpy"
types = []
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------
#
# Runtime support code
......
......@@ -8,6 +8,9 @@
pyrex_prefix = "__pyx_"
codewriter_temp_prefix = pyrex_prefix + "t_"
temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_"
......@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_"
type_prefix = pyrex_prefix + "t_"
typeobj_prefix = pyrex_prefix + "type_"
var_prefix = pyrex_prefix + "v_"
bufstruct_prefix = pyrex_prefix + "bstruct_"
bufstride_prefix = pyrex_prefix + "bstride_"
bufshape_prefix = pyrex_prefix + "bshape_"
bufsuboffset_prefix = pyrex_prefix + "boffset_"
vtable_prefix = pyrex_prefix + "vtable_"
vtabptr_prefix = pyrex_prefix + "vtabptr_"
vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
......
......@@ -73,6 +73,7 @@ class Node(object):
is_name = 0
is_literal = 0
temps = None
# All descandants should set child_attrs to a list of the attributes
# containing nodes considered "children" in the tree. Each such attribute
......@@ -102,7 +103,7 @@ class Node(object):
for attrname in result.child_attrs:
value = getattr(result, attrname)
if isinstance(value, list):
setattr(result, attrname, value)
setattr(result, attrname, [x for x in value])
return result
......@@ -252,6 +253,11 @@ class StatListNode(Node):
child_attrs = ["stats"]
def create_analysed(pos, env, *args, **kw):
node = StatListNode(pos, *args, **kw)
return node # No node-specific analysis necesarry
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env):
for stat in self.stats:
stat.analyse_control_flow(env)
......@@ -344,12 +350,6 @@ class CDeclaratorNode(Node):
calling_convention = ""
def analyse_expressions(self, env):
pass
def generate_execution_code(self, env):
pass
class CNameDeclaratorNode(CDeclaratorNode):
# name string The Pyrex name being declared
......@@ -368,29 +368,6 @@ class CNameDeclaratorNode(CDeclaratorNode):
self.type = base_type
return self, base_type
def analyse_expressions(self, env):
self.entry = env.lookup(self.name)
if self.default is not None:
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'initalized'), True)
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'source'), 'assignment')
self.entry.used = 1
if self.type.is_pyobject:
self.entry.init_to_none = False
self.entry.init = 0
self.default.analyse_types(env)
self.default = self.default.coerce_to(self.type, env)
self.default.allocate_temps(env)
self.default.release_temp(env)
def generate_execution_code(self, code):
if self.default is not None:
self.default.generate_evaluation_code(code)
if self.type.is_pyobject:
self.default.make_owned_reference(code)
code.putln('%s = %s;' % (self.entry.cname, self.default.result_as(self.entry.type)))
self.default.generate_post_assignment_code(code)
code.putln()
class CPtrDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode
......@@ -403,12 +380,6 @@ class CPtrDeclaratorNode(CDeclaratorNode):
ptr_type = PyrexTypes.c_ptr_type(base_type)
return self.base.analyse(ptr_type, env, nonempty = nonempty)
def analyse_expressions(self, env):
self.base.analyse_expressions(env)
def generate_execution_code(self, env):
self.base.generate_execution_code(env)
class CArrayDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode
# dimension ExprNode
......@@ -629,8 +600,8 @@ class CBufferAccessTypeNode(Node):
def analyse(self, env):
base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env)
options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim)
self.type = PyrexTypes.create_buffer_type(base_type, options)
self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
mode=self.mode)
return self.type
class CComplexBaseTypeNode(CBaseTypeNode):
......@@ -685,14 +656,6 @@ class CVarDefNode(StatNode):
dest_scope.declare_var(name, type, declarator.pos,
cname = cname, visibility = self.visibility, is_cdef = 1)
def analyse_expressions(self, env):
for declarator in self.declarators:
declarator.analyse_expressions(env)
def generate_execution_code(self, code):
for declarator in self.declarators:
declarator.generate_execution_code(code)
class CStructOrUnionDefNode(StatNode):
# name string
......@@ -866,9 +829,14 @@ class FuncDefNode(StatNode, BlockNode):
return lenv
def generate_function_definitions(self, env, code, transforms):
# Generate C code for header and body of function
code.init_labels()
import Buffer
lenv = self.local_scope
# Generate C code for header and body of function
code.enter_cfunc_scope()
code.return_from_error_cleanup_label = code.new_label()
# ----- Top-level constants used by this function
code.mark_pos(self.pos)
self.generate_interned_num_decls(lenv, code)
......@@ -899,7 +867,7 @@ class FuncDefNode(StatNode, BlockNode):
(self.return_type.declaration_code(
Naming.retval_cname),
init))
code.put_var_declarations(lenv.temp_entries)
tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code)
# ----- Extern library function declarations
lenv.generate_library_function_declarations(code)
......@@ -914,6 +882,8 @@ class FuncDefNode(StatNode, BlockNode):
for entry in lenv.arg_entries:
if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
code.put_var_incref(entry)
if entry.type.is_buffer:
Buffer.put_acquire_arg_buffer(entry, code, self.pos)
# ----- Initialise local variables
for entry in lenv.var_entries:
if entry.type.is_pyobject and entry.init_to_none and entry.used:
......@@ -934,12 +904,23 @@ class FuncDefNode(StatNode, BlockNode):
val = self.return_type.default_value
if val:
code.putln("%s = %s;" % (Naming.retval_cname, val))
#code.putln("goto %s;" % code.return_label)
# ----- Error cleanup
if code.error_label in code.labels_used:
code.put_goto(code.return_label)
code.put_label(code.error_label)
code.put_var_xdecrefs(lenv.temp_entries)
# Clean up buffers -- this calls a Python function
# so need to save and restore error state
buffers_present = len(lenv.buffer_entries) > 0
if buffers_present:
code.putln("{ PyObject *__pyx_type, *__pyx_value, *__pyx_tb;")
code.putln("PyErr_Fetch(&__pyx_type, &__pyx_value, &__pyx_tb);")
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
#code.putln("%s = 0;" % entry.cname)
code.putln("PyErr_Restore(__pyx_type, __pyx_value, __pyx_tb);}")
err_val = self.error_value()
exc_check = self.caller_will_check_exceptions()
if err_val is not None or exc_check:
......@@ -957,8 +938,18 @@ class FuncDefNode(StatNode, BlockNode):
"%s = %s;" % (
Naming.retval_cname,
err_val))
# ----- Return cleanup
if buffers_present:
# Else, non-error return will be an empty clause
code.put_goto(code.return_from_error_cleanup_label)
# ----- Non-error return cleanup
# PS! If adding something here, modify the conditions for the
# goto statement in error cleanup above
code.put_label(code.return_label)
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
# ----- Return cleanup for both error and no-error return
code.put_label(code.return_from_error_cleanup_label)
if not Options.init_local_none:
for entry in lenv.var_entries:
if lenv.control_flow.get_state((entry.name, 'initalized')) is not True:
......@@ -976,7 +967,11 @@ class FuncDefNode(StatNode, BlockNode):
if not self.return_type.is_void:
code.putln("return %s;" % Naming.retval_cname)
code.putln("}")
# ----- Go back and insert temp variable declarations
tempvardecl_code.put_var_declarations(lenv.temp_entries)
tempvardecl_code.put_temp_declarations(code.func)
# ----- Python version
code.exit_cfunc_scope()
if self.py_func:
self.py_func.generate_function_definitions(env, code, transforms)
self.generate_optarg_wrapper_function(env, code)
......@@ -2231,8 +2226,10 @@ class SingleAssignmentNode(AssignmentNode):
#
# lhs ExprNode Left hand side
# rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs?
child_attrs = ["lhs", "rhs"]
first = False
def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env)
......@@ -3344,6 +3341,7 @@ class TryExceptStatNode(StatNode):
self.gil_check(env)
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
self.cleanup_list = env.free_temp_entries[:]
for except_clause in self.except_clauses:
......@@ -3524,6 +3522,12 @@ class TryFinallyStatNode(StatNode):
# continue in the try block, since we have no problem
# handling it.
def create_analysed(pos, env, body, finally_clause):
node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause)
node.cleanup_list = []
return node
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env):
env.start_branching(self.pos)
self.body.analyse_control_flow(env)
......
......@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform):
def visit_Node(self, node):
self.visitchildren(node)
return node
class OptimizeRefcounting(Visitor.CythonTransform):
def visit_SingleAssignmentNode(self, node):
if node.first:
lhs = node.lhs
if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
# Have variable initialized to 0 rather than None
lhs.entry.init_to_none = False
lhs.entry.init = 0
# Set a flag in NameNode to skip the decref
lhs.skip_assignment_decref = True
return node
......@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
class PostParse(CythonTransform):
"""
Basic interpretation of the parse tree, as well as validity
......@@ -91,6 +93,9 @@ class PostParse(CythonTransform):
as such).
Specifically:
- Default values to cdef assignments are turned into single
assignments following the declaration (everywhere but in class
bodies, where they raise a compile error)
- CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and
......@@ -101,12 +106,65 @@ class PostParse(CythonTransform):
if a more pure Abstract Syntax Tree is wanted.
"""
buffer_options = ("dtype", "ndim") # ordered!
# Track our context.
scope_type = None # can be either of 'module', 'function', 'class'
def visit_ModuleNode(self, node):
self.scope_type = 'module'
self.visitchildren(node)
return node
def visit_ClassDefNode(self, node):
prev = self.scope_type
self.scope_type = 'class'
self.visitchildren(node)
self.scope_type = prev
return node
def visit_FuncDefNode(self, node):
prev = self.scope_type
self.scope_type = 'function'
self.visitchildren(node)
self.scope_type = prev
return node
# cdef variables
def visit_CVarDefNode(self, node):
# This assumes only plain names and pointers are assignable on
# declaration. Also, it makes use of the fact that a cdef decl
# must appear before the first use, so we don't have to deal with
# "i = 3; cdef int i = i" and can simply move the nodes around.
try:
self.visitchildren(node)
except PostParseError, e:
# An error in a cdef clause is ok, simply remove the declaration
# and try to move on to report more errors
self.context.nonfatal_error(e)
return None
stats = [node]
for decl in node.declarators:
while isinstance(decl, CPtrDeclaratorNode):
decl = decl.base
if isinstance(decl, CNameDeclaratorNode):
if decl.default is not None:
if self.scope_type == 'class':
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
stats.append(SingleAssignmentNode(node.pos,
lhs=NameNode(node.pos, name=decl.name),
rhs=decl.default, first=True))
decl.default = None
return stats
# buffer access
buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
options = {}
# Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY)
raise PostParseError(node.pos, ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name)
options[name] = arg
......@@ -114,21 +172,19 @@ class PostParse(CythonTransform):
for item in node.keyword_args.key_value_pairs:
name = str(item.key.value)
if not name in self.buffer_options:
raise PostParseError(item.key.pos,
ERR_BUF_UNKNOWN % name)
raise PostParseError(item.key.pos, ERR_BUF_OPTION_UNKNOWN % name)
if name in options.keys():
raise PostParseError(item.key.pos,
ERR_BUF_DUP % key)
raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
options[name] = item.value
provided = options.keys()
# get dtype
dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
if dtype is None:
raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype_node = dtype
# get ndim
if "ndim" in provided:
if "ndim" in options:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
......@@ -141,6 +197,17 @@ class PostParse(CythonTransform):
else:
node.ndim = 1
if "mode" in options:
modenode = options["mode"]
if not isinstance(modenode, StringNode):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
mode = modenode.value
if not mode in ('full', 'strided'):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
node.mode = mode
else:
node.mode = 'full'
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
......@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform):
self.env_stack.pop()
return node
# Some nodes are no longer needed after declaration
# analysis and can be dropped. The analysis was performed
# on these nodes in a seperate recursive process from the
# enclosing function or module, so we can simply drop them.
def visit_CVarDefNode(self, node):
return None
class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope)
......
......@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty):
# Treat trailing [] on type as buffer access
if 0: # s.sy == '[':
if is_basic:
s.error("Basic C types do not support buffer access")
if not is_basic and s.sy == '[':
return p_buffer_access(s, type_node)
else:
return type_node
......
......@@ -6,21 +6,6 @@ from Cython import Utils
import Naming
import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType:
#
......@@ -57,6 +42,7 @@ class PyrexType(BaseType):
# is_unicode boolean Is a UTF-8 encoded C char * type
# is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type
# is_buffer boolean Is buffer access type
# has_attributes boolean Has C dot-selectable attributes
# default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple
......@@ -106,11 +92,11 @@ class PyrexType(BaseType):
is_unicode = 0
is_returncode = 0
is_error = 0
is_buffer = 0
has_attributes = 0
default_value = ""
parsetuple_format = ""
pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self):
# If a typedef, returns the base type.
......@@ -202,6 +188,37 @@ class CTypedefType(BaseType):
def __getattr__(self, name):
return getattr(self.typedef_base_type, name)
class BufferType(BaseType):
#
# Delegates most attribute
# lookups to the base type. ANYTHING NOT DEFINED
# HERE IS DELEGATED!
# dtype PyrexType
# ndim int
# mode str
# is_buffer boolean
# writable boolean
is_buffer = 1
writable = True
def __init__(self, base, dtype, ndim, mode):
self.base = base
self.dtype = dtype
self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode
def as_argument_type(self):
return self
def __getattr__(self, name):
return getattr(self.base, name)
def __repr__(self):
return "<BufferType %r>" % self.base
class PyObjectType(PyrexType):
#
# Base class for all Python object types (reference-counted).
......
......@@ -15,11 +15,14 @@ from TypeSlots import \
get_special_method_signature, get_property_accessor_signature
import ControlFlow
import __builtin__
from sets import Set as set
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux:
writable_needed = False
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var
self.stridevars = stridevars
......@@ -216,6 +219,7 @@ class Scope:
self.num_to_entry = {}
self.obj_to_entry = {}
self.pystring_entries = []
self.buffer_entries = []
self.control_flow = ControlFlow.LinearControlFlow()
def start_branching(self, pos):
......@@ -616,8 +620,11 @@ class Scope:
return [entry for entry in self.temp_entries
if entry not in self.free_temp_entries]
def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code)
def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used.
......@@ -738,6 +745,7 @@ class ModuleScope(Scope):
# doc_cname string C name of module doc string
# const_counter integer Counter for naming constants
# utility_code_used [string] Utility code to be included
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included
......@@ -772,6 +780,7 @@ class ModuleScope(Scope):
self.doc_cname = Naming.moddoc_cname
self.const_counter = 1
self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = []
self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"]
......@@ -930,13 +939,25 @@ class ModuleScope(Scope):
self.const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code):
def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included,
# if not already there (tested using 'is').
# if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used:
if old_code is new_code:
return
self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None,
......
......@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest):
self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x")
# See also tests/error/e_bufaccess.pyx and tets/run/bufaccess.pyx
class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets
def parse_opts(self, opts):
s = u"cdef object[%s] x" % opts
root = self.fragment(s, pipeline=[PostParse(self)]).root
buftype = root.stats[0].base_type
def nonfatal_error(self, error):
# We're passing self as context to transform to trap this
self.error = error
self.assert_(self.expect_error)
def parse_opts(self, opts, expect_error=False):
s = u"def f():\n cdef object[%s] x" % opts
self.expect_error = expect_error
root = self.fragment(s, pipeline=[NormalizeTree(self), PostParse(self)]).root
if not expect_error:
vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
else:
self.assert_(len(root.stats[0].body.stats) == 0)
def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, e.message_only)
self.parse_opts(opts, expect_error=True)
# e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, self.error.message_only)
def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3")
......@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest):
def test_use_DEF(self):
t = self.fragment(u"""
DEF ndim = 3
def f():
cdef object[int, ndim] x
cdef object[ndim=ndim, dtype=int] y
""", pipeline=[PostParse(self)]).root
self.assert_(t.stats[1].base_type.ndim == 3)
self.assert_(t.stats[2].base_type.ndim == 3)
""", pipeline=[NormalizeTree(self), PostParse(self)]).root
stats = t.stats[0].body.stats
self.assert_(stats[0].base_type.ndim == 3)
self.assert_(stats[1].base_type.ndim == 3)
# add exotic and impossible combinations as they come along
# add exotic and impossible combinations as they come along...
......@@ -6,6 +6,8 @@ import re
from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope
import Symtab
import PyrexTypes
from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node, StatListNode
from ExprNodes import NameNode
......@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = temp_name_handle(key)
self.temps = tempdict
tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
self.temp_key_to_entries = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
def get_pos(self, node):
if self.pos:
return self.pos
else:
return node.pos
def visit_Node(self, node):
if node is None:
......@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node):
tempname = self.temps.get(node.name)
if tempname is not None:
tempentry = self.temp_key_to_entries.get(node.name)
if tempentry is not None:
# Replace name with temporary
node.name = tempname
return self.visit_Node(node)
return NameNode(self.get_pos(node), name=tempentry)
# Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
else:
return self.try_substitution(node, node.name)
......@@ -157,6 +164,7 @@ def copy_code_tree(node):
INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines"
# TODO: Facilitate textwrap.indent instead
lines = [x for x in lines if x.strip() != u""]
minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
lines = [x[minindent:] for x in lines]
......
# Please see the Python header files (object.h) for docs
cdef extern from "Python.h":
ctypedef void PyObject
ctypedef struct bufferinfo:
void *buf
Py_ssize_t len
Py_ssize_t itemsize
int readonly
int ndim
char *format
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
void *internal
ctypedef bufferinfo Py_buffer
cdef enum:
PyBUF_SIMPLE,
PyBUF_WRITABLE,
PyBUF_WRITEABLE, # backwards compatability
PyBUF_FORMAT,
PyBUF_ND,
PyBUF_STRIDES,
PyBUF_C_CONTIGUOUS,
PyBUF_F_CONTIGUOUS,
PyBUF_ANY_CONTIGUOUS,
PyBUF_INDIRECT,
PyBUF_CONTIG,
PyBUF_CONTIG_RO,
PyBUF_STRIDED,
PyBUF_STRIDED_RO,
PyBUF_RECORDS,
PyBUF_RECORDS_RO,
PyBUF_FULL,
PyBUF_FULL_RO,
PyBUF_READ,
PyBUF_WRITE,
PyBUF_SHADOW
int PyObject_CheckBuffer(PyObject* obj)
int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags)
void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view)
void* PyBuffer_GetPointer(Py_buffer *view, Py_ssize_t *indices)
int PyBuffer_SizeFromFormat(char *) # actually const char
int PyBuffer_ToContiguous(void *buf, Py_buffer *view, Py_ssize_t len, char fort)
int PyBuffer_FromContiguous(Py_buffer *view, void *buf, Py_ssize_t len, char fort)
int PyObject_CopyData(PyObject *dest, PyObject *src)
int PyBuffer_IsContiguous(Py_buffer *view, char fort)
void PyBuffer_FillContiguousStrides(int ndims,
Py_ssize_t *shape,
Py_ssize_t *strides,
int itemsize,
char fort)
int PyBuffer_FillInfo(Py_buffer *view, void *buf,
Py_ssize_t len, int readonly,
int flags)
PyObject* PyObject_Format(PyObject* obj,
PyObject *format_spec)
from cStringIO import StringIO
class StringIOTree(object):
"""
See module docs.
"""
def __init__(self, stream=None):
self.prepended_children = []
if stream is None: stream = StringIO()
self.stream = stream
def getvalue(self):
return ("".join([x.getvalue() for x in self.prepended_children]) +
self.stream.getvalue())
def copyto(self, target):
"""Potentially cheaper than getvalue as no string concatenation
needs to happen."""
for child in self.prepended_children:
child.copyto(target)
target.write(self.stream.getvalue())
def write(self, what):
self.stream.write(what)
def insertion_point(self):
# Save what we have written until now
# (would it be more efficient to check with len(self.stream.getvalue())?
# leaving it out for now)
self.prepended_children.append(StringIOTree(self.stream))
# Construct the new forked object to return
other = StringIOTree()
self.prepended_children.append(other)
self.stream = StringIO()
return other
__doc__ = r"""
Implements a buffer with insertion points. When you know you need to
"get back" to a place and write more later, simply call insertion_point()
at that spot and get a new StringIOTree object that is "left behind".
EXAMPLE:
>>> a = StringIOTree()
>>> a.write('first\n')
>>> b = a.insertion_point()
>>> a.write('third\n')
>>> b.write('second\n')
>>> print a.getvalue()
first
second
third
<BLANKLINE>
>>> c = b.insertion_point()
>>> d = c.insertion_point()
>>> d.write('alpha\n')
>>> b.write('gamma\n')
>>> c.write('beta\n')
>>> print b.getvalue()
second
alpha
beta
gamma
<BLANKLINE>
>>> out = StringIO()
>>> a.copyto(out)
>>> print out.getvalue()
first
second
alpha
beta
gamma
third
<BLANKLINE>
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
......@@ -46,12 +46,13 @@ class ErrorWriter(object):
class TestBuilder(object):
def __init__(self, rootdir, workdir, selectors, annotate,
cleanup_workdir, with_pyregr):
cleanup_workdir, cleanup_sharedlibs, with_pyregr):
self.rootdir = rootdir
self.workdir = workdir
self.selectors = selectors
self.annotate = annotate
self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
self.with_pyregr = with_pyregr
def build_suite(self):
......@@ -84,6 +85,7 @@ class TestBuilder(object):
for filename in filenames:
if not (filename.endswith(".pyx") or filename.endswith(".py")):
continue
if filename.startswith('.'): continue # certain emacs backup files
if context == 'pyregr' and not filename.startswith('test_'):
continue
module = os.path.splitext(filename)[0]
......@@ -99,25 +101,29 @@ class TestBuilder(object):
test = build_test(
path, workdir, module,
annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir)
cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
else:
test = CythonCompileTestCase(
path, workdir, module,
expect_errors=expect_errors,
annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir)
cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
suite.addTest(test)
return suite
class CythonCompileTestCase(unittest.TestCase):
def __init__(self, directory, workdir, module,
expect_errors=False, annotate=False, cleanup_workdir=True):
expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True):
self.directory = directory
self.workdir = workdir
self.module = module
self.expect_errors = expect_errors
self.annotate = annotate
self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
unittest.TestCase.__init__(self)
def shortDescription(self):
......@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase):
def tearDown(self):
cleanup_c_files = WITH_CYTHON and self.cleanup_workdir
cleanup_lib_files = self.cleanup_sharedlibs
if os.path.exists(self.workdir):
for rmfile in os.listdir(self.workdir):
if not cleanup_c_files and rmfile[-2:] in (".c", ".h"):
continue
if not cleanup_lib_files and rmfile.endswith(".so") or rmfile.endswith(".dll"):
continue
if self.annotate and rmfile.endswith(".html"):
continue
try:
......@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase):
except Exception:
pass
def collect_unittests(path, suite, selectors):
def collect_unittests(path, module_prefix, suite, selectors):
def file_matches(filename):
return filename.startswith("Test") and filename.endswith(".py")
......@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors):
for f in filenames:
if file_matches(f):
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = filepath[len(path)+1:].replace(os.path.sep, '.')
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
......@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors):
module = getattr(module, x)
suite.addTests([loader.loadTestsFromModule(module)])
def collect_doctests(path, module_prefix, suite, selectors):
def package_matches(dirname):
return dirname not in ("Mac", "Distutils", "Plex")
def file_matches(filename):
return (filename.endswith(".py") and not ('~' in filename
or '#' in filename or filename.startswith('.')))
import doctest, types
for dirpath, dirnames, filenames in os.walk(path):
parentname = os.path.split(dirpath)[-1]
if package_matches(parentname):
for f in filenames:
if file_matches(f):
if not f.endswith('.py'): continue
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
for x in modulename.split('.')[1:]:
module = getattr(module, x)
if hasattr(module, "__doc__") or hasattr(module, "__test__"):
try:
suite.addTests(doctest.DocTestSuite(module))
except ValueError: # no tests
pass
if __name__ == '__main__':
from optparse import OptionParser
parser = OptionParser()
parser.add_option("--no-cleanup", dest="cleanup_workdir",
action="store_false", default=True,
help="do not delete the generated C files (allows passing --no-cython on next run)")
parser.add_option("--no-cleanup-sharedlibs", dest="cleanup_sharedlibs",
action="store_false", default=True,
help="do not delete the generated shared libary files (allows manual module experimentation)")
parser.add_option("--no-cython", dest="with_cython",
action="store_false", default=True,
help="do not run the Cython compiler, only the C compiler")
parser.add_option("--no-unit", dest="unittests",
action="store_false", default=True,
help="do not run the unit tests")
parser.add_option("--no-doctest", dest="doctests",
action="store_false", default=True,
help="do not run the doctests")
parser.add_option("--no-file", dest="filetests",
action="store_false", default=True,
help="do not run the file based tests")
......@@ -360,6 +401,8 @@ if __name__ == '__main__':
# RUN ALL TESTS!
ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests')
WORKDIR = os.path.join(os.getcwd(), 'BUILD')
UNITTEST_MODULE = "Cython"
UNITTEST_ROOT = os.path.join(os.getcwd(), UNITTEST_MODULE)
if WITH_CYTHON:
if os.path.exists(WORKDIR):
shutil.rmtree(WORKDIR, ignore_errors=True)
......@@ -382,13 +425,15 @@ if __name__ == '__main__':
test_suite = unittest.TestSuite()
if options.unittests:
collect_unittests(os.getcwd(), test_suite, selectors)
collect_unittests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.doctests:
collect_doctests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source,
options.cleanup_workdir,
options.pyregr)
options.annotate_source, options.cleanup_workdir,
options.cleanup_sharedlibs, options.pyregr)
test_suite.addTests([filetests.build_suite()])
unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
......
cdef extern:
cdef func(int[])
cdef object[int] buf
cdef class A:
cdef object[int] buf
def f():
cdef object[fakeoption=True] buf1
cdef object[int, -1] buf1b
cdef object[ndim=-1] buf2
cdef object[int, 'a'] buf3
cdef object[int,2,3,4,5,6] buf4
cdef object[int, 2, 'foo'] buf5
cdef object[int, 2, well] buf6
_ERRORS = u"""
1:11: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option
7:22: "ndim" must be non-negative
8:15: "dtype" missing
9:21: "ndim" must be an integer
10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
"""
cdef class A:
cdef int value = 3
_ERRORS = u"""
2:13: Cannot assign default value to cdef class attributes
"""
# Tests the buffer access syntax functionality by constructing
# mock buffer objects.
#
# Note that the buffers are mock objects created for testing
# the buffer access behaviour -- for instance there is no flag
# checking in the buffer objects (why test our test case?), rather
# what we want to test is what is passed into the flags argument.
#
cimport stdlib
cimport python_buffer
# Add all test_X function docstrings as unit tests
__test__ = {}
setup_string = """
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> E = ErrorBuffer("E")
"""
def testcase(func):
__test__[func.__name__] = setup_string + func.__doc__
return func
def testcas(a):
pass
#
# Buffer acquire and release tests
#
@testcase
def acquire_release(o1, o2):
"""
>>> acquire_release(A, B)
acquired A
released A
acquired B
released B
>>> acquire_release(None, None)
>>> acquire_release(None, B)
acquired B
released B
"""
cdef object[int] buf
buf = o1
buf = o2
@testcase
def acquire_raise(o):
"""
Apparently, doctest won't handle mixed exceptions and print
stats, so need to circumvent this.
>>> A.resetlog()
>>> acquire_raise(A)
Traceback (most recent call last):
...
Exception: on purpose
>>> A.printlog()
acquired A
released A
"""
cdef object[int] buf
buf = o
raise Exception("on purpose")
@testcase
def acquire_failure1():
"""
>>> acquire_failure1()
acquired working
0 3
0 3
released working
"""
cdef object[int] buf
buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = ErrorBuffer()
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure2():
"""
>>> acquire_failure2()
acquired working
0 3
0 3
released working
"""
cdef object[int] buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = ErrorBuffer()
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure3():
"""
>>> acquire_failure3()
acquired working
0 3
released working
acquired working
0 3
released working
"""
cdef object[int] buf
buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = 3
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure4():
"""
>>> acquire_failure4()
acquired working
0 3
released working
acquired working
0 3
released working
"""
cdef object[int] buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = 2
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure5():
"""
>>> acquire_failure5()
Traceback (most recent call last):
...
ValueError: Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!
"""
cdef object[int] buf
buf = IntMockBuffer("working", range(4))
buf.fail = True
buf = 3
@testcase
def acquire_nonbuffer1(first, second=None):
"""
>>> acquire_nonbuffer1(3)
Traceback (most recent call last):
...
TypeError: 'int' does not have the buffer interface
>>> acquire_nonbuffer1(type)
Traceback (most recent call last):
...
TypeError: 'type' does not have the buffer interface
>>> acquire_nonbuffer1(None, 2)
Traceback (most recent call last):
...
TypeError: 'int' does not have the buffer interface
"""
cdef object[int] buf
buf = first
buf = second
@testcase
def acquire_nonbuffer2():
"""
>>> acquire_nonbuffer2()
acquired working
0 3
released working
acquired working
0 3
released working
"""
cdef object[int] buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = ErrorBuffer
assert False
except Exception:
print buf[0], buf[3]
@testcase
def as_argument(object[int] bufarg, int n):
"""
>>> as_argument(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef int i
for i in range(n):
print bufarg[i],
print
@testcase
def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), int n=6):
"""
>>> as_argument_defval()
acquired default
0 1 2 3 4 5
released default
>>> as_argument_defval(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef int i
for i in range(n):
print bufarg[i],
print
@testcase
def cdef_assignment(obj, n):
"""
>>> cdef_assignment(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef object[int] buf = obj
cdef int i
for i in range(n):
print buf[i],
print
@testcase
def forin_assignment(objs, int pick):
"""
>>> as_argument_defval()
acquired default
0 1 2 3 4 5
released default
>>> as_argument_defval(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef object[int] buf
for buf in objs:
print buf[pick]
@testcase
def cascaded_buffer_assignment(obj):
"""
>>> cascaded_buffer_assignment(A)
acquired A
acquired A
released A
released A
"""
cdef object[int] a, b
a = b = obj
@testcase
def tuple_buffer_assignment1(a, b):
"""
>>> tuple_buffer_assignment1(A, B)
acquired A
acquired B
released A
released B
"""
cdef object[int] x, y
x, y = a, b
@testcase
def tuple_buffer_assignment2(tup):
"""
>>> tuple_buffer_assignment2((A, B))
acquired A
acquired B
released A
released B
"""
cdef object[int] x, y
x, y = tup
@testcase
def explicitly_release_buffer():
"""
>>> explicitly_release_buffer()
acquired A
released A
After release
"""
cdef object[int] x = IntMockBuffer("A", range(10))
x = None
print "After release"
#
# Index bounds checking
#
@testcase
def get_int_2d(object[int, 2] buf, int i, int j):
"""
>>> get_int_2d(C, 1, 1)
acquired C
released C
4
Check negative indexing:
>>> get_int_2d(C, -1, 0)
acquired C
released C
3
>>> get_int_2d(C, -1, -2)
acquired C
released C
4
>>> get_int_2d(C, -2, -3)
acquired C
released C
0
Out-of-bounds errors:
>>> get_int_2d(C, 2, 0)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
>>> get_int_2d(C, 0, -4)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 1)
"""
return buf[i, j]
@testcase
def get_int_2d_uintindex(object[int, 2] buf, unsigned int i, unsigned int j):
"""
Unsigned indexing:
>>> get_int_2d_uintindex(C, 0, 0)
acquired C
released C
0
>>> get_int_2d_uintindex(C, 1, 2)
acquired C
released C
5
"""
# This is most interesting with regards to the C code
# generated.
return buf[i, j]
@testcase
def set_int_2d(object[int, 2] buf, int i, int j, int value):
"""
Uses get_int_2d to read back the value afterwards. For pure
unit test, one should support reading in MockBuffer instead.
>>> set_int_2d(C, 1, 1, 10)
acquired C
released C
>>> get_int_2d(C, 1, 1)
acquired C
released C
10
Check negative indexing:
>>> set_int_2d(C, -1, 0, 3)
acquired C
released C
>>> get_int_2d(C, -1, 0)
acquired C
released C
3
>>> set_int_2d(C, -1, -2, 8)
acquired C
released C
>>> get_int_2d(C, -1, -2)
acquired C
released C
8
>>> set_int_2d(C, -2, -3, 9)
acquired C
released C
>>> get_int_2d(C, -2, -3)
acquired C
released C
9
Out-of-bounds errors:
>>> set_int_2d(C, 2, 0, 19)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
>>> set_int_2d(C, 0, -4, 19)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 1)
"""
buf[i, j] = value
#
# Buffer type mismatch examples. Varying the type and access
# method simultaneously, the odds of an interaction is virtually
# zero.
#
@testcase
def fmtst1(buf):
"""
>>> fmtst1(IntMockBuffer("A", range(3)))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (rejecting on 'i')
"""
cdef object[float] a = buf
@testcase
def fmtst2(object[int] buf):
"""
>>> fmtst2(FloatMockBuffer("A", range(3)))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (rejecting on 'f')
"""
@testcase
def ndim1(object[int, 2] buf):
"""
>>> ndim1(IntMockBuffer("A", range(3)))
Traceback (most recent call last):
...
ValueError: Buffer has wrong number of dimensions (expected 2, got 1)
"""
#
# Test which flags are passed.
#
@testcase
def readonly(obj):
"""
>>> R = UnsignedShortMockBuffer("R", range(27), shape=(3, 3, 3))
>>> readonly(R)
acquired R
25
released R
>>> R.recieved_flags
['FORMAT', 'INDIRECT', 'ND', 'STRIDES']
"""
cdef object[unsigned short int, 3] buf = obj
print buf[2, 2, 1]
@testcase
def writable(obj):
"""
>>> R = UnsignedShortMockBuffer("R", range(27), shape=(3, 3, 3))
>>> writable(R)
acquired R
released R
>>> R.recieved_flags
['FORMAT', 'INDIRECT', 'ND', 'STRIDES', 'WRITABLE']
"""
cdef object[unsigned short int, 3] buf = obj
buf[2, 2, 1] = 23
@testcase
def strided(object[int, 1, 'strided'] buf):
"""
>>> A = IntMockBuffer("A", range(4))
>>> strided(A)
acquired A
released A
2
>>> A.recieved_flags
['FORMAT', 'ND', 'STRIDES']
"""
return buf[2]
#
# Coercions
#
@testcase
def coercions(object[unsigned char] uc):
"""
TODO
"""
print type(uc[0])
uc[0] = -1
print uc[0]
uc[0] = <int>3.14
print uc[0]
#
# Testing that accessing data using various types of buffer access
# all works.
#
@testcase
def printbuf_int_2d(o, shape):
"""
Strided:
>>> printbuf_int_2d(IntMockBuffer("A", range(6), (2,3)), (2,3))
acquired A
0 1 2
3 4 5
released A
>>> printbuf_int_2d(IntMockBuffer("A", range(100), (3,3), strides=(20,5)), (3,3))
acquired A
0 5 10
20 25 30
40 45 50
released A
Indirect:
>>> printbuf_int_2d(IntMockBuffer("A", [[1,2],[3,4]]), (2,2))
acquired A
1 2
3 4
released A
"""
# should make shape builtin
cdef object[int, 2] buf
buf = o
cdef int i, j
for i in range(shape[0]):
for j in range(shape[1]):
print buf[i, j],
print
@testcase
def printbuf_float(o, shape):
"""
>>> printbuf_float(FloatMockBuffer("F", [1.0, 1.25, 0.75, 1.0]), (4,))
acquired F
1.0 1.25 0.75 1.0
released F
"""
# should make shape builtin
cdef object[float] buf
buf = o
cdef int i, j
for i in range(shape[0]):
print buf[i],
print
available_flags = (
('FORMAT', python_buffer.PyBUF_FORMAT),
('INDIRECT', python_buffer.PyBUF_INDIRECT),
('ND', python_buffer.PyBUF_ND),
('STRIDES', python_buffer.PyBUF_STRIDES),
('WRITABLE', python_buffer.PyBUF_WRITABLE)
)
cimport stdio
cdef class MockBuffer:
cdef object format
cdef void* buffer
cdef int len, itemsize, ndim
cdef Py_ssize_t* strides
cdef Py_ssize_t* shape
cdef Py_ssize_t* suboffsets
cdef object label, log
cdef readonly object recieved_flags
cdef public object fail
def __init__(self, label, data, shape=None, strides=None, format=None):
self.label = label
self.log = ""
self.itemsize = self.get_itemsize()
if format is None: format = self.get_default_format()
if shape is None: shape = (len(data),)
if strides is None:
strides = []
cumprod = 1
rshape = list(shape)
rshape.reverse()
for s in rshape:
strides.append(cumprod)
cumprod *= s
strides.reverse()
strides = [x * self.itemsize for x in strides]
suboffsets = [-1] * len(shape)
datashape = [len(data)]
p = data
while True:
p = p[0]
if isinstance(p, list): datashape.append(len(p))
else: break
if len(datashape) > 1:
# indirect access
self.ndim = len(datashape)
shape = datashape
self.buffer = self.create_indirect_buffer(data, shape)
suboffsets = [0] * (self.ndim-1) + [-1]
strides = [sizeof(void*)] * (self.ndim-1) + [self.itemsize]
self.suboffsets = self.list_to_sizebuf(suboffsets)
# printf("%ld; %ld %ld %ld %ld %ld", i0, s0, o0, i1, s1, o1);
else:
# strided and/or simple access
self.buffer = self.create_buffer(data)
self.ndim = len(shape)
self.suboffsets = NULL
self.format = format
self.len = len(data) * self.itemsize
self.strides = self.list_to_sizebuf(strides)
self.shape = self.list_to_sizebuf(shape)
def __dealloc__(self):
stdlib.free(self.strides)
stdlib.free(self.shape)
if self.suboffsets != NULL:
stdlib.free(self.suboffsets)
# must recursively free indirect...
else:
stdlib.free(self.buffer)
cdef void* create_buffer(self, data):
cdef char* buf = <char*>stdlib.malloc(len(data) * self.itemsize)
cdef char* it = buf
for value in data:
self.write(it, value)
it += self.itemsize
return buf
cdef void* create_indirect_buffer(self, data, shape):
cdef void** buf
assert shape[0] == len(data)
if len(shape) == 1:
return self.create_buffer(data)
else:
shape = shape[1:]
buf = <void**>stdlib.malloc(len(data) * sizeof(void*))
for idx, subdata in enumerate(data):
buf[idx] = self.create_indirect_buffer(subdata, shape)
return buf
cdef Py_ssize_t* list_to_sizebuf(self, l):
cdef Py_ssize_t* buf = <Py_ssize_t*>stdlib.malloc(len(l) * sizeof(Py_ssize_t))
for i, x in enumerate(l):
buf[i] = x
return buf
def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
if self.fail:
raise ValueError("Failing on purpose")
if buffer is NULL:
print u"locking!"
return
self.recieved_flags = []
cdef int value
for name, value in available_flags:
if (value & flags) == value:
self.recieved_flags.append(name)
buffer.buf = self.buffer
buffer.len = self.len
buffer.readonly = 0
buffer.format = <char*>self.format
buffer.ndim = self.ndim
buffer.shape = self.shape
buffer.strides = self.strides
buffer.suboffsets = self.suboffsets
buffer.itemsize = self.itemsize
buffer.internal = NULL
msg = "acquired %s" % self.label
print msg
self.log += msg + "\n"
def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
msg = "released %s" % self.label
print msg
self.log += msg + "\n"
def printlog(self):
print self.log,
def resetlog(self):
self.log = ""
cdef int write(self, char* buf, object value) except -1: raise Exception()
cdef get_itemsize(self):
print "ERROR, not subclassed", self.__class__
cdef get_default_format(self):
print "ERROR, not subclassed", self.__class__
cdef class FloatMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<float*>buf)[0] = <float>value
return 0
cdef get_itemsize(self): return sizeof(float)
cdef get_default_format(self): return "=f"
cdef class IntMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<int*>buf)[0] = <int>value
return 0
cdef get_itemsize(self): return sizeof(int)
cdef get_default_format(self): return "=i"
cdef class UnsignedShortMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<unsigned short*>buf)[0] = <unsigned short>value
return 0
cdef get_itemsize(self): return sizeof(unsigned short)
cdef get_default_format(self): return "=H"
cdef class ErrorBuffer:
cdef object label
def __init__(self, label):
self.label = label
def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
raise Exception("acquiring %s" % self.label)
def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
raise Exception("releasing %s" % self.label)
__doc__ = """
>>> test(1, 2)
4 1 2 2 0 7 8
"""
cdef int g = 7
def test(x, int y):
if True:
before = 0
cdef int a = 4, b = x, c = y, *p = &y
cdef object o = int(8)
print a, b, c, p[0], before, g, o
# Also test that pruning cdefs doesn't hurt
def empty():
cdef int i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment