Commit b35bb6c0 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge in Dag's work

parents bbb832bc eaa6d477
...@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'), ...@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'),
class AnnotationCCodeWriter(CCodeWriter): class AnnotationCCodeWriter(CCodeWriter):
def __init__(self, f): def __init__(self, create_from=None, buffer=None):
CCodeWriter.__init__(self, self) CCodeWriter.__init__(self, create_from, buffer)
self.buffer = StringIO() self.annotation_buffer = StringIO()
self.real_f = f if create_from is None:
self.annotations = [] self.annotations = []
self.last_pos = None self.last_pos = None
self.code = {} self.code = {}
else:
def getvalue(self): # When creating an insertion point, keep references to the same database
return self.real_f.getvalue() self.annotation_buffer = create_from.annotation_buffer
self.annotations = create_from.annotations
self.code = create_from.code
def create_new(self, create_from, buffer):
return AnnotationCCodeWriter(create_from, buffer)
def write(self, s): def write(self, s):
self.real_f.write(s) CCodeWriter.write(self, s)
self.buffer.write(s) self.annotation_buffer.write(s)
def mark_pos(self, pos): def mark_pos(self, pos):
# if pos is not None: # if pos is not None:
# CCodeWriter.mark_pos(self, pos) # CCodeWriter.mark_pos(self, pos)
# return # return
if self.last_pos: if self.last_pos:
try: code = self.code.get(self.last_pos[1], "")
code = self.code[self.last_pos[1]] self.code[self.last_pos[1]] = code + self.annotation_buffer.getvalue()
except KeyError: self.annotation_buffer = StringIO()
code = ""
self.code[self.last_pos[1]] = code + self.buffer.getvalue()
self.buffer = StringIO()
self.last_pos = pos self.last_pos = pos
def annotate(self, pos, item): def annotate(self, pos, item):
......
...@@ -7,84 +7,33 @@ from Cython.Utils import EncodedString ...@@ -7,84 +7,33 @@ from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError from Cython.Compiler.Errors import CompileError
import PyrexTypes import PyrexTypes
from sets import Set as set from sets import Set as set
import textwrap
class PureCFuncNode(Node): def dedent(text, reindent=0):
def __init__(self, pos, cname, type, c_code, visibility='private'): text = textwrap.dedent(text)
self.pos = pos if reindent > 0:
self.cname = cname indent = " " * reindent
self.type = type text = '\n'.join([indent + x for x in text.split('\n')])
self.c_code = c_code return text
self.visibility = visibility
def analyse_types(self, env):
self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname,
defining=True, visibility=self.visibility)
def generate_function_definitions(self, env, code, transforms):
assert self.type.optional_arg_count == 0
visibility = self.entry.visibility
if visibility != 'private':
storage_class = "%s " % Naming.extern_c_macro
else:
storage_class = "static "
arg_decls = [arg.declaration_code() for arg in self.type.args]
sig = self.type.return_type.declaration_code(
self.type.function_header_code(self.cname, ", ".join(arg_decls)))
code.putln("")
code.putln("%s%s {" % (storage_class, sig))
code.put(self.c_code)
code.putln("}")
def generate_execution_code(self, code):
pass
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform): class IntroduceBufferAuxiliaryVars(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
# #
# Entry point # Entry point
# #
buffers_exists = False
def __call__(self, node): def __call__(self, node):
assert isinstance(node, ModuleNode) assert isinstance(node, ModuleNode)
self.max_ndim = 0
try: result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
cymod = self.context.modules[u'__cython__'] if self.buffers_exists:
except KeyError: if "endian.h" not in node.scope.include_files:
# No buffer fun for this module node.scope.include_files.append("endian.h")
return node use_py2_buffer_functions(node.scope)
self.bufstruct_type = cymod.entries[u'Py_buffer'].type use_empty_bufstruct_code(node.scope, self.max_ndim)
self.tscheckers = {} node.scope.use_utility_code(access_utility_code)
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result return result
...@@ -95,136 +44,60 @@ class BufferTransform(CythonTransform): ...@@ -95,136 +44,60 @@ class BufferTransform(CythonTransform):
# For all buffers, insert extra variables in the scope. # For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info # The variables are also accessible from the buffer_info
# on the buffer entry # on the buffer entry
bufvars = [(name, entry) for name, entry bufvars = [entry for name, entry
in scope.entries.iteritems() in scope.entries.iteritems()
if entry.type.buffer_options is not None] if entry.type.is_buffer]
if len(bufvars) > 0:
for name, entry in bufvars: self.buffers_exists = True
bufopts = entry.type.buffer_options
# Get or make a type string checker
tschecker = self.tschecker(bufopts.dtype)
# Declare auxiliary vars
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.bufstruct_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars, tschecker)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
# Notes: The cast to <char*> gets around Cython not supporting const types if isinstance(node, ModuleNode) and len(bufvars) > 0:
acquire_buffer_fragment = TreeFragment(u""" # for now...note that pos is wrong
TMP = LHS raise CompileError(node.pos, "Buffer vars not allowed in module scope")
if TMP is not None: for entry in bufvars:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) name = entry.name
TMP = RHS buftype = entry.type
if TMP is not None: if buftype.ndim > self.max_ndim:
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0) self.max_ndim = buftype.ndim
TSCHECKER(<char*>BUFINFO.format)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u""" # Get or make a type string checker
TARGET = BUFINFO.strides[IDX] tschecker = buffer_type_checker(buftype.dtype, scope)
""")
fetch_shape = TreeFragment(u""" # Declare auxiliary vars
TARGET = BUFINFO.shape[IDX] cname = scope.mangle(Naming.bufstruct_prefix, name)
""") bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
type=PyrexTypes.c_py_buffer_type, pos=node.pos)
def reacquire_buffer(self, node): bufinfo.used = True
bufaux = node.lhs.entry.buffer_aux
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx)),
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node): def var(prefix, idx, initval):
pos = node.pos cname = scope.mangle(prefix, "%d_%s" % (idx, name))
bufaux = node.base.entry.buffer_aux result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
assert bufaux is not None node.pos, cname=cname, is_cdef=True)
# indices * strides...
to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them with the buffer pointer
expr = AttributeNode(pos,
obj=NameNode(pos, name=bufaux.buffer_info_var.name),
attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result result.init = initval
if entry.is_arg:
result.used = True
return result
stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
mode = entry.type.mode
if mode == 'full':
suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim)
elif mode == 'strided':
suboffsetvars = None
entry.buffer_aux.lookup = get_buf_lookup_strided(scope, entry.type.ndim)
entry.buffer_aux.suboffsetvars = suboffsetvars
entry.buffer_aux.get_buffer_cname = tschecker
scope.buffer_entries = bufvars
self.scope = scope
#
# Transforms
#
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope) self.handle_scope(node, node.scope)
self.visitchildren(node) self.visitchildren(node)
...@@ -235,102 +108,541 @@ class BufferTransform(CythonTransform): ...@@ -235,102 +108,541 @@ class BufferTransform(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ... def get_flags(buffer_aux, buffer_type):
# Since we don't allow nested buffers, these don't overlap. flags = 'PyBUF_FORMAT'
self.visitchildren(node) if buffer_type.mode == 'full':
# Only acquire buffers on vars (not attributes) for now. flags += '| PyBUF_INDIRECT'
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux: elif buffer_type.mode == 'strided':
# Is buffer variable flags += '| PyBUF_STRIDES'
return self.reacquire_buffer(node) else:
elif (isinstance(node.lhs, IndexNode) and assert False
isinstance(node.lhs.base, NameNode) and if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
node.lhs.base.entry.buffer_aux is not None): return flags
return self.assign_into_buffer(node)
else:
return node
def visit_IndexNode(self, node): def used_buffer_aux_vars(entry):
# Only occurs when the IndexNode is an rvalue buffer_aux = entry.buffer_aux
if node.is_buffer_access: buffer_aux.buffer_info_var.used = True
assert node.index is None for s in buffer_aux.shapevars: s.used = True
assert node.indices is not None for s in buffer_aux.stridevars: s.used = True
result = self.buffer_index(node) for s in buffer_aux.suboffsetvars: s.used = True
result.analyse_expressions(self.scope)
return result def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
else: # Generate code to copy the needed struct info into local
return node # variables.
bufstruct = buffer_aux.buffer_info_var.cname
varspec = [("strides", buffer_aux.stridevars),
("shape", buffer_aux.shapevars)]
if mode == 'full':
varspec.append(("suboffsets", buffer_aux.suboffsetvars))
for field, vars in varspec:
code.putln(" ".join(["%s = %s.%s[%d];" %
(s.cname, bufstruct, field, idx)
for idx, s in enumerate(vars)]))
def getbuffer_cond_code(obj_cname, buffer_aux, flags, ndim):
bufstruct = buffer_aux.buffer_info_var.cname
return "%s(%s, &%s, %s, %d) == -1" % (
buffer_aux.get_buffer_cname, obj_cname, bufstruct, flags, ndim)
def put_acquire_arg_buffer(entry, code, pos):
buffer_aux = entry.buffer_aux
cname = entry.cname
bufstruct = buffer_aux.buffer_info_var.cname
flags = get_flags(buffer_aux, entry.type)
# Acquire any new buffer
code.putln(code.error_goto_if(getbuffer_cond_code(cname,
buffer_aux,
flags,
entry.type.ndim),
pos))
# An exception raised in arg parsing cannot be catched, so no
# need to do care about the buffer then.
put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
#def put_release_buffer_normal(entry, code):
# code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
# entry.cname,
# entry.cname,
# entry.buffer_aux.buffer_info_var.cname))
def get_release_buffer_code(entry):
return "if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s)" % (
entry.cname,
entry.cname,
entry.buffer_aux.buffer_info_var.cname)
def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
is_initialized, pos, code):
"""
Generate code for reassigning a buffer variables. This only deals with getting
the buffer auxiliary structure and variables set up correctly, the assignment
itself and refcounting is the responsibility of the caller.
# However, the assignment operation may throw an exception so that the reassignment
# Utils for creating type string checkers never happens.
#
def new_ts_func(self, name, code): Depending on the circumstances there are two possible outcomes:
cname = "%s_%s" % (tsprefix, name) - Old buffer released, new acquired, rhs assigned to lhs
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code) - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
funcnode.analyse_types(self.module_scope) (which may or may not succeed).
self.ts_funcs.append(funcnode) """
return funcnode
def mangle_dtype_name(self, dtype): bufstruct = buffer_aux.buffer_info_var.cname
# Use prefixes to seperate user defined types from builtins flags = get_flags(buffer_aux, buffer_type)
# (consider "typedef float unsigned_int")
return dtype.declaration_code("").replace(" ", "_") getbuffer = "%s(%%s, &%s, %s, %d)" % (buffer_aux.get_buffer_cname,
# note: object is filled in later
bufstruct,
flags,
buffer_type.ndim)
if is_initialized:
# Release any existing buffer
code.put('if (%s != Py_None) ' % lhs_cname)
code.begin_block();
code.putln('PyObject_ReleaseBuffer(%s, &%s);' % (
lhs_cname, bufstruct))
code.end_block()
# Acquire
retcode_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
# If acquisition failed, attempt to reacquire the old buffer
# before raising the exception. A failure of reacquisition
# will cause the reacquisition exception to be reported, one
# can consider working around this later.
code.begin_block()
type, value, tb = [code.func.allocate_temp(PyrexTypes.py_object_type)
for i in range(3)]
code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
code.begin_block()
code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
code.putln('__Pyx_RaiseBufferFallbackError();')
code.putln('} else {')
code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
for t in (type, value, tb):
code.func.release_temp(t)
code.end_block()
# Unpack indices
code.end_block()
put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln(code.error_goto_if_neg(retcode_cname, pos))
code.func.release_temp(retcode_cname)
else:
# Our entry had no previous value, so set to None when acquisition fails.
# In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
# so it suffices to set the buf field to NULL.
code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
code.putln('%s = Py_None; Py_INCREF(Py_None); %s.buf = NULL;' % (lhs_cname, bufstruct))
code.putln(code.error_goto(pos))
code.put('} else {')
# Unpack indices
put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln('}')
def put_access(entry, index_signeds, index_cnames, pos, code):
"""Returns a c string which can be used to access the buffer
for reading or writing.
As the bounds checking can have any number of combinations of unsigned
arguments, smart optimizations etc. we insert it directly in the function
body. The lookup however is delegated to a inline function that is instantiated
once per ndim (lookup with suboffsets tend to get quite complicated).
"""
bufaux = entry.buffer_aux
bufstruct = bufaux.buffer_info_var.cname
# Check bounds and fix negative indices
boundscheck = True
nonegs = True
tmp_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
if boundscheck:
code.putln("%s = -1;" % tmp_cname)
for idx, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
bufaux.shapevars)):
if signed != 0:
nonegs = False
# not unsigned, deal with negative index
code.putln("if (%s < 0) {" % cname)
code.putln("%s += %s;" % (cname, shape.cname))
if boundscheck:
code.putln("if (%s) %s = %d;" % (
code.unlikely("%s < 0" % cname), tmp_cname, idx))
code.put("} else ")
else:
if idx > 0: code.put("else ")
if boundscheck:
# check bounds in positive direction
code.putln("if (%s) %s = %d;" % (
code.unlikely("%s >= %s" % (cname, shape.cname)),
tmp_cname, idx))
if boundscheck:
code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
code.begin_block()
code.putln('__Pyx_BufferIndexError(%s);' % tmp_cname)
code.putln(code.error_goto(pos))
code.end_block()
code.func.release_temp(tmp_cname)
# Create buffer lookup and return it
params = []
if entry.type.mode == 'full':
for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
params.append(i)
params.append(s.cname)
params.append(o.cname)
else:
for i, s in zip(index_cnames, bufaux.stridevars):
params.append(i)
params.append(s.cname)
ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct,
", ".join(params))
valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode
def use_empty_bufstruct_code(env, max_ndim):
code = dedent("""
Py_ssize_t __Pyx_zeros[] = {%s};
Py_ssize_t __Pyx_minusones[] = {%s};
""") % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
env.use_utility_code([code, ""])
def get_buf_lookup_strided(env, nd):
"""
Generates and registers as utility a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location.
"""
name = "__Pyx_BufPtrStrided_%dd" % nd
if not env.has_utility_code(name):
# _i_ndex, _s_tride
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
proto = dedent("""\
#define %s(buf, %s) ((char*)buf + %s)
""") % (name, args, offset)
env.use_utility_code([proto, ""], name=name)
def get_ts_check_item(self, dtype): return name
# See if we can consume one (unnamed) dtype as next item
funcnode = self.ts_item_checkers.get(dtype)
if funcnode is None: def get_buf_lookup_full(env, nd):
char = dtype.typestring """
if char is not None and len(char) > 1: Generates and registers as utility a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location.
"""
name = "__Pyx_BufPtrFull_%dd" % nd
if not env.has_utility_code(name):
# _i_ndex, _s_tride, sub_o_ffset
args = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
proto = dedent("""\
static INLINE void* %s(void* buf, %s);
""") % (name, args)
func = dedent("""
static INLINE void* %s(void* buf, %s) {
char* ptr = (char*)buf;
""") % (name, args) + "".join([dedent("""\
ptr += s%d * i%d;
if (o%d >= 0) ptr = *((char**)ptr) + o%d;
""") % (i, i, i, i) for i in range(nd)]
) + "\nreturn ptr;\n}"
env.use_utility_code([proto, func], name=name)
return name
#
# Utils for creating type string checkers
#
def mangle_dtype_name(dtype):
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
if dtype.typestring is None:
prefix = "nn_"
else:
prefix = ""
return prefix + dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(dtype, env):
# See if we can consume one (unnamed) dtype as next item
# Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype)
if not env.has_utility_code(name):
char = dtype.typestring
if char is not None:
# Can use direct comparison # Can use direct comparison
funcnode = self.new_ts_func("natitem_%s" % self.mangle_dtype_name(dtype), """\ code = dedent("""\
if (*ts != '%s') { if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts); PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL; return NULL;
} else return ts + 1; } else return ts + 1;
""" % char) """, 2) % char
else: else:
# Must deduce sign and length; rely on int vs. float to be correctly declared # Cannot trust declared size; but rely on int vs float and
ctype = dtype.declaration_code("") # signed/unsigned to be correctly declared
ctype = dtype.declaration_code("")
code = """\ code = dedent("""\
int ok; int ok;
switch (*ts) {""" switch (*ts) {""", 2)
if dtype.is_int: if dtype.is_int:
types = [ types = [
('b', 'char'), ('h', 'short'), ('i', 'int'), ('b', 'char'), ('h', 'short'), ('i', 'int'),
('l', 'long'), ('q', 'long long') ('l', 'long'), ('q', 'long long')
] ]
code += "".join(["""\ if dtype.signed == 0:
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break; code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break; (char.upper(), ctype, against, ctype) for char, against in types])
""" % (char, ctype, against, ctype, char.upper(), ctype, "unsigned " + against, ctype) for else:
char, against in types]) code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
code += """\ (char, ctype, against, ctype) for char, against in types])
default: ok = 0; code += dedent("""\
default: ok = 0;
}
if (!ok) {
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%s')", ts);
return NULL;
} else return ts + 1;
""", 2)
env.use_utility_code([dedent("""\
static const char* %s(const char* ts); /*proto*/
""") % name, dedent("""
static const char* %s(const char* ts) {
%s
}
""") % (name, code)], name=name)
return name
def get_getbuffer_code(dtype, env):
"""
Generate a utility function for getting a buffer for the given dtype.
The function will:
- Call PyObject_GetBuffer
- Check that ndim matched the expected value
- Check that the format string is right
- Set suboffsets to all -1 if it is returned as NULL.
"""
name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
if not env.has_utility_code(name):
env.use_utility_code(acquire_utility_code)
itemchecker = get_ts_check_item(dtype, env)
utilcode = [dedent("""
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/
""") % name, dedent("""
static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd) {
const char* ts;
if (obj == Py_None) {
__Pyx_ZeroBuffer(buf);
return 0;
}
buf->buf = NULL;
if (PyObject_GetBuffer(obj, buf, flags) == -1) goto fail;
if (buf->ndim != nd) {
__Pyx_BufferNdimError(buf, nd);
goto fail;
}
ts = buf->format;
ts = __Pyx_ConsumeWhitespace(ts);
ts = __Pyx_BufferTypestringCheckEndian(ts);
if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts);
ts = %(itemchecker)s(ts);
if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts);
if (*ts != 0) {
PyErr_Format(PyExc_ValueError,
"Expected non-struct buffer data type (rejecting on '%%s')", ts);
goto fail;
}
if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
return 0;
fail:;
__Pyx_ZeroBuffer(buf);
return -1;
}""") % locals()]
env.use_utility_code(utilcode, name)
return name
def buffer_type_checker(dtype, env):
# Creates a type checker function for the given type.
if dtype.is_struct_or_union:
assert False
elif dtype.is_int or dtype.is_float:
# This includes simple typedef-ed types
funcname = get_getbuffer_code(dtype, env)
else:
assert False
return funcname
def use_py2_buffer_functions(env):
# will be refactored
try:
env.entries[u'numpy']
env.use_utility_code(["","""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
} }
if (!ok) {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%s')", ts); /*
return NULL; NumPy format codes doesn't completely match buffer codes;
} else return ts + 1; seems safest to retranslate.
""" 01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
funcnode = self.new_ts_func("tdefitem_%s" % self.mangle_dtype_name(dtype), code)
char* format = (char*)malloc(4);
self.ts_item_checkers[dtype] = funcnode char* fp = format;
return funcnode.entry.cname *fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
ts_consume_whitespace_cname = None *fp++ = base_codes[typenum];
ts_check_endian_cname = None *fp = 0;
def ensure_ts_utils(self): view->buf = arr->data;
# Makes sure that the typechecker utils are in scope view->readonly = !PyArray_ISWRITEABLE(obj);
# (and constructs them if not) view->ndim = PyArray_NDIM(arr);
if self.ts_consume_whitespace_cname is None: view->strides = PyArray_STRIDES(arr);
self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\ view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
"""])
except KeyError:
pass
codename = "PyObject_GetBuffer" # just a representative unique key
# Search all types for __getbuffer__ overloads
types = []
def find_buffer_types(scope):
for m in scope.cimported_modules:
find_buffer_types(m)
for e in scope.type_entries:
t = e.type
if t.is_extension_type:
release = get = None
for x in t.scope.pyfunc_entries:
if x.name == u"__getbuffer__": get = x.func_cname
elif x.name == u"__releasebuffer__": release = x.func_cname
if get:
types.append((t.typeptr_cname, get, release))
find_buffer_types(env)
# For now, hard-code numpy imported as "numpy"
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code = dedent("""
#if PY_VERSION_HEX < 0x02060000
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
""")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
clause = "else if"
code += " else {\n"
code += dedent("""\
PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
return -1;
""", 2)
if len(types) > 0: code += " }"
code += dedent("""
}
static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {
""")
if len(types) > 0:
clause = "if"
for t, get, release in types:
if release:
code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
clause = "else if"
code += dedent("""
}
#endif
""")
env.use_utility_code([dedent("""\
#if PY_VERSION_HEX < 0x02060000
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);
#endif
""") ,code], codename)
#
# Static utility code
#
# Utility function to set the right exception
# The caller should immediately goto_error
access_utility_code = [
"""\
static void __Pyx_BufferIndexError(int axis); /*proto*/
""","""\
static void __Pyx_BufferIndexError(int axis) {
PyErr_Format(PyExc_IndexError,
"Out of bounds on buffer access (axis %d)", axis);
}
"""]
#
# Buffer type checking. Utility code for checking that acquired
# buffers match our assumptions. We only need to check ndim and
# the format string; the access mode/flags is checked by the
# exporter.
#
acquire_utility_code = ["""\
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
static INLINE const char* __Pyx_BufferTypestringCheckEndian(const char* ts); /*proto*/
static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
""", """
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
buf->buf = NULL;
buf->strides = __Pyx_zeros;
buf->shape = __Pyx_zeros;
buf->suboffsets = __Pyx_minusones;
}
static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
while (1) { while (1) {
switch (*ts) { switch (*ts) {
case 10: case 10:
...@@ -341,9 +653,9 @@ class BufferTransform(CythonTransform): ...@@ -341,9 +653,9 @@ class BufferTransform(CythonTransform):
return ts; return ts;
} }
} }
""").entry.cname }
if self.ts_check_endian_cname is None:
self.ts_check_endian_cname = self.new_ts_func("check_endian", """\ static INLINE const char* __Pyx_BufferTypestringCheckEndian(const char* ts) {
int ok = 1; int ok = 1;
switch (*ts) { switch (*ts) {
case '@': case '@':
...@@ -360,55 +672,21 @@ class BufferTransform(CythonTransform): ...@@ -360,55 +672,21 @@ class BufferTransform(CythonTransform):
break; break;
} }
if (!ok) { if (!ok) {
PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts); PyErr_Format(PyExc_ValueError, "Buffer has wrong endianness (rejecting on '%s')", ts);
return NULL; return NULL;
} }
return ts; return ts;
""").entry.cname }
def create_ts_check_simple(self, dtype): static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
# Check whole string for single unnamed item PyErr_Format(PyExc_ValueError,
consume_whitespace = self.ts_consume_whitespace_cname "Buffer has wrong number of dimensions (expected %d, got %d)",
check_endian = self.ts_check_endian_cname expected_ndim, buffer->ndim);
check_item = self.get_ts_check_item(dtype) }
return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\
ts = %(consume_whitespace)s(ts); static void __Pyx_RaiseBufferFallbackError(void) {
ts = %(check_endian)s(ts); PyErr_Format(PyExc_ValueError,
if (!ts) return NULL; "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
ts = %(consume_whitespace)s(ts); }
ts = %(check_item)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
if (*ts != 0) {
PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts);
return NULL;
}
return ts;
""" % locals())
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
self.ensure_ts_utils()
funcnode = self.tscheckers.get(dtype)
if funcnode is None:
if dtype.is_struct_or_union:
assert False
elif dtype.is_int or dtype.is_float:
# This includes simple typedef-ed types
funcnode = self.create_ts_check_simple(dtype)
else:
assert False
self.tscheckers[dtype] = funcnode
return funcnode.entry
# TODO:
# - buf must be NULL before getting new buffer
"""]
...@@ -9,9 +9,142 @@ from Cython.Utils import open_new_file, open_source_file ...@@ -9,9 +9,142 @@ from Cython.Utils import open_new_file, open_source_file
from PyrexTypes import py_object_type, typecast from PyrexTypes import py_object_type, typecast
from TypeSlots import method_coexist from TypeSlots import method_coexist
from Scanning import SourceDescriptor from Scanning import SourceDescriptor
from Cython.StringIOTree import StringIOTree
from sets import Set as set
class CCodeWriter: class FunctionContext(object):
# Not used for now, perhaps later
def __init__(self, names_taken=set()):
self.names_taken = names_taken
self.error_label = None
self.label_counter = 0
self.labels_used = {}
self.return_label = self.new_label()
self.new_error_label()
self.continue_label = None
self.break_label = None
self.temps_allocated = [] # of (name, type)
self.temps_free = {} # type -> list of free vars
self.temps_used_type = {} # name -> type
self.temp_counter = 0
def new_label(self):
n = self.label_counter
self.label_counter = n + 1
return "%s%d" % (Naming.label_prefix, n)
def new_error_label(self):
old_err_lbl = self.error_label
self.error_label = self.new_label()
return old_err_lbl
def get_loop_labels(self):
return (
self.continue_label,
self.break_label)
def set_loop_labels(self, labels):
(self.continue_label,
self.break_label) = labels
def new_loop_labels(self):
old_labels = self.get_loop_labels()
self.set_loop_labels(
(self.new_label(),
self.new_label()))
return old_labels
def get_all_labels(self):
return (
self.continue_label,
self.break_label,
self.return_label,
self.error_label)
def set_all_labels(self, labels):
(self.continue_label,
self.break_label,
self.return_label,
self.error_label) = labels
def all_new_labels(self):
old_labels = self.get_all_labels()
new_labels = []
for old_label in old_labels:
if old_label:
new_labels.append(self.new_label())
else:
new_labels.append(old_label)
self.set_all_labels(new_labels)
return old_labels
def use_label(self, lbl):
self.labels_used[lbl] = 1
def label_used(self, lbl):
return lbl in self.labels_used
def allocate_temp(self, type):
"""
Allocates a temporary (which may create a new one or get a previously
allocated and released one of the same type). Type is simply registered
and handed back, but will usually be a PyrexType.
A C string referring to the variable is returned.
"""
freelist = self.temps_free.get(type)
if freelist is not None and len(freelist) > 0:
result = freelist.pop()
else:
while True:
self.temp_counter += 1
result = "%s%d" % (Naming.codewriter_temp_prefix, self.temp_counter)
if not result in self.names_taken: break
self.temps_allocated.append((result, type))
self.temps_used_type[result] = type
return result
def release_temp(self, name):
"""
Releases a temporary so that it can be reused by other code needing
a temp of the same type.
"""
type = self.temps_used_type[name]
freelist = self.temps_free.get(type)
if freelist is None:
freelist = []
self.temps_free[type] = freelist
freelist.append(name)
def funccontext_property(name):
def get(self):
return getattr(self.func, name)
def set(self, value):
setattr(self.func, name, value)
return property(get, set)
class CCodeWriter(object):
"""
Utility class to output C code.
When creating an insertion point one must care about the state that is
kept:
- formatting state (level, bol) is cloned and used in insertion points
as well
- labels, temps, exc_vars: One must construct a scope in which these can
exist by calling enter_cfunc_scope/exit_cfunc_scope (these are for
sanity checking and forward compatabilty). Created insertion points
looses this scope and cannot access it.
- marker: Not copied to insertion point
- filename_table, filename_list, input_file_contents: All codewriters
coming from the same root share the same instances simultaneously.
"""
# f file output file # f file output file
# buffer StringIOTree
# level int indentation level # level int indentation level
# bol bool beginning of line? # bol bool beginning of line?
# marker string comment to emit before next line # marker string comment to emit before next line
...@@ -19,6 +152,7 @@ class CCodeWriter: ...@@ -19,6 +152,7 @@ class CCodeWriter:
# error_label string error catch point label # error_label string error catch point label
# continue_label string loop continue point label # continue_label string loop continue point label
# break_label string loop break point label # break_label string loop break point label
# return_from_error_cleanup_label string
# label_counter integer counter for naming labels # label_counter integer counter for naming labels
# in_try_finally boolean inside try of try...finally # in_try_finally boolean inside try of try...finally
# filename_table {string : int} for finding filename table indexes # filename_table {string : int} for finding filename table indexes
...@@ -26,37 +160,95 @@ class CCodeWriter: ...@@ -26,37 +160,95 @@ class CCodeWriter:
# exc_vars (string * 3) exception variables for reraise, or None # exc_vars (string * 3) exception variables for reraise, or None
# input_file_contents dict contents (=list of lines) of any file that was used as input # input_file_contents dict contents (=list of lines) of any file that was used as input
# to create this output C code. This is # to create this output C code. This is
# used to annotate the comments. # used to annotate the comments.
# func FunctionContext contains labels and temps context info
in_try_finally = 0 in_try_finally = 0
def __init__(self, f): def __init__(self, create_from=None, buffer=None):
#self.f = open_new_file(outfile_name) if buffer is None: buffer = StringIOTree()
self.f = f self.buffer = buffer
self._write = f.write
self.level = 0
self.bol = 1
self.marker = None self.marker = None
self.last_marker_line = 0 self.last_marker_line = 0
self.label_counter = 1 self.func = None
self.error_label = None if create_from is None:
self.filename_table = {} # Root CCodeWriter
self.filename_list = [] self.level = 0
self.exc_vars = None self.bol = 1
self.input_file_contents = {} self.filename_table = {}
self.filename_list = []
self.exc_vars = None
self.input_file_contents = {}
else:
# Clone formatting state
c = create_from
self.level = c.level
self.bol = c.bol
# Note: NOT copying but sharing instance
self.filename_table = c.filename_table
self.filename_list = []
self.input_file_contents = c.input_file_contents
# Leave other state alone
def create_new(self, create_from, buffer):
# polymorphic constructor -- very slightly more versatile
# than using __class__
return CCodeWriter(create_from, buffer)
def copyto(self, f):
self.buffer.copyto(f)
def getvalue(self):
return self.buffer.getvalue()
def write(self, s):
self.buffer.write(s)
def insertion_point(self):
other = self.create_new(create_from=self, buffer=self.buffer.insertion_point())
return other
# Properties delegated to function scope
label_counter = funccontext_property("label_counter")
return_label = funccontext_property("return_label")
error_label = funccontext_property("error_label")
labels_used = funccontext_property("labels_used")
continue_label = funccontext_property("continue_label")
break_label = funccontext_property("break_label")
# Functions delegated to function scope
def new_label(self): return self.func.new_label()
def new_error_label(self): return self.func.new_error_label()
def get_loop_labels(self): return self.func.get_loop_labels()
def set_loop_labels(self, labels): return self.func.set_loop_labels(labels)
def new_loop_labels(self): return self.func.new_loop_labels()
def get_all_labels(self): return self.func.get_all_labels()
def set_all_labels(self, labels): return self.func.set_all_labels(labels)
def all_new_labels(self): return self.func.all_new_labels()
def use_label(self, lbl): return self.func.use_label(lbl)
def label_used(self, lbl): return self.func.label_used(lbl)
def enter_cfunc_scope(self):
self.func = FunctionContext()
def exit_cfunc_scope(self):
self.func = None
def putln(self, code = ""): def putln(self, code = ""):
if self.marker and self.bol: if self.marker and self.bol:
self.emit_marker() self.emit_marker()
if code: if code:
self.put(code) self.put(code)
self._write("\n"); self.write("\n");
self.bol = 1 self.bol = 1
def emit_marker(self): def emit_marker(self):
self._write("\n"); self.write("\n");
self.indent() self.indent()
self._write("/* %s */\n" % self.marker[1]) self.write("/* %s */\n" % self.marker[1])
self.last_marker_line = self.marker[0] self.last_marker_line = self.marker[0]
self.marker = None self.marker = None
...@@ -68,7 +260,7 @@ class CCodeWriter: ...@@ -68,7 +260,7 @@ class CCodeWriter:
self.level -= 1 self.level -= 1
if self.bol: if self.bol:
self.indent() self.indent()
self._write(code) self.write(code)
self.bol = 0 self.bol = 0
if dl > 0: if dl > 0:
self.level += dl self.level += dl
...@@ -90,7 +282,7 @@ class CCodeWriter: ...@@ -90,7 +282,7 @@ class CCodeWriter:
self.putln("}") self.putln("}")
def indent(self): def indent(self):
self._write(" " * self.level) self.write(" " * self.level)
def get_py_version_hex(self, pyversion): def get_py_version_hex(self, pyversion):
return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4] return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4]
...@@ -123,76 +315,13 @@ class CCodeWriter: ...@@ -123,76 +315,13 @@ class CCodeWriter:
source_desc.get_escaped_description(), line, u'\n'.join(lines)) source_desc.get_escaped_description(), line, u'\n'.join(lines))
self.marker = (line, marker) self.marker = (line, marker)
def init_labels(self):
self.label_counter = 0
self.labels_used = {}
self.return_label = self.new_label()
self.new_error_label()
self.continue_label = None
self.break_label = None
def new_label(self):
n = self.label_counter
self.label_counter = n + 1
return "%s%d" % (Naming.label_prefix, n)
def new_error_label(self):
old_err_lbl = self.error_label
self.error_label = self.new_label()
return old_err_lbl
def get_loop_labels(self):
return (
self.continue_label,
self.break_label)
def set_loop_labels(self, labels):
(self.continue_label,
self.break_label) = labels
def new_loop_labels(self):
old_labels = self.get_loop_labels()
self.set_loop_labels(
(self.new_label(),
self.new_label()))
return old_labels
def get_all_labels(self):
return (
self.continue_label,
self.break_label,
self.return_label,
self.error_label)
def set_all_labels(self, labels):
(self.continue_label,
self.break_label,
self.return_label,
self.error_label) = labels
def all_new_labels(self):
old_labels = self.get_all_labels()
new_labels = []
for old_label in old_labels:
if old_label:
new_labels.append(self.new_label())
else:
new_labels.append(old_label)
self.set_all_labels(new_labels)
return old_labels
def use_label(self, lbl):
self.labels_used[lbl] = 1
def label_used(self, lbl):
return lbl in self.labels_used
def put_label(self, lbl): def put_label(self, lbl):
if lbl in self.labels_used: if lbl in self.func.labels_used:
self.putln("%s:;" % lbl) self.putln("%s:;" % lbl)
def put_goto(self, lbl): def put_goto(self, lbl):
self.use_label(lbl) self.func.use_label(lbl)
self.putln("goto %s;" % lbl) self.putln("goto %s;" % lbl)
def put_var_declarations(self, entries, static = 0, dll_linkage = None, def put_var_declarations(self, entries, static = 0, dll_linkage = None,
...@@ -211,7 +340,7 @@ class CCodeWriter: ...@@ -211,7 +340,7 @@ class CCodeWriter:
#print "...private and not definition, skipping" ### #print "...private and not definition, skipping" ###
return return
if not entry.used and visibility == "private": if not entry.used and visibility == "private":
#print "not used and private, skipping" ### #print "not used and private, skipping", entry.cname ###
return return
storage_class = "" storage_class = ""
if visibility == 'extern': if visibility == 'extern':
...@@ -231,7 +360,15 @@ class CCodeWriter: ...@@ -231,7 +360,15 @@ class CCodeWriter:
if entry.init is not None: if entry.init is not None:
self.put(" = %s" % entry.type.literal_code(entry.init)) self.put(" = %s" % entry.type.literal_code(entry.init))
self.putln(";") self.putln(";")
def put_temp_declarations(self, func_context):
for name, type in func_context.temps_allocated:
decl = type.declaration_code(name)
if type.is_pyobject:
self.putln("%s = NULL;" % decl)
else:
self.putln("%s;" % decl)
def entry_as_pyobject(self, entry): def entry_as_pyobject(self, entry):
type = entry.type type = entry.type
if (not entry.is_self_arg and not entry.type.is_complete()) \ if (not entry.is_self_arg and not entry.type.is_complete()) \
...@@ -337,9 +474,15 @@ class CCodeWriter: ...@@ -337,9 +474,15 @@ class CCodeWriter:
self.putln("#ifndef %s" % guard) self.putln("#ifndef %s" % guard)
self.putln("#define %s" % guard) self.putln("#define %s" % guard)
def unlikely(self, cond):
if Options.gcc_branch_hints:
return 'unlikely(%s)' % cond
else:
return cond
def error_goto(self, pos): def error_goto(self, pos):
lbl = self.error_label lbl = self.func.error_label
self.use_label(lbl) self.func.use_label(lbl)
if Options.c_line_in_traceback: if Options.c_line_in_traceback:
cinfo = " %s = %s;" % (Naming.clineno_cname, Naming.line_c_macro) cinfo = " %s = %s;" % (Naming.clineno_cname, Naming.line_c_macro)
else: else:
...@@ -352,12 +495,9 @@ class CCodeWriter: ...@@ -352,12 +495,9 @@ class CCodeWriter:
pos[1], pos[1],
cinfo, cinfo,
lbl) lbl)
def error_goto_if(self, cond, pos): def error_goto_if(self, cond, pos):
if Options.gcc_branch_hints: return "if (%s) %s" % (self.unlikely(cond), self.error_goto(pos))
return "if (unlikely(%s)) %s" % (cond, self.error_goto(pos))
else:
return "if (%s) %s" % (cond, self.error_goto(pos))
def error_goto_if_null(self, cname, pos): def error_goto_if_null(self, cname, pos):
return self.error_goto_if("!%s" % cname, pos) return self.error_goto_if("!%s" % cname, pos)
......
...@@ -33,6 +33,7 @@ class CompileError(PyrexError): ...@@ -33,6 +33,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
self.position = position self.position = position
self.message_only = message self.message_only = message
self.reported = False
# Deprecated and withdrawn in 2.6: # Deprecated and withdrawn in 2.6:
# self.message = message # self.message = message
if position: if position:
...@@ -88,17 +89,23 @@ def close_listing_file(): ...@@ -88,17 +89,23 @@ def close_listing_file():
listing_file.close() listing_file.close()
listing_file = None listing_file = None
def error(position, message): def report_error(err):
#print "Errors.error:", repr(position), repr(message) ###
global num_errors global num_errors
err = CompileError(position, message) # See Main.py for why dual reporting occurs. Quick fix for now.
# if position is not None: raise Exception(err) # debug if err.reported: return
err.reported = True
line = "%s\n" % err line = "%s\n" % err
if listing_file: if listing_file:
listing_file.write(line) listing_file.write(line)
if echo_file: if echo_file:
echo_file.write(line) echo_file.write(line)
num_errors = num_errors + 1 num_errors = num_errors + 1
def error(position, message):
#print "Errors.error:", repr(position), repr(message) ###
err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
report_error(err)
return err return err
LEVEL=1 # warn about all errors level 1 or higher LEVEL=1 # warn about all errors level 1 or higher
......
...@@ -806,6 +806,15 @@ class NameNode(AtomicExprNode): ...@@ -806,6 +806,15 @@ class NameNode(AtomicExprNode):
# interned_cname string # interned_cname string
is_name = 1 is_name = 1
skip_assignment_decref = False
entry = None
def create_analysed_rvalue(pos, env, entry):
node = NameNode(pos)
node.analyse_types(env, entry=entry)
return node
create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def compile_time_value(self, denv): def compile_time_value(self, denv):
try: try:
...@@ -834,7 +843,9 @@ class NameNode(AtomicExprNode): ...@@ -834,7 +843,9 @@ class NameNode(AtomicExprNode):
def analyse_as_module(self, env): def analyse_as_module(self, env):
# Try to interpret this as a reference to a cimported module. # Try to interpret this as a reference to a cimported module.
# Returns the module scope, or None. # Returns the module scope, or None.
entry = env.lookup(self.name) entry = self.entry
if not entry:
entry = env.lookup(self.name)
if entry and entry.as_module: if entry and entry.as_module:
return entry.as_module return entry.as_module
return None return None
...@@ -842,14 +853,17 @@ class NameNode(AtomicExprNode): ...@@ -842,14 +853,17 @@ class NameNode(AtomicExprNode):
def analyse_as_extension_type(self, env): def analyse_as_extension_type(self, env):
# Try to interpret this as a reference to an extension type. # Try to interpret this as a reference to an extension type.
# Returns the extension type, or None. # Returns the extension type, or None.
entry = env.lookup(self.name) entry = self.entry
if not entry:
entry = env.lookup(self.name)
if entry and entry.is_type and entry.type.is_extension_type: if entry and entry.is_type and entry.type.is_extension_type:
return entry.type return entry.type
else: else:
return None return None
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
self.entry = env.lookup_here(self.name) if not self.entry:
self.entry = env.lookup_here(self.name)
if not self.entry: if not self.entry:
self.entry = env.declare_var(self.name, py_object_type, self.pos) self.entry = env.declare_var(self.name, py_object_type, self.pos)
env.control_flow.set_state(self.pos, (self.name, 'initalized'), True) env.control_flow.set_state(self.pos, (self.name, 'initalized'), True)
...@@ -860,7 +874,8 @@ class NameNode(AtomicExprNode): ...@@ -860,7 +874,8 @@ class NameNode(AtomicExprNode):
env.use_utility_code(type_cache_invalidation_code) env.use_utility_code(type_cache_invalidation_code)
def analyse_types(self, env): def analyse_types(self, env):
self.entry = env.lookup(self.name) if self.entry is None:
self.entry = env.lookup(self.name)
if not self.entry: if not self.entry:
self.entry = env.declare_builtin(self.name, self.pos) self.entry = env.declare_builtin(self.name, self.pos)
if not self.entry: if not self.entry:
...@@ -875,7 +890,13 @@ class NameNode(AtomicExprNode): ...@@ -875,7 +890,13 @@ class NameNode(AtomicExprNode):
% self.name) % self.name)
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.entry.used = 1 self.entry.used = 1
if self.entry.type.is_buffer:
# Have an rhs temp just in case. All rhs I could
# think of had a single symbol result_code but better
# safe than sorry. Feel free to change this.
import Buffer
Buffer.used_buffer_aux_vars(self.entry)
def analyse_rvalue_entry(self, env): def analyse_rvalue_entry(self, env):
#print "NameNode.analyse_rvalue_entry:", self.name ### #print "NameNode.analyse_rvalue_entry:", self.name ###
#print "Entry:", self.entry.__dict__ ### #print "Entry:", self.entry.__dict__ ###
...@@ -990,7 +1011,7 @@ class NameNode(AtomicExprNode): ...@@ -990,7 +1011,7 @@ class NameNode(AtomicExprNode):
entry = self.entry entry = self.entry
if entry is None: if entry is None:
return # There was an error earlier return # There was an error earlier
# is_pyglobal seems to be True for module level-globals only. # is_pyglobal seems to be True for module level-globals only.
# We use this to access class->tp_dict if necessary. # We use this to access class->tp_dict if necessary.
if entry.is_pyglobal: if entry.is_pyglobal:
...@@ -1018,25 +1039,49 @@ class NameNode(AtomicExprNode): ...@@ -1018,25 +1039,49 @@ class NameNode(AtomicExprNode):
rhs.generate_disposal_code(code) rhs.generate_disposal_code(code)
else: else:
if self.type.is_buffer:
# Generate code for doing the buffer release/acquisition.
# This might raise an exception in which case the assignment (done
# below) will not happen.
#
# The reason this is not in a typetest-like node is because the
# variables that the acquired buffer info is stored to is allocated
# per entry and coupled with it.
self.generate_acquire_buffer(rhs, code)
if self.type.is_pyobject: if self.type.is_pyobject:
rhs.make_owned_reference(code)
#print "NameNode.generate_assignment_code: to", self.name ### #print "NameNode.generate_assignment_code: to", self.name ###
#print "...from", rhs ### #print "...from", rhs ###
#print "...LHS type", self.type, "ctype", self.ctype() ### #print "...LHS type", self.type, "ctype", self.ctype() ###
#print "...RHS type", rhs.type, "ctype", rhs.ctype() ### #print "...RHS type", rhs.type, "ctype", rhs.ctype() ###
rhs.make_owned_reference(code) if not self.skip_assignment_decref:
if entry.is_local and not Options.init_local_none: if entry.is_local and not Options.init_local_none:
initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos) initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos)
if initalized is True: if initalized is True:
code.put_decref(self.result_code, self.ctype())
elif initalized is None:
code.put_xdecref(self.result_code, self.ctype())
else:
code.put_decref(self.result_code, self.ctype()) code.put_decref(self.result_code, self.ctype())
elif initalized is None:
code.put_xdecref(self.result_code, self.ctype())
else:
code.put_decref(self.result_code, self.ctype())
code.putln('%s = %s;' % (self.result_code, rhs.result_as(self.ctype()))) code.putln('%s = %s;' % (self.result_code, rhs.result_as(self.ctype())))
if debug_disposal_code: if debug_disposal_code:
print("NameNode.generate_assignment_code:") print("NameNode.generate_assignment_code:")
print("...generating post-assignment code for %s" % rhs) print("...generating post-assignment code for %s" % rhs)
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
def generate_acquire_buffer(self, rhs, code):
rhstmp = code.func.allocate_temp(self.entry.type)
buffer_aux = self.entry.buffer_aux
bufstruct = buffer_aux.buffer_info_var.cname
code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype())))
import Buffer
Buffer.put_assign_to_buffer(self.result_code, rhstmp, buffer_aux, self.entry.type,
is_initialized=not self.skip_assignment_decref,
pos=self.pos, code=code)
code.putln("%s = 0;" % rhstmp)
code.func.release_temp(rhstmp)
def generate_deletion_code(self, code): def generate_deletion_code(self, code):
if self.entry is None: if self.entry is None:
...@@ -1297,39 +1342,45 @@ class IndexNode(ExprNode): ...@@ -1297,39 +1342,45 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0): def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
self.is_buffer_access = False self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
if self.base.type.buffer_options is not None: if self.base.type.is_buffer:
assert isinstance(self.base, NameNode)
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
indices = self.index.args indices = self.index.args
else: else:
indices = [self.index] indices = [self.index]
if len(indices) == self.base.type.buffer_options.ndim: if len(indices) == self.base.type.ndim:
buffer_access = True buffer_access = True
skip_child_analysis = True skip_child_analysis = True
for x in indices: for x in indices:
x.analyse_types(env) x.analyse_types(env)
if not x.type.is_int: if not x.type.is_int:
buffer_access = False buffer_access = False
if buffer_access:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
if not buffer_access: if buffer_access:
self.indices = indices
self.index = None
self.type = self.base.type.dtype
self.is_buffer_access = True
if getting:
# we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take
self.is_temp = True
if setting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.base.entry.buffer_aux.writable_needed = True
else:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis) self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis: elif not skip_child_analysis:
...@@ -1374,7 +1425,7 @@ class IndexNode(ExprNode): ...@@ -1374,7 +1425,7 @@ class IndexNode(ExprNode):
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access: if self.is_buffer_access:
return "<not needed>" return "<not used>"
else: else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result_code, self.index.result_code) self.base.result_code, self.index.result_code)
...@@ -1393,17 +1444,22 @@ class IndexNode(ExprNode): ...@@ -1393,17 +1444,22 @@ class IndexNode(ExprNode):
if self.index is not None: if self.index is not None:
self.index.generate_evaluation_code(code) self.index.generate_evaluation_code(code)
else: else:
for i in self.indices: i.generate_evaluation_code(code) for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
if self.index is not None: if self.index is not None:
self.index.generate_disposal_code(code) self.index.generate_disposal_code(code)
else: else:
for i in self.indices: i.generate_disposal_code(code) for i in self.indices:
i.generate_disposal_code(code)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.type.is_pyobject: if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode))
elif self.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
function = "__Pyx_GetItemInt" function = "__Pyx_GetItemInt"
index_code = self.index.result_code index_code = self.index.result_code
...@@ -1439,7 +1495,10 @@ class IndexNode(ExprNode): ...@@ -1439,7 +1495,10 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject: if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code))
elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
else: else:
code.putln( code.putln(
...@@ -1465,6 +1524,19 @@ class IndexNode(ExprNode): ...@@ -1465,6 +1524,19 @@ class IndexNode(ExprNode):
code.error_goto(self.pos))) code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code):
# Assign indices to temps
index_temps = [code.func.allocate_temp(i.type) for i in self.indices]
for temp, index in zip(index_temps, self.indices):
code.putln("%s = %s;" % (temp, index.result_code))
# Generate buffer access code using these temps
import Buffer
valuecode = Buffer.put_access(entry=self.base.entry,
index_signeds=[i.type.signed for i in self.indices],
index_cnames=index_temps,
pos=self.pos, code=code)
return valuecode
class SliceIndexNode(ExprNode): class SliceIndexNode(ExprNode):
# 2-element slice indexing # 2-element slice indexing
......
...@@ -46,6 +46,13 @@ class Context: ...@@ -46,6 +46,13 @@ class Context:
self.pyxs = {} self.pyxs = {}
self.include_directories = include_directories self.include_directories = include_directories
self.future_directives = set() self.future_directives = set()
import os.path
standard_include_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), '..', 'Includes'))
self.include_directories = include_directories + [standard_include_path]
def find_module(self, module_name, def find_module(self, module_name,
relative_to = None, pos = None, need_pxd = 1): relative_to = None, pos = None, need_pxd = 1):
...@@ -323,14 +330,18 @@ class Context: ...@@ -323,14 +330,18 @@ class Context:
verbose_flag = options.show_version, verbose_flag = options.show_version,
cplus = options.cplus) cplus = options.cplus)
def nonfatal_error(self, exc):
return Errors.report_error(exc)
def run_pipeline(self, pipeline, source): def run_pipeline(self, pipeline, source):
errors_occurred = False errors_occurred = False
data = source data = source
try: try:
for phase in pipeline: for phase in pipeline:
data = phase(data) data = phase(data)
except CompileError: except CompileError, err:
errors_occurred = True errors_occurred = True
Errors.report_error(err)
return (errors_occurred, data) return (errors_occurred, data)
def create_parse(context): def create_parse(context):
...@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result): ...@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from Buffer import BufferTransform from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [ return [
create_parse(context), create_parse(context),
# printit,
NormalizeTree(context), NormalizeTree(context),
PostParse(context), PostParse(context),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(context), WithTransform(context),
DecoratorTransform(context), DecoratorTransform(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
BufferTransform(context), # BufferTransform(context),
SwitchTransform(), SwitchTransform(),
OptimizeRefcounting(context),
# CreateClosureClasses(context), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
...@@ -607,6 +621,7 @@ def main(command_line = 0): ...@@ -607,6 +621,7 @@ def main(command_line = 0):
else: else:
options = CompilationOptions(default_options) options = CompilationOptions(default_options)
sources = args sources = args
if options.show_version: if options.show_version:
sys.stderr.write("Cython version %s\n" % Version.version) sys.stderr.write("Cython version %s\n" % Version.version)
if options.working_path!="": if options.working_path!="":
......
...@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_extension_types = h_entries(env.c_class_entries) h_extension_types = h_entries(env.c_class_entries)
if h_types or h_vars or h_funcs or h_extension_types: if h_types or h_vars or h_funcs or h_extension_types:
result.h_file = replace_suffix(result.c_file, ".h") result.h_file = replace_suffix(result.c_file, ".h")
h_code = Code.CCodeWriter(open_new_file(result.h_file)) h_code = Code.CCodeWriter()
if options.generate_pxi: if options.generate_pxi:
result.i_file = replace_suffix(result.c_file, ".pxi") result.i_file = replace_suffix(result.c_file, ".pxi")
i_code = Code.PyrexCodeWriter(result.i_file) i_code = Code.PyrexCodeWriter(result.i_file)
...@@ -129,6 +129,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -129,6 +129,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("PyMODINIT_FUNC init%s(void);" % env.module_name) h_code.putln("PyMODINIT_FUNC init%s(void);" % env.module_name)
h_code.putln("") h_code.putln("")
h_code.putln("#endif") h_code.putln("#endif")
h_code.copyto(open_new_file(result.h_file))
def generate_public_declaration(self, entry, h_code, i_code): def generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % ( h_code.putln("%s %s;" % (
...@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_api_extension_types = 1 has_api_extension_types = 1
if api_funcs or has_api_extension_types: if api_funcs or has_api_extension_types:
result.api_file = replace_suffix(result.c_file, "_api.h") result.api_file = replace_suffix(result.c_file, "_api.h")
h_code = Code.CCodeWriter(open_new_file(result.api_file)) h_code = Code.CCodeWriter()
name = self.api_name(env) name = self.api_name(env)
guard = Naming.api_guard_prefix + name guard = Naming.api_guard_prefix + name
h_code.put_h_guard(guard) h_code.put_h_guard(guard)
...@@ -209,6 +211,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -209,6 +211,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("}") h_code.putln("}")
h_code.putln("") h_code.putln("")
h_code.putln("#endif") h_code.putln("#endif")
h_code.copyto(open_new_file(result.api_file))
def generate_cclass_header_code(self, type, h_code): def generate_cclass_header_code(self, type, h_code):
h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % ( h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % (
...@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_code(self, env, options, result): def generate_c_code(self, env, options, result):
modules = self.referenced_modules modules = self.referenced_modules
if Options.annotate or options.annotate: if Options.annotate or options.annotate:
code = Annotate.AnnotationCCodeWriter(StringIO()) code = Annotate.AnnotationCCodeWriter()
else: else:
code = Code.CCodeWriter(StringIO()) code = Code.CCodeWriter()
code.h = Code.CCodeWriter(StringIO()) h_code = code.insertion_point()
code.init_labels() self.generate_module_preamble(env, modules, h_code)
self.generate_module_preamble(env, modules, code.h)
code.putln("") code.putln("")
code.putln("/* Implementation of %s */" % env.qualified_name) code.putln("/* Implementation of %s */" % env.qualified_name)
...@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.mark_pos(None) code.mark_pos(None)
self.generate_module_cleanup_func(env, code) self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code) self.generate_filename_table(code)
self.generate_utility_functions(env, code) self.generate_utility_functions(env, code, h_code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h) self.generate_declarations_for_modules(env, modules, h_code)
h_code.write('\n')
f = open_new_file(result.c_file) f = open_new_file(result.c_file)
f.write(code.h.f.getvalue()) code.copyto(f)
f.write("\n")
f.write(code.f.getvalue())
f.close() f.close()
result.c_file_generated = 1 result.c_file_generated = 1
if Options.annotate or options.annotate: if Options.annotate or options.annotate:
...@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("") code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif") code.putln("#endif")
code.put(builtin_module_name_utility_code[0]) code.put(builtin_module_name_utility_code[0])
...@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("0") code.putln("0")
code.putln("};") code.putln("};")
code.putln() code.putln()
code.enter_cfunc_scope() # as we need labels
code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set) code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set)
code.putln("char** type_name = %s_type_names;" % Naming.import_star) code.putln("char** type_name = %s_type_names;" % Naming.import_star)
code.putln("while (*type_name) {") code.putln("while (*type_name) {")
...@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("return -1;") code.putln("return -1;")
code.putln("}") code.putln("}")
code.putln(import_star_utility_code) code.putln(import_star_utility_code)
code.exit_cfunc_scope() # done with labels
def generate_module_init_func(self, imported_modules, env, code): def generate_module_init_func(self, imported_modules, env, code):
code.enter_cfunc_scope()
code.putln("") code.putln("")
header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name
header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name
...@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(header3) code.putln(header3)
code.putln("#endif") code.putln("#endif")
code.putln("{") code.putln("{")
tempdecl_code = code.insertion_point()
code.put_var_declarations(env.temp_entries)
code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos))); code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos)));
code.putln("/*--- Libary function declarations ---*/") code.putln("/*--- Libary function declarations ---*/")
...@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
code.mark_pos(None) code.mark_pos(None)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if Options.generate_cleanup_code: if Options.generate_cleanup_code:
...@@ -1609,6 +1611,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1609,6 +1611,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("return NULL;") code.putln("return NULL;")
code.putln("#endif") code.putln("#endif")
code.putln('}') code.putln('}')
tempdecl_code.put_var_declarations(env.temp_entries)
tempdecl_code.put_temp_declarations(code.func)
code.exit_cfunc_scope()
def generate_module_cleanup_func(self, env, code): def generate_module_cleanup_func(self, env, code):
if not Options.generate_cleanup_code: if not Options.generate_cleanup_code:
...@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"%s = &%s;" % ( "%s = &%s;" % (
type.typeptr_cname, type.typeobj_cname)) type.typeptr_cname, type.typeobj_cname))
def generate_utility_functions(self, env, code): def generate_utility_functions(self, env, code, h_code):
code.putln("") code.putln("")
code.putln("/* Runtime support code */") code.putln("/* Runtime support code */")
code.putln("") code.putln("")
...@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
(Naming.filetable_cname, Naming.filenames_cname)) (Naming.filetable_cname, Naming.filenames_cname))
code.putln("}") code.putln("}")
for utility_code in env.utility_code_used: for utility_code in env.utility_code_used:
code.h.put(utility_code[0]) h_code.put(utility_code[0])
code.put(utility_code[1]) code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("") code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# For now, hard-code numpy imported as "numpy"
types = []
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
pyrex_prefix = "__pyx_" pyrex_prefix = "__pyx_"
codewriter_temp_prefix = pyrex_prefix + "t_"
temp_prefix = u"__cyt_" temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_" builtin_prefix = pyrex_prefix + "builtin_"
...@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_" ...@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_"
type_prefix = pyrex_prefix + "t_" type_prefix = pyrex_prefix + "t_"
typeobj_prefix = pyrex_prefix + "type_" typeobj_prefix = pyrex_prefix + "type_"
var_prefix = pyrex_prefix + "v_" var_prefix = pyrex_prefix + "v_"
bufstruct_prefix = pyrex_prefix + "bstruct_"
bufstride_prefix = pyrex_prefix + "bstride_"
bufshape_prefix = pyrex_prefix + "bshape_"
bufsuboffset_prefix = pyrex_prefix + "boffset_"
vtable_prefix = pyrex_prefix + "vtable_" vtable_prefix = pyrex_prefix + "vtable_"
vtabptr_prefix = pyrex_prefix + "vtabptr_" vtabptr_prefix = pyrex_prefix + "vtabptr_"
vtabstruct_prefix = pyrex_prefix + "vtabstruct_" vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
......
...@@ -73,6 +73,7 @@ class Node(object): ...@@ -73,6 +73,7 @@ class Node(object):
is_name = 0 is_name = 0
is_literal = 0 is_literal = 0
temps = None
# All descandants should set child_attrs to a list of the attributes # All descandants should set child_attrs to a list of the attributes
# containing nodes considered "children" in the tree. Each such attribute # containing nodes considered "children" in the tree. Each such attribute
...@@ -102,7 +103,7 @@ class Node(object): ...@@ -102,7 +103,7 @@ class Node(object):
for attrname in result.child_attrs: for attrname in result.child_attrs:
value = getattr(result, attrname) value = getattr(result, attrname)
if isinstance(value, list): if isinstance(value, list):
setattr(result, attrname, value) setattr(result, attrname, [x for x in value])
return result return result
...@@ -251,6 +252,11 @@ class StatListNode(Node): ...@@ -251,6 +252,11 @@ class StatListNode(Node):
# stats a list of StatNode # stats a list of StatNode
child_attrs = ["stats"] child_attrs = ["stats"]
def create_analysed(pos, env, *args, **kw):
node = StatListNode(pos, *args, **kw)
return node # No node-specific analysis necesarry
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env): def analyse_control_flow(self, env):
for stat in self.stats: for stat in self.stats:
...@@ -344,12 +350,6 @@ class CDeclaratorNode(Node): ...@@ -344,12 +350,6 @@ class CDeclaratorNode(Node):
calling_convention = "" calling_convention = ""
def analyse_expressions(self, env):
pass
def generate_execution_code(self, env):
pass
class CNameDeclaratorNode(CDeclaratorNode): class CNameDeclaratorNode(CDeclaratorNode):
# name string The Pyrex name being declared # name string The Pyrex name being declared
...@@ -368,29 +368,6 @@ class CNameDeclaratorNode(CDeclaratorNode): ...@@ -368,29 +368,6 @@ class CNameDeclaratorNode(CDeclaratorNode):
self.type = base_type self.type = base_type
return self, base_type return self, base_type
def analyse_expressions(self, env):
self.entry = env.lookup(self.name)
if self.default is not None:
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'initalized'), True)
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'source'), 'assignment')
self.entry.used = 1
if self.type.is_pyobject:
self.entry.init_to_none = False
self.entry.init = 0
self.default.analyse_types(env)
self.default = self.default.coerce_to(self.type, env)
self.default.allocate_temps(env)
self.default.release_temp(env)
def generate_execution_code(self, code):
if self.default is not None:
self.default.generate_evaluation_code(code)
if self.type.is_pyobject:
self.default.make_owned_reference(code)
code.putln('%s = %s;' % (self.entry.cname, self.default.result_as(self.entry.type)))
self.default.generate_post_assignment_code(code)
code.putln()
class CPtrDeclaratorNode(CDeclaratorNode): class CPtrDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode # base CDeclaratorNode
...@@ -403,12 +380,6 @@ class CPtrDeclaratorNode(CDeclaratorNode): ...@@ -403,12 +380,6 @@ class CPtrDeclaratorNode(CDeclaratorNode):
ptr_type = PyrexTypes.c_ptr_type(base_type) ptr_type = PyrexTypes.c_ptr_type(base_type)
return self.base.analyse(ptr_type, env, nonempty = nonempty) return self.base.analyse(ptr_type, env, nonempty = nonempty)
def analyse_expressions(self, env):
self.base.analyse_expressions(env)
def generate_execution_code(self, env):
self.base.generate_execution_code(env)
class CArrayDeclaratorNode(CDeclaratorNode): class CArrayDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode # base CDeclaratorNode
# dimension ExprNode # dimension ExprNode
...@@ -629,8 +600,8 @@ class CBufferAccessTypeNode(Node): ...@@ -629,8 +600,8 @@ class CBufferAccessTypeNode(Node):
def analyse(self, env): def analyse(self, env):
base_type = self.base_type_node.analyse(env) base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env) dtype = self.dtype_node.analyse(env)
options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim) self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
self.type = PyrexTypes.create_buffer_type(base_type, options) mode=self.mode)
return self.type return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
...@@ -685,14 +656,6 @@ class CVarDefNode(StatNode): ...@@ -685,14 +656,6 @@ class CVarDefNode(StatNode):
dest_scope.declare_var(name, type, declarator.pos, dest_scope.declare_var(name, type, declarator.pos,
cname = cname, visibility = self.visibility, is_cdef = 1) cname = cname, visibility = self.visibility, is_cdef = 1)
def analyse_expressions(self, env):
for declarator in self.declarators:
declarator.analyse_expressions(env)
def generate_execution_code(self, code):
for declarator in self.declarators:
declarator.generate_execution_code(code)
class CStructOrUnionDefNode(StatNode): class CStructOrUnionDefNode(StatNode):
# name string # name string
...@@ -866,9 +829,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -866,9 +829,14 @@ class FuncDefNode(StatNode, BlockNode):
return lenv return lenv
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code, transforms):
# Generate C code for header and body of function import Buffer
code.init_labels()
lenv = self.local_scope lenv = self.local_scope
# Generate C code for header and body of function
code.enter_cfunc_scope()
code.return_from_error_cleanup_label = code.new_label()
# ----- Top-level constants used by this function # ----- Top-level constants used by this function
code.mark_pos(self.pos) code.mark_pos(self.pos)
self.generate_interned_num_decls(lenv, code) self.generate_interned_num_decls(lenv, code)
...@@ -899,7 +867,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -899,7 +867,7 @@ class FuncDefNode(StatNode, BlockNode):
(self.return_type.declaration_code( (self.return_type.declaration_code(
Naming.retval_cname), Naming.retval_cname),
init)) init))
code.put_var_declarations(lenv.temp_entries) tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code) self.generate_keyword_list(code)
# ----- Extern library function declarations # ----- Extern library function declarations
lenv.generate_library_function_declarations(code) lenv.generate_library_function_declarations(code)
...@@ -914,7 +882,9 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -914,7 +882,9 @@ class FuncDefNode(StatNode, BlockNode):
for entry in lenv.arg_entries: for entry in lenv.arg_entries:
if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg': if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
code.put_var_incref(entry) code.put_var_incref(entry)
# ----- Initialise local variables if entry.type.is_buffer:
Buffer.put_acquire_arg_buffer(entry, code, self.pos)
# ----- Initialise local variables
for entry in lenv.var_entries: for entry in lenv.var_entries:
if entry.type.is_pyobject and entry.init_to_none and entry.used: if entry.type.is_pyobject and entry.init_to_none and entry.used:
code.put_init_var_to_py_none(entry) code.put_init_var_to_py_none(entry)
...@@ -934,12 +904,23 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -934,12 +904,23 @@ class FuncDefNode(StatNode, BlockNode):
val = self.return_type.default_value val = self.return_type.default_value
if val: if val:
code.putln("%s = %s;" % (Naming.retval_cname, val)) code.putln("%s = %s;" % (Naming.retval_cname, val))
#code.putln("goto %s;" % code.return_label)
# ----- Error cleanup # ----- Error cleanup
if code.error_label in code.labels_used: if code.error_label in code.labels_used:
code.put_goto(code.return_label) code.put_goto(code.return_label)
code.put_label(code.error_label) code.put_label(code.error_label)
code.put_var_xdecrefs(lenv.temp_entries) code.put_var_xdecrefs(lenv.temp_entries)
# Clean up buffers -- this calls a Python function
# so need to save and restore error state
buffers_present = len(lenv.buffer_entries) > 0
if buffers_present:
code.putln("{ PyObject *__pyx_type, *__pyx_value, *__pyx_tb;")
code.putln("PyErr_Fetch(&__pyx_type, &__pyx_value, &__pyx_tb);")
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
#code.putln("%s = 0;" % entry.cname)
code.putln("PyErr_Restore(__pyx_type, __pyx_value, __pyx_tb);}")
err_val = self.error_value() err_val = self.error_value()
exc_check = self.caller_will_check_exceptions() exc_check = self.caller_will_check_exceptions()
if err_val is not None or exc_check: if err_val is not None or exc_check:
...@@ -957,8 +938,18 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -957,8 +938,18 @@ class FuncDefNode(StatNode, BlockNode):
"%s = %s;" % ( "%s = %s;" % (
Naming.retval_cname, Naming.retval_cname,
err_val)) err_val))
# ----- Return cleanup if buffers_present:
# Else, non-error return will be an empty clause
code.put_goto(code.return_from_error_cleanup_label)
# ----- Non-error return cleanup
# PS! If adding something here, modify the conditions for the
# goto statement in error cleanup above
code.put_label(code.return_label) code.put_label(code.return_label)
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
# ----- Return cleanup for both error and no-error return
code.put_label(code.return_from_error_cleanup_label)
if not Options.init_local_none: if not Options.init_local_none:
for entry in lenv.var_entries: for entry in lenv.var_entries:
if lenv.control_flow.get_state((entry.name, 'initalized')) is not True: if lenv.control_flow.get_state((entry.name, 'initalized')) is not True:
...@@ -976,7 +967,11 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -976,7 +967,11 @@ class FuncDefNode(StatNode, BlockNode):
if not self.return_type.is_void: if not self.return_type.is_void:
code.putln("return %s;" % Naming.retval_cname) code.putln("return %s;" % Naming.retval_cname)
code.putln("}") code.putln("}")
# ----- Go back and insert temp variable declarations
tempvardecl_code.put_var_declarations(lenv.temp_entries)
tempvardecl_code.put_temp_declarations(code.func)
# ----- Python version # ----- Python version
code.exit_cfunc_scope()
if self.py_func: if self.py_func:
self.py_func.generate_function_definitions(env, code, transforms) self.py_func.generate_function_definitions(env, code, transforms)
self.generate_optarg_wrapper_function(env, code) self.generate_optarg_wrapper_function(env, code)
...@@ -2231,8 +2226,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -2231,8 +2226,10 @@ class SingleAssignmentNode(AssignmentNode):
# #
# lhs ExprNode Left hand side # lhs ExprNode Left hand side
# rhs ExprNode Right hand side # rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs?
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
first = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
...@@ -2265,7 +2262,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -2265,7 +2262,7 @@ class SingleAssignmentNode(AssignmentNode):
# self.lhs.allocate_target_temps(env) # self.lhs.allocate_target_temps(env)
# self.lhs.release_target_temp(env) # self.lhs.release_target_temp(env)
# self.rhs.release_temp(env) # self.rhs.release_temp(env)
def generate_rhs_evaluation_code(self, code): def generate_rhs_evaluation_code(self, code):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
...@@ -3344,6 +3341,7 @@ class TryExceptStatNode(StatNode): ...@@ -3344,6 +3341,7 @@ class TryExceptStatNode(StatNode):
self.gil_check(env) self.gil_check(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
self.cleanup_list = env.free_temp_entries[:] self.cleanup_list = env.free_temp_entries[:]
for except_clause in self.except_clauses: for except_clause in self.except_clauses:
...@@ -3523,6 +3521,12 @@ class TryFinallyStatNode(StatNode): ...@@ -3523,6 +3521,12 @@ class TryFinallyStatNode(StatNode):
# There doesn't seem to be any point in disallowing # There doesn't seem to be any point in disallowing
# continue in the try block, since we have no problem # continue in the try block, since we have no problem
# handling it. # handling it.
def create_analysed(pos, env, body, finally_clause):
node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause)
node.cleanup_list = []
return node
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env): def analyse_control_flow(self, env):
env.start_branching(self.pos) env.start_branching(self.pos)
......
...@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform): ...@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform):
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
return node return node
class OptimizeRefcounting(Visitor.CythonTransform):
def visit_SingleAssignmentNode(self, node):
if node.first:
lhs = node.lhs
if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
# Have variable initialized to 0 rather than None
lhs.entry.init_to_none = False
lhs.entry.init = 0
# Set a flag in NameNode to skip the decref
lhs.skip_assignment_decref = True
return node
...@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied' ...@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing' ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer' ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative' ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
class PostParse(CythonTransform): class PostParse(CythonTransform):
""" """
Basic interpretation of the parse tree, as well as validity Basic interpretation of the parse tree, as well as validity
...@@ -91,6 +93,9 @@ class PostParse(CythonTransform): ...@@ -91,6 +93,9 @@ class PostParse(CythonTransform):
as such). as such).
Specifically: Specifically:
- Default values to cdef assignments are turned into single
assignments following the declaration (everywhere but in class
bodies, where they raise a compile error)
- CBufferAccessTypeNode has its options interpreted: - CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute, Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and any "ndim" keyword argument goes into the "ndim" attribute and
...@@ -101,12 +106,65 @@ class PostParse(CythonTransform): ...@@ -101,12 +106,65 @@ class PostParse(CythonTransform):
if a more pure Abstract Syntax Tree is wanted. if a more pure Abstract Syntax Tree is wanted.
""" """
buffer_options = ("dtype", "ndim") # ordered! # Track our context.
scope_type = None # can be either of 'module', 'function', 'class'
def visit_ModuleNode(self, node):
self.scope_type = 'module'
self.visitchildren(node)
return node
def visit_ClassDefNode(self, node):
prev = self.scope_type
self.scope_type = 'class'
self.visitchildren(node)
self.scope_type = prev
return node
def visit_FuncDefNode(self, node):
prev = self.scope_type
self.scope_type = 'function'
self.visitchildren(node)
self.scope_type = prev
return node
# cdef variables
def visit_CVarDefNode(self, node):
# This assumes only plain names and pointers are assignable on
# declaration. Also, it makes use of the fact that a cdef decl
# must appear before the first use, so we don't have to deal with
# "i = 3; cdef int i = i" and can simply move the nodes around.
try:
self.visitchildren(node)
except PostParseError, e:
# An error in a cdef clause is ok, simply remove the declaration
# and try to move on to report more errors
self.context.nonfatal_error(e)
return None
stats = [node]
for decl in node.declarators:
while isinstance(decl, CPtrDeclaratorNode):
decl = decl.base
if isinstance(decl, CNameDeclaratorNode):
if decl.default is not None:
if self.scope_type == 'class':
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
stats.append(SingleAssignmentNode(node.pos,
lhs=NameNode(node.pos, name=decl.name),
rhs=decl.default, first=True))
decl.default = None
return stats
# buffer access
buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node): def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
options = {} options = {}
# Fetch positional arguments # Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options): if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY) raise PostParseError(node.pos, ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options): for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name) name = str(unicode_name)
options[name] = arg options[name] = arg
...@@ -114,21 +172,19 @@ class PostParse(CythonTransform): ...@@ -114,21 +172,19 @@ class PostParse(CythonTransform):
for item in node.keyword_args.key_value_pairs: for item in node.keyword_args.key_value_pairs:
name = str(item.key.value) name = str(item.key.value)
if not name in self.buffer_options: if not name in self.buffer_options:
raise PostParseError(item.key.pos, raise PostParseError(item.key.pos, ERR_BUF_OPTION_UNKNOWN % name)
ERR_BUF_UNKNOWN % name)
if name in options.keys(): if name in options.keys():
raise PostParseError(item.key.pos, raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
ERR_BUF_DUP % key)
options[name] = item.value options[name] = item.value
provided = options.keys()
# get dtype # get dtype
dtype = options.get("dtype") dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype') if dtype is None:
raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype_node = dtype node.dtype_node = dtype
# get ndim # get ndim
if "ndim" in provided: if "ndim" in options:
ndimnode = options["ndim"] ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode): if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser, # Compile-time values (DEF) are currently resolved by the parser,
...@@ -140,7 +196,18 @@ class PostParse(CythonTransform): ...@@ -140,7 +196,18 @@ class PostParse(CythonTransform):
node.ndim = int(ndimnode.value) node.ndim = int(ndimnode.value)
else: else:
node.ndim = 1 node.ndim = 1
if "mode" in options:
modenode = options["mode"]
if not isinstance(modenode, StringNode):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
mode = modenode.value
if not mode in ('full', 'strided'):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
node.mode = mode
else:
node.mode = 'full'
# We're done with the parse tree args # We're done with the parse tree args
node.positional_args = None node.positional_args = None
node.keyword_args = None node.keyword_args = None
...@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform): ...@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform):
self.env_stack.pop() self.env_stack.pop()
return node return node
# Some nodes are no longer needed after declaration
# analysis and can be dropped. The analysis was performed
# on these nodes in a seperate recursive process from the
# enclosing function or module, so we can simply drop them.
def visit_CVarDefNode(self, node):
return None
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
...@@ -263,7 +337,7 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -263,7 +337,7 @@ class AnalyseExpressionsTransform(CythonTransform):
node.body.analyse_expressions(node.local_scope) node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
class MarkClosureVisitor(CythonTransform): class MarkClosureVisitor(CythonTransform):
needs_closure = False needs_closure = False
......
...@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty):
# Treat trailing [] on type as buffer access # Treat trailing [] on type as buffer access
if 0: # s.sy == '[': if not is_basic and s.sy == '[':
if is_basic:
s.error("Basic C types do not support buffer access")
return p_buffer_access(s, type_node) return p_buffer_access(s, type_node)
else: else:
return type_node return type_node
......
...@@ -6,21 +6,6 @@ from Cython import Utils ...@@ -6,21 +6,6 @@ from Cython import Utils
import Naming import Naming
import copy import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType: class BaseType:
# #
...@@ -57,6 +42,7 @@ class PyrexType(BaseType): ...@@ -57,6 +42,7 @@ class PyrexType(BaseType):
# is_unicode boolean Is a UTF-8 encoded C char * type # is_unicode boolean Is a UTF-8 encoded C char * type
# is_returncode boolean Is used only to signal exceptions # is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type # is_error boolean Is the dummy error type
# is_buffer boolean Is buffer access type
# has_attributes boolean Has C dot-selectable attributes # has_attributes boolean Has C dot-selectable attributes
# default_value string Initial value # default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple # parsetuple_format string Format char for PyArg_ParseTuple
...@@ -106,11 +92,11 @@ class PyrexType(BaseType): ...@@ -106,11 +92,11 @@ class PyrexType(BaseType):
is_unicode = 0 is_unicode = 0
is_returncode = 0 is_returncode = 0
is_error = 0 is_error = 0
is_buffer = 0
has_attributes = 0 has_attributes = 0
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -202,6 +188,37 @@ class CTypedefType(BaseType): ...@@ -202,6 +188,37 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class BufferType(BaseType):
#
# Delegates most attribute
# lookups to the base type. ANYTHING NOT DEFINED
# HERE IS DELEGATED!
# dtype PyrexType
# ndim int
# mode str
# is_buffer boolean
# writable boolean
is_buffer = 1
writable = True
def __init__(self, base, dtype, ndim, mode):
self.base = base
self.dtype = dtype
self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode
def as_argument_type(self):
return self
def __getattr__(self, name):
return getattr(self.base, name)
def __repr__(self):
return "<BufferType %r>" % self.base
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
...@@ -927,7 +944,7 @@ class CEnumType(CType): ...@@ -927,7 +944,7 @@ class CEnumType(CType):
# name string # name string
# cname string or None # cname string or None
# typedef_flag boolean # typedef_flag boolean
is_enum = 1 is_enum = 1
signed = 1 signed = 1
rank = -1 # Ranks below any integer type rank = -1 # Ranks below any integer type
......
...@@ -15,11 +15,14 @@ from TypeSlots import \ ...@@ -15,11 +15,14 @@ from TypeSlots import \
get_special_method_signature, get_property_accessor_signature get_special_method_signature, get_property_accessor_signature
import ControlFlow import ControlFlow
import __builtin__ import __builtin__
from sets import Set as set
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux: class BufferAux:
writable_needed = False
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker): def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var self.buffer_info_var = buffer_info_var
self.stridevars = stridevars self.stridevars = stridevars
...@@ -141,7 +144,7 @@ class Entry: ...@@ -141,7 +144,7 @@ class Entry:
def redeclared(self, pos): def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
error(self.pos, "Previous declaration is here") error(self.pos, "Previous declaration is here")
class Scope: class Scope:
# name string Unqualified name # name string Unqualified name
# outer_scope Scope or None Enclosing scope # outer_scope Scope or None Enclosing scope
...@@ -216,6 +219,7 @@ class Scope: ...@@ -216,6 +219,7 @@ class Scope:
self.num_to_entry = {} self.num_to_entry = {}
self.obj_to_entry = {} self.obj_to_entry = {}
self.pystring_entries = [] self.pystring_entries = []
self.buffer_entries = []
self.control_flow = ControlFlow.LinearControlFlow() self.control_flow = ControlFlow.LinearControlFlow()
def start_branching(self, pos): def start_branching(self, pos):
...@@ -616,9 +620,12 @@ class Scope: ...@@ -616,9 +620,12 @@ class Scope:
return [entry for entry in self.temp_entries return [entry for entry in self.temp_entries
if entry not in self.free_temp_entries] if entry not in self.free_temp_entries]
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code) self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code): def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used. # Generate extern decls for C library funcs used.
#if self.pow_function_used: #if self.pow_function_used:
...@@ -738,6 +745,7 @@ class ModuleScope(Scope): ...@@ -738,6 +745,7 @@ class ModuleScope(Scope):
# doc_cname string C name of module doc string # doc_cname string C name of module doc string
# const_counter integer Counter for naming constants # const_counter integer Counter for naming constants
# utility_code_used [string] Utility code to be included # utility_code_used [string] Utility code to be included
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries # default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included # python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included # include_files [string] Other C headers to be included
...@@ -772,6 +780,7 @@ class ModuleScope(Scope): ...@@ -772,6 +780,7 @@ class ModuleScope(Scope):
self.doc_cname = Naming.moddoc_cname self.doc_cname = Naming.moddoc_cname
self.const_counter = 1 self.const_counter = 1
self.utility_code_used = [] self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = [] self.default_entries = []
self.module_entries = {} self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"] self.python_include_files = ["Python.h", "structmember.h"]
...@@ -930,13 +939,25 @@ class ModuleScope(Scope): ...@@ -930,13 +939,25 @@ class ModuleScope(Scope):
self.const_counter = n + 1 self.const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n) return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included, # Add string to list of utility code to be included,
# if not already there (tested using 'is'). # if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used: for old_code in self.utility_code_used:
if old_code is new_code: if old_code is new_code:
return return
self.utility_code_used.append(new_code) self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0, def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None, module_name = None, base_type = None, objstruct_cname = None,
......
...@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest): ...@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest):
self.not_parseable("Non-keyword arg following keyword arg", self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x") u"cdef object[foo=1, 2] x")
# See also tests/error/e_bufaccess.pyx and tets/run/bufaccess.pyx
class TestBufferOptions(CythonTest): class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets # Tests the full parsing of the options within the brackets
def parse_opts(self, opts): def nonfatal_error(self, error):
s = u"cdef object[%s] x" % opts # We're passing self as context to transform to trap this
root = self.fragment(s, pipeline=[PostParse(self)]).root self.error = error
buftype = root.stats[0].base_type self.assert_(self.expect_error)
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) def parse_opts(self, opts, expect_error=False):
self.assertEqual(u"object", buftype.base_type_node.name) s = u"def f():\n cdef object[%s] x" % opts
return buftype self.expect_error = expect_error
root = self.fragment(s, pipeline=[NormalizeTree(self), PostParse(self)]).root
if not expect_error:
vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
else:
self.assert_(len(root.stats[0].body.stats) == 0)
def non_parse(self, expected_err, opts): def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts)) self.parse_opts(opts, expect_error=True)
self.assertEqual(expected_err, e.message_only) # e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, self.error.message_only)
def test_basic(self): def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3") buf = self.parse_opts(u"unsigned short int, 3")
...@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest): ...@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest):
def test_use_DEF(self): def test_use_DEF(self):
t = self.fragment(u""" t = self.fragment(u"""
DEF ndim = 3 DEF ndim = 3
cdef object[int, ndim] x def f():
cdef object[ndim=ndim, dtype=int] y cdef object[int, ndim] x
""", pipeline=[PostParse(self)]).root cdef object[ndim=ndim, dtype=int] y
self.assert_(t.stats[1].base_type.ndim == 3) """, pipeline=[NormalizeTree(self), PostParse(self)]).root
self.assert_(t.stats[2].base_type.ndim == 3) stats = t.stats[0].body.stats
self.assert_(stats[0].base_type.ndim == 3)
self.assert_(stats[1].base_type.ndim == 3)
# add exotic and impossible combinations as they come along # add exotic and impossible combinations as they come along...
...@@ -6,6 +6,8 @@ import re ...@@ -6,6 +6,8 @@ import re
from cStringIO import StringIO from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
import Symtab
import PyrexTypes
from Visitor import VisitorTransform, temp_name_handle from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node, StatListNode from Nodes import Node, StatListNode
from ExprNodes import NameNode from ExprNodes import NameNode
...@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform): ...@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform):
self.substitutions = substitutions self.substitutions = substitutions
tempdict = {} tempdict = {}
for key in temps: for key in temps:
tempdict[key] = temp_name_handle(key) tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
self.temps = tempdict self.temp_key_to_entries = tempdict
self.pos = pos self.pos = pos
return super(TemplateTransform, self).__call__(node) return super(TemplateTransform, self).__call__(node)
def get_pos(self, node):
if self.pos:
return self.pos
else:
return node.pos
def visit_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
...@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform): ...@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node): def visit_NameNode(self, node):
tempname = self.temps.get(node.name) tempentry = self.temp_key_to_entries.get(node.name)
if tempname is not None: if tempentry is not None:
# Replace name with temporary # Replace name with temporary
node.name = tempname return NameNode(self.get_pos(node), name=tempentry)
return self.visit_Node(node) # Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
else: else:
return self.try_substitution(node, node.name) return self.try_substitution(node, node.name)
...@@ -157,6 +164,7 @@ def copy_code_tree(node): ...@@ -157,6 +164,7 @@ def copy_code_tree(node):
INDENT_RE = re.compile(ur"^ *") INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines): def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines" "Strips empty lines and common indentation from the list of strings given in lines"
# TODO: Facilitate textwrap.indent instead
lines = [x for x in lines if x.strip() != u""] lines = [x for x in lines if x.strip() != u""]
minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines]) minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
lines = [x[minindent:] for x in lines] lines = [x[minindent:] for x in lines]
......
# Please see the Python header files (object.h) for docs
cdef extern from "Python.h":
ctypedef void PyObject
ctypedef struct bufferinfo:
void *buf
Py_ssize_t len
Py_ssize_t itemsize
int readonly
int ndim
char *format
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
void *internal
ctypedef bufferinfo Py_buffer
cdef enum:
PyBUF_SIMPLE,
PyBUF_WRITABLE,
PyBUF_WRITEABLE, # backwards compatability
PyBUF_FORMAT,
PyBUF_ND,
PyBUF_STRIDES,
PyBUF_C_CONTIGUOUS,
PyBUF_F_CONTIGUOUS,
PyBUF_ANY_CONTIGUOUS,
PyBUF_INDIRECT,
PyBUF_CONTIG,
PyBUF_CONTIG_RO,
PyBUF_STRIDED,
PyBUF_STRIDED_RO,
PyBUF_RECORDS,
PyBUF_RECORDS_RO,
PyBUF_FULL,
PyBUF_FULL_RO,
PyBUF_READ,
PyBUF_WRITE,
PyBUF_SHADOW
int PyObject_CheckBuffer(PyObject* obj)
int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags)
void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view)
void* PyBuffer_GetPointer(Py_buffer *view, Py_ssize_t *indices)
int PyBuffer_SizeFromFormat(char *) # actually const char
int PyBuffer_ToContiguous(void *buf, Py_buffer *view, Py_ssize_t len, char fort)
int PyBuffer_FromContiguous(Py_buffer *view, void *buf, Py_ssize_t len, char fort)
int PyObject_CopyData(PyObject *dest, PyObject *src)
int PyBuffer_IsContiguous(Py_buffer *view, char fort)
void PyBuffer_FillContiguousStrides(int ndims,
Py_ssize_t *shape,
Py_ssize_t *strides,
int itemsize,
char fort)
int PyBuffer_FillInfo(Py_buffer *view, void *buf,
Py_ssize_t len, int readonly,
int flags)
PyObject* PyObject_Format(PyObject* obj,
PyObject *format_spec)
from cStringIO import StringIO
class StringIOTree(object):
"""
See module docs.
"""
def __init__(self, stream=None):
self.prepended_children = []
if stream is None: stream = StringIO()
self.stream = stream
def getvalue(self):
return ("".join([x.getvalue() for x in self.prepended_children]) +
self.stream.getvalue())
def copyto(self, target):
"""Potentially cheaper than getvalue as no string concatenation
needs to happen."""
for child in self.prepended_children:
child.copyto(target)
target.write(self.stream.getvalue())
def write(self, what):
self.stream.write(what)
def insertion_point(self):
# Save what we have written until now
# (would it be more efficient to check with len(self.stream.getvalue())?
# leaving it out for now)
self.prepended_children.append(StringIOTree(self.stream))
# Construct the new forked object to return
other = StringIOTree()
self.prepended_children.append(other)
self.stream = StringIO()
return other
__doc__ = r"""
Implements a buffer with insertion points. When you know you need to
"get back" to a place and write more later, simply call insertion_point()
at that spot and get a new StringIOTree object that is "left behind".
EXAMPLE:
>>> a = StringIOTree()
>>> a.write('first\n')
>>> b = a.insertion_point()
>>> a.write('third\n')
>>> b.write('second\n')
>>> print a.getvalue()
first
second
third
<BLANKLINE>
>>> c = b.insertion_point()
>>> d = c.insertion_point()
>>> d.write('alpha\n')
>>> b.write('gamma\n')
>>> c.write('beta\n')
>>> print b.getvalue()
second
alpha
beta
gamma
<BLANKLINE>
>>> out = StringIO()
>>> a.copyto(out)
>>> print out.getvalue()
first
second
alpha
beta
gamma
third
<BLANKLINE>
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
...@@ -46,12 +46,13 @@ class ErrorWriter(object): ...@@ -46,12 +46,13 @@ class ErrorWriter(object):
class TestBuilder(object): class TestBuilder(object):
def __init__(self, rootdir, workdir, selectors, annotate, def __init__(self, rootdir, workdir, selectors, annotate,
cleanup_workdir, with_pyregr): cleanup_workdir, cleanup_sharedlibs, with_pyregr):
self.rootdir = rootdir self.rootdir = rootdir
self.workdir = workdir self.workdir = workdir
self.selectors = selectors self.selectors = selectors
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
self.with_pyregr = with_pyregr self.with_pyregr = with_pyregr
def build_suite(self): def build_suite(self):
...@@ -84,6 +85,7 @@ class TestBuilder(object): ...@@ -84,6 +85,7 @@ class TestBuilder(object):
for filename in filenames: for filename in filenames:
if not (filename.endswith(".pyx") or filename.endswith(".py")): if not (filename.endswith(".pyx") or filename.endswith(".py")):
continue continue
if filename.startswith('.'): continue # certain emacs backup files
if context == 'pyregr' and not filename.startswith('test_'): if context == 'pyregr' and not filename.startswith('test_'):
continue continue
module = os.path.splitext(filename)[0] module = os.path.splitext(filename)[0]
...@@ -99,25 +101,29 @@ class TestBuilder(object): ...@@ -99,25 +101,29 @@ class TestBuilder(object):
test = build_test( test = build_test(
path, workdir, module, path, workdir, module,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir) cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
else: else:
test = CythonCompileTestCase( test = CythonCompileTestCase(
path, workdir, module, path, workdir, module,
expect_errors=expect_errors, expect_errors=expect_errors,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir) cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
suite.addTest(test) suite.addTest(test)
return suite return suite
class CythonCompileTestCase(unittest.TestCase): class CythonCompileTestCase(unittest.TestCase):
def __init__(self, directory, workdir, module, def __init__(self, directory, workdir, module,
expect_errors=False, annotate=False, cleanup_workdir=True): expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True):
self.directory = directory self.directory = directory
self.workdir = workdir self.workdir = workdir
self.module = module self.module = module
self.expect_errors = expect_errors self.expect_errors = expect_errors
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
unittest.TestCase.__init__(self) unittest.TestCase.__init__(self)
def shortDescription(self): def shortDescription(self):
...@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
cleanup_c_files = WITH_CYTHON and self.cleanup_workdir cleanup_c_files = WITH_CYTHON and self.cleanup_workdir
cleanup_lib_files = self.cleanup_sharedlibs
if os.path.exists(self.workdir): if os.path.exists(self.workdir):
for rmfile in os.listdir(self.workdir): for rmfile in os.listdir(self.workdir):
if not cleanup_c_files and rmfile[-2:] in (".c", ".h"): if not cleanup_c_files and rmfile[-2:] in (".c", ".h"):
continue continue
if not cleanup_lib_files and rmfile.endswith(".so") or rmfile.endswith(".dll"):
continue
if self.annotate and rmfile.endswith(".html"): if self.annotate and rmfile.endswith(".html"):
continue continue
try: try:
...@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase): ...@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase):
except Exception: except Exception:
pass pass
def collect_unittests(path, suite, selectors): def collect_unittests(path, module_prefix, suite, selectors):
def file_matches(filename): def file_matches(filename):
return filename.startswith("Test") and filename.endswith(".py") return filename.startswith("Test") and filename.endswith(".py")
...@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors): ...@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors):
for f in filenames: for f in filenames:
if file_matches(f): if file_matches(f):
filepath = os.path.join(dirpath, f)[:-len(".py")] filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = filepath[len(path)+1:].replace(os.path.sep, '.') modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]: if not [ 1 for match in selectors if match(modulename) ]:
continue continue
module = __import__(modulename) module = __import__(modulename)
...@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors): ...@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors):
module = getattr(module, x) module = getattr(module, x)
suite.addTests([loader.loadTestsFromModule(module)]) suite.addTests([loader.loadTestsFromModule(module)])
def collect_doctests(path, module_prefix, suite, selectors):
def package_matches(dirname):
return dirname not in ("Mac", "Distutils", "Plex")
def file_matches(filename):
return (filename.endswith(".py") and not ('~' in filename
or '#' in filename or filename.startswith('.')))
import doctest, types
for dirpath, dirnames, filenames in os.walk(path):
parentname = os.path.split(dirpath)[-1]
if package_matches(parentname):
for f in filenames:
if file_matches(f):
if not f.endswith('.py'): continue
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
for x in modulename.split('.')[1:]:
module = getattr(module, x)
if hasattr(module, "__doc__") or hasattr(module, "__test__"):
try:
suite.addTests(doctest.DocTestSuite(module))
except ValueError: # no tests
pass
if __name__ == '__main__': if __name__ == '__main__':
from optparse import OptionParser from optparse import OptionParser
parser = OptionParser() parser = OptionParser()
parser.add_option("--no-cleanup", dest="cleanup_workdir", parser.add_option("--no-cleanup", dest="cleanup_workdir",
action="store_false", default=True, action="store_false", default=True,
help="do not delete the generated C files (allows passing --no-cython on next run)") help="do not delete the generated C files (allows passing --no-cython on next run)")
parser.add_option("--no-cleanup-sharedlibs", dest="cleanup_sharedlibs",
action="store_false", default=True,
help="do not delete the generated shared libary files (allows manual module experimentation)")
parser.add_option("--no-cython", dest="with_cython", parser.add_option("--no-cython", dest="with_cython",
action="store_false", default=True, action="store_false", default=True,
help="do not run the Cython compiler, only the C compiler") help="do not run the Cython compiler, only the C compiler")
parser.add_option("--no-unit", dest="unittests", parser.add_option("--no-unit", dest="unittests",
action="store_false", default=True, action="store_false", default=True,
help="do not run the unit tests") help="do not run the unit tests")
parser.add_option("--no-doctest", dest="doctests",
action="store_false", default=True,
help="do not run the doctests")
parser.add_option("--no-file", dest="filetests", parser.add_option("--no-file", dest="filetests",
action="store_false", default=True, action="store_false", default=True,
help="do not run the file based tests") help="do not run the file based tests")
...@@ -360,6 +401,8 @@ if __name__ == '__main__': ...@@ -360,6 +401,8 @@ if __name__ == '__main__':
# RUN ALL TESTS! # RUN ALL TESTS!
ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests') ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests')
WORKDIR = os.path.join(os.getcwd(), 'BUILD') WORKDIR = os.path.join(os.getcwd(), 'BUILD')
UNITTEST_MODULE = "Cython"
UNITTEST_ROOT = os.path.join(os.getcwd(), UNITTEST_MODULE)
if WITH_CYTHON: if WITH_CYTHON:
if os.path.exists(WORKDIR): if os.path.exists(WORKDIR):
shutil.rmtree(WORKDIR, ignore_errors=True) shutil.rmtree(WORKDIR, ignore_errors=True)
...@@ -382,13 +425,15 @@ if __name__ == '__main__': ...@@ -382,13 +425,15 @@ if __name__ == '__main__':
test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
if options.unittests: if options.unittests:
collect_unittests(os.getcwd(), test_suite, selectors) collect_unittests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.doctests:
collect_doctests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.filetests: if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors, filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source, options.annotate_source, options.cleanup_workdir,
options.cleanup_workdir, options.cleanup_sharedlibs, options.pyregr)
options.pyregr)
test_suite.addTests([filetests.build_suite()]) test_suite.addTests([filetests.build_suite()])
unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite) unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
......
cdef extern:
cdef func(int[])
cdef object[int] buf
cdef class A:
cdef object[int] buf
def f():
cdef object[fakeoption=True] buf1
cdef object[int, -1] buf1b
cdef object[ndim=-1] buf2
cdef object[int, 'a'] buf3
cdef object[int,2,3,4,5,6] buf4
cdef object[int, 2, 'foo'] buf5
cdef object[int, 2, well] buf6
_ERRORS = u"""
1:11: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option
7:22: "ndim" must be non-negative
8:15: "dtype" missing
9:21: "ndim" must be an integer
10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
"""
cdef class A:
cdef int value = 3
_ERRORS = u"""
2:13: Cannot assign default value to cdef class attributes
"""
# Tests the buffer access syntax functionality by constructing
# mock buffer objects.
#
# Note that the buffers are mock objects created for testing
# the buffer access behaviour -- for instance there is no flag
# checking in the buffer objects (why test our test case?), rather
# what we want to test is what is passed into the flags argument.
#
cimport stdlib
cimport python_buffer
# Add all test_X function docstrings as unit tests
__test__ = {}
setup_string = """
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> E = ErrorBuffer("E")
"""
def testcase(func):
__test__[func.__name__] = setup_string + func.__doc__
return func
def testcas(a):
pass
#
# Buffer acquire and release tests
#
@testcase
def acquire_release(o1, o2):
"""
>>> acquire_release(A, B)
acquired A
released A
acquired B
released B
>>> acquire_release(None, None)
>>> acquire_release(None, B)
acquired B
released B
"""
cdef object[int] buf
buf = o1
buf = o2
@testcase
def acquire_raise(o):
"""
Apparently, doctest won't handle mixed exceptions and print
stats, so need to circumvent this.
>>> A.resetlog()
>>> acquire_raise(A)
Traceback (most recent call last):
...
Exception: on purpose
>>> A.printlog()
acquired A
released A
"""
cdef object[int] buf
buf = o
raise Exception("on purpose")
@testcase
def acquire_failure1():
"""
>>> acquire_failure1()
acquired working
0 3
0 3
released working
"""
cdef object[int] buf
buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = ErrorBuffer()
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure2():
"""
>>> acquire_failure2()
acquired working
0 3
0 3
released working
"""
cdef object[int] buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = ErrorBuffer()
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure3():
"""
>>> acquire_failure3()
acquired working
0 3
released working
acquired working
0 3
released working
"""
cdef object[int] buf
buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = 3
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure4():
"""
>>> acquire_failure4()
acquired working
0 3
released working
acquired working
0 3
released working
"""
cdef object[int] buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = 2
assert False
except Exception:
print buf[0], buf[3]
@testcase
def acquire_failure5():
"""
>>> acquire_failure5()
Traceback (most recent call last):
...
ValueError: Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!
"""
cdef object[int] buf
buf = IntMockBuffer("working", range(4))
buf.fail = True
buf = 3
@testcase
def acquire_nonbuffer1(first, second=None):
"""
>>> acquire_nonbuffer1(3)
Traceback (most recent call last):
...
TypeError: 'int' does not have the buffer interface
>>> acquire_nonbuffer1(type)
Traceback (most recent call last):
...
TypeError: 'type' does not have the buffer interface
>>> acquire_nonbuffer1(None, 2)
Traceback (most recent call last):
...
TypeError: 'int' does not have the buffer interface
"""
cdef object[int] buf
buf = first
buf = second
@testcase
def acquire_nonbuffer2():
"""
>>> acquire_nonbuffer2()
acquired working
0 3
released working
acquired working
0 3
released working
"""
cdef object[int] buf = IntMockBuffer("working", range(4))
print buf[0], buf[3]
try:
buf = ErrorBuffer
assert False
except Exception:
print buf[0], buf[3]
@testcase
def as_argument(object[int] bufarg, int n):
"""
>>> as_argument(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef int i
for i in range(n):
print bufarg[i],
print
@testcase
def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), int n=6):
"""
>>> as_argument_defval()
acquired default
0 1 2 3 4 5
released default
>>> as_argument_defval(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef int i
for i in range(n):
print bufarg[i],
print
@testcase
def cdef_assignment(obj, n):
"""
>>> cdef_assignment(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef object[int] buf = obj
cdef int i
for i in range(n):
print buf[i],
print
@testcase
def forin_assignment(objs, int pick):
"""
>>> as_argument_defval()
acquired default
0 1 2 3 4 5
released default
>>> as_argument_defval(A, 6)
acquired A
0 1 2 3 4 5
released A
"""
cdef object[int] buf
for buf in objs:
print buf[pick]
@testcase
def cascaded_buffer_assignment(obj):
"""
>>> cascaded_buffer_assignment(A)
acquired A
acquired A
released A
released A
"""
cdef object[int] a, b
a = b = obj
@testcase
def tuple_buffer_assignment1(a, b):
"""
>>> tuple_buffer_assignment1(A, B)
acquired A
acquired B
released A
released B
"""
cdef object[int] x, y
x, y = a, b
@testcase
def tuple_buffer_assignment2(tup):
"""
>>> tuple_buffer_assignment2((A, B))
acquired A
acquired B
released A
released B
"""
cdef object[int] x, y
x, y = tup
@testcase
def explicitly_release_buffer():
"""
>>> explicitly_release_buffer()
acquired A
released A
After release
"""
cdef object[int] x = IntMockBuffer("A", range(10))
x = None
print "After release"
#
# Index bounds checking
#
@testcase
def get_int_2d(object[int, 2] buf, int i, int j):
"""
>>> get_int_2d(C, 1, 1)
acquired C
released C
4
Check negative indexing:
>>> get_int_2d(C, -1, 0)
acquired C
released C
3
>>> get_int_2d(C, -1, -2)
acquired C
released C
4
>>> get_int_2d(C, -2, -3)
acquired C
released C
0
Out-of-bounds errors:
>>> get_int_2d(C, 2, 0)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
>>> get_int_2d(C, 0, -4)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 1)
"""
return buf[i, j]
@testcase
def get_int_2d_uintindex(object[int, 2] buf, unsigned int i, unsigned int j):
"""
Unsigned indexing:
>>> get_int_2d_uintindex(C, 0, 0)
acquired C
released C
0
>>> get_int_2d_uintindex(C, 1, 2)
acquired C
released C
5
"""
# This is most interesting with regards to the C code
# generated.
return buf[i, j]
@testcase
def set_int_2d(object[int, 2] buf, int i, int j, int value):
"""
Uses get_int_2d to read back the value afterwards. For pure
unit test, one should support reading in MockBuffer instead.
>>> set_int_2d(C, 1, 1, 10)
acquired C
released C
>>> get_int_2d(C, 1, 1)
acquired C
released C
10
Check negative indexing:
>>> set_int_2d(C, -1, 0, 3)
acquired C
released C
>>> get_int_2d(C, -1, 0)
acquired C
released C
3
>>> set_int_2d(C, -1, -2, 8)
acquired C
released C
>>> get_int_2d(C, -1, -2)
acquired C
released C
8
>>> set_int_2d(C, -2, -3, 9)
acquired C
released C
>>> get_int_2d(C, -2, -3)
acquired C
released C
9
Out-of-bounds errors:
>>> set_int_2d(C, 2, 0, 19)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
>>> set_int_2d(C, 0, -4, 19)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 1)
"""
buf[i, j] = value
#
# Buffer type mismatch examples. Varying the type and access
# method simultaneously, the odds of an interaction is virtually
# zero.
#
@testcase
def fmtst1(buf):
"""
>>> fmtst1(IntMockBuffer("A", range(3)))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (rejecting on 'i')
"""
cdef object[float] a = buf
@testcase
def fmtst2(object[int] buf):
"""
>>> fmtst2(FloatMockBuffer("A", range(3)))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (rejecting on 'f')
"""
@testcase
def ndim1(object[int, 2] buf):
"""
>>> ndim1(IntMockBuffer("A", range(3)))
Traceback (most recent call last):
...
ValueError: Buffer has wrong number of dimensions (expected 2, got 1)
"""
#
# Test which flags are passed.
#
@testcase
def readonly(obj):
"""
>>> R = UnsignedShortMockBuffer("R", range(27), shape=(3, 3, 3))
>>> readonly(R)
acquired R
25
released R
>>> R.recieved_flags
['FORMAT', 'INDIRECT', 'ND', 'STRIDES']
"""
cdef object[unsigned short int, 3] buf = obj
print buf[2, 2, 1]
@testcase
def writable(obj):
"""
>>> R = UnsignedShortMockBuffer("R", range(27), shape=(3, 3, 3))
>>> writable(R)
acquired R
released R
>>> R.recieved_flags
['FORMAT', 'INDIRECT', 'ND', 'STRIDES', 'WRITABLE']
"""
cdef object[unsigned short int, 3] buf = obj
buf[2, 2, 1] = 23
@testcase
def strided(object[int, 1, 'strided'] buf):
"""
>>> A = IntMockBuffer("A", range(4))
>>> strided(A)
acquired A
released A
2
>>> A.recieved_flags
['FORMAT', 'ND', 'STRIDES']
"""
return buf[2]
#
# Coercions
#
@testcase
def coercions(object[unsigned char] uc):
"""
TODO
"""
print type(uc[0])
uc[0] = -1
print uc[0]
uc[0] = <int>3.14
print uc[0]
#
# Testing that accessing data using various types of buffer access
# all works.
#
@testcase
def printbuf_int_2d(o, shape):
"""
Strided:
>>> printbuf_int_2d(IntMockBuffer("A", range(6), (2,3)), (2,3))
acquired A
0 1 2
3 4 5
released A
>>> printbuf_int_2d(IntMockBuffer("A", range(100), (3,3), strides=(20,5)), (3,3))
acquired A
0 5 10
20 25 30
40 45 50
released A
Indirect:
>>> printbuf_int_2d(IntMockBuffer("A", [[1,2],[3,4]]), (2,2))
acquired A
1 2
3 4
released A
"""
# should make shape builtin
cdef object[int, 2] buf
buf = o
cdef int i, j
for i in range(shape[0]):
for j in range(shape[1]):
print buf[i, j],
print
@testcase
def printbuf_float(o, shape):
"""
>>> printbuf_float(FloatMockBuffer("F", [1.0, 1.25, 0.75, 1.0]), (4,))
acquired F
1.0 1.25 0.75 1.0
released F
"""
# should make shape builtin
cdef object[float] buf
buf = o
cdef int i, j
for i in range(shape[0]):
print buf[i],
print
available_flags = (
('FORMAT', python_buffer.PyBUF_FORMAT),
('INDIRECT', python_buffer.PyBUF_INDIRECT),
('ND', python_buffer.PyBUF_ND),
('STRIDES', python_buffer.PyBUF_STRIDES),
('WRITABLE', python_buffer.PyBUF_WRITABLE)
)
cimport stdio
cdef class MockBuffer:
cdef object format
cdef void* buffer
cdef int len, itemsize, ndim
cdef Py_ssize_t* strides
cdef Py_ssize_t* shape
cdef Py_ssize_t* suboffsets
cdef object label, log
cdef readonly object recieved_flags
cdef public object fail
def __init__(self, label, data, shape=None, strides=None, format=None):
self.label = label
self.log = ""
self.itemsize = self.get_itemsize()
if format is None: format = self.get_default_format()
if shape is None: shape = (len(data),)
if strides is None:
strides = []
cumprod = 1
rshape = list(shape)
rshape.reverse()
for s in rshape:
strides.append(cumprod)
cumprod *= s
strides.reverse()
strides = [x * self.itemsize for x in strides]
suboffsets = [-1] * len(shape)
datashape = [len(data)]
p = data
while True:
p = p[0]
if isinstance(p, list): datashape.append(len(p))
else: break
if len(datashape) > 1:
# indirect access
self.ndim = len(datashape)
shape = datashape
self.buffer = self.create_indirect_buffer(data, shape)
suboffsets = [0] * (self.ndim-1) + [-1]
strides = [sizeof(void*)] * (self.ndim-1) + [self.itemsize]
self.suboffsets = self.list_to_sizebuf(suboffsets)
# printf("%ld; %ld %ld %ld %ld %ld", i0, s0, o0, i1, s1, o1);
else:
# strided and/or simple access
self.buffer = self.create_buffer(data)
self.ndim = len(shape)
self.suboffsets = NULL
self.format = format
self.len = len(data) * self.itemsize
self.strides = self.list_to_sizebuf(strides)
self.shape = self.list_to_sizebuf(shape)
def __dealloc__(self):
stdlib.free(self.strides)
stdlib.free(self.shape)
if self.suboffsets != NULL:
stdlib.free(self.suboffsets)
# must recursively free indirect...
else:
stdlib.free(self.buffer)
cdef void* create_buffer(self, data):
cdef char* buf = <char*>stdlib.malloc(len(data) * self.itemsize)
cdef char* it = buf
for value in data:
self.write(it, value)
it += self.itemsize
return buf
cdef void* create_indirect_buffer(self, data, shape):
cdef void** buf
assert shape[0] == len(data)
if len(shape) == 1:
return self.create_buffer(data)
else:
shape = shape[1:]
buf = <void**>stdlib.malloc(len(data) * sizeof(void*))
for idx, subdata in enumerate(data):
buf[idx] = self.create_indirect_buffer(subdata, shape)
return buf
cdef Py_ssize_t* list_to_sizebuf(self, l):
cdef Py_ssize_t* buf = <Py_ssize_t*>stdlib.malloc(len(l) * sizeof(Py_ssize_t))
for i, x in enumerate(l):
buf[i] = x
return buf
def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
if self.fail:
raise ValueError("Failing on purpose")
if buffer is NULL:
print u"locking!"
return
self.recieved_flags = []
cdef int value
for name, value in available_flags:
if (value & flags) == value:
self.recieved_flags.append(name)
buffer.buf = self.buffer
buffer.len = self.len
buffer.readonly = 0
buffer.format = <char*>self.format
buffer.ndim = self.ndim
buffer.shape = self.shape
buffer.strides = self.strides
buffer.suboffsets = self.suboffsets
buffer.itemsize = self.itemsize
buffer.internal = NULL
msg = "acquired %s" % self.label
print msg
self.log += msg + "\n"
def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
msg = "released %s" % self.label
print msg
self.log += msg + "\n"
def printlog(self):
print self.log,
def resetlog(self):
self.log = ""
cdef int write(self, char* buf, object value) except -1: raise Exception()
cdef get_itemsize(self):
print "ERROR, not subclassed", self.__class__
cdef get_default_format(self):
print "ERROR, not subclassed", self.__class__
cdef class FloatMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<float*>buf)[0] = <float>value
return 0
cdef get_itemsize(self): return sizeof(float)
cdef get_default_format(self): return "=f"
cdef class IntMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<int*>buf)[0] = <int>value
return 0
cdef get_itemsize(self): return sizeof(int)
cdef get_default_format(self): return "=i"
cdef class UnsignedShortMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<unsigned short*>buf)[0] = <unsigned short>value
return 0
cdef get_itemsize(self): return sizeof(unsigned short)
cdef get_default_format(self): return "=H"
cdef class ErrorBuffer:
cdef object label
def __init__(self, label):
self.label = label
def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
raise Exception("acquiring %s" % self.label)
def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
raise Exception("releasing %s" % self.label)
__doc__ = """
>>> test(1, 2)
4 1 2 2 0 7 8
"""
cdef int g = 7
def test(x, int y):
if True:
before = 0
cdef int a = 4, b = x, c = y, *p = &y
cdef object o = int(8)
print a, b, c, p[0], before, g, o
# Also test that pruning cdefs doesn't hurt
def empty():
cdef int i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment