Commit b35bb6c0 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge in Dag's work

parents bbb832bc eaa6d477
...@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'), ...@@ -17,32 +17,34 @@ special_chars = [(u'<', u'\xF0', u'&lt;'),
class AnnotationCCodeWriter(CCodeWriter): class AnnotationCCodeWriter(CCodeWriter):
def __init__(self, f): def __init__(self, create_from=None, buffer=None):
CCodeWriter.__init__(self, self) CCodeWriter.__init__(self, create_from, buffer)
self.buffer = StringIO() self.annotation_buffer = StringIO()
self.real_f = f if create_from is None:
self.annotations = [] self.annotations = []
self.last_pos = None self.last_pos = None
self.code = {} self.code = {}
else:
# When creating an insertion point, keep references to the same database
self.annotation_buffer = create_from.annotation_buffer
self.annotations = create_from.annotations
self.code = create_from.code
def getvalue(self): def create_new(self, create_from, buffer):
return self.real_f.getvalue() return AnnotationCCodeWriter(create_from, buffer)
def write(self, s): def write(self, s):
self.real_f.write(s) CCodeWriter.write(self, s)
self.buffer.write(s) self.annotation_buffer.write(s)
def mark_pos(self, pos): def mark_pos(self, pos):
# if pos is not None: # if pos is not None:
# CCodeWriter.mark_pos(self, pos) # CCodeWriter.mark_pos(self, pos)
# return # return
if self.last_pos: if self.last_pos:
try: code = self.code.get(self.last_pos[1], "")
code = self.code[self.last_pos[1]] self.code[self.last_pos[1]] = code + self.annotation_buffer.getvalue()
except KeyError: self.annotation_buffer = StringIO()
code = ""
self.code[self.last_pos[1]] = code + self.buffer.getvalue()
self.buffer = StringIO()
self.last_pos = pos self.last_pos = pos
def annotate(self, pos, item): def annotate(self, pos, item):
......
This diff is collapsed.
This diff is collapsed.
...@@ -33,6 +33,7 @@ class CompileError(PyrexError): ...@@ -33,6 +33,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
self.position = position self.position = position
self.message_only = message self.message_only = message
self.reported = False
# Deprecated and withdrawn in 2.6: # Deprecated and withdrawn in 2.6:
# self.message = message # self.message = message
if position: if position:
...@@ -88,17 +89,23 @@ def close_listing_file(): ...@@ -88,17 +89,23 @@ def close_listing_file():
listing_file.close() listing_file.close()
listing_file = None listing_file = None
def error(position, message): def report_error(err):
#print "Errors.error:", repr(position), repr(message) ###
global num_errors global num_errors
err = CompileError(position, message) # See Main.py for why dual reporting occurs. Quick fix for now.
# if position is not None: raise Exception(err) # debug if err.reported: return
err.reported = True
line = "%s\n" % err line = "%s\n" % err
if listing_file: if listing_file:
listing_file.write(line) listing_file.write(line)
if echo_file: if echo_file:
echo_file.write(line) echo_file.write(line)
num_errors = num_errors + 1 num_errors = num_errors + 1
def error(position, message):
#print "Errors.error:", repr(position), repr(message) ###
err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
report_error(err)
return err return err
LEVEL=1 # warn about all errors level 1 or higher LEVEL=1 # warn about all errors level 1 or higher
......
...@@ -806,6 +806,15 @@ class NameNode(AtomicExprNode): ...@@ -806,6 +806,15 @@ class NameNode(AtomicExprNode):
# interned_cname string # interned_cname string
is_name = 1 is_name = 1
skip_assignment_decref = False
entry = None
def create_analysed_rvalue(pos, env, entry):
node = NameNode(pos)
node.analyse_types(env, entry=entry)
return node
create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def compile_time_value(self, denv): def compile_time_value(self, denv):
try: try:
...@@ -834,6 +843,8 @@ class NameNode(AtomicExprNode): ...@@ -834,6 +843,8 @@ class NameNode(AtomicExprNode):
def analyse_as_module(self, env): def analyse_as_module(self, env):
# Try to interpret this as a reference to a cimported module. # Try to interpret this as a reference to a cimported module.
# Returns the module scope, or None. # Returns the module scope, or None.
entry = self.entry
if not entry:
entry = env.lookup(self.name) entry = env.lookup(self.name)
if entry and entry.as_module: if entry and entry.as_module:
return entry.as_module return entry.as_module
...@@ -842,6 +853,8 @@ class NameNode(AtomicExprNode): ...@@ -842,6 +853,8 @@ class NameNode(AtomicExprNode):
def analyse_as_extension_type(self, env): def analyse_as_extension_type(self, env):
# Try to interpret this as a reference to an extension type. # Try to interpret this as a reference to an extension type.
# Returns the extension type, or None. # Returns the extension type, or None.
entry = self.entry
if not entry:
entry = env.lookup(self.name) entry = env.lookup(self.name)
if entry and entry.is_type and entry.type.is_extension_type: if entry and entry.is_type and entry.type.is_extension_type:
return entry.type return entry.type
...@@ -849,6 +862,7 @@ class NameNode(AtomicExprNode): ...@@ -849,6 +862,7 @@ class NameNode(AtomicExprNode):
return None return None
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
if not self.entry:
self.entry = env.lookup_here(self.name) self.entry = env.lookup_here(self.name)
if not self.entry: if not self.entry:
self.entry = env.declare_var(self.name, py_object_type, self.pos) self.entry = env.declare_var(self.name, py_object_type, self.pos)
...@@ -860,6 +874,7 @@ class NameNode(AtomicExprNode): ...@@ -860,6 +874,7 @@ class NameNode(AtomicExprNode):
env.use_utility_code(type_cache_invalidation_code) env.use_utility_code(type_cache_invalidation_code)
def analyse_types(self, env): def analyse_types(self, env):
if self.entry is None:
self.entry = env.lookup(self.name) self.entry = env.lookup(self.name)
if not self.entry: if not self.entry:
self.entry = env.declare_builtin(self.name, self.pos) self.entry = env.declare_builtin(self.name, self.pos)
...@@ -875,6 +890,12 @@ class NameNode(AtomicExprNode): ...@@ -875,6 +890,12 @@ class NameNode(AtomicExprNode):
% self.name) % self.name)
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.entry.used = 1 self.entry.used = 1
if self.entry.type.is_buffer:
# Have an rhs temp just in case. All rhs I could
# think of had a single symbol result_code but better
# safe than sorry. Feel free to change this.
import Buffer
Buffer.used_buffer_aux_vars(self.entry)
def analyse_rvalue_entry(self, env): def analyse_rvalue_entry(self, env):
#print "NameNode.analyse_rvalue_entry:", self.name ### #print "NameNode.analyse_rvalue_entry:", self.name ###
...@@ -1018,12 +1039,23 @@ class NameNode(AtomicExprNode): ...@@ -1018,12 +1039,23 @@ class NameNode(AtomicExprNode):
rhs.generate_disposal_code(code) rhs.generate_disposal_code(code)
else: else:
if self.type.is_buffer:
# Generate code for doing the buffer release/acquisition.
# This might raise an exception in which case the assignment (done
# below) will not happen.
#
# The reason this is not in a typetest-like node is because the
# variables that the acquired buffer info is stored to is allocated
# per entry and coupled with it.
self.generate_acquire_buffer(rhs, code)
if self.type.is_pyobject: if self.type.is_pyobject:
rhs.make_owned_reference(code)
#print "NameNode.generate_assignment_code: to", self.name ### #print "NameNode.generate_assignment_code: to", self.name ###
#print "...from", rhs ### #print "...from", rhs ###
#print "...LHS type", self.type, "ctype", self.ctype() ### #print "...LHS type", self.type, "ctype", self.ctype() ###
#print "...RHS type", rhs.type, "ctype", rhs.ctype() ### #print "...RHS type", rhs.type, "ctype", rhs.ctype() ###
rhs.make_owned_reference(code) if not self.skip_assignment_decref:
if entry.is_local and not Options.init_local_none: if entry.is_local and not Options.init_local_none:
initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos) initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos)
if initalized is True: if initalized is True:
...@@ -1038,6 +1070,19 @@ class NameNode(AtomicExprNode): ...@@ -1038,6 +1070,19 @@ class NameNode(AtomicExprNode):
print("...generating post-assignment code for %s" % rhs) print("...generating post-assignment code for %s" % rhs)
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
def generate_acquire_buffer(self, rhs, code):
rhstmp = code.func.allocate_temp(self.entry.type)
buffer_aux = self.entry.buffer_aux
bufstruct = buffer_aux.buffer_info_var.cname
code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype())))
import Buffer
Buffer.put_assign_to_buffer(self.result_code, rhstmp, buffer_aux, self.entry.type,
is_initialized=not self.skip_assignment_decref,
pos=self.pos, code=code)
code.putln("%s = 0;" % rhstmp)
code.func.release_temp(rhstmp)
def generate_deletion_code(self, code): def generate_deletion_code(self, code):
if self.entry is None: if self.entry is None:
return # There was an error earlier return # There was an error earlier
...@@ -1297,39 +1342,45 @@ class IndexNode(ExprNode): ...@@ -1297,39 +1342,45 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0): def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
self.is_buffer_access = False self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
if self.base.type.buffer_options is not None: if self.base.type.is_buffer:
assert isinstance(self.base, NameNode)
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
indices = self.index.args indices = self.index.args
else: else:
indices = [self.index] indices = [self.index]
if len(indices) == self.base.type.buffer_options.ndim: if len(indices) == self.base.type.ndim:
buffer_access = True buffer_access = True
skip_child_analysis = True skip_child_analysis = True
for x in indices: for x in indices:
x.analyse_types(env) x.analyse_types(env)
if not x.type.is_int: if not x.type.is_int:
buffer_access = False buffer_access = False
if buffer_access: if buffer_access:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices self.indices = indices
self.index = None self.index = None
self.type = self.base.type.buffer_options.dtype self.type = self.base.type.dtype
self.is_temp = 1
self.is_buffer_access = True self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode if getting:
# parsed in a saner way and only construct the tuple if # we only need a temp because result_code isn't refactored to
# needed. # generation time, but this seems an ok shortcut to take
self.is_temp = True
if not buffer_access: if setting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.base.entry.buffer_aux.writable_needed = True
else:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis) self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis: elif not skip_child_analysis:
...@@ -1374,7 +1425,7 @@ class IndexNode(ExprNode): ...@@ -1374,7 +1425,7 @@ class IndexNode(ExprNode):
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access: if self.is_buffer_access:
return "<not needed>" return "<not used>"
else: else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result_code, self.index.result_code) self.base.result_code, self.index.result_code)
...@@ -1393,17 +1444,22 @@ class IndexNode(ExprNode): ...@@ -1393,17 +1444,22 @@ class IndexNode(ExprNode):
if self.index is not None: if self.index is not None:
self.index.generate_evaluation_code(code) self.index.generate_evaluation_code(code)
else: else:
for i in self.indices: i.generate_evaluation_code(code) for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
if self.index is not None: if self.index is not None:
self.index.generate_disposal_code(code) self.index.generate_disposal_code(code)
else: else:
for i in self.indices: i.generate_disposal_code(code) for i in self.indices:
i.generate_disposal_code(code)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.type.is_pyobject: if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode))
elif self.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
function = "__Pyx_GetItemInt" function = "__Pyx_GetItemInt"
index_code = self.index.result_code index_code = self.index.result_code
...@@ -1439,7 +1495,10 @@ class IndexNode(ExprNode): ...@@ -1439,7 +1495,10 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject: if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code))
elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
else: else:
code.putln( code.putln(
...@@ -1465,6 +1524,19 @@ class IndexNode(ExprNode): ...@@ -1465,6 +1524,19 @@ class IndexNode(ExprNode):
code.error_goto(self.pos))) code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code):
# Assign indices to temps
index_temps = [code.func.allocate_temp(i.type) for i in self.indices]
for temp, index in zip(index_temps, self.indices):
code.putln("%s = %s;" % (temp, index.result_code))
# Generate buffer access code using these temps
import Buffer
valuecode = Buffer.put_access(entry=self.base.entry,
index_signeds=[i.type.signed for i in self.indices],
index_cnames=index_temps,
pos=self.pos, code=code)
return valuecode
class SliceIndexNode(ExprNode): class SliceIndexNode(ExprNode):
# 2-element slice indexing # 2-element slice indexing
......
...@@ -47,6 +47,13 @@ class Context: ...@@ -47,6 +47,13 @@ class Context:
self.include_directories = include_directories self.include_directories = include_directories
self.future_directives = set() self.future_directives = set()
import os.path
standard_include_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), '..', 'Includes'))
self.include_directories = include_directories + [standard_include_path]
def find_module(self, module_name, def find_module(self, module_name,
relative_to = None, pos = None, need_pxd = 1): relative_to = None, pos = None, need_pxd = 1):
# Finds and returns the module scope corresponding to # Finds and returns the module scope corresponding to
...@@ -323,14 +330,18 @@ class Context: ...@@ -323,14 +330,18 @@ class Context:
verbose_flag = options.show_version, verbose_flag = options.show_version,
cplus = options.cplus) cplus = options.cplus)
def nonfatal_error(self, exc):
return Errors.report_error(exc)
def run_pipeline(self, pipeline, source): def run_pipeline(self, pipeline, source):
errors_occurred = False errors_occurred = False
data = source data = source
try: try:
for phase in pipeline: for phase in pipeline:
data = phase(data) data = phase(data)
except CompileError: except CompileError, err:
errors_occurred = True errors_occurred = True
Errors.report_error(err)
return (errors_occurred, data) return (errors_occurred, data)
def create_parse(context): def create_parse(context):
...@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result): ...@@ -358,22 +369,25 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from Buffer import BufferTransform from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [ return [
create_parse(context), create_parse(context),
# printit,
NormalizeTree(context), NormalizeTree(context),
PostParse(context), PostParse(context),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(context), WithTransform(context),
DecoratorTransform(context), DecoratorTransform(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
BufferTransform(context), # BufferTransform(context),
SwitchTransform(), SwitchTransform(),
OptimizeRefcounting(context),
# CreateClosureClasses(context), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
...@@ -607,6 +621,7 @@ def main(command_line = 0): ...@@ -607,6 +621,7 @@ def main(command_line = 0):
else: else:
options = CompilationOptions(default_options) options = CompilationOptions(default_options)
sources = args sources = args
if options.show_version: if options.show_version:
sys.stderr.write("Cython version %s\n" % Version.version) sys.stderr.write("Cython version %s\n" % Version.version)
if options.working_path!="": if options.working_path!="":
......
...@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_extension_types = h_entries(env.c_class_entries) h_extension_types = h_entries(env.c_class_entries)
if h_types or h_vars or h_funcs or h_extension_types: if h_types or h_vars or h_funcs or h_extension_types:
result.h_file = replace_suffix(result.c_file, ".h") result.h_file = replace_suffix(result.c_file, ".h")
h_code = Code.CCodeWriter(open_new_file(result.h_file)) h_code = Code.CCodeWriter()
if options.generate_pxi: if options.generate_pxi:
result.i_file = replace_suffix(result.c_file, ".pxi") result.i_file = replace_suffix(result.c_file, ".pxi")
i_code = Code.PyrexCodeWriter(result.i_file) i_code = Code.PyrexCodeWriter(result.i_file)
...@@ -130,6 +130,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -130,6 +130,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("") h_code.putln("")
h_code.putln("#endif") h_code.putln("#endif")
h_code.copyto(open_new_file(result.h_file))
def generate_public_declaration(self, entry, h_code, i_code): def generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % ( h_code.putln("%s %s;" % (
Naming.extern_c_macro, Naming.extern_c_macro,
...@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_api_extension_types = 1 has_api_extension_types = 1
if api_funcs or has_api_extension_types: if api_funcs or has_api_extension_types:
result.api_file = replace_suffix(result.c_file, "_api.h") result.api_file = replace_suffix(result.c_file, "_api.h")
h_code = Code.CCodeWriter(open_new_file(result.api_file)) h_code = Code.CCodeWriter()
name = self.api_name(env) name = self.api_name(env)
guard = Naming.api_guard_prefix + name guard = Naming.api_guard_prefix + name
h_code.put_h_guard(guard) h_code.put_h_guard(guard)
...@@ -210,6 +212,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -210,6 +212,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code.putln("") h_code.putln("")
h_code.putln("#endif") h_code.putln("#endif")
h_code.copyto(open_new_file(result.api_file))
def generate_cclass_header_code(self, type, h_code): def generate_cclass_header_code(self, type, h_code):
h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % ( h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % (
Naming.extern_c_macro, Naming.extern_c_macro,
...@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -232,12 +236,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_code(self, env, options, result): def generate_c_code(self, env, options, result):
modules = self.referenced_modules modules = self.referenced_modules
if Options.annotate or options.annotate: if Options.annotate or options.annotate:
code = Annotate.AnnotationCCodeWriter(StringIO()) code = Annotate.AnnotationCCodeWriter()
else: else:
code = Code.CCodeWriter(StringIO()) code = Code.CCodeWriter()
code.h = Code.CCodeWriter(StringIO()) h_code = code.insertion_point()
code.init_labels() self.generate_module_preamble(env, modules, h_code)
self.generate_module_preamble(env, modules, code.h)
code.putln("") code.putln("")
code.putln("/* Implementation of %s */" % env.qualified_name) code.putln("/* Implementation of %s */" % env.qualified_name)
...@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -259,15 +262,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.mark_pos(None) code.mark_pos(None)
self.generate_module_cleanup_func(env, code) self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code) self.generate_filename_table(code)
self.generate_utility_functions(env, code) self.generate_utility_functions(env, code, h_code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h) self.generate_declarations_for_modules(env, modules, h_code)
h_code.write('\n')
f = open_new_file(result.c_file) f = open_new_file(result.c_file)
f.write(code.h.f.getvalue()) code.copyto(f)
f.write("\n")
f.write(code.f.getvalue())
f.close() f.close()
result.c_file_generated = 1 result.c_file_generated = 1
if Options.annotate or options.annotate: if Options.annotate or options.annotate:
...@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -441,8 +442,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("") code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif") code.putln("#endif")
code.put(builtin_module_name_utility_code[0]) code.put(builtin_module_name_utility_code[0])
...@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1485,6 +1484,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("0") code.putln("0")
code.putln("};") code.putln("};")
code.putln() code.putln()
code.enter_cfunc_scope() # as we need labels
code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set) code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set)
code.putln("char** type_name = %s_type_names;" % Naming.import_star) code.putln("char** type_name = %s_type_names;" % Naming.import_star)
code.putln("while (*type_name) {") code.putln("while (*type_name) {")
...@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1535,8 +1535,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("return -1;") code.putln("return -1;")
code.putln("}") code.putln("}")
code.putln(import_star_utility_code) code.putln(import_star_utility_code)
code.exit_cfunc_scope() # done with labels
def generate_module_init_func(self, imported_modules, env, code): def generate_module_init_func(self, imported_modules, env, code):
code.enter_cfunc_scope()
code.putln("") code.putln("")
header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name
header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name
...@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1548,8 +1550,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(header3) code.putln(header3)
code.putln("#endif") code.putln("#endif")
code.putln("{") code.putln("{")
tempdecl_code = code.insertion_point()
code.put_var_declarations(env.temp_entries)
code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos))); code.putln("%s = PyTuple_New(0); %s" % (Naming.empty_tuple, code.error_goto_if_null(Naming.empty_tuple, self.pos)));
code.putln("/*--- Libary function declarations ---*/") code.putln("/*--- Libary function declarations ---*/")
...@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1590,6 +1591,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
code.mark_pos(None) code.mark_pos(None)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if Options.generate_cleanup_code: if Options.generate_cleanup_code:
...@@ -1610,6 +1612,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1610,6 +1612,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("#endif") code.putln("#endif")
code.putln('}') code.putln('}')
tempdecl_code.put_var_declarations(env.temp_entries)
tempdecl_code.put_temp_declarations(code.func)
code.exit_cfunc_scope()
def generate_module_cleanup_func(self, env, code): def generate_module_cleanup_func(self, env, code):
if not Options.generate_cleanup_code: if not Options.generate_cleanup_code:
return return
...@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1951,7 +1958,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"%s = &%s;" % ( "%s = &%s;" % (
type.typeptr_cname, type.typeobj_cname)) type.typeptr_cname, type.typeobj_cname))
def generate_utility_functions(self, env, code): def generate_utility_functions(self, env, code, h_code):
code.putln("") code.putln("")
code.putln("/* Runtime support code */") code.putln("/* Runtime support code */")
code.putln("") code.putln("")
...@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1960,94 +1967,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
(Naming.filetable_cname, Naming.filenames_cname)) (Naming.filetable_cname, Naming.filenames_cname))
code.putln("}") code.putln("}")
for utility_code in env.utility_code_used: for utility_code in env.utility_code_used:
code.h.put(utility_code[0]) h_code.put(utility_code[0])
code.put(utility_code[1]) code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("") code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# For now, hard-code numpy imported as "numpy"
types = []
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
pyrex_prefix = "__pyx_" pyrex_prefix = "__pyx_"
codewriter_temp_prefix = pyrex_prefix + "t_"
temp_prefix = u"__cyt_" temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_" builtin_prefix = pyrex_prefix + "builtin_"
...@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_" ...@@ -31,6 +34,10 @@ prop_set_prefix = pyrex_prefix + "setprop_"
type_prefix = pyrex_prefix + "t_" type_prefix = pyrex_prefix + "t_"
typeobj_prefix = pyrex_prefix + "type_" typeobj_prefix = pyrex_prefix + "type_"
var_prefix = pyrex_prefix + "v_" var_prefix = pyrex_prefix + "v_"
bufstruct_prefix = pyrex_prefix + "bstruct_"
bufstride_prefix = pyrex_prefix + "bstride_"
bufshape_prefix = pyrex_prefix + "bshape_"
bufsuboffset_prefix = pyrex_prefix + "boffset_"
vtable_prefix = pyrex_prefix + "vtable_" vtable_prefix = pyrex_prefix + "vtable_"
vtabptr_prefix = pyrex_prefix + "vtabptr_" vtabptr_prefix = pyrex_prefix + "vtabptr_"
vtabstruct_prefix = pyrex_prefix + "vtabstruct_" vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
......
...@@ -73,6 +73,7 @@ class Node(object): ...@@ -73,6 +73,7 @@ class Node(object):
is_name = 0 is_name = 0
is_literal = 0 is_literal = 0
temps = None
# All descandants should set child_attrs to a list of the attributes # All descandants should set child_attrs to a list of the attributes
# containing nodes considered "children" in the tree. Each such attribute # containing nodes considered "children" in the tree. Each such attribute
...@@ -102,7 +103,7 @@ class Node(object): ...@@ -102,7 +103,7 @@ class Node(object):
for attrname in result.child_attrs: for attrname in result.child_attrs:
value = getattr(result, attrname) value = getattr(result, attrname)
if isinstance(value, list): if isinstance(value, list):
setattr(result, attrname, value) setattr(result, attrname, [x for x in value])
return result return result
...@@ -252,6 +253,11 @@ class StatListNode(Node): ...@@ -252,6 +253,11 @@ class StatListNode(Node):
child_attrs = ["stats"] child_attrs = ["stats"]
def create_analysed(pos, env, *args, **kw):
node = StatListNode(pos, *args, **kw)
return node # No node-specific analysis necesarry
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env): def analyse_control_flow(self, env):
for stat in self.stats: for stat in self.stats:
stat.analyse_control_flow(env) stat.analyse_control_flow(env)
...@@ -344,12 +350,6 @@ class CDeclaratorNode(Node): ...@@ -344,12 +350,6 @@ class CDeclaratorNode(Node):
calling_convention = "" calling_convention = ""
def analyse_expressions(self, env):
pass
def generate_execution_code(self, env):
pass
class CNameDeclaratorNode(CDeclaratorNode): class CNameDeclaratorNode(CDeclaratorNode):
# name string The Pyrex name being declared # name string The Pyrex name being declared
...@@ -368,29 +368,6 @@ class CNameDeclaratorNode(CDeclaratorNode): ...@@ -368,29 +368,6 @@ class CNameDeclaratorNode(CDeclaratorNode):
self.type = base_type self.type = base_type
return self, base_type return self, base_type
def analyse_expressions(self, env):
self.entry = env.lookup(self.name)
if self.default is not None:
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'initalized'), True)
env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'source'), 'assignment')
self.entry.used = 1
if self.type.is_pyobject:
self.entry.init_to_none = False
self.entry.init = 0
self.default.analyse_types(env)
self.default = self.default.coerce_to(self.type, env)
self.default.allocate_temps(env)
self.default.release_temp(env)
def generate_execution_code(self, code):
if self.default is not None:
self.default.generate_evaluation_code(code)
if self.type.is_pyobject:
self.default.make_owned_reference(code)
code.putln('%s = %s;' % (self.entry.cname, self.default.result_as(self.entry.type)))
self.default.generate_post_assignment_code(code)
code.putln()
class CPtrDeclaratorNode(CDeclaratorNode): class CPtrDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode # base CDeclaratorNode
...@@ -403,12 +380,6 @@ class CPtrDeclaratorNode(CDeclaratorNode): ...@@ -403,12 +380,6 @@ class CPtrDeclaratorNode(CDeclaratorNode):
ptr_type = PyrexTypes.c_ptr_type(base_type) ptr_type = PyrexTypes.c_ptr_type(base_type)
return self.base.analyse(ptr_type, env, nonempty = nonempty) return self.base.analyse(ptr_type, env, nonempty = nonempty)
def analyse_expressions(self, env):
self.base.analyse_expressions(env)
def generate_execution_code(self, env):
self.base.generate_execution_code(env)
class CArrayDeclaratorNode(CDeclaratorNode): class CArrayDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode # base CDeclaratorNode
# dimension ExprNode # dimension ExprNode
...@@ -629,8 +600,8 @@ class CBufferAccessTypeNode(Node): ...@@ -629,8 +600,8 @@ class CBufferAccessTypeNode(Node):
def analyse(self, env): def analyse(self, env):
base_type = self.base_type_node.analyse(env) base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env) dtype = self.dtype_node.analyse(env)
options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim) self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
self.type = PyrexTypes.create_buffer_type(base_type, options) mode=self.mode)
return self.type return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
...@@ -685,14 +656,6 @@ class CVarDefNode(StatNode): ...@@ -685,14 +656,6 @@ class CVarDefNode(StatNode):
dest_scope.declare_var(name, type, declarator.pos, dest_scope.declare_var(name, type, declarator.pos,
cname = cname, visibility = self.visibility, is_cdef = 1) cname = cname, visibility = self.visibility, is_cdef = 1)
def analyse_expressions(self, env):
for declarator in self.declarators:
declarator.analyse_expressions(env)
def generate_execution_code(self, code):
for declarator in self.declarators:
declarator.generate_execution_code(code)
class CStructOrUnionDefNode(StatNode): class CStructOrUnionDefNode(StatNode):
# name string # name string
...@@ -866,9 +829,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -866,9 +829,14 @@ class FuncDefNode(StatNode, BlockNode):
return lenv return lenv
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code, transforms):
# Generate C code for header and body of function import Buffer
code.init_labels()
lenv = self.local_scope lenv = self.local_scope
# Generate C code for header and body of function
code.enter_cfunc_scope()
code.return_from_error_cleanup_label = code.new_label()
# ----- Top-level constants used by this function # ----- Top-level constants used by this function
code.mark_pos(self.pos) code.mark_pos(self.pos)
self.generate_interned_num_decls(lenv, code) self.generate_interned_num_decls(lenv, code)
...@@ -899,7 +867,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -899,7 +867,7 @@ class FuncDefNode(StatNode, BlockNode):
(self.return_type.declaration_code( (self.return_type.declaration_code(
Naming.retval_cname), Naming.retval_cname),
init)) init))
code.put_var_declarations(lenv.temp_entries) tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code) self.generate_keyword_list(code)
# ----- Extern library function declarations # ----- Extern library function declarations
lenv.generate_library_function_declarations(code) lenv.generate_library_function_declarations(code)
...@@ -914,6 +882,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -914,6 +882,8 @@ class FuncDefNode(StatNode, BlockNode):
for entry in lenv.arg_entries: for entry in lenv.arg_entries:
if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg': if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
code.put_var_incref(entry) code.put_var_incref(entry)
if entry.type.is_buffer:
Buffer.put_acquire_arg_buffer(entry, code, self.pos)
# ----- Initialise local variables # ----- Initialise local variables
for entry in lenv.var_entries: for entry in lenv.var_entries:
if entry.type.is_pyobject and entry.init_to_none and entry.used: if entry.type.is_pyobject and entry.init_to_none and entry.used:
...@@ -934,12 +904,23 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -934,12 +904,23 @@ class FuncDefNode(StatNode, BlockNode):
val = self.return_type.default_value val = self.return_type.default_value
if val: if val:
code.putln("%s = %s;" % (Naming.retval_cname, val)) code.putln("%s = %s;" % (Naming.retval_cname, val))
#code.putln("goto %s;" % code.return_label)
# ----- Error cleanup # ----- Error cleanup
if code.error_label in code.labels_used: if code.error_label in code.labels_used:
code.put_goto(code.return_label) code.put_goto(code.return_label)
code.put_label(code.error_label) code.put_label(code.error_label)
code.put_var_xdecrefs(lenv.temp_entries) code.put_var_xdecrefs(lenv.temp_entries)
# Clean up buffers -- this calls a Python function
# so need to save and restore error state
buffers_present = len(lenv.buffer_entries) > 0
if buffers_present:
code.putln("{ PyObject *__pyx_type, *__pyx_value, *__pyx_tb;")
code.putln("PyErr_Fetch(&__pyx_type, &__pyx_value, &__pyx_tb);")
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
#code.putln("%s = 0;" % entry.cname)
code.putln("PyErr_Restore(__pyx_type, __pyx_value, __pyx_tb);}")
err_val = self.error_value() err_val = self.error_value()
exc_check = self.caller_will_check_exceptions() exc_check = self.caller_will_check_exceptions()
if err_val is not None or exc_check: if err_val is not None or exc_check:
...@@ -957,8 +938,18 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -957,8 +938,18 @@ class FuncDefNode(StatNode, BlockNode):
"%s = %s;" % ( "%s = %s;" % (
Naming.retval_cname, Naming.retval_cname,
err_val)) err_val))
# ----- Return cleanup if buffers_present:
# Else, non-error return will be an empty clause
code.put_goto(code.return_from_error_cleanup_label)
# ----- Non-error return cleanup
# PS! If adding something here, modify the conditions for the
# goto statement in error cleanup above
code.put_label(code.return_label) code.put_label(code.return_label)
for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
# ----- Return cleanup for both error and no-error return
code.put_label(code.return_from_error_cleanup_label)
if not Options.init_local_none: if not Options.init_local_none:
for entry in lenv.var_entries: for entry in lenv.var_entries:
if lenv.control_flow.get_state((entry.name, 'initalized')) is not True: if lenv.control_flow.get_state((entry.name, 'initalized')) is not True:
...@@ -976,7 +967,11 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -976,7 +967,11 @@ class FuncDefNode(StatNode, BlockNode):
if not self.return_type.is_void: if not self.return_type.is_void:
code.putln("return %s;" % Naming.retval_cname) code.putln("return %s;" % Naming.retval_cname)
code.putln("}") code.putln("}")
# ----- Go back and insert temp variable declarations
tempvardecl_code.put_var_declarations(lenv.temp_entries)
tempvardecl_code.put_temp_declarations(code.func)
# ----- Python version # ----- Python version
code.exit_cfunc_scope()
if self.py_func: if self.py_func:
self.py_func.generate_function_definitions(env, code, transforms) self.py_func.generate_function_definitions(env, code, transforms)
self.generate_optarg_wrapper_function(env, code) self.generate_optarg_wrapper_function(env, code)
...@@ -2231,8 +2226,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -2231,8 +2226,10 @@ class SingleAssignmentNode(AssignmentNode):
# #
# lhs ExprNode Left hand side # lhs ExprNode Left hand side
# rhs ExprNode Right hand side # rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs?
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
first = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
...@@ -3344,6 +3341,7 @@ class TryExceptStatNode(StatNode): ...@@ -3344,6 +3341,7 @@ class TryExceptStatNode(StatNode):
self.gil_check(env) self.gil_check(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
self.cleanup_list = env.free_temp_entries[:] self.cleanup_list = env.free_temp_entries[:]
for except_clause in self.except_clauses: for except_clause in self.except_clauses:
...@@ -3524,6 +3522,12 @@ class TryFinallyStatNode(StatNode): ...@@ -3524,6 +3522,12 @@ class TryFinallyStatNode(StatNode):
# continue in the try block, since we have no problem # continue in the try block, since we have no problem
# handling it. # handling it.
def create_analysed(pos, env, body, finally_clause):
node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause)
node.cleanup_list = []
return node
create_analysed = staticmethod(create_analysed)
def analyse_control_flow(self, env): def analyse_control_flow(self, env):
env.start_branching(self.pos) env.start_branching(self.pos)
self.body.analyse_control_flow(env) self.body.analyse_control_flow(env)
......
...@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform): ...@@ -134,3 +134,16 @@ class FlattenInListTransform(Visitor.VisitorTransform):
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
return node return node
class OptimizeRefcounting(Visitor.CythonTransform):
def visit_SingleAssignmentNode(self, node):
if node.first:
lhs = node.lhs
if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
# Have variable initialized to 0 rather than None
lhs.entry.init_to_none = False
lhs.entry.init = 0
# Set a flag in NameNode to skip the decref
lhs.skip_assignment_decref = True
return node
...@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied' ...@@ -82,7 +82,9 @@ ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing' ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer' ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative' ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
class PostParse(CythonTransform): class PostParse(CythonTransform):
""" """
Basic interpretation of the parse tree, as well as validity Basic interpretation of the parse tree, as well as validity
...@@ -91,6 +93,9 @@ class PostParse(CythonTransform): ...@@ -91,6 +93,9 @@ class PostParse(CythonTransform):
as such). as such).
Specifically: Specifically:
- Default values to cdef assignments are turned into single
assignments following the declaration (everywhere but in class
bodies, where they raise a compile error)
- CBufferAccessTypeNode has its options interpreted: - CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute, Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and any "ndim" keyword argument goes into the "ndim" attribute and
...@@ -101,12 +106,65 @@ class PostParse(CythonTransform): ...@@ -101,12 +106,65 @@ class PostParse(CythonTransform):
if a more pure Abstract Syntax Tree is wanted. if a more pure Abstract Syntax Tree is wanted.
""" """
buffer_options = ("dtype", "ndim") # ordered! # Track our context.
scope_type = None # can be either of 'module', 'function', 'class'
def visit_ModuleNode(self, node):
self.scope_type = 'module'
self.visitchildren(node)
return node
def visit_ClassDefNode(self, node):
prev = self.scope_type
self.scope_type = 'class'
self.visitchildren(node)
self.scope_type = prev
return node
def visit_FuncDefNode(self, node):
prev = self.scope_type
self.scope_type = 'function'
self.visitchildren(node)
self.scope_type = prev
return node
# cdef variables
def visit_CVarDefNode(self, node):
# This assumes only plain names and pointers are assignable on
# declaration. Also, it makes use of the fact that a cdef decl
# must appear before the first use, so we don't have to deal with
# "i = 3; cdef int i = i" and can simply move the nodes around.
try:
self.visitchildren(node)
except PostParseError, e:
# An error in a cdef clause is ok, simply remove the declaration
# and try to move on to report more errors
self.context.nonfatal_error(e)
return None
stats = [node]
for decl in node.declarators:
while isinstance(decl, CPtrDeclaratorNode):
decl = decl.base
if isinstance(decl, CNameDeclaratorNode):
if decl.default is not None:
if self.scope_type == 'class':
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
stats.append(SingleAssignmentNode(node.pos,
lhs=NameNode(node.pos, name=decl.name),
rhs=decl.default, first=True))
decl.default = None
return stats
# buffer access
buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node): def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
options = {} options = {}
# Fetch positional arguments # Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options): if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY) raise PostParseError(node.pos, ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options): for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name) name = str(unicode_name)
options[name] = arg options[name] = arg
...@@ -114,21 +172,19 @@ class PostParse(CythonTransform): ...@@ -114,21 +172,19 @@ class PostParse(CythonTransform):
for item in node.keyword_args.key_value_pairs: for item in node.keyword_args.key_value_pairs:
name = str(item.key.value) name = str(item.key.value)
if not name in self.buffer_options: if not name in self.buffer_options:
raise PostParseError(item.key.pos, raise PostParseError(item.key.pos, ERR_BUF_OPTION_UNKNOWN % name)
ERR_BUF_UNKNOWN % name)
if name in options.keys(): if name in options.keys():
raise PostParseError(item.key.pos, raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
ERR_BUF_DUP % key)
options[name] = item.value options[name] = item.value
provided = options.keys()
# get dtype # get dtype
dtype = options.get("dtype") dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype') if dtype is None:
raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype_node = dtype node.dtype_node = dtype
# get ndim # get ndim
if "ndim" in provided: if "ndim" in options:
ndimnode = options["ndim"] ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode): if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser, # Compile-time values (DEF) are currently resolved by the parser,
...@@ -141,6 +197,17 @@ class PostParse(CythonTransform): ...@@ -141,6 +197,17 @@ class PostParse(CythonTransform):
else: else:
node.ndim = 1 node.ndim = 1
if "mode" in options:
modenode = options["mode"]
if not isinstance(modenode, StringNode):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
mode = modenode.value
if not mode in ('full', 'strided'):
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
node.mode = mode
else:
node.mode = 'full'
# We're done with the parse tree args # We're done with the parse tree args
node.positional_args = None node.positional_args = None
node.keyword_args = None node.keyword_args = None
...@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform): ...@@ -253,6 +320,13 @@ class AnalyseDeclarationsTransform(CythonTransform):
self.env_stack.pop() self.env_stack.pop()
return node return node
# Some nodes are no longer needed after declaration
# analysis and can be dropped. The analysis was performed
# on these nodes in a seperate recursive process from the
# enclosing function or module, so we can simply drop them.
def visit_CVarDefNode(self, node):
return None
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
......
...@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1620,9 +1620,7 @@ def p_c_simple_base_type(s, self_flag, nonempty):
# Treat trailing [] on type as buffer access # Treat trailing [] on type as buffer access
if 0: # s.sy == '[': if not is_basic and s.sy == '[':
if is_basic:
s.error("Basic C types do not support buffer access")
return p_buffer_access(s, type_node) return p_buffer_access(s, type_node)
else: else:
return type_node return type_node
......
...@@ -6,21 +6,6 @@ from Cython import Utils ...@@ -6,21 +6,6 @@ from Cython import Utils
import Naming import Naming
import copy import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType: class BaseType:
# #
...@@ -57,6 +42,7 @@ class PyrexType(BaseType): ...@@ -57,6 +42,7 @@ class PyrexType(BaseType):
# is_unicode boolean Is a UTF-8 encoded C char * type # is_unicode boolean Is a UTF-8 encoded C char * type
# is_returncode boolean Is used only to signal exceptions # is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type # is_error boolean Is the dummy error type
# is_buffer boolean Is buffer access type
# has_attributes boolean Has C dot-selectable attributes # has_attributes boolean Has C dot-selectable attributes
# default_value string Initial value # default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple # parsetuple_format string Format char for PyArg_ParseTuple
...@@ -106,11 +92,11 @@ class PyrexType(BaseType): ...@@ -106,11 +92,11 @@ class PyrexType(BaseType):
is_unicode = 0 is_unicode = 0
is_returncode = 0 is_returncode = 0
is_error = 0 is_error = 0
is_buffer = 0
has_attributes = 0 has_attributes = 0
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -202,6 +188,37 @@ class CTypedefType(BaseType): ...@@ -202,6 +188,37 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class BufferType(BaseType):
#
# Delegates most attribute
# lookups to the base type. ANYTHING NOT DEFINED
# HERE IS DELEGATED!
# dtype PyrexType
# ndim int
# mode str
# is_buffer boolean
# writable boolean
is_buffer = 1
writable = True
def __init__(self, base, dtype, ndim, mode):
self.base = base
self.dtype = dtype
self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
self.mode = mode
def as_argument_type(self):
return self
def __getattr__(self, name):
return getattr(self.base, name)
def __repr__(self):
return "<BufferType %r>" % self.base
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
......
...@@ -15,11 +15,14 @@ from TypeSlots import \ ...@@ -15,11 +15,14 @@ from TypeSlots import \
get_special_method_signature, get_property_accessor_signature get_special_method_signature, get_property_accessor_signature
import ControlFlow import ControlFlow
import __builtin__ import __builtin__
from sets import Set as set
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux: class BufferAux:
writable_needed = False
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker): def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var self.buffer_info_var = buffer_info_var
self.stridevars = stridevars self.stridevars = stridevars
...@@ -216,6 +219,7 @@ class Scope: ...@@ -216,6 +219,7 @@ class Scope:
self.num_to_entry = {} self.num_to_entry = {}
self.obj_to_entry = {} self.obj_to_entry = {}
self.pystring_entries = [] self.pystring_entries = []
self.buffer_entries = []
self.control_flow = ControlFlow.LinearControlFlow() self.control_flow = ControlFlow.LinearControlFlow()
def start_branching(self, pos): def start_branching(self, pos):
...@@ -616,8 +620,11 @@ class Scope: ...@@ -616,8 +620,11 @@ class Scope:
return [entry for entry in self.temp_entries return [entry for entry in self.temp_entries
if entry not in self.free_temp_entries] if entry not in self.free_temp_entries]
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code) self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code): def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used. # Generate extern decls for C library funcs used.
...@@ -738,6 +745,7 @@ class ModuleScope(Scope): ...@@ -738,6 +745,7 @@ class ModuleScope(Scope):
# doc_cname string C name of module doc string # doc_cname string C name of module doc string
# const_counter integer Counter for naming constants # const_counter integer Counter for naming constants
# utility_code_used [string] Utility code to be included # utility_code_used [string] Utility code to be included
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries # default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included # python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included # include_files [string] Other C headers to be included
...@@ -772,6 +780,7 @@ class ModuleScope(Scope): ...@@ -772,6 +780,7 @@ class ModuleScope(Scope):
self.doc_cname = Naming.moddoc_cname self.doc_cname = Naming.moddoc_cname
self.const_counter = 1 self.const_counter = 1
self.utility_code_used = [] self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = [] self.default_entries = []
self.module_entries = {} self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"] self.python_include_files = ["Python.h", "structmember.h"]
...@@ -930,13 +939,25 @@ class ModuleScope(Scope): ...@@ -930,13 +939,25 @@ class ModuleScope(Scope):
self.const_counter = n + 1 self.const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n) return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code): def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included, # Add string to list of utility code to be included,
# if not already there (tested using 'is'). # if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used: for old_code in self.utility_code_used:
if old_code is new_code: if old_code is new_code:
return return
self.utility_code_used.append(new_code) self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0, def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None, module_name = None, base_type = None, objstruct_cname = None,
......
...@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest): ...@@ -47,21 +47,35 @@ class TestBufferParsing(CythonTest):
self.not_parseable("Non-keyword arg following keyword arg", self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x") u"cdef object[foo=1, 2] x")
# See also tests/error/e_bufaccess.pyx and tets/run/bufaccess.pyx
class TestBufferOptions(CythonTest): class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets # Tests the full parsing of the options within the brackets
def parse_opts(self, opts): def nonfatal_error(self, error):
s = u"cdef object[%s] x" % opts # We're passing self as context to transform to trap this
root = self.fragment(s, pipeline=[PostParse(self)]).root self.error = error
buftype = root.stats[0].base_type self.assert_(self.expect_error)
def parse_opts(self, opts, expect_error=False):
s = u"def f():\n cdef object[%s] x" % opts
self.expect_error = expect_error
root = self.fragment(s, pipeline=[NormalizeTree(self), PostParse(self)]).root
if not expect_error:
vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode)) self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name) self.assertEqual(u"object", buftype.base_type_node.name)
return buftype return buftype
else:
self.assert_(len(root.stats[0].body.stats) == 0)
def non_parse(self, expected_err, opts): def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts)) self.parse_opts(opts, expect_error=True)
self.assertEqual(expected_err, e.message_only) # e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, self.error.message_only)
def test_basic(self): def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3") buf = self.parse_opts(u"unsigned short int, 3")
...@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest): ...@@ -86,10 +100,12 @@ class TestBufferOptions(CythonTest):
def test_use_DEF(self): def test_use_DEF(self):
t = self.fragment(u""" t = self.fragment(u"""
DEF ndim = 3 DEF ndim = 3
def f():
cdef object[int, ndim] x cdef object[int, ndim] x
cdef object[ndim=ndim, dtype=int] y cdef object[ndim=ndim, dtype=int] y
""", pipeline=[PostParse(self)]).root """, pipeline=[NormalizeTree(self), PostParse(self)]).root
self.assert_(t.stats[1].base_type.ndim == 3) stats = t.stats[0].body.stats
self.assert_(t.stats[2].base_type.ndim == 3) self.assert_(stats[0].base_type.ndim == 3)
self.assert_(stats[1].base_type.ndim == 3)
# add exotic and impossible combinations as they come along # add exotic and impossible combinations as they come along...
...@@ -6,6 +6,8 @@ import re ...@@ -6,6 +6,8 @@ import re
from cStringIO import StringIO from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
import Symtab
import PyrexTypes
from Visitor import VisitorTransform, temp_name_handle from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node, StatListNode from Nodes import Node, StatListNode
from ExprNodes import NameNode from ExprNodes import NameNode
...@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform): ...@@ -108,11 +110,16 @@ class TemplateTransform(VisitorTransform):
self.substitutions = substitutions self.substitutions = substitutions
tempdict = {} tempdict = {}
for key in temps: for key in temps:
tempdict[key] = temp_name_handle(key) tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
self.temps = tempdict self.temp_key_to_entries = tempdict
self.pos = pos self.pos = pos
return super(TemplateTransform, self).__call__(node) return super(TemplateTransform, self).__call__(node)
def get_pos(self, node):
if self.pos:
return self.pos
else:
return node.pos
def visit_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
...@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform): ...@@ -135,11 +142,11 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node): def visit_NameNode(self, node):
tempname = self.temps.get(node.name) tempentry = self.temp_key_to_entries.get(node.name)
if tempname is not None: if tempentry is not None:
# Replace name with temporary # Replace name with temporary
node.name = tempname return NameNode(self.get_pos(node), name=tempentry)
return self.visit_Node(node) # Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
else: else:
return self.try_substitution(node, node.name) return self.try_substitution(node, node.name)
...@@ -157,6 +164,7 @@ def copy_code_tree(node): ...@@ -157,6 +164,7 @@ def copy_code_tree(node):
INDENT_RE = re.compile(ur"^ *") INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines): def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines" "Strips empty lines and common indentation from the list of strings given in lines"
# TODO: Facilitate textwrap.indent instead
lines = [x for x in lines if x.strip() != u""] lines = [x for x in lines if x.strip() != u""]
minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines]) minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
lines = [x[minindent:] for x in lines] lines = [x[minindent:] for x in lines]
......
# Please see the Python header files (object.h) for docs
cdef extern from "Python.h":
ctypedef void PyObject
ctypedef struct bufferinfo:
void *buf
Py_ssize_t len
Py_ssize_t itemsize
int readonly
int ndim
char *format
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
void *internal
ctypedef bufferinfo Py_buffer
cdef enum:
PyBUF_SIMPLE,
PyBUF_WRITABLE,
PyBUF_WRITEABLE, # backwards compatability
PyBUF_FORMAT,
PyBUF_ND,
PyBUF_STRIDES,
PyBUF_C_CONTIGUOUS,
PyBUF_F_CONTIGUOUS,
PyBUF_ANY_CONTIGUOUS,
PyBUF_INDIRECT,
PyBUF_CONTIG,
PyBUF_CONTIG_RO,
PyBUF_STRIDED,
PyBUF_STRIDED_RO,
PyBUF_RECORDS,
PyBUF_RECORDS_RO,
PyBUF_FULL,
PyBUF_FULL_RO,
PyBUF_READ,
PyBUF_WRITE,
PyBUF_SHADOW
int PyObject_CheckBuffer(PyObject* obj)
int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags)
void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view)
void* PyBuffer_GetPointer(Py_buffer *view, Py_ssize_t *indices)
int PyBuffer_SizeFromFormat(char *) # actually const char
int PyBuffer_ToContiguous(void *buf, Py_buffer *view, Py_ssize_t len, char fort)
int PyBuffer_FromContiguous(Py_buffer *view, void *buf, Py_ssize_t len, char fort)
int PyObject_CopyData(PyObject *dest, PyObject *src)
int PyBuffer_IsContiguous(Py_buffer *view, char fort)
void PyBuffer_FillContiguousStrides(int ndims,
Py_ssize_t *shape,
Py_ssize_t *strides,
int itemsize,
char fort)
int PyBuffer_FillInfo(Py_buffer *view, void *buf,
Py_ssize_t len, int readonly,
int flags)
PyObject* PyObject_Format(PyObject* obj,
PyObject *format_spec)
from cStringIO import StringIO
class StringIOTree(object):
"""
See module docs.
"""
def __init__(self, stream=None):
self.prepended_children = []
if stream is None: stream = StringIO()
self.stream = stream
def getvalue(self):
return ("".join([x.getvalue() for x in self.prepended_children]) +
self.stream.getvalue())
def copyto(self, target):
"""Potentially cheaper than getvalue as no string concatenation
needs to happen."""
for child in self.prepended_children:
child.copyto(target)
target.write(self.stream.getvalue())
def write(self, what):
self.stream.write(what)
def insertion_point(self):
# Save what we have written until now
# (would it be more efficient to check with len(self.stream.getvalue())?
# leaving it out for now)
self.prepended_children.append(StringIOTree(self.stream))
# Construct the new forked object to return
other = StringIOTree()
self.prepended_children.append(other)
self.stream = StringIO()
return other
__doc__ = r"""
Implements a buffer with insertion points. When you know you need to
"get back" to a place and write more later, simply call insertion_point()
at that spot and get a new StringIOTree object that is "left behind".
EXAMPLE:
>>> a = StringIOTree()
>>> a.write('first\n')
>>> b = a.insertion_point()
>>> a.write('third\n')
>>> b.write('second\n')
>>> print a.getvalue()
first
second
third
<BLANKLINE>
>>> c = b.insertion_point()
>>> d = c.insertion_point()
>>> d.write('alpha\n')
>>> b.write('gamma\n')
>>> c.write('beta\n')
>>> print b.getvalue()
second
alpha
beta
gamma
<BLANKLINE>
>>> out = StringIO()
>>> a.copyto(out)
>>> print out.getvalue()
first
second
alpha
beta
gamma
third
<BLANKLINE>
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
...@@ -46,12 +46,13 @@ class ErrorWriter(object): ...@@ -46,12 +46,13 @@ class ErrorWriter(object):
class TestBuilder(object): class TestBuilder(object):
def __init__(self, rootdir, workdir, selectors, annotate, def __init__(self, rootdir, workdir, selectors, annotate,
cleanup_workdir, with_pyregr): cleanup_workdir, cleanup_sharedlibs, with_pyregr):
self.rootdir = rootdir self.rootdir = rootdir
self.workdir = workdir self.workdir = workdir
self.selectors = selectors self.selectors = selectors
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
self.with_pyregr = with_pyregr self.with_pyregr = with_pyregr
def build_suite(self): def build_suite(self):
...@@ -84,6 +85,7 @@ class TestBuilder(object): ...@@ -84,6 +85,7 @@ class TestBuilder(object):
for filename in filenames: for filename in filenames:
if not (filename.endswith(".pyx") or filename.endswith(".py")): if not (filename.endswith(".pyx") or filename.endswith(".py")):
continue continue
if filename.startswith('.'): continue # certain emacs backup files
if context == 'pyregr' and not filename.startswith('test_'): if context == 'pyregr' and not filename.startswith('test_'):
continue continue
module = os.path.splitext(filename)[0] module = os.path.splitext(filename)[0]
...@@ -99,25 +101,29 @@ class TestBuilder(object): ...@@ -99,25 +101,29 @@ class TestBuilder(object):
test = build_test( test = build_test(
path, workdir, module, path, workdir, module,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir) cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
else: else:
test = CythonCompileTestCase( test = CythonCompileTestCase(
path, workdir, module, path, workdir, module,
expect_errors=expect_errors, expect_errors=expect_errors,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir) cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs)
suite.addTest(test) suite.addTest(test)
return suite return suite
class CythonCompileTestCase(unittest.TestCase): class CythonCompileTestCase(unittest.TestCase):
def __init__(self, directory, workdir, module, def __init__(self, directory, workdir, module,
expect_errors=False, annotate=False, cleanup_workdir=True): expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True):
self.directory = directory self.directory = directory
self.workdir = workdir self.workdir = workdir
self.module = module self.module = module
self.expect_errors = expect_errors self.expect_errors = expect_errors
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs
unittest.TestCase.__init__(self) unittest.TestCase.__init__(self)
def shortDescription(self): def shortDescription(self):
...@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -125,10 +131,13 @@ class CythonCompileTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
cleanup_c_files = WITH_CYTHON and self.cleanup_workdir cleanup_c_files = WITH_CYTHON and self.cleanup_workdir
cleanup_lib_files = self.cleanup_sharedlibs
if os.path.exists(self.workdir): if os.path.exists(self.workdir):
for rmfile in os.listdir(self.workdir): for rmfile in os.listdir(self.workdir):
if not cleanup_c_files and rmfile[-2:] in (".c", ".h"): if not cleanup_c_files and rmfile[-2:] in (".c", ".h"):
continue continue
if not cleanup_lib_files and rmfile.endswith(".so") or rmfile.endswith(".dll"):
continue
if self.annotate and rmfile.endswith(".html"): if self.annotate and rmfile.endswith(".html"):
continue continue
try: try:
...@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase): ...@@ -278,7 +287,7 @@ class CythonUnitTestCase(CythonCompileTestCase):
except Exception: except Exception:
pass pass
def collect_unittests(path, suite, selectors): def collect_unittests(path, module_prefix, suite, selectors):
def file_matches(filename): def file_matches(filename):
return filename.startswith("Test") and filename.endswith(".py") return filename.startswith("Test") and filename.endswith(".py")
...@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors): ...@@ -304,7 +313,7 @@ def collect_unittests(path, suite, selectors):
for f in filenames: for f in filenames:
if file_matches(f): if file_matches(f):
filepath = os.path.join(dirpath, f)[:-len(".py")] filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = filepath[len(path)+1:].replace(os.path.sep, '.') modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]: if not [ 1 for match in selectors if match(modulename) ]:
continue continue
module = __import__(modulename) module = __import__(modulename)
...@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors): ...@@ -312,18 +321,50 @@ def collect_unittests(path, suite, selectors):
module = getattr(module, x) module = getattr(module, x)
suite.addTests([loader.loadTestsFromModule(module)]) suite.addTests([loader.loadTestsFromModule(module)])
def collect_doctests(path, module_prefix, suite, selectors):
def package_matches(dirname):
return dirname not in ("Mac", "Distutils", "Plex")
def file_matches(filename):
return (filename.endswith(".py") and not ('~' in filename
or '#' in filename or filename.startswith('.')))
import doctest, types
for dirpath, dirnames, filenames in os.walk(path):
parentname = os.path.split(dirpath)[-1]
if package_matches(parentname):
for f in filenames:
if file_matches(f):
if not f.endswith('.py'): continue
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
for x in modulename.split('.')[1:]:
module = getattr(module, x)
if hasattr(module, "__doc__") or hasattr(module, "__test__"):
try:
suite.addTests(doctest.DocTestSuite(module))
except ValueError: # no tests
pass
if __name__ == '__main__': if __name__ == '__main__':
from optparse import OptionParser from optparse import OptionParser
parser = OptionParser() parser = OptionParser()
parser.add_option("--no-cleanup", dest="cleanup_workdir", parser.add_option("--no-cleanup", dest="cleanup_workdir",
action="store_false", default=True, action="store_false", default=True,
help="do not delete the generated C files (allows passing --no-cython on next run)") help="do not delete the generated C files (allows passing --no-cython on next run)")
parser.add_option("--no-cleanup-sharedlibs", dest="cleanup_sharedlibs",
action="store_false", default=True,
help="do not delete the generated shared libary files (allows manual module experimentation)")
parser.add_option("--no-cython", dest="with_cython", parser.add_option("--no-cython", dest="with_cython",
action="store_false", default=True, action="store_false", default=True,
help="do not run the Cython compiler, only the C compiler") help="do not run the Cython compiler, only the C compiler")
parser.add_option("--no-unit", dest="unittests", parser.add_option("--no-unit", dest="unittests",
action="store_false", default=True, action="store_false", default=True,
help="do not run the unit tests") help="do not run the unit tests")
parser.add_option("--no-doctest", dest="doctests",
action="store_false", default=True,
help="do not run the doctests")
parser.add_option("--no-file", dest="filetests", parser.add_option("--no-file", dest="filetests",
action="store_false", default=True, action="store_false", default=True,
help="do not run the file based tests") help="do not run the file based tests")
...@@ -360,6 +401,8 @@ if __name__ == '__main__': ...@@ -360,6 +401,8 @@ if __name__ == '__main__':
# RUN ALL TESTS! # RUN ALL TESTS!
ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests') ROOTDIR = os.path.join(os.getcwd(), os.path.dirname(sys.argv[0]), 'tests')
WORKDIR = os.path.join(os.getcwd(), 'BUILD') WORKDIR = os.path.join(os.getcwd(), 'BUILD')
UNITTEST_MODULE = "Cython"
UNITTEST_ROOT = os.path.join(os.getcwd(), UNITTEST_MODULE)
if WITH_CYTHON: if WITH_CYTHON:
if os.path.exists(WORKDIR): if os.path.exists(WORKDIR):
shutil.rmtree(WORKDIR, ignore_errors=True) shutil.rmtree(WORKDIR, ignore_errors=True)
...@@ -382,13 +425,15 @@ if __name__ == '__main__': ...@@ -382,13 +425,15 @@ if __name__ == '__main__':
test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
if options.unittests: if options.unittests:
collect_unittests(os.getcwd(), test_suite, selectors) collect_unittests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.doctests:
collect_doctests(UNITTEST_ROOT, UNITTEST_MODULE + ".", test_suite, selectors)
if options.filetests: if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors, filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source, options.annotate_source, options.cleanup_workdir,
options.cleanup_workdir, options.cleanup_sharedlibs, options.pyregr)
options.pyregr)
test_suite.addTests([filetests.build_suite()]) test_suite.addTests([filetests.build_suite()])
unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite) unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
......
cdef extern:
cdef func(int[])
cdef object[int] buf
cdef class A:
cdef object[int] buf
def f():
cdef object[fakeoption=True] buf1
cdef object[int, -1] buf1b
cdef object[ndim=-1] buf2
cdef object[int, 'a'] buf3
cdef object[int,2,3,4,5,6] buf4
cdef object[int, 2, 'foo'] buf5
cdef object[int, 2, well] buf6
_ERRORS = u"""
1:11: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option
7:22: "ndim" must be non-negative
8:15: "dtype" missing
9:21: "ndim" must be an integer
10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
"""
cdef class A:
cdef int value = 3
_ERRORS = u"""
2:13: Cannot assign default value to cdef class attributes
"""
This diff is collapsed.
__doc__ = """
>>> test(1, 2)
4 1 2 2 0 7 8
"""
cdef int g = 7
def test(x, int y):
if True:
before = 0
cdef int a = 4, b = x, c = y, *p = &y
cdef object o = int(8)
print a, b, c, p[0], before, g, o
# Also test that pruning cdefs doesn't hurt
def empty():
cdef int i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment