Commit b35bb6c0 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge in Dag's work

parents bbb832bc eaa6d477
...@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'), ...@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'),
class AnnotationCCodeWriter(CCodeWriter): class AnnotationCCodeWriter(CCodeWriter):
def __init__(self, f): def __init__(self, create_from=None, buffer=None):
CCodeWriter.__init__(self, self) CCodeWriter.__init__(self, create_from, buffer)
self.buffer = StringIO() self.annotation_buffer = StringIO()
self.real_f = f if create_from is None:
self.annotations = [] self.annotations = []
self.last_pos = None self.last_pos = None
self.code = {} self.code = {}
else:
def getvalue(self): # When creating an insertion point, keep references to the same database
return self.real_f.getvalue() self.annotation_buffer = create_from.annotation_buffer
self.annotations = create_from.annotations
self.code = create_from.code
def create_new(self, create_from, buffer):
return AnnotationCCodeWriter(create_from, buffer)
def write(self, s): def write(self, s):
self.real_f.write(s) CCodeWriter.write(self, s)
self.buffer.write(s) self.annotation_buffer.write(s)
def mark_pos(self, pos): def mark_pos(self, pos):
# if pos is not None: # if pos is not None:
# CCodeWriter.mark_pos(self, pos) # CCodeWriter.mark_pos(self, pos)
# return # return
if self.last_pos: if self.last_pos:
try: code = self.code.get(self.last_pos[1], "")
code = self.code[self.last_pos[1]] self.code[self.last_pos[1]] = code + self.annotation_buffer.getvalue()
except KeyError: self.annotation_buffer = StringIO()
code = ""
self.code[self.last_pos[1]] = code + self.buffer.getvalue()
self.buffer = StringIO()
self.last_pos = pos self.last_pos = pos
def annotate(self, pos, item): def annotate(self, pos, item):
......
This diff is collapsed.
This diff is collapsed.
...@@ -33,6 +33,7 @@ class CompileError(PyrexError): ...@@ -33,6 +33,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
self.position = position self.position = position
self.message_only = message self.message_only = message
self.reported = False
# Deprecated and withdrawn in 2.6: # Deprecated and withdrawn in 2.6:
# self.message = message # self.message = message
if position: if position:
...@@ -88,17 +89,23 @@ def close_listing_file(): ...@@ -88,17 +89,23 @@ def close_listing_file():
listing_file.close() listing_file.close()
listing_file = None listing_file = None
def error(position, message): def report_error(err):
#print "Errors.error:", repr(position), repr(message) ###
global num_errors global num_errors
err = CompileError(position, message) # See Main.py for why dual reporting occurs. Quick fix for now.
# if position is not None: raise Exception(err) # debug if err.reported: return
err.reported = True
line = "%s\n" % err line = "%s\n" % err
if listing_file: if listing_file:
listing_file.write(line) listing_file.write(line)
if echo_file: if echo_file:
echo_file.write(line) echo_file.write(line)
num_errors = num_errors + 1 num_errors = num_errors + 1
def error(position, message):
#print "Errors.error:", repr(position), repr(message) ###
err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
report_error(err)
return err return err
LEVEL=1 # warn about all errors level 1 or higher LEVEL=1 # warn about all errors level 1 or higher
......
This diff is collapsed.
...@@ -46,6 +46,13 @@ class Context: ...@@ -46,6 +46,13 @@ class Context:
self.pyxs = {} self.pyxs = {}
self.include_directories = include_directories self.include_directories = include_directories
self.future_directives = set() self.future_directives = set()
import os.path
standard_include_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), '..', 'Includes'))
self.include_directories = include_directories + [standard_include_path]
def find_module(self, module_name, def find_module(self, module_name,
relative_to = None, pos = None, need_pxd = 1): relative_to = None, pos = None, need_pxd = 1):
...@@ -323,14 +330,18 @@ class Context: ...@@ -323,14 +330,18 @@ class Context:
verbose_flag = options.show_version, verbose_flag = options.show_version,
cplus = options.cplus) cplus = options.cplus)
def nonfatal_error(self, exc):
return Errors.report_error(exc)
def run_pipeline(self, pipeline, source): def run_pipeline(self, pipeline, source):
errors_occurred = False errors_occurred = False
data = source data = source
try: try:
for phase in pipeline: for phase in pipeline:
data = phase(data) data = phase(data)
except CompileError: except CompileError, err:
errors_occurred = True errors_occurred = True
Errors.report_error(err)
return (errors_occurred, data) return (errors_occurred, data)
def create_parse(context): def create_parse(context):
...@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result): ...@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from Buffer import BufferTransform from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [ return [
create_parse(context), create_parse(context),
# printit,
NormalizeTree(context), NormalizeTree(context),
PostParse(context), PostParse(context),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(context), WithTransform(context),
DecoratorTransform(context), DecoratorTransform(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
BufferTransform(context), # BufferTransform(context),
SwitchTransform(), SwitchTransform(),
OptimizeRefcounting(context),
# CreateClosureClasses(context), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
...@@ -607,6 +621,7 @@ def main(command_line = 0): ...@@ -607,6 +621,7 @@ def main(command_line = 0):
else: else:
options = CompilationOptions(default_options) options = CompilationOptions(default_options)
sources = args sources = args
if options.show_version: if options.show_version:
sys.stderr.write("Cython version %s\n" % Version.version) sys.stderr.write("Cython version %s\n" % Version.version)
if options.working_path!="": if options.working_path!="":
......
...@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_extension_types = h_entries(env.c_class_entries) h_extension_types = h_entries(env.c_class_entries)
if h_types or h_vars or h_funcs or h_extension_types: if h_types or h_vars or h_funcs or h_extension_types:
result.h_file = replace_suffix(result.c_file, ".h") result.h_file = replace_suffix(result.c_file, ".h")
h_code = Code.CCodeWriter(open_new_file(result.h_file)) h_code = Code.CCodeWriter()
if options.generate_pxi: if options.generate_pxi:
result.i_file = replace_suffix(result.c_file, ".pxi") result.i_file = replace_suffix(result.c_file, ".pxi")
i_code = Code.PyrexCodeWriter(result.i_file) i_code = Code.PyrexCodeWriter(result.i_file)
...@@ -129,6 +129,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -129,6 +129,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("PyMODINIT_FUNC init%s(void);" % env.module_name) h_code.putln("PyMODINIT_FUNC init%s(void);" % env.module_name)
h_code.putln("") h_code.putln("")
h_code.putln("#endif") h_code.putln("#endif")
h_code.copyto(open_new_file(result.h_file))
def generate_public_declaration(self, entry, h_code, i_code): def generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % ( h_code.putln("%s %s;" % (
...@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_api_extension_types = 1 has_api_extension_types = 1
if api_funcs or has_api_extension_types: if api_funcs or has_api_extension_types:
result.api_file = replace_suffix(result.c_file, "_api.h") result.api_file = replace_suffix(result.c_file, "_api.h")
h_code = Code.CCodeWriter(open_new_file(result.api_file)) h_code = Code.CCodeWriter()
name = self.api_name(env) name = self.api_name(env)
guard = Naming.api_guard_prefix + name guard = Naming.api_guard_prefix + name
h_code.put_h_guard(guard) h_code.put_h_guard(guard)
...@@ -209,6 +211,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -209,6 +211,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("}") h_code.putln("}")
h_code.putln("") h_code.putln("")
h_code.putln("#endif") h_code.putln("#endif")
h_code.copyto(open_new_file(result.api_file))
def generate_cclass_header_code(self, type, h_code): def generate_cclass_header_code(self, type, h_code):
h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % ( h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % (
...@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_code(self, env, options, result): def generate_c_code(self, env, options, result):
modules = self.referenced_modules modules = self.referenced_modules
if Options.annotate or options.annotate: if Options.annotate or options.annotate:
code = Annotate.AnnotationCCodeWriter(StringIO()) code = Annotate.AnnotationCCodeWriter()
else: else:
code = Code.CCodeWriter(StringIO()) code = Code.CCodeWriter()
code.h = Code.CCodeWriter(StringIO()) h_code = code.insertion_point()
code.init_labels() self.generate_module_preamble(env, modules, h_code)
self.generate_module_preamble(env, modules, code.h)
code.putln("") code.putln("")
code.putln("/* Implementation of %s */" % env.qualified_name) code.putln("/* Implementation of %s */" % env.qualified_name)
...@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.mark_pos(None) code.mark_pos(None)
self.generate_module_cleanup_func(env, code) self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code) self.generate_filename_table(code)
self.generate_utility_functions(env, code) self.generate_utility_functions(env, code, h_code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h) self.generate_declarations_for_modules(env, modules, h_code)
h_code.write('\n')
f = open_new_file(result.c_file) f = open_new_file(result.c_file)
f.write(code.h.f.getvalue()) code.copyto(f)
f.write("\n")
f.write(code.f.getvalue())
f.close() f.close()
result.c_file_generated = 1 result.c_file_generated = 1
if Options.annotate or options.annotate: if Options.annotate or options.annotate:
...@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("") code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif") code.putln("#endif")
code.put(builtin_module_name_utility_code[0]) code.put(builtin_module_name_utility_code[0])
...@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("0") code.putln("0")
code.putln("};") code.putln("};")
code.putln() code.putln()
code.enter_cfunc_scope() # as we need labels
code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set) code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set)
code.putln("char** type_name = %s_type_names;" % Naming.import_star) code.putln("char** type_name = %s_type_names;" % Naming.import_star)
code.putln("while (*type_name) {") code.putln("while (*type_name) {")
...@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("return -1;") code.putln("return -1;")
code.putln("}") code.putln("}")
code.putln(import_star_utility_code) code.putln(import_star_utility_code)
code.exit_cfunc_scope() # done with labels
def generate_module_init_func(self, imported_modules, env, code): def generate_module_init_func(self, imported_modules, env, code):
code.enter_cfunc_scope()
code.putln("") code.putln("")
header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name
header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name
...@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(header3) code.putln(header3)
code.putln("#endif") code.putln("#endif")
code.putln("{") code.putln("{")
tempdecl_code = code.insertion_point()
code.put_var_declarations(env.temp_entries)
code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos))); code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos)));
code.putln("/*--- Libary function declarations ---*/") code.putln("/*--- Libary function declarations ---*/")
...@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
code.mark_pos(None) code.mark_pos(None)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if Options.generate_cleanup_code: if Options.generate_cleanup_code:
...@@ -1609,6 +1611,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1609,6 +1611,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("return NULL;") code.putln("return NULL;")
code.putln("#endif") code.putln("#endif")
code.putln('}') code.putln('}')
tempdecl_code.put_var_declarations(env.temp_entries)
tempdecl_code.put_temp_declarations(code.func)
code.exit_cfunc_scope()
def generate_module_cleanup_func(self, env, code): def generate_module_cleanup_func(self, env, code):
if not Options.generate_cleanup_code: if not Options.generate_cleanup_code:
...@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"%s = &%s;" % ( "%s = &%s;" % (
type.typeptr_cname, type.typeobj_cname)) type.typeptr_cname, type.typeobj_cname))
def generate_utility_functions(self, env, code): def generate_utility_functions(self, env, code, h_code):
code.putln("") code.putln("")
code.putln("/* Runtime support code */") code.putln("/* Runtime support code */")
code.putln("") code.putln("")
...@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
(Naming.filetable_cname, Naming.filenames_cname)) (Naming.filetable_cname, Naming.filenames_cname))
code.putln("}") code.putln("}")
for utility_code in env.utility_code_used: for utility_code in env.utility_code_used:
code.h.put(utility_code[0]) h_code.put(utility_code[0])
code.put(utility_code[1]) code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("") code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# For now, hard-code numpy imported as "numpy"
types = []
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
pyrex_prefix = "__pyx_" pyrex_prefix = "__pyx_"
codewriter_temp_prefix = pyrex_prefix + "t_"
temp_prefix = u"__cyt_" temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_" builtin_prefix = pyrex_prefix + "builtin_"
...@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_" ...@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_"
type_prefix = pyrex_prefix + "t_" type_prefix = pyrex_prefix + "t_"
typeobj_prefix = pyrex_prefix + "type_" typeobj_prefix = pyrex_prefix + "type_"
var_prefix = pyrex_prefix + "v_" var_prefix = pyrex_prefix + "v_"
bufstruct_prefix = pyrex_prefix + "bstruct_"
bufstride_prefix = pyrex_prefix + "bstride_"
bufshape_prefix = pyrex_prefix + "bshape_"
bufsuboffset_prefix = pyrex_prefix + "boffset_"
vtable_prefix = pyrex_prefix + "vtable_" vtable_prefix = pyrex_prefix + "vtable_"
vtabptr_prefix = pyrex_prefix + "vtabptr_" vtabptr_prefix = pyrex_prefix + "vtabptr_"
vtabstruct_prefix = pyrex_prefix + "vtabstruct_" vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
......
This diff is collapsed.
...@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform): ...@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform):
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
return node return node
class OptimizeRefcounting(Visitor.CythonTransform):
def visit_SingleAssignmentNode(self, node):
if node.first:
lhs = node.lhs
if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
# Have variable initialized to 0 rather than None
lhs.entry.init_to_none = False
lhs.entry.init = 0
# Set a flag in NameNode to skip the decref
lhs.skip_assignment_decref = True
return node
...@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied' ...@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing' ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer' ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative' ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
class PostParse(CythonTransform): class PostParse(CythonTransform):
""" """
Basic interpretation of the parse tree, as well as validity Basic interpretation of the parse tree, as well as validity
...@@ -91,6 +93,9 @@ class PostParse(CythonTransform): ...@@ -91,6 +93,9 @@ class PostParse(CythonTransform):
as such). as such).
Specifically: Specifically:
- Default values to cdef assignments are turned into single
assignments following the declaration (everywhere but in class
bodies, where they raise a compile error)
- CBufferAccessTypeNode has its options interpreted: - CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute, Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and any "ndim" keyword argument goes into the "ndim" attribute and
...@@ -101,12 +106,65 @@ class PostParse(CythonTransform): ...@@ -101,12 +106,65 @@ class PostParse(CythonTransform):
if a more pure Abstract Syntax Tree is wanted. if a more pure Abstract Syntax Tree is wanted.
""" """
buffer_options = ("dtype", "ndim") # ordered! # Track our context.
scope_type = None # can be either of 'module', 'function', 'class'
def visit_ModuleNode(self, node):
self.scope_type = 'module'
self.visitchildren(node)
return node
def visit_ClassDefNode(self, node):
prev = self.scope_type
self.scope_type = 'class'
self.visitchildren(node)
self.scope_type = prev
return node
def visit_FuncDefNode(self, node):
prev = self.scope_type
self.scope_type = 'function'
self.visitchildren(node)
self.scope_type = prev
return node
# cdef variables
def visit_CVarDefNode(self, node):
# This assumes only plain names and pointers are assignable on
# declaration. Also, it makes use of the fact that a cdef decl
# must appear before the first use, so we don't have to deal with
# "i = 3; cdef int i = i" and can simply move the nodes around.
try:
self.visitchildren(node)
except PostParseError, e:
# An error in a cdef clause is ok, simply remove the declaration
# and try to move on to report more errors
self.context.nonfatal_error(e)
return None
stats = [node]
for decl in node.declarators:
while isinstance(decl, CPtrDeclaratorNode):
decl = decl.base
if isinstance(decl, CNameDeclaratorNode):
if decl.default is not None:
if self.scope_type == 'class':
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
stats.append(SingleAssignmentNode(node.pos,
lhs=NameNode(node.pos, name=decl.name),
rhs=decl.default, first=True))
decl.default = None
return stats
# buffer access
buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node): def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
options = {} options = {}
# Fetch positional arguments # Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options): if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY) raise PostParseError(node.pos, ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options): for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name) name = str(unicode_name)
options[name] = arg options[name] = arg
...@@ -114,21 +172,19 @@ class PostParse(CythonTransform): ...@@ -114,21 +172,19 @@ class PostParse(CythonTransform):
for item in node.keyword_args.key_value_pairs: for item in node.keyword_args.key_value_pairs:
name = str(item.key.value) name = str(item.key.value)
if not name in self.buffer_options: if not name in self.buffer_options:
raise PostParseError(item.key.pos, raise PostParseError(item.key.pos, ERR_BUF_OPTION_UNKNOWN % name)
ERR_BUF_UNKNOWN % name)
if name in options.keys(): if name in options.keys():
raise PostParseError(item.key.pos, raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
ERR_BUF_DUP % key)
options[name] = item.value options[name] = item.value
provided = options.keys()
# get dtype # get dtype
dtype = options.get("dtype") dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype') if dtype is None:
raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype_node = dtype node.dtype_node = dtype
# get ndim # get ndim
if "ndim" in provided: if "ndim" in options:
ndimnode = options["ndim"] ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode): if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser, # Compile-time values (DEF) are currently resolved by the parser,
...@@ -140,7 +196,18 @@ class PostParse(CythonTransform): ...@@ -140,7 +196,18 @@ class PostParse(CythonTransform):
node.ndim = int(ndimnode.value) node.ndim = int(ndimnode.value)
else: else:
node.ndim = 1 node.ndim = 1
if "mode" in options:
modenode = options["mode"]
if not isinstance(modenode, StringNode):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
mode = modenode.value
if not mode in ('full', 'strided'):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
node.mode = mode
else:
node.mode = 'full'
# We're done with the parse tree args # We're done with the parse tree args
node.positional_args = None node.positional_args = None
node.keyword_args = None node.keyword_args = None
...@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform): ...@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform):
self.env_stack.pop() self.env_stack.pop()
return node return node
# Some nodes are no longer needed after declaration
# analysis and can be dropped. The analysis was performed
# on these nodes in a seperate recursive process from the
# enclosing function or module, so we can simply drop them.
def visit_CVarDefNode(self, node):
return None
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
...@@ -263,7 +337,7 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -263,7 +337,7 @@ class AnalyseExpressionsTransform(CythonTransform):
node.body.analyse_expressions(node.local_scope) node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
class MarkClosureVisitor(CythonTransform): class MarkClosureVisitor(CythonTransform):
needs_closure = False needs_closure = False
......
...@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty):
# Treat trailing [] on type as buffer access # Treat trailing [] on type as buffer access
if 0: # s.sy == '[': if not is_basic and s.sy == '[':
if is_basic:
s.error("Basic C types do not support buffer access")
return p_buffer_access(s, type_node) return p_buffer_access(s, type_node)
else: else:
return type_node return type_node
......
...@@ -6,21 +6,6 @@ from Cython import Utils ...@@ -6,21 +6,6 @@ from Cython import Utils
import Naming import Naming
import copy import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType: class BaseType:
# #
...@@ -57,6 +42,7 @@ class PyrexType(BaseType): ...@@ -57,6 +42,7 @@ class PyrexType(BaseType):
# is_unicode boolean Is a UTF-8 encoded C char * type # is_unicode boolean Is a UTF-8 encoded C char * type
# is_returncode boolean Is used only to signal exceptions # is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type # is_error boolean Is the dummy error type
# is_buffer boolean Is buffer access type
# has_attributes boolean Has C dot-selectable attributes # has_attributes boolean Has C dot-selectable attributes
# default_value string Initial value # default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple # parsetuple_format string Format char for PyArg_ParseTuple
...@@ -106,11 +92,11 @@ class PyrexType(BaseType): ...@@ -106,11 +92,11 @@ class PyrexType(BaseType):
is_unicode = 0 is_unicode = 0
is_returncode = 0 is_returncode = 0
is_error = 0 is_error = 0
is_buffer = 0
has_attributes = 0 has_attributes = 0
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -202,6 +188,37 @@ class CTypedefType(BaseType): ...@@ -202,6 +188,37 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class BufferType(BaseType):
#
# Delegates most attribute
# lookups to the base type. ANYTHING NOT DEFINED
# HERE IS DELEGATED!
# dtype PyrexType
# ndim int
# mode str
# is_buffer boolean
# writable boolean
is_buffer = 1
writable = True
def __init__(self, base, dtype, ndim, mode):
self.base = base
self.dtype = dtype
self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode
def as_argument_type(self):
return self
def __getattr__(self, name):
return getattr(self.base, name)
def __repr__(self):
return "<BufferType %r>" % self.base
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
...@@ -927,7 +944,7 @@ class CEnumType(CType): ...@@ -927,7 +944,7 @@ class CEnumType(CType):
# name string # name string
# cname string or None # cname string or None
# typedef_flag boolean # typedef_flag boolean
is_enum = 1 is_enum = 1
signed = 1 signed = 1
rank = -1 # Ranks below any integer type rank = -1 # Ranks below any integer type
......
...@@ -15,11 +15,14 @@ from TypeSlots import \ ...@@ -15,11 +15,14 @@ from TypeSlots import \
get_special_method_signature, get_property_accessor_signature get_special_method_signature, get_property_accessor_signature
import ControlFlow import ControlFlow
import __builtin__ import __builtin__
from sets import Set as set
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux: class BufferAux:
writable_needed = False
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker): def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var self.buffer_info_var = buffer_info_var
self.stridevars = stridevars self.stridevars = stridevars
...@@ -141,7 +144,7 @@ class Entry: ...@@ -141,7 +144,7 @@ class Entry:
def redeclared(self, pos): def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
error(self.pos, "Previous declaration is here") error(self.pos, "Previous declaration is here")
class Scope: class Scope:
# name string Unqualified name # name string Unqualified name
# outer_scope Scope or None Enclosing scope # outer_scope Scope or None Enclosing scope
...@@ -216,6 +219,7 @@ class Scope: ...@@ -216,6 +219,7 @@ class Scope:
self.num_to_entry = {} self.num_to_entry = {}
self.obj_to_entry = {} self.obj_to_entry = {}
self.pystring_entries = [] self.pystring_entries = []
self.buffer_entries = []
self.control_flow = ControlFlow.LinearControlFlow() self.control_flow = ControlFlow.LinearControlFlow()
def start_branching(self, pos): def start_branching(self, pos):
...@@ -616,9 +620,12 @@ class Scope: ...@@ -616,9 +620,12 @@ class Scope:
return [entry for entry in self.temp_entries return [entry for entry in self.temp_entries
if entry not in self.free_temp_entries] if entry not in self.free_temp_entries]
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code) self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code): def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used. # Generate extern decls for C library funcs used.
#if self.pow_function_used: #if self.pow_function_used:
...@@ -738,6 +745,7 @@ class ModuleScope(Scope): ...@@ -738,6 +745,7 @@ class ModuleScope(Scope):
# doc_cname string C name of module doc string # doc_cname string C name of module doc string
# const_counter integer Counter for naming constants # const_counter integer Counter for naming constants
# utility_code_used [string] Utility code to be included # utility_code_used [string] Utility code to be included
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries # default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included # python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included # include_files [string] Other C headers to be included
...@@ -772,6 +780,7 @@ class ModuleScope(Scope): ...@@ -772,6 +780,7 @@ class ModuleScope(Scope):
self.doc_cname = Naming.moddoc_cname self.doc_cname = Naming.moddoc_cname
self.const_counter = 1 self.const_counter = 1
self.utility_code_used = [] self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = [] self.default_entries = []
self.module_entries = {} self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"] self.python_include_files = ["Python.h", "structmember.h"]
...@@ -930,13 +939,25 @@ class ModuleScope(Scope): ...@@ -930,13 +939,25 @@ class ModuleScope(Scope):
self.const_counter = n + 1 self.const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n) return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included, # Add string to list of utility code to be included,
# if not already there (tested using 'is'). # if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used: for old_code in self.utility_code_used:
if old_code is new_code: if old_code is new_code:
return return
self.utility_code_used.append(new_code) self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0, def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None, module_name = None, base_type = None, objstruct_cname = None,
......
...@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest): ...@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest):
self.not_parseable("Non-keyword arg following keyword arg", self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x") u"cdef object[foo=1, 2] x")
# See also tests/error/e_bufaccess.pyx and tets/run/bufaccess.pyx
class TestBufferOptions(CythonTest): class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets # Tests the full parsing of the options within the brackets
def parse_opts(self, opts): def nonfatal_error(self, error):
s = u"cdef object[%s] x" % opts # We're passing self as context to transform to trap this
root = self.fragment(s, pipeline=[PostParse(self)]).root self.error = error
buftype = root.stats[0].base_type self.assert_(self.expect_error)
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) def parse_opts(self, opts, expect_error=False):
self.assertEqual(u"object", buftype.base_type_node.name) s = u"def f():\n cdef object[%s] x" % opts
return buftype self.expect_error = expect_error
root = self.fragment(s, pipeline=[NormalizeTree(self), PostParse(self)]).root
if not expect_error:
vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
else:
self.assert_(len(root.stats[0].body.stats) == 0)
def non_parse(self, expected_err, opts): def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts)) self.parse_opts(opts, expect_error=True)
self.assertEqual(expected_err, e.message_only) # e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, self.error.message_only)
def test_basic(self): def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3") buf = self.parse_opts(u"unsigned short int, 3")
...@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest): ...@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest):
def test_use_DEF(self): def test_use_DEF(self):
t = self.fragment(u""" t = self.fragment(u"""
DEF ndim = 3 DEF ndim = 3
cdef object[int, ndim] x def f():
cdef object[ndim=ndim, dtype=int] y cdef object[int, ndim] x
""", pipeline=[PostParse(self)]).root cdef object[ndim=ndim, dtype=int] y
self.assert_(t.stats[1].base_type.ndim == 3) """, pipeline=[NormalizeTree(self), PostParse(self)]).root
self.assert_(t.stats[2].base_type.ndim == 3) stats = t.stats[0].body.stats
self.assert_(stats[0].base_type.ndim == 3)
self.assert_(stats[1].base_type.ndim == 3)
# add exotic and impossible combinations as they come along # add exotic and impossible combinations as they come along...
...@@ -6,6 +6,8 @@ import re ...@@ -6,6 +6,8 @@ import re
from cStringIO import StringIO from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
import Symtab
import PyrexTypes
from Visitor import VisitorTransform, temp_name_handle from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node, StatListNode from Nodes import Node, StatListNode
from ExprNodes import NameNode from ExprNodes import NameNode
...@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform): ...@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform):
self.substitutions = substitutions self.substitutions = substitutions
tempdict = {} tempdict = {}
for key in temps: for key in temps:
tempdict[key] = temp_name_handle(key) tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
self.temps = tempdict self.temp_key_to_entries = tempdict
self.pos = pos self.pos = pos
return super(TemplateTransform, self).__call__(node) return super(TemplateTransform, self).__call__(node)
def get_pos(self, node):
if self.pos:
return self.pos
else:
return node.pos
def visit_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
...@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform): ...@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node): def visit_NameNode(self, node):
tempname = self.temps.get(node.name) tempentry = self.temp_key_to_entries.get(node.name)
if tempname is not None: if tempentry is not None:
# Replace name with temporary # Replace name with temporary
node.name = tempname return NameNode(self.get_pos(node), name=tempentry)
return self.visit_Node(node) # Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
else: else:
return self.try_substitution(node, node.name) return self.try_substitution(node, node.name)
...@@ -157,6 +164,7 @@ def copy_code_tree(node): ...@@ -157,6 +164,7 @@ def copy_code_tree(node):
INDENT_RE = re.compile(ur"^ *") INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines): def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines" "Strips empty lines and common indentation from the list of strings given in lines"
# TODO: Facilitate textwrap.indent instead
lines = [x for x in lines if x.strip() != u""] lines = [x for x in lines if x.strip() != u""]
minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines]) minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
lines = [x[minindent:] for x in lines] lines = [x[minindent:] for x in lines]
......
# Please see the Python header files (object.h) for docs
cdef extern from "Python.h":
ctypedef void PyObject
ctypedef struct bufferinfo:
void *buf
Py_ssize_t len
Py_ssize_t itemsize
int readonly
int ndim
char *format
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
void *internal
ctypedef bufferinfo Py_buffer
cdef enum:
PyBUF_SIMPLE,
PyBUF_WRITABLE,
PyBUF_WRITEABLE, # backwards compatability
PyBUF_FORMAT,
PyBUF_ND,
PyBUF_STRIDES,
PyBUF_C_CONTIGUOUS,
PyBUF_F_CONTIGUOUS,
PyBUF_ANY_CONTIGUOUS,
PyBUF_INDIRECT,
PyBUF_CONTIG,
PyBUF_CONTIG_RO,
PyBUF_STRIDED,
PyBUF_STRIDED_RO,
PyBUF_RECORDS,
PyBUF_RECORDS_RO,
PyBUF_FULL,
PyBUF_FULL_RO,
PyBUF_READ,
PyBUF_WRITE,
PyBUF_SHADOW
int PyObject_CheckBuffer(PyObject* obj)
int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags)
void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view)
void* PyBuffer_GetPointer(Py_buffer *view, Py_ssize_t *indices)
int PyBuffer_SizeFromFormat(char *) # actually const char
int PyBuffer_ToContiguous(void *buf, Py_buffer *view, Py_ssize_t len, char fort)
int PyBuffer_FromContiguous(Py_buffer *view, void *buf, Py_ssize_t len, char fort)
int PyObject_CopyData(PyObject *dest, PyObject *src)
int PyBuffer_IsContiguous(Py_buffer *view, char fort)
void PyBuffer_FillContiguousStrides(int ndims,
Py_ssize_t *shape,
Py_ssize_t *strides,
int itemsize,
char fort)
int PyBuffer_FillInfo(Py_buffer *view, void *buf,
Py_ssize_t len, int readonly,
int flags)
PyObject* PyObject_Format(PyObject* obj,
PyObject *format_spec)
from cStringIO import StringIO
class StringIOTree(object):
"""
See module docs.
"""
def __init__(self, stream=None):
self.prepended_children = []
if stream is None: stream = StringIO()
self.stream = stream
def getvalue(self):
return ("".join([x.getvalue() for x in self.prepended_children]) +
self.stream.getvalue())
def copyto(self, target):
"""Potentially cheaper than getvalue as no string concatenation
needs to happen."""
for child in self.prepended_children:
child.copyto(target)
target.write(self.stream.getvalue())
def write(self, what):
self.stream.write(what)
def insertion_point(self):
# Save what we have written until now
# (would it be more efficient to check with len(self.stream.getvalue())?
# leaving it out for now)
self.prepended_children.append(StringIOTree(self.stream))
# Construct the new forked object to return
other = StringIOTree()
self.prepended_children.append(other)
self.stream = StringIO()
return other
__doc__ = r"""
Implements a buffer with insertion points. When you know you need to
"get back" to a place and write more later, simply call insertion_point()
at that spot and get a new StringIOTree object that is "left behind".
EXAMPLE:
>>> a = StringIOTree()
>>> a.write('first\n')
>>> b = a.insertion_point()
>>> a.write('third\n')
>>> b.write('second\n')
>>> print a.getvalue()
first
second
third
<BLANKLINE>
>>> c = b.insertion_point()
>>> d = c.insertion_point()
>>> d.write('alpha\n')
>>> b.write('gamma\n')
>>> c.write('beta\n')
>>> print b.getvalue()
second
alpha
beta
gamma
<BLANKLINE>
>>> out = StringIO()
>>> a.copyto(out)
>>> print out.getvalue()
first
second
alpha
beta
gamma
third
<BLANKLINE>
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
...@@ -46,12 +46,13 @@ class ErrorWriter(object): ...@@ -46,12 +46,13 @@ class ErrorWriter(object):
class TestBuilder(object): class TestBuilder(object):
def __init__(self, rootdir, workdir, selectors, annotate, def __init__(self, rootdir, workdir, selectors, annotate,
cleanup_workdir, with_pyregr): cleanup_workdir, cleanup_sharedlibs, with_pyregr):
self.rootdir = rootdir self.rootdir = rootdir
self.workdir = workdir self.workdir = workdir
self.selectors = selectors self.selectors = selectors
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
self.with_pyregr = with_pyregr self.with_pyregr = with_pyregr
def build_suite(self): def build_suite(self):
...@@ -84,6 +85,7 @@ class TestBuilder(object): ...@@ -84,6 +85,7 @@ class TestBuilder(object):
for filename in filenames: for filename in filenames:
if not (filename.endswith(".pyx") or filename.endswith(".py")): if not (filename.endswith(".pyx") or filename.endswith(".py")):
continue continue
if filename.startswith('.'): continue # certain emacs backup files
if context == 'pyregr' and not filename.startswith('test_'): if context == 'pyregr' and not filename.startswith('test_'):
continue continue
module = os.path.splitext(filename)[0] module = os.path.splitext(filename)[0]
...@@ -99,25 +101,29 @@ class TestBuilder(object): ...@@ -99,25 +101,29 @@ class TestBuilder(object):
test = build_test( test = build_test(
path, workdir, module, path, workdir, module,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir) cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
else: else:
test = CythonCompileTestCase( test = CythonCompileTestCase(
path, workdir, module, path, workdir, module,
expect_errors=expect_errors, expect_errors=expect_errors,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir) cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
suite.addTest(test) suite.addTest(test)
return suite return suite
class CythonCompileTestCase(unittest.TestCase): class CythonCompileTestCase(unittest.TestCase):
def __init__(self, directory, workdir, module, def __init__(self, directory, workdir, module,
expect_errors=False, annotate=False, cleanup_workdir=True): expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True):
self.directory = directory self.directory = directory
self.workdir = workdir self.workdir = workdir
self.module = module self.module = module
self.expect_errors = expect_errors self.expect_errors = expect_errors
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
unittest.TestCase.__init__(self) unittest.TestCase.__init__(self)
def shortDescription(self): def shortDescription(self):
...@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
cleanup_c_files = WITH_CYTHON and self.cleanup_workdir cleanup_c_files = WITH_CYTHON and self.cleanup_workdir
cleanup_lib_files = self.cleanup_sharedlibs
if os.path.exists(self.workdir): if os.path.exists(self.workdir):
for rmfile in os.listdir(self.workdir): for rmfile in os.listdir(self.workdir):
if not cleanup_c_files and rmfile[-2:] in (".c", ".h"): if not cleanup_c_files and rmfile[-2:] in (".c", ".h"):
continue continue
if not cleanup_lib_files and rmfile.endswith(".so") or rmfile.endswith(".dll"):
continue
if self.annotate and rmfile.endswith(".html"): if self.annotate and rmfile.endswith(".html"):
continue continue
try: try:
...@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase): ...@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase):
except Exception: except Exception:
pass pass
def collect_unittests(path, suite, selectors): def collect_unittests(path, module_prefix, suite, selectors):
def file_matches(filename): def file_matches(filename):
return filename.startswith("Test") and filename.endswith(".py") return filename.startswith("Test") and filename.endswith(".py")
...@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors): ...@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors):
for f in filenames: for f in filenames:
if file_matches(f): if file_matches(f):
filepath = os.path.join(dirpath, f)[:-len(".py")] filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = filepath[len(path)+1:].replace(os.path.sep, '.') modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]: if not [ 1 for match in selectors if match(modulename) ]:
continue continue
module = __import__(modulename) module = __import__(modulename)
...@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors): ...@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors):
module = getattr(module, x) module = getattr(module, x)
suite.addTests([loader.loadTestsFromModule(module)]) suite.addTests([loader.loadTestsFromModule(module)])
def collect_doctests(path, module_prefix, suite, selectors):
def package_matches(dirname):
return dirname not in ("Mac", "Distutils", "Plex")
def file_matches(filename):
return (filename.endswith(".py") and not ('~' in filename
or '#' in filename or filename.startswith('.')))
import doctest, types
for dirpath, dirnames, filenames in os.walk(path):
parentname = os.path.split(dirpath)[-1]
if package_matches(parentname):
for f in filenames:
if file_matches(f):
if not f.endswith('.py'): continue
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
for x in modulename.split('.')[1:]:
module = getattr(module, x)
if hasattr(module, "__doc__") or hasattr(module, "__test__"):
try:
suite.addTests(doctest.DocTestSuite(module))
except ValueError: # no tests
pass
if __name__ == '__main__': if __name__ == '__main__':
from optparse import OptionParser from optparse import OptionParser
parser = OptionParser() parser = OptionParser()
parser.add_option("--no-cleanup", dest="cleanup_workdir", parser.add_option("--no-cleanup", dest="cleanup_workdir",
action="store_false", default=True, action="store_false", default=True,
help="do not delete the generated C files (allows passing --no-cython on next run)") help="do not delete the generated C files (allows passing --no-cython on next run)")
parser.add_option("--no-cleanup-sharedlibs", dest="cleanup_sharedlibs",
action="store_false", default=True,
help="do not delete the generated shared libary files (allows manual module experimentation)")
parser.add_option("--no-cython", dest="with_cython", parser.add_option("--no-cython", dest="with_cython",
action="store_false", default=True, action="store_false", default=True,
help="do not run the Cython compiler, only the C compiler") help="do not run the Cython compiler, only the C compiler")
parser.add_option("--no-unit", dest="unittests", parser.add_option("--no-unit", dest="unittests",
action="store_false", default=True, action="store_false", default=True,
help="do not run the unit tests") help="do not run the unit tests")
parser.add_option("--no-doctest", dest="doctests",
action="store_false", default=True,
help="do not run the doctests")
parser.add_option("--no-file", dest="filetests", parser.add_option("--no-file", dest="filetests",
action="store_false", default=True, action="store_false", default=True,
help="do not run the file based tests") help="do not run the file based tests")
...@@ -360,6 +401,8 @@ if __name__ == '__main__': ...@@ -360,6 +401,8 @@ if __name__ == '__main__':
# RUN ALL TESTS! # RUN ALL TESTS!
ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests') ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests')
WORKDIR = os.path.join(os.getcwd(), 'BUILD') WORKDIR = os.path.join(os.getcwd(), 'BUILD')
UNITTEST_MODULE = "Cython"
UNITTEST_ROOT = os.path.join(os.getcwd(), UNITTEST_MODULE)
if WITH_CYTHON: if WITH_CYTHON:
if os.path.exists(WORKDIR): if os.path.exists(WORKDIR):
shutil.rmtree(WORKDIR, ignore_errors=True) shutil.rmtree(WORKDIR, ignore_errors=True)
...@@ -382,13 +425,15 @@ if __name__ == '__main__': ...@@ -382,13 +425,15 @@ if __name__ == '__main__':
test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
if options.unittests: if options.unittests:
collect_unittests(os.getcwd(), test_suite, selectors) collect_unittests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.doctests:
collect_doctests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.filetests: if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors, filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source, options.annotate_source, options.cleanup_workdir,
options.cleanup_workdir, options.cleanup_sharedlibs, options.pyregr)
options.pyregr)
test_suite.addTests([filetests.build_suite()]) test_suite.addTests([filetests.build_suite()])
unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite) unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
......
cdef extern:
cdef func(int[])
cdef object[int] buf
cdef class A:
cdef object[int] buf
def f():
cdef object[fakeoption=True] buf1
cdef object[int, -1] buf1b
cdef object[ndim=-1] buf2
cdef object[int, 'a'] buf3
cdef object[int,2,3,4,5,6] buf4
cdef object[int, 2, 'foo'] buf5
cdef object[int, 2, well] buf6
_ERRORS = u"""
1:11: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option
7:22: "ndim" must be non-negative
8:15: "dtype" missing
9:21: "ndim" must be an integer
10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
"""
cdef class A:
cdef int value = 3
_ERRORS = u"""
2:13: Cannot assign default value to cdef class attributes
"""
This diff is collapsed.
__doc__ = """
>>> test(1, 2)
4 1 2 2 0 7 8
"""
cdef int g = 7
def test(x, int y):
if True:
before = 0
cdef int a = 4, b = x, c = y, *p = &y
cdef object o = int(8)
print a, b, c, p[0], before, g, o
# Also test that pruning cdefs doesn't hurt
def empty():
cdef int i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment