Commit b35bb6c0 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge in Dag's work

parents bbb832bc eaa6d477
......@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'),
class AnnotationCCodeWriter(CCodeWriter):
def __init__(self, f):
CCodeWriter.__init__(self, self)
self.buffer = StringIO()
self.real_f = f
def __init__(self, create_from=None, buffer=None):
CCodeWriter.__init__(self, create_from, buffer)
self.annotation_buffer = StringIO()
if create_from is None:
self.annotations = []
self.last_pos = None
self.code = {}
else:
# When creating an insertion point, keep references to the same database
self.annotation_buffer = create_from.annotation_buffer
self.annotations = create_from.annotations
self.code = create_from.code
def getvalue(self):
return self.real_f.getvalue()
def create_new(self, create_from, buffer):
return AnnotationCCodeWriter(create_from, buffer)
def write(self, s):
self.real_f.write(s)
self.buffer.write(s)
CCodeWriter.write(self, s)
self.annotation_buffer.write(s)
def mark_pos(self, pos):
# if pos is not None:
# CCodeWriter.mark_pos(self, pos)
# return
if self.last_pos:
try:
code = self.code[self.last_pos[1]]
except KeyError:
code = ""
self.code[self.last_pos[1]] = code + self.buffer.getvalue()
self.buffer = StringIO()
code = self.code.get(self.last_pos[1], "")
self.code[self.last_pos[1]] = code + self.annotation_buffer.getvalue()
self.annotation_buffer = StringIO()
self.last_pos = pos
def annotate(self, pos, item):
......
This diff is collapsed.
This diff is collapsed.
......@@ -33,6 +33,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""):
self.position = position
self.message_only = message
self.reported = False
# Deprecated and withdrawn in 2.6:
# self.message = message
if position:
......@@ -88,17 +89,23 @@ def close_listing_file():
listing_file.close()
listing_file = None
def error(position, message):
#print "Errors.error:", repr(position), repr(message) ###
def report_error(err):
global num_errors
err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
# See Main.py for why dual reporting occurs. Quick fix for now.
if err.reported: return
err.reported = True
line = "%s\n" % err
if listing_file:
listing_file.write(line)
if echo_file:
echo_file.write(line)
num_errors = num_errors + 1
def error(position, message):
#print "Errors.error:", repr(position), repr(message) ###
err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
report_error(err)
return err
LEVEL=1 # warn about all errors level 1 or higher
......
......@@ -806,6 +806,15 @@ class NameNode(AtomicExprNode):
# interned_cname string
is_name = 1
skip_assignment_decref = False
entry = None
def create_analysed_rvalue(pos, env, entry):
node = NameNode(pos)
node.analyse_types(env, entry=entry)
return node
create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def compile_time_value(self, denv):
try:
......@@ -834,6 +843,8 @@ class NameNode(AtomicExprNode):
def analyse_as_module(self, env):
# Try to interpret this as a reference to a cimported module.
# Returns the module scope, or None.
entry = self.entry
if not entry:
entry = env.lookup(self.name)
if entry and entry.as_module:
return entry.as_module
......@@ -842,6 +853,8 @@ class NameNode(AtomicExprNode):
def analyse_as_extension_type(self, env):
# Try to interpret this as a reference to an extension type.
# Returns the extension type, or None.
entry = self.entry
if not entry:
entry = env.lookup(self.name)
if entry and entry.is_type and entry.type.is_extension_type:
return entry.type
......@@ -849,6 +862,7 @@ class NameNode(AtomicExprNode):
return None
def analyse_target_declaration(self, env):
if not self.entry:
self.entry = env.lookup_here(self.name)
if not self.entry:
self.entry = env.declare_var(self.name, py_object_type, self.pos)
......@@ -860,6 +874,7 @@ class NameNode(AtomicExprNode):
env.use_utility_code(type_cache_invalidation_code)
def analyse_types(self, env):
if self.entry is None:
self.entry = env.lookup(self.name)
if not self.entry:
self.entry = env.declare_builtin(self.name, self.pos)
......@@ -875,6 +890,12 @@ class NameNode(AtomicExprNode):
% self.name)
self.type = PyrexTypes.error_type
self.entry.used = 1
if self.entry.type.is_buffer:
# Have an rhs temp just in case. All rhs I could
# think of had a single symbol result_code but better
# safe than sorry. Feel free to change this.
import Buffer
Buffer.used_buffer_aux_vars(self.entry)
def analyse_rvalue_entry(self, env):
#print "NameNode.analyse_rvalue_entry:", self.name ###
......@@ -1018,12 +1039,23 @@ class NameNode(AtomicExprNode):
rhs.generate_disposal_code(code)
else:
if self.type.is_buffer:
# Generate code for doing the buffer release/acquisition.
# This might raise an exception in which case the assignment (done
# below) will not happen.
#
# The reason this is not in a typetest-like node is because the
# variables that the acquired buffer info is stored to is allocated
# per entry and coupled with it.
self.generate_acquire_buffer(rhs, code)
if self.type.is_pyobject:
rhs.make_owned_reference(code)
#print "NameNode.generate_assignment_code: to", self.name ###
#print "...from", rhs ###
#print "...LHS type", self.type, "ctype", self.ctype() ###
#print "...RHS type", rhs.type, "ctype", rhs.ctype() ###
rhs.make_owned_reference(code)
if not self.skip_assignment_decref:
if entry.is_local and not Options.init_local_none:
initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos)
if initalized is True:
......@@ -1038,6 +1070,19 @@ class NameNode(AtomicExprNode):
print("...generating post-assignment code for %s" % rhs)
rhs.generate_post_assignment_code(code)
def generate_acquire_buffer(self, rhs, code):
rhstmp = code.func.allocate_temp(self.entry.type)
buffer_aux = self.entry.buffer_aux
bufstruct = buffer_aux.buffer_info_var.cname
code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype())))
import Buffer
Buffer.put_assign_to_buffer(self.result_code, rhstmp, buffer_aux, self.entry.type,
is_initialized=not self.skip_assignment_decref,
pos=self.pos, code=code)
code.putln("%s = 0;" % rhstmp)
code.func.release_temp(rhstmp)
def generate_deletion_code(self, code):
if self.entry is None:
return # There was an error earlier
......@@ -1297,39 +1342,45 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
self.is_buffer_access = False
self.base.analyse_types(env)
skip_child_analysis = False
buffer_access = False
if self.base.type.buffer_options is not None:
if self.base.type.is_buffer:
assert isinstance(self.base, NameNode)
if isinstance(self.index, TupleNode):
indices = self.index.args
else:
indices = [self.index]
if len(indices) == self.base.type.buffer_options.ndim:
if len(indices) == self.base.type.ndim:
buffer_access = True
skip_child_analysis = True
for x in indices:
x.analyse_types(env)
if not x.type.is_int:
buffer_access = False
if buffer_access:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.type = self.base.type.dtype
self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
if not buffer_access:
if getting:
# we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take
self.is_temp = True
if setting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.base.entry.buffer_aux.writable_needed = True
else:
if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
......@@ -1374,7 +1425,7 @@ class IndexNode(ExprNode):
def calculate_result_code(self):
if self.is_buffer_access:
return "<not needed>"
return "<not used>"
else:
return "(%s[%s])" % (
self.base.result_code, self.index.result_code)
......@@ -1393,17 +1444,22 @@ class IndexNode(ExprNode):
if self.index is not None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices: i.generate_evaluation_code(code)
for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code)
if self.index is not None:
self.index.generate_disposal_code(code)
else:
for i in self.indices: i.generate_disposal_code(code)
for i in self.indices:
i.generate_disposal_code(code)
def generate_result_code(self, code):
if self.type.is_pyobject:
if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode))
elif self.type.is_pyobject:
if self.index.type.is_int:
function = "__Pyx_GetItemInt"
index_code = self.index.result_code
......@@ -1439,7 +1495,10 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject:
if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code))
elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code)
else:
code.putln(
......@@ -1465,6 +1524,19 @@ class IndexNode(ExprNode):
code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code):
# Assign indices to temps
index_temps = [code.func.allocate_temp(i.type) for i in self.indices]
for temp, index in zip(index_temps, self.indices):
code.putln("%s = %s;" % (temp, index.result_code))
# Generate buffer access code using these temps
import Buffer
valuecode = Buffer.put_access(entry=self.base.entry,
index_signeds=[i.type.signed for i in self.indices],
index_cnames=index_temps,
pos=self.pos, code=code)
return valuecode
class SliceIndexNode(ExprNode):
# 2-element slice indexing
......
......@@ -47,6 +47,13 @@ class Context:
self.include_directories = include_directories
self.future_directives = set()
import os.path
standard_include_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), '..', 'Includes'))
self.include_directories = include_directories + [standard_include_path]
def find_module(self, module_name,
relative_to = None, pos = None, need_pxd = 1):
# Finds and returns the module scope corresponding to
......@@ -323,14 +330,18 @@ class Context:
verbose_flag = options.show_version,
cplus = options.cplus)
def nonfatal_error(self, exc):
return Errors.report_error(exc)
def run_pipeline(self, pipeline, source):
errors_occurred = False
data = source
try:
for phase in pipeline:
data = phase(data)
except CompileError:
except CompileError, err:
errors_occurred = True
Errors.report_error(err)
return (errors_occurred, data)
def create_parse(context):
......@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform
from Buffer import BufferTransform
from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [
create_parse(context),
# printit,
NormalizeTree(context),
PostParse(context),
FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes,
AnalyseExpressionsTransform(context),
BufferTransform(context),
# BufferTransform(context),
SwitchTransform(),
OptimizeRefcounting(context),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
......@@ -607,6 +621,7 @@ def main(command_line = 0):
else:
options = CompilationOptions(default_options)
sources = args
if options.show_version:
sys.stderr.write("Cython version %s\n" % Version.version)
if options.working_path!="":
......
......@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_extension_types = h_entries(env.c_class_entries)
if h_types or h_vars or h_funcs or h_extension_types:
result.h_file = replace_suffix(result.c_file, ".h")
h_code = Code.CCodeWriter(open_new_file(result.h_file))
h_code = Code.CCodeWriter()
if options.generate_pxi:
result.i_file = replace_suffix(result.c_file, ".pxi")
i_code = Code.PyrexCodeWriter(result.i_file)
......@@ -130,6 +130,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("")
h_code.putln("#endif")
h_code.copyto(open_new_file(result.h_file))
def generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % (
Naming.extern_c_macro,
......@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_api_extension_types = 1
if api_funcs or has_api_extension_types:
result.api_file = replace_suffix(result.c_file, "_api.h")
h_code = Code.CCodeWriter(open_new_file(result.api_file))
h_code = Code.CCodeWriter()
name = self.api_name(env)
guard = Naming.api_guard_prefix + name
h_code.put_h_guard(guard)
......@@ -210,6 +212,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("")
h_code.putln("#endif")
h_code.copyto(open_new_file(result.api_file))
def generate_cclass_header_code(self, type, h_code):
h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % (
Naming.extern_c_macro,
......@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_code(self, env, options, result):
modules = self.referenced_modules
if Options.annotate or options.annotate:
code = Annotate.AnnotationCCodeWriter(StringIO())
code = Annotate.AnnotationCCodeWriter()
else:
code = Code.CCodeWriter(StringIO())
code.h = Code.CCodeWriter(StringIO())
code.init_labels()
self.generate_module_preamble(env, modules, code.h)
code = Code.CCodeWriter()
h_code = code.insertion_point()
self.generate_module_preamble(env, modules, h_code)
code.putln("")
code.putln("/* Implementation of %s */" % env.qualified_name)
......@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.mark_pos(None)
self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code)
self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_utility_functions(env, code, h_code)
self.generate_declarations_for_modules(env, modules, code.h)
self.generate_declarations_for_modules(env, modules, h_code)
h_code.write('\n')
f = open_new_file(result.c_file)
f.write(code.h.f.getvalue())
f.write("\n")
f.write(code.f.getvalue())
code.copyto(f)
f.close()
result.c_file_generated = 1
if Options.annotate or options.annotate:
......@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif")
code.put(builtin_module_name_utility_code[0])
......@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("0")
code.putln("};")
code.putln()
code.enter_cfunc_scope() # as we need labels
code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set)
code.putln("char** type_name = %s_type_names;" % Naming.import_star)
code.putln("while (*type_name) {")
......@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("return -1;")
code.putln("}")
code.putln(import_star_utility_code)
code.exit_cfunc_scope() # done with labels
def generate_module_init_func(self, imported_modules, env, code):
code.enter_cfunc_scope()
code.putln("")
header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name
header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name
......@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(header3)
code.putln("#endif")
code.putln("{")
code.put_var_declarations(env.temp_entries)
tempdecl_code = code.insertion_point()
code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos)));
code.putln("/*--- Libary function declarations ---*/")
......@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Execution code ---*/")
code.mark_pos(None)
self.body.generate_execution_code(code)
if Options.generate_cleanup_code:
......@@ -1610,6 +1612,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("#endif")
code.putln('}')
tempdecl_code.put_var_declarations(env.temp_entries)
tempdecl_code.put_temp_declarations(code.func)
code.exit_cfunc_scope()
def generate_module_cleanup_func(self, env, code):
if not Options.generate_cleanup_code:
return
......@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"%s = &%s;" % (
type.typeptr_cname, type.typeobj_cname))
def generate_utility_functions(self, env, code):
def generate_utility_functions(self, env, code, h_code):
code.putln("")
code.putln("/* Runtime support code */")
code.putln("")
......@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
(Naming.filetable_cname, Naming.filenames_cname))
code.putln("}")
for utility_code in env.utility_code_used:
code.h.put(utility_code[0])
h_code.put(utility_code[0])
code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions)
code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# For now, hard-code numpy imported as "numpy"
types = []
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------
#
# Runtime support code
......
......@@ -8,6 +8,9 @@
pyrex_prefix = "__pyx_"
codewriter_temp_prefix = pyrex_prefix + "t_"
temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_"
......@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_"
type_prefix = pyrex_prefix + "t_"
typeobj_prefix = pyrex_prefix + "type_"
var_prefix = pyrex_prefix + "v_"
bufstruct_prefix = pyrex_prefix + "bstruct_"
bufstride_prefix = pyrex_prefix + "bstride_"
bufshape_prefix = pyrex_prefix + "bshape_"
bufsuboffset_prefix = pyrex_prefix + "boffset_"
vtable_prefix = pyrex_prefix + "vtable_"
vtabptr_prefix = pyrex_prefix + "vtabptr_"
vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
......
......@@ -73,6 +73,7 @@ class Node(object):
is_name = 0
is_literal = 0
temps = None
# All descandants should set child_attrs to a list of the attributes
# containing nodes considered "children" in the tree. Each such attribute
......@@ -102,7 +103,7 @@ class Node(object):
for attrname in result.child_attrs:
value = getattr(result, attrname)
if isinstance(value, list):
setattr(result, attrname, value)
setattr(result, attrname, [x for x in value])
return result
......@@ -252,6 +253,11 @@ class StatListNode(Node):
child_attrs = ["stats"]
def create_analysed(pos, env, *args, **kw):
node = StatListNode(pos, *args, **kw)
return node # No node-specific analysis necesarry
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env):
for stat in self.stats:
stat.analyse_control_flow(env)
......@@ -344,12 +350,6 @@ class CDeclaratorNode(Node):
calling_convention = ""
def analyse_expressions(self, env):
pass
def generate_execution_code(self, env):
pass
class CNameDeclaratorNode(CDeclaratorNode):
# name string The Pyrex name being declared
......@@ -368,29 +368,6 @@ class CNameDeclaratorNode(CDeclaratorNode):
self.type = base_type
return self, base_type
def analyse_expressions(self, env):
self.entry = env.lookup(self.name)
if self.default is not None:
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'initalized'), True)
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'source'), 'assignment')
self.entry.used = 1
if self.type.is_pyobject:
self.entry.init_to_none = False
self.entry.init = 0
self.default.analyse_types(env)
self.default = self.default.coerce_to(self.type, env)
self.default.allocate_temps(env)
self.default.release_temp(env)
def generate_execution_code(self, code):
if self.default is not None:
self.default.generate_evaluation_code(code)
if self.type.is_pyobject:
self.default.make_owned_reference(code)
code.putln('%s = %s;' % (self.entry.cname, self.default.result_as(self.entry.type)))
self.default.generate_post_assignment_code(code)
code.putln()
class CPtrDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode
......@@ -403,12 +380,6 @@ class CPtrDeclaratorNode(CDeclaratorNode):
ptr_type = PyrexTypes.c_ptr_type(base_type)
return self.base.analyse(ptr_type, env, nonempty = nonempty)
def analyse_expressions(self, env):
self.base.analyse_expressions(env)
def generate_execution_code(self, env):
self.base.generate_execution_code(env)
class CArrayDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode
# dimension ExprNode
......@@ -629,8 +600,8 @@ class CBufferAccessTypeNode(Node):
def analyse(self, env):
base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env)
options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim)
self.type = PyrexTypes.create_buffer_type(base_type, options)
self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
mode=self.mode)
return self.type
class CComplexBaseTypeNode(CBaseTypeNode):
......@@ -685,14 +656,6 @@ class CVarDefNode(StatNode):
dest_scope.declare_var(name, type, declarator.pos,
cname = cname, visibility = self.visibility, is_cdef = 1)
def analyse_expressions(self, env):
for declarator in self.declarators:
declarator.analyse_expressions(env)
def generate_execution_code(self, code):
for declarator in self.declarators:
declarator.generate_execution_code(code)
class CStructOrUnionDefNode(StatNode):
# name string
......@@ -866,9 +829,14 @@ class FuncDefNode(StatNode, BlockNode):
return lenv
def generate_function_definitions(self, env, code, transforms):
# Generate C code for header and body of function
code.init_labels()
import Buffer
lenv = self.local_scope
# Generate C code for header and body of function
code.enter_cfunc_scope()
code.return_from_error_cleanup_label = code.new_label()
# ----- Top-level constants used by this function
code.mark_pos(self.pos)
self.generate_interned_num_decls(lenv, code)
......@@ -899,7 +867,7 @@ class FuncDefNode(StatNode, BlockNode):
(self.return_type.declaration_code(
Naming.retval_cname),
init))
code.put_var_declarations(lenv.temp_entries)
tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code)
# ----- Extern library function declarations
lenv.generate_library_function_declarations(code)
......@@ -914,6 +882,8 @@ class FuncDefNode(StatNode, BlockNode):
for entry in lenv.arg_entries:
if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
code.put_var_incref(entry)
if entry.type.is_buffer:
Buffer.put_acquire_arg_buffer(entry, code, self.pos)
# ----- Initialise local variables
for entry in lenv.var_entries:
if entry.type.is_pyobject and entry.init_to_none and entry.used:
......@@ -934,12 +904,23 @@ class FuncDefNode(StatNode, BlockNode):
val = self.return_type.default_value
if val:
code.putln("%s = %s;" % (Naming.retval_cname, val))
#code.putln("goto %s;" % code.return_label)
# ----- Error cleanup
if code.error_label in code.labels_used:
code.put_goto(code.return_label)
code.put_label(code.error_label)
code.put_var_xdecrefs(lenv.temp_entries)
# Clean up buffers -- this calls a Python function
# so need to save and restore error state
buffers_present = len(lenv.buffer_entries) > 0
if buffers_present:
code.putln("{ PyObject *__pyx_type, *__pyx_value, *__pyx_tb;")
code.putln("PyErr_Fetch(&__pyx_type, &__pyx_value, &__pyx_tb);")
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
#code.putln("%s = 0;" % entry.cname)
code.putln("PyErr_Restore(__pyx_type, __pyx_value, __pyx_tb);}")
err_val = self.error_value()
exc_check = self.caller_will_check_exceptions()
if err_val is not None or exc_check:
......@@ -957,8 +938,18 @@ class FuncDefNode(StatNode, BlockNode):
"%s = %s;" % (
Naming.retval_cname,
err_val))
# ----- Return cleanup
if buffers_present:
# Else, non-error return will be an empty clause
code.put_goto(code.return_from_error_cleanup_label)
# ----- Non-error return cleanup
# PS! If adding something here, modify the conditions for the
# goto statement in error cleanup above
code.put_label(code.return_label)
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
# ----- Return cleanup for both error and no-error return
code.put_label(code.return_from_error_cleanup_label)
if not Options.init_local_none:
for entry in lenv.var_entries:
if lenv.control_flow.get_state((entry.name, 'initalized')) is not True:
......@@ -976,7 +967,11 @@ class FuncDefNode(StatNode, BlockNode):
if not self.return_type.is_void:
code.putln("return %s;" % Naming.retval_cname)
code.putln("}")
# ----- Go back and insert temp variable declarations
tempvardecl_code.put_var_declarations(lenv.temp_entries)
tempvardecl_code.put_temp_declarations(code.func)
# ----- Python version
code.exit_cfunc_scope()
if self.py_func:
self.py_func.generate_function_definitions(env, code, transforms)
self.generate_optarg_wrapper_function(env, code)
......@@ -2231,8 +2226,10 @@ class SingleAssignmentNode(AssignmentNode):
#
# lhs ExprNode Left hand side
# rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs?
child_attrs = ["lhs", "rhs"]
first = False
def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env)
......@@ -3344,6 +3341,7 @@ class TryExceptStatNode(StatNode):
self.gil_check(env)
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
self.cleanup_list = env.free_temp_entries[:]
for except_clause in self.except_clauses:
......@@ -3524,6 +3522,12 @@ class TryFinallyStatNode(StatNode):
# continue in the try block, since we have no problem
# handling it.
def create_analysed(pos, env, body, finally_clause):
node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause)
node.cleanup_list = []
return node
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env):
env.start_branching(self.pos)
self.body.analyse_control_flow(env)
......
......@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform):
def visit_Node(self, node):
self.visitchildren(node)
return node
class OptimizeRefcounting(Visitor.CythonTransform):
def visit_SingleAssignmentNode(self, node):
if node.first:
lhs = node.lhs
if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
# Have variable initialized to 0 rather than None
lhs.entry.init_to_none = False
lhs.entry.init = 0
# Set a flag in NameNode to skip the decref
lhs.skip_assignment_decref = True
return node
......@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
class PostParse(CythonTransform):
"""
Basic interpretation of the parse tree, as well as validity
......@@ -91,6 +93,9 @@ class PostParse(CythonTransform):
as such).
Specifically:
- Default values to cdef assignments are turned into single
assignments following the declaration (everywhere but in class
bodies, where they raise a compile error)
- CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and
......@@ -101,12 +106,65 @@ class PostParse(CythonTransform):
if a more pure Abstract Syntax Tree is wanted.
"""
buffer_options = ("dtype", "ndim") # ordered!
# Track our context.
scope_type = None # can be either of 'module', 'function', 'class'
def visit_ModuleNode(self, node):
self.scope_type = 'module'
self.visitchildren(node)
return node
def visit_ClassDefNode(self, node):
prev = self.scope_type
self.scope_type = 'class'
self.visitchildren(node)
self.scope_type = prev
return node
def visit_FuncDefNode(self, node):
prev = self.scope_type
self.scope_type = 'function'
self.visitchildren(node)
self.scope_type = prev
return node
# cdef variables
def visit_CVarDefNode(self, node):
# This assumes only plain names and pointers are assignable on
# declaration. Also, it makes use of the fact that a cdef decl
# must appear before the first use, so we don't have to deal with
# "i = 3; cdef int i = i" and can simply move the nodes around.
try:
self.visitchildren(node)
except PostParseError, e:
# An error in a cdef clause is ok, simply remove the declaration
# and try to move on to report more errors
self.context.nonfatal_error(e)
return None
stats = [node]
for decl in node.declarators:
while isinstance(decl, CPtrDeclaratorNode):
decl = decl.base
if isinstance(decl, CNameDeclaratorNode):
if decl.default is not None:
if self.scope_type == 'class':
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
stats.append(SingleAssignmentNode(node.pos,
lhs=NameNode(node.pos, name=decl.name),
rhs=decl.default, first=True))
decl.default = None
return stats
# buffer access
buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
options = {}
# Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY)
raise PostParseError(node.pos, ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name)
options[name] = arg
......@@ -114,21 +172,19 @@ class PostParse(CythonTransform):
for item in node.keyword_args.key_value_pairs:
name = str(item.key.value)
if not name in self.buffer_options:
raise PostParseError(item.key.pos,
ERR_BUF_UNKNOWN % name)
raise PostParseError(item.key.pos, ERR_BUF_OPTION_UNKNOWN % name)
if name in options.keys():
raise PostParseError(item.key.pos,
ERR_BUF_DUP % key)
raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
options[name] = item.value
provided = options.keys()
# get dtype
dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
if dtype is None:
raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype_node = dtype
# get ndim
if "ndim" in provided:
if "ndim" in options:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
......@@ -141,6 +197,17 @@ class PostParse(CythonTransform):
else:
node.ndim = 1
if "mode" in options:
modenode = options["mode"]
if not isinstance(modenode, StringNode):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
mode = modenode.value
if not mode in ('full', 'strided'):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
node.mode = mode
else:
node.mode = 'full'
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
......@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform):
self.env_stack.pop()
return node
# Some nodes are no longer needed after declaration
# analysis and can be dropped. The analysis was performed
# on these nodes in a seperate recursive process from the
# enclosing function or module, so we can simply drop them.
def visit_CVarDefNode(self, node):
return None
class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope)
......
......@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty):
# Treat trailing [] on type as buffer access
if 0: # s.sy == '[':
if is_basic:
s.error("Basic C types do not support buffer access")
if not is_basic and s.sy == '[':
return p_buffer_access(s, type_node)
else:
return type_node
......
......@@ -6,21 +6,6 @@ from Cython import Utils
import Naming
import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType:
#
......@@ -57,6 +42,7 @@ class PyrexType(BaseType):
# is_unicode boolean Is a UTF-8 encoded C char * type
# is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type
# is_buffer boolean Is buffer access type
# has_attributes boolean Has C dot-selectable attributes
# default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple
......@@ -106,11 +92,11 @@ class PyrexType(BaseType):
is_unicode = 0
is_returncode = 0
is_error = 0
is_buffer = 0
has_attributes = 0
default_value = ""
parsetuple_format = ""
pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self):
# If a typedef, returns the base type.
......@@ -202,6 +188,37 @@ class CTypedefType(BaseType):
def __getattr__(self, name):
return getattr(self.typedef_base_type, name)
class BufferType(BaseType):
#
# Delegates most attribute
# lookups to the base type. ANYTHING NOT DEFINED
# HERE IS DELEGATED!
# dtype PyrexType
# ndim int
# mode str
# is_buffer boolean
# writable boolean
is_buffer = 1
writable = True
def __init__(self, base, dtype, ndim, mode):
self.base = base
self.dtype = dtype
self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode
def as_argument_type(self):
return self
def __getattr__(self, name):
return getattr(self.base, name)
def __repr__(self):
return "<BufferType %r>" % self.base
class PyObjectType(PyrexType):
#
# Base class for all Python object types (reference-counted).
......
......@@ -15,11 +15,14 @@ from TypeSlots import \
get_special_method_signature, get_property_accessor_signature
import ControlFlow
import __builtin__
from sets import Set as set
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux:
writable_needed = False
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var
self.stridevars = stridevars
......@@ -216,6 +219,7 @@ class Scope:
self.num_to_entry = {}
self.obj_to_entry = {}
self.pystring_entries = []
self.buffer_entries = []
self.control_flow = ControlFlow.LinearControlFlow()
def start_branching(self, pos):
......@@ -616,8 +620,11 @@ class Scope:
return [entry for entry in self.temp_entries
if entry not in self.free_temp_entries]
def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code)
def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used.
......@@ -738,6 +745,7 @@ class ModuleScope(Scope):
# doc_cname string C name of module doc string
# const_counter integer Counter for naming constants
# utility_code_used [string] Utility code to be included
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included
......@@ -772,6 +780,7 @@ class ModuleScope(Scope):
self.doc_cname = Naming.moddoc_cname
self.const_counter = 1
self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = []
self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"]
......@@ -930,13 +939,25 @@ class ModuleScope(Scope):
self.const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code):
def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included,
# if not already there (tested using 'is').
# if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used:
if old_code is new_code:
return
self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None,
......
......@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest):
self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x")
# See also tests/error/e_bufaccess.pyx and tets/run/bufaccess.pyx
class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets
def parse_opts(self, opts):
s = u"cdef object[%s] x" % opts
root = self.fragment(s, pipeline=[PostParse(self)]).root
buftype = root.stats[0].base_type
def nonfatal_error(self, error):
# We're passing self as context to transform to trap this
self.error = error
self.assert_(self.expect_error)
def parse_opts(self, opts, expect_error=False):
s = u"def f():\n cdef object[%s] x" % opts
self.expect_error = expect_error
root = self.fragment(s, pipeline=[NormalizeTree(self), PostParse(self)]).root
if not expect_error:
vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
else:
self.assert_(len(root.stats[0].body.stats) == 0)
def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, e.message_only)
self.parse_opts(opts, expect_error=True)
# e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, self.error.message_only)
def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3")
......@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest):
def test_use_DEF(self):
t = self.fragment(u"""
DEF ndim = 3
def f():
cdef object[int, ndim] x
cdef object[ndim=ndim, dtype=int] y
""", pipeline=[PostParse(self)]).root
self.assert_(t.stats[1].base_type.ndim == 3)
self.assert_(t.stats[2].base_type.ndim == 3)
""", pipeline=[NormalizeTree(self), PostParse(self)]).root
stats = t.stats[0].body.stats
self.assert_(stats[0].base_type.ndim == 3)
self.assert_(stats[1].base_type.ndim == 3)
# add exotic and impossible combinations as they come along
# add exotic and impossible combinations as they come along...
......@@ -6,6 +6,8 @@ import re
from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope
import Symtab
import PyrexTypes
from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node, StatListNode
from ExprNodes import NameNode
......@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = temp_name_handle(key)
self.temps = tempdict
tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
self.temp_key_to_entries = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
def get_pos(self, node):
if self.pos:
return self.pos
else:
return node.pos
def visit_Node(self, node):
if node is None:
......@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node):
tempname = self.temps.get(node.name)
if tempname is not None:
tempentry = self.temp_key_to_entries.get(node.name)
if tempentry is not None:
# Replace name with temporary
node.name = tempname
return self.visit_Node(node)
return NameNode(self.get_pos(node), name=tempentry)
# Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
else:
return self.try_substitution(node, node.name)
......@@ -157,6 +164,7 @@ def copy_code_tree(node):
INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines"
# TODO: Facilitate textwrap.indent instead
lines = [x for x in lines if x.strip() != u""]
minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
lines = [x[minindent:] for x in lines]
......
# Please see the Python header files (object.h) for docs
cdef extern from "Python.h":
ctypedef void PyObject
ctypedef struct bufferinfo:
void *buf
Py_ssize_t len
Py_ssize_t itemsize
int readonly
int ndim
char *format
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
void *internal
ctypedef bufferinfo Py_buffer
cdef enum:
PyBUF_SIMPLE,
PyBUF_WRITABLE,
PyBUF_WRITEABLE, # backwards compatability
PyBUF_FORMAT,
PyBUF_ND,
PyBUF_STRIDES,
PyBUF_C_CONTIGUOUS,
PyBUF_F_CONTIGUOUS,
PyBUF_ANY_CONTIGUOUS,
PyBUF_INDIRECT,
PyBUF_CONTIG,
PyBUF_CONTIG_RO,
PyBUF_STRIDED,
PyBUF_STRIDED_RO,
PyBUF_RECORDS,
PyBUF_RECORDS_RO,
PyBUF_FULL,
PyBUF_FULL_RO,
PyBUF_READ,
PyBUF_WRITE,
PyBUF_SHADOW
int PyObject_CheckBuffer(PyObject* obj)
int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags)
void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view)
void* PyBuffer_GetPointer(Py_buffer *view, Py_ssize_t *indices)
int PyBuffer_SizeFromFormat(char *) # actually const char
int PyBuffer_ToContiguous(void *buf, Py_buffer *view, Py_ssize_t len, char fort)
int PyBuffer_FromContiguous(Py_buffer *view, void *buf, Py_ssize_t len, char fort)
int PyObject_CopyData(PyObject *dest, PyObject *src)
int PyBuffer_IsContiguous(Py_buffer *view, char fort)
void PyBuffer_FillContiguousStrides(int ndims,
Py_ssize_t *shape,
Py_ssize_t *strides,
int itemsize,
char fort)
int PyBuffer_FillInfo(Py_buffer *view, void *buf,
Py_ssize_t len, int readonly,
int flags)
PyObject* PyObject_Format(PyObject* obj,
PyObject *format_spec)
from cStringIO import StringIO
class StringIOTree(object):
"""
See module docs.
"""
def __init__(self, stream=None):
self.prepended_children = []
if stream is None: stream = StringIO()
self.stream = stream
def getvalue(self):
return ("".join([x.getvalue() for x in self.prepended_children]) +
self.stream.getvalue())
def copyto(self, target):
"""Potentially cheaper than getvalue as no string concatenation
needs to happen."""
for child in self.prepended_children:
child.copyto(target)
target.write(self.stream.getvalue())
def write(self, what):
self.stream.write(what)
def insertion_point(self):
# Save what we have written until now
# (would it be more efficient to check with len(self.stream.getvalue())?
# leaving it out for now)
self.prepended_children.append(StringIOTree(self.stream))
# Construct the new forked object to return
other = StringIOTree()
self.prepended_children.append(other)
self.stream = StringIO()
return other
__doc__ = r"""
Implements a buffer with insertion points. When you know you need to
"get back" to a place and write more later, simply call insertion_point()
at that spot and get a new StringIOTree object that is "left behind".
EXAMPLE:
>>> a = StringIOTree()
>>> a.write('first\n')
>>> b = a.insertion_point()
>>> a.write('third\n')
>>> b.write('second\n')
>>> print a.getvalue()
first
second
third
<BLANKLINE>
>>> c = b.insertion_point()
>>> d = c.insertion_point()
>>> d.write('alpha\n')
>>> b.write('gamma\n')
>>> c.write('beta\n')
>>> print b.getvalue()
second
alpha
beta
gamma
<BLANKLINE>
>>> out = StringIO()
>>> a.copyto(out)
>>> print out.getvalue()
first
second
alpha
beta
gamma
third
<BLANKLINE>
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
......@@ -46,12 +46,13 @@ class ErrorWriter(object):
class TestBuilder(object):
def __init__(self, rootdir, workdir, selectors, annotate,
cleanup_workdir, with_pyregr):
cleanup_workdir, cleanup_sharedlibs, with_pyregr):
self.rootdir = rootdir
self.workdir = workdir
self.selectors = selectors
self.annotate = annotate
self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
self.with_pyregr = with_pyregr
def build_suite(self):
......@@ -84,6 +85,7 @@ class TestBuilder(object):
for filename in filenames:
if not (filename.endswith(".pyx") or filename.endswith(".py")):
continue
if filename.startswith('.'): continue # certain emacs backup files
if context == 'pyregr' and not filename.startswith('test_'):
continue
module = os.path.splitext(filename)[0]
......@@ -99,25 +101,29 @@ class TestBuilder(object):
test = build_test(
path, workdir, module,
annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir)
cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
else:
test = CythonCompileTestCase(
path, workdir, module,
expect_errors=expect_errors,
annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir)
cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
suite.addTest(test)
return suite
class CythonCompileTestCase(unittest.TestCase):
def __init__(self, directory, workdir, module,
expect_errors=False, annotate=False, cleanup_workdir=True):
expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True):
self.directory = directory
self.workdir = workdir
self.module = module
self.expect_errors = expect_errors
self.annotate = annotate
self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
unittest.TestCase.__init__(self)
def shortDescription(self):
......@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase):
def tearDown(self):
cleanup_c_files = WITH_CYTHON and self.cleanup_workdir
cleanup_lib_files = self.cleanup_sharedlibs
if os.path.exists(self.workdir):
for rmfile in os.listdir(self.workdir):
if not cleanup_c_files and rmfile[-2:] in (".c", ".h"):
continue
if not cleanup_lib_files and rmfile.endswith(".so") or rmfile.endswith(".dll"):
continue
if self.annotate and rmfile.endswith(".html"):
continue
try:
......@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase):
except Exception:
pass
def collect_unittests(path, suite, selectors):
def collect_unittests(path, module_prefix, suite, selectors):
def file_matches(filename):
return filename.startswith("Test") and filename.endswith(".py")
......@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors):
for f in filenames:
if file_matches(f):
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = filepath[len(path)+1:].replace(os.path.sep, '.')
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
......@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors):
module = getattr(module, x)
suite.addTests([loader.loadTestsFromModule(module)])
def collect_doctests(path, module_prefix, suite, selectors):
def package_matches(dirname):
return dirname not in ("Mac", "Distutils", "Plex")
def file_matches(filename):
return (filename.endswith(".py") and not ('~' in filename
or '#' in filename or filename.startswith('.')))
import doctest, types
for dirpath, dirnames, filenames in os.walk(path):
parentname = os.path.split(dirpath)[-1]
if package_matches(parentname):
for f in filenames:
if file_matches(f):
if not f.endswith('.py'): continue
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
for x in modulename.split('.')[1:]:
module = getattr(module, x)
if hasattr(module, "__doc__") or hasattr(module, "__test__"):
try:
suite.addTests(doctest.DocTestSuite(module))
except ValueError: # no tests
pass
if __name__ == '__main__':
from optparse import OptionParser
parser = OptionParser()
parser.add_option("--no-cleanup", dest="cleanup_workdir",
action="store_false", default=True,
help="do not delete the generated C files (allows passing --no-cython on next run)")
parser.add_option("--no-cleanup-sharedlibs", dest="cleanup_sharedlibs",
action="store_false", default=True,
help="do not delete the generated shared libary files (allows manual module experimentation)")
parser.add_option("--no-cython", dest="with_cython",
action="store_false", default=True,
help="do not run the Cython compiler, only the C compiler")
parser.add_option("--no-unit", dest="unittests",
action="store_false", default=True,
help="do not run the unit tests")
parser.add_option("--no-doctest", dest="doctests",
action="store_false", default=True,
help="do not run the doctests")
parser.add_option("--no-file", dest="filetests",
action="store_false", default=True,
help="do not run the file based tests")
......@@ -360,6 +401,8 @@ if __name__ == '__main__':
# RUN ALL TESTS!
ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests')
WORKDIR = os.path.join(os.getcwd(), 'BUILD')
UNITTEST_MODULE = "Cython"
UNITTEST_ROOT = os.path.join(os.getcwd(), UNITTEST_MODULE)
if WITH_CYTHON:
if os.path.exists(WORKDIR):
shutil.rmtree(WORKDIR, ignore_errors=True)
......@@ -382,13 +425,15 @@ if __name__ == '__main__':
test_suite = unittest.TestSuite()
if options.unittests:
collect_unittests(os.getcwd(), test_suite, selectors)
collect_unittests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.doctests:
collect_doctests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source,
options.cleanup_workdir,
options.pyregr)
options.annotate_source, options.cleanup_workdir,
options.cleanup_sharedlibs, options.pyregr)
test_suite.addTests([filetests.build_suite()])
unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
......
cdef extern:
cdef func(int[])
cdef object[int] buf
cdef class A:
cdef object[int] buf
def f():
cdef object[fakeoption=True] buf1
cdef object[int, -1] buf1b
cdef object[ndim=-1] buf2
cdef object[int, 'a'] buf3
cdef object[int,2,3,4,5,6] buf4
cdef object[int, 2, 'foo'] buf5
cdef object[int, 2, well] buf6
_ERRORS = u"""
1:11: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option
7:22: "ndim" must be non-negative
8:15: "dtype" missing
9:21: "ndim" must be an integer
10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
"""
cdef class A:
cdef int value = 3
_ERRORS = u"""
2:13: Cannot assign default value to cdef class attributes
"""
This diff is collapsed.
__doc__ = """
>>> test(1, 2)
4 1 2 2 0 7 8
"""
cdef int g = 7
def test(x, int y):
if True:
before = 0
cdef int a = 4, b = x, c = y, *p = &y
cdef object o = int(8)
print a, b, c, p[0], before, g, o
# Also test that pruning cdefs doesn't hurt
def empty():
cdef int i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment