Commit 01a5a332 authored by Stefan Behnel's avatar Stefan Behnel

large merge of cython-dagss as of revision 764

parents df65309f 72d54fb4
from Cython.Compiler.Visitor import TreeVisitor from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
""" """
Serializes a Cython code tree to Cython code. This is primarily useful for Serializes a Cython code tree to Cython code. This is primarily useful for
...@@ -35,6 +36,7 @@ class CodeWriter(TreeVisitor): ...@@ -35,6 +36,7 @@ class CodeWriter(TreeVisitor):
result = LinesResult() result = LinesResult()
self.result = result self.result = result
self.numindents = 0 self.numindents = 0
self.tempnames = {}
def write(self, tree): def write(self, tree):
self.visit(tree) self.visit(tree)
...@@ -57,6 +59,12 @@ class CodeWriter(TreeVisitor): ...@@ -57,6 +59,12 @@ class CodeWriter(TreeVisitor):
def line(self, s): def line(self, s):
self.startline(s) self.startline(s)
self.endline() self.endline()
def putname(self, name):
tmpdesc = get_temp_name_handle_desc(name)
if tmpdesc is not None:
name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
self.put(name)
def comma_seperated_list(self, items, output_rhs=False): def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0: if len(items) > 0:
...@@ -116,7 +124,7 @@ class CodeWriter(TreeVisitor): ...@@ -116,7 +124,7 @@ class CodeWriter(TreeVisitor):
self.endline() self.endline()
def visit_NameNode(self, node): def visit_NameNode(self, node):
self.put(node.name) self.putname(node.name)
def visit_IntNode(self, node): def visit_IntNode(self, node):
self.put(node.value) self.put(node.value)
...@@ -185,10 +193,23 @@ class CodeWriter(TreeVisitor): ...@@ -185,10 +193,23 @@ class CodeWriter(TreeVisitor):
self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm... self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
def visit_SimpleCallNode(self, node): def visit_SimpleCallNode(self, node):
self.put(node.function.name + u"(") self.visit(node.function)
self.put(u"(")
self.comma_seperated_list(node.args) self.comma_seperated_list(node.args)
self.put(")") self.put(")")
def visit_GeneralCallNode(self, node):
self.visit(node.function)
self.put(u"(")
posarg = node.positional_args
if isinstance(posarg, AsTupleNode):
self.visit(posarg.arg)
else:
self.comma_seperated_list(posarg)
if node.keyword_args is not None or node.starstar_arg is not None:
raise Exception("Not implemented yet")
self.put(u")")
def visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
self.startline() self.startline()
self.visit(node.expr) self.visit(node.expr)
...@@ -197,9 +218,73 @@ class CodeWriter(TreeVisitor): ...@@ -197,9 +218,73 @@ class CodeWriter(TreeVisitor):
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.startline() self.startline()
self.visit(node.lhs) self.visit(node.lhs)
self.put(" %s= " % node.operator) self.put(u" %s= " % node.operator)
self.visit(node.rhs) self.visit(node.rhs)
self.endline() self.endline()
def visit_WithStatNode(self, node):
self.startline()
self.put(u"with ")
self.visit(node.manager)
if node.target is not None:
self.put(u" as ")
self.visit(node.target)
self.endline(u":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_AttributeNode(self, node):
self.visit(node.obj)
self.put(u".%s" % node.attribute)
def visit_BoolNode(self, node):
self.put(str(node.value))
def visit_TryFinallyStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
self.line(u"finally:")
self.indent()
self.visit(node.finally_clause)
self.dedent()
def visit_TryExceptStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
for x in node.except_clauses:
self.visit(x)
if node.else_clause is not None:
self.visit(node.else_clause)
def visit_ExceptClauseNode(self, node):
self.startline(u"except")
if node.pattern is not None:
self.put(u" ")
self.visit(node.pattern)
if node.target is not None:
self.put(u", ")
self.visit(node.target)
self.endline(":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_ReraiseStatNode(self, node):
self.line("raise")
def visit_NoneNode(self, node):
self.put(u"None")
def visit_ImportNode(self, node):
self.put(u"(import %s)" % node.module_name.value)
def visit_NotNode(self, node):
self.put(u"(not ")
self.visit(node.operand)
self.put(u")")
...@@ -200,6 +200,8 @@ class CCodeWriter: ...@@ -200,6 +200,8 @@ class CCodeWriter:
def put_var_declaration(self, entry, static = 0, dll_linkage = None, def put_var_declaration(self, entry, static = 0, dll_linkage = None,
definition = True): definition = True):
#print "Code.put_var_declaration:", entry.name, "definition =", definition ### #print "Code.put_var_declaration:", entry.name, "definition =", definition ###
if entry.in_closure:
return
visibility = entry.visibility visibility = entry.visibility
if visibility == 'private' and not definition: if visibility == 'private' and not definition:
#print "...private and not definition, skipping" ### #print "...private and not definition, skipping" ###
......
...@@ -32,6 +32,7 @@ class CompileError(PyrexError): ...@@ -32,6 +32,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
self.position = position self.position = position
self.message_only = message
# Deprecated and withdrawn in 2.6: # Deprecated and withdrawn in 2.6:
# self.message = message # self.message = message
if position: if position:
...@@ -91,6 +92,7 @@ def error(position, message): ...@@ -91,6 +92,7 @@ def error(position, message):
#print "Errors.error:", repr(position), repr(message) ### #print "Errors.error:", repr(position), repr(message) ###
global num_errors global num_errors
err = CompileError(position, message) err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
line = "%s\n" % err line = "%s\n" % err
if listing_file: if listing_file:
listing_file.write(line) listing_file.write(line)
......
...@@ -972,7 +972,7 @@ class NameNode(AtomicExprNode): ...@@ -972,7 +972,7 @@ class NameNode(AtomicExprNode):
if entry.is_builtin: if entry.is_builtin:
namespace = Naming.builtins_cname namespace = Naming.builtins_cname
else: # entry.is_pyglobal else: # entry.is_pyglobal
namespace = entry.namespace_cname namespace = entry.scope.namespace_cname
code.putln( code.putln(
'%s = __Pyx_GetName(%s, %s); %s' % ( '%s = __Pyx_GetName(%s, %s); %s' % (
self.result_code, self.result_code,
...@@ -997,7 +997,7 @@ class NameNode(AtomicExprNode): ...@@ -997,7 +997,7 @@ class NameNode(AtomicExprNode):
# is_pyglobal seems to be True for module level-globals only. # is_pyglobal seems to be True for module level-globals only.
# We use this to access class->tp_dict if necessary. # We use this to access class->tp_dict if necessary.
if entry.is_pyglobal: if entry.is_pyglobal:
namespace = self.entry.namespace_cname namespace = self.entry.scope.namespace_cname
if entry.is_member: if entry.is_member:
# if the entry is a member we have to cheat: SetAttr does not work # if the entry is a member we have to cheat: SetAttr does not work
# on types, so we create a descriptor which is then added to tp_dict # on types, so we create a descriptor which is then added to tp_dict
...@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode): ...@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode):
else: else:
code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name))) code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name)))
class BackquoteNode(ExprNode): class BackquoteNode(ExprNode):
# `expr` # `expr`
# #
...@@ -1212,6 +1211,9 @@ class ExcValueNode(AtomicExprNode): ...@@ -1212,6 +1211,9 @@ class ExcValueNode(AtomicExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
def analyse_types(self, env):
pass
class TempNode(AtomicExprNode): class TempNode(AtomicExprNode):
# Node created during analyse_types phase # Node created during analyse_types phase
...@@ -1273,36 +1275,59 @@ class IndexNode(ExprNode): ...@@ -1273,36 +1275,59 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0): def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
self.index.analyse_types(env)
if self.base.type.is_pyobject: if self.base.type.buffer_options is not None:
if self.index.type.is_int: if isinstance(self.index, TupleNode):
self.original_index_type = self.index.type indices = self.index.args
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env) # is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int])
if getting:
env.use_utility_code(getitem_int_utility_code)
if setting:
env.use_utility_code(setitem_int_utility_code)
else: else:
self.index = self.index.coerce_to_pyobject(env) # is_int_indices = self.index.type.is_int
self.type = py_object_type indices = [self.index]
self.gil_check(env) all_ints = True
self.is_temp = 1 for index in indices:
else: index.analyse_types(env)
if self.base.type.is_ptr or self.base.type.is_array: if not index.type.is_int:
self.type = self.base.type.base_type all_ints = False
if all_ints:
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.is_buffer_access = True
if not self.is_buffer_access:
self.index.analyse_types(env) # ok to analyse as tuple
if self.base.type.is_pyobject:
if self.index.type.is_int:
self.original_index_type = self.index.type
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
if getting:
env.use_utility_code(getitem_int_utility_code)
if setting:
env.use_utility_code(setitem_int_utility_code)
else:
self.index = self.index.coerce_to_pyobject(env)
self.type = py_object_type
self.gil_check(env)
self.is_temp = 1
else: else:
error(self.pos, if self.base.type.is_ptr or self.base.type.is_array:
"Attempting to index non-array type '%s'" % self.type = self.base.type.base_type
self.base.type) else:
self.type = PyrexTypes.error_type error(self.pos,
if self.index.type.is_pyobject: "Attempting to index non-array type '%s'" %
self.index = self.index.coerce_to( self.base.type)
PyrexTypes.c_py_ssize_t_type, env) self.type = PyrexTypes.error_type
if not self.index.type.is_int: if self.index.type.is_pyobject:
error(self.pos, self.index = self.index.coerce_to(
"Invalid index type '%s'" % PyrexTypes.c_py_ssize_t_type, env)
self.index.type) if not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
...@@ -1328,11 +1353,17 @@ class IndexNode(ExprNode): ...@@ -1328,11 +1353,17 @@ class IndexNode(ExprNode):
def generate_subexpr_evaluation_code(self, code): def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code) self.base.generate_evaluation_code(code)
self.index.generate_evaluation_code(code) if self.index is not None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices: i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
self.index.generate_disposal_code(code) if self.index is not None:
self.index.generate_disposal_code(code)
else:
for i in self.indices: i.generate_disposal_code(code)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.type.is_pyobject: if self.type.is_pyobject:
......
...@@ -7,5 +7,6 @@ def _get_feature(name): ...@@ -7,5 +7,6 @@ def _get_feature(name):
return object() return object()
unicode_literals = _get_feature("unicode_literals") unicode_literals = _get_feature("unicode_literals")
with_statement = _get_feature("with_statement")
del _get_feature del _get_feature
This diff is collapsed.
...@@ -25,6 +25,10 @@ from PyrexTypes import py_object_type ...@@ -25,6 +25,10 @@ from PyrexTypes import py_object_type
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
def check_c_classes(module_node):
module_node.scope.check_c_classes()
return module_node
class ModuleNode(Nodes.Node, Nodes.BlockNode): class ModuleNode(Nodes.Node, Nodes.BlockNode):
# doc string or None # doc string or None
# body StatListNode # body StatListNode
...@@ -32,6 +36,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -32,6 +36,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# referenced_modules [ModuleScope] # referenced_modules [ModuleScope]
# module_temp_cname string # module_temp_cname string
# full_module_name string # full_module_name string
#
# scope The module scope.
# compilation_source A CompilationSource (see Main)
child_attrs = ["body"] child_attrs = ["body"]
...@@ -44,10 +51,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -44,10 +51,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
env.doc = self.doc env.doc = self.doc
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def process_implementation(self, env, options, result): def process_implementation(self, options, result):
self.analyse_declarations(env) env = self.scope
env.check_c_classes()
self.body.analyse_expressions(env)
env.return_type = PyrexTypes.c_void_type env.return_type = PyrexTypes.c_void_type
self.referenced_modules = [] self.referenced_modules = []
self.find_referenced_modules(env, self.referenced_modules, {}) self.find_referenced_modules(env, self.referenced_modules, {})
...@@ -254,6 +259,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -254,6 +259,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_module_cleanup_func(env, code) self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code) self.generate_filename_table(code)
self.generate_utility_functions(env, code) self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h) self.generate_declarations_for_modules(env, modules, code.h)
...@@ -433,6 +439,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -433,6 +439,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)") code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif") code.putln("#endif")
code.put(builtin_module_name_utility_code[0]) code.put(builtin_module_name_utility_code[0])
...@@ -1940,6 +1949,60 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1940,6 +1949,60 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.h.put(utility_code[0]) code.h.put(utility_code[0])
code.put(utility_code[1]) code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
view->buf = arr->data;
view->readonly = 0; /*fixme*/
view->format = "B"; /*fixme*/
view->ndim = arr->nd;
view->strides = arr->strides;
view->shape = arr->dimensions;
view->suboffsets = 0;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
}
""")
# For now, hard-code numpy imported as "numpy"
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types = [
(ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
]
# typeptr_cname = ndarrtype.typeptr_cname
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
pyrex_prefix = "__pyx_" pyrex_prefix = "__pyx_"
temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_" builtin_prefix = pyrex_prefix + "builtin_"
arg_prefix = pyrex_prefix + "arg_" arg_prefix = pyrex_prefix + "arg_"
funcdoc_prefix = pyrex_prefix + "doc_" funcdoc_prefix = pyrex_prefix + "doc_"
...@@ -70,6 +72,8 @@ optional_args_cname = pyrex_prefix + "optional_args" ...@@ -70,6 +72,8 @@ optional_args_cname = pyrex_prefix + "optional_args"
no_opt_args = pyrex_prefix + "no_opt_args" no_opt_args = pyrex_prefix + "no_opt_args"
import_star = pyrex_prefix + "import_star" import_star = pyrex_prefix + "import_star"
import_star_set = pyrex_prefix + "import_star_set" import_star_set = pyrex_prefix + "import_star_set"
cur_scope_cname = pyrex_prefix + "cur_scope"
enc_scope_cname = pyrex_prefix + "enc_scope"
line_c_macro = "__LINE__" line_c_macro = "__LINE__"
......
This diff is collapsed.
This diff is collapsed.
...@@ -312,6 +312,7 @@ def p_call(s, function): ...@@ -312,6 +312,7 @@ def p_call(s, function):
if s.sy != ',': if s.sy != ',':
break break
s.next() s.next()
if s.sy == '*': if s.sy == '*':
s.next() s.next()
star_arg = p_simple_expr(s) star_arg = p_simple_expr(s)
...@@ -1159,13 +1160,13 @@ def p_for_from_step(s): ...@@ -1159,13 +1160,13 @@ def p_for_from_step(s):
inequality_relations = ('<', '<=', '>', '>=') inequality_relations = ('<', '<=', '>', '>=')
def p_for_target(s): def p_target(s, terminator):
pos = s.position() pos = s.position()
expr = p_bit_expr(s) expr = p_bit_expr(s)
if s.sy == ',': if s.sy == ',':
s.next() s.next()
exprs = [expr] exprs = [expr]
while s.sy != 'in': while s.sy != terminator:
exprs.append(p_bit_expr(s)) exprs.append(p_bit_expr(s))
if s.sy != ',': if s.sy != ',':
break break
...@@ -1174,6 +1175,9 @@ def p_for_target(s): ...@@ -1174,6 +1175,9 @@ def p_for_target(s):
else: else:
return expr return expr
def p_for_target(s):
return p_target(s, 'in')
def p_for_iterator(s): def p_for_iterator(s):
pos = s.position() pos = s.position()
expr = p_testlist(s) expr = p_testlist(s)
...@@ -1253,8 +1257,17 @@ def p_with_statement(s): ...@@ -1253,8 +1257,17 @@ def p_with_statement(s):
body = p_suite(s) body = p_suite(s)
return Nodes.GILStatNode(pos, state = state, body = body) return Nodes.GILStatNode(pos, state = state, body = body)
else: else:
s.error("Only 'with gil' and 'with nogil' implemented", manager = p_expr(s)
pos = pos) target = None
if s.sy == 'IDENT' and s.systring == 'as':
s.next()
allow_multi = (s.sy == '(')
target = p_target(s, ':')
if not allow_multi and isinstance(target, ExprNodes.TupleNode):
s.error("Multiple with statement target values not allowed without paranthesis")
body = p_suite(s)
return Nodes.WithStatNode(pos, manager = manager,
target = target, body = body)
def p_simple_statement(s, first_statement = 0): def p_simple_statement(s, first_statement = 0):
#print "p_simple_statement:", s.sy, s.systring ### #print "p_simple_statement:", s.sy, s.systring ###
...@@ -1447,6 +1460,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0): ...@@ -1447,6 +1460,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0):
else: else:
return body return body
def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keywords=()):
"""
Parses positional and keyword arguments. end_sy_set
should contain any s.sy that terminate the argument list.
Argument expansion (* and **) are not allowed.
type_positions and type_keywords specifies which argument
positions and/or names which should be interpreted as
types. Other arguments will be treated as expressions.
Returns: (positional_args, keyword_args)
"""
positional_args = []
keyword_args = []
pos_idx = 0
while s.sy not in end_sy_set:
if s.sy == '*' or s.sy == '**':
s.error('Argument expansion not allowed here.')
was_keyword = False
parsed_type = False
if s.sy == 'IDENT':
# Since we can have either types or expressions as positional args,
# we use a strategy of looking an extra step forward for a '=' and
# if it is a positional arg we backtrack.
ident = s.systring
s.next()
if s.sy == '=':
s.next()
# Is keyword arg
if ident in type_keywords:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
value = Utils.EncodedString(ident))
keyword_args.append((keyword_node, arg))
was_keyword = True
else:
s.put_back('IDENT', ident)
if not was_keyword:
if pos_idx in type_positions:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
positional_args.append(arg)
pos_idx += 1
if len(keyword_args) > 0:
s.error("Non-keyword arg following keyword arg",
pos = arg.pos)
if s.sy != ',':
if s.sy not in end_sy_set:
if parsed_type:
s.error("Expected: type")
else:
s.error("Expected: expression")
break
s.next()
return positional_args, keyword_args
def p_c_base_type(s, self_flag = 0, nonempty = 0): def p_c_base_type(s, self_flag = 0, nonempty = 0):
# If self_flag is true, this is the base type for the # If self_flag is true, this is the base type for the
# self argument of a C method of an extension type. # self argument of a C method of an extension type.
...@@ -1519,11 +1597,43 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1519,11 +1597,43 @@ def p_c_simple_base_type(s, self_flag, nonempty):
else: else:
#print "p_c_simple_base_type: not looking at type at", s.position() #print "p_c_simple_base_type: not looking at type at", s.position()
name = None name = None
return Nodes.CSimpleBaseTypeNode(pos,
type_node = Nodes.CSimpleBaseTypeNode(pos,
name = name, module_path = module_path, name = name, module_path = module_path,
is_basic_c_type = is_basic, signed = signed, is_basic_c_type = is_basic, signed = signed,
longness = longness, is_self_arg = self_flag) longness = longness, is_self_arg = self_flag)
# Treat trailing [] on type as buffer access
if s.sy == '[':
if is_basic:
p.error("Basic C types do not support buffer access")
return p_buffer_access(s, type_node)
else:
return type_node
def p_buffer_access(s, type_node):
# s.sy == '['
pos = s.position()
s.next()
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = type_node)
return result
def looking_at_type(s): def looking_at_type(s):
return looking_at_base_type(s) or s.looking_at_type_name() return looking_at_base_type(s) or s.looking_at_type_name()
......
...@@ -4,6 +4,23 @@ ...@@ -4,6 +4,23 @@
from Cython import Utils from Cython import Utils
import Naming import Naming
import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType: class BaseType:
# #
...@@ -92,6 +109,7 @@ class PyrexType(BaseType): ...@@ -92,6 +109,7 @@ class PyrexType(BaseType):
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -183,7 +201,6 @@ class CTypedefType(BaseType): ...@@ -183,7 +201,6 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
......
# #
# Pyrex - Symbol Table # Symbol Table
# #
import re import re
...@@ -19,6 +19,14 @@ import __builtin__ ...@@ -19,6 +19,14 @@ import __builtin__
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux:
def __init__(self, buffer_info_var, stridevars, shapevars):
self.buffer_info_var = buffer_info_var
self.stridevars = stridevars
self.shapevars = shapevars
def __repr__(self):
return "<BufferAux %r>" % self.__dict__
class Entry: class Entry:
# A symbol table entry in a Scope or ModuleNamespace. # A symbol table entry in a Scope or ModuleNamespace.
# #
...@@ -47,6 +55,7 @@ class Entry: ...@@ -47,6 +55,7 @@ class Entry:
# is_self_arg boolean Is the "self" arg of an exttype method # is_self_arg boolean Is the "self" arg of an exttype method
# is_arg boolean Is the arg of a method # is_arg boolean Is the arg of a method
# is_local boolean Is a local variable # is_local boolean Is a local variable
# in_closure boolean Is referenced in an inner scope
# is_readonly boolean Can't be assigned to # is_readonly boolean Can't be assigned to
# func_cname string C func implementing Python func # func_cname string C func implementing Python func
# pos position Source position where declared # pos position Source position where declared
...@@ -75,6 +84,8 @@ class Entry: ...@@ -75,6 +84,8 @@ class Entry:
# defined_in_pxd boolean Is defined in a .pxd file (not just declared) # defined_in_pxd boolean Is defined in a .pxd file (not just declared)
# api boolean Generate C API for C class or function # api boolean Generate C API for C class or function
# utility_code string Utility code needed when this entry is used # utility_code string Utility code needed when this entry is used
#
# buffer_aux BufferAux or None Extra information needed for buffer variables
borrowed = 0 borrowed = 0
init = "" init = ""
...@@ -96,6 +107,7 @@ class Entry: ...@@ -96,6 +107,7 @@ class Entry:
is_self_arg = 0 is_self_arg = 0
is_arg = 0 is_arg = 0
is_local = 0 is_local = 0
in_closure = 0
is_declared_generic = 0 is_declared_generic = 0
is_readonly = 0 is_readonly = 0
func_cname = None func_cname = None
...@@ -115,6 +127,7 @@ class Entry: ...@@ -115,6 +127,7 @@ class Entry:
api = 0 api = 0
utility_code = None utility_code = None
is_overridable = 0 is_overridable = 0
buffer_aux = None
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -163,6 +176,8 @@ class Scope: ...@@ -163,6 +176,8 @@ class Scope:
in_cinclude = 0 in_cinclude = 0
nogil = 0 nogil = 0
temp_prefix = Naming.pyrex_prefix
def __init__(self, name, outer_scope, parent_scope): def __init__(self, name, outer_scope, parent_scope):
# The outer_scope is the next scope in the lookup chain. # The outer_scope is the next scope in the lookup chain.
# The parent_scope is used to derive the qualified name of this scope. # The parent_scope is used to derive the qualified name of this scope.
...@@ -447,7 +462,14 @@ class Scope: ...@@ -447,7 +462,14 @@ class Scope:
# Look up name in this scope or an enclosing one. # Look up name in this scope or an enclosing one.
# Return None if not found. # Return None if not found.
return (self.lookup_here(name) return (self.lookup_here(name)
or (self.outer_scope and self.outer_scope.lookup(name)) or (self.outer_scope and self.outer_scope.lookup_from_inner(name))
or None)
def lookup_from_inner(self, name):
# Look up name in this scope or an enclosing one.
# This is only called from enclosing scopes.
return (self.lookup_here(name)
or (self.outer_scope and self.outer_scope.lookup_from_inner(name))
or None) or None)
def lookup_here(self, name): def lookup_here(self, name):
...@@ -562,7 +584,7 @@ class Scope: ...@@ -562,7 +584,7 @@ class Scope:
return entry.cname return entry.cname
n = self.temp_counter n = self.temp_counter
self.temp_counter = n + 1 self.temp_counter = n + 1
cname = "%s%d" % (Naming.pyrex_prefix, n) cname = "%s%d" % (self.temp_prefix, n)
entry = Entry("", cname, type) entry = Entry("", cname, type)
entry.used = 1 entry.used = 1
if type.is_pyobject or type == PyrexTypes.c_py_ssize_t_type: if type.is_pyobject or type == PyrexTypes.c_py_ssize_t_type:
...@@ -608,6 +630,9 @@ class Scope: ...@@ -608,6 +630,9 @@ class Scope:
return 0 return 0
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname
def __init__(self): def __init__(self):
Scope.__init__(self, Options.pre_import, None, None) Scope.__init__(self, Options.pre_import, None, None)
...@@ -615,7 +640,6 @@ class PreImportScope(Scope): ...@@ -615,7 +640,6 @@ class PreImportScope(Scope):
entry = self.declare(name, name, py_object_type, pos) entry = self.declare(name, name, py_object_type, pos)
entry.is_variable = True entry.is_variable = True
entry.is_pyglobal = True entry.is_pyglobal = True
entry.namespace_cname = Naming.preimport_cname
return entry return entry
...@@ -761,6 +785,7 @@ class ModuleScope(Scope): ...@@ -761,6 +785,7 @@ class ModuleScope(Scope):
self.has_extern_class = 0 self.has_extern_class = 0
self.cached_builtins = [] self.cached_builtins = []
self.undeclared_cached_builtins = [] self.undeclared_cached_builtins = []
self.namespace_cname = self.module_cname
def qualifying_scope(self): def qualifying_scope(self):
return self.parent_module return self.parent_module
...@@ -876,7 +901,6 @@ class ModuleScope(Scope): ...@@ -876,7 +901,6 @@ class ModuleScope(Scope):
raise InternalError( raise InternalError(
"Non-cdef global variable is not a generic Python object") "Non-cdef global variable is not a generic Python object")
entry.is_pyglobal = 1 entry.is_pyglobal = 1
entry.namespace_cname = self.module_cname
else: else:
entry.is_cglobal = 1 entry.is_cglobal = 1
self.var_entries.append(entry) self.var_entries.append(entry)
...@@ -1075,8 +1099,7 @@ class ModuleScope(Scope): ...@@ -1075,8 +1099,7 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
class LocalScope(Scope):
class LocalScope(Scope):
def __init__(self, name, outer_scope): def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, outer_scope) Scope.__init__(self, name, outer_scope, outer_scope)
...@@ -1119,6 +1142,33 @@ class LocalScope(Scope): ...@@ -1119,6 +1142,33 @@ class LocalScope(Scope):
entry = self.global_scope().lookup_target(name) entry = self.global_scope().lookup_target(name)
self.entries[name] = entry self.entries[name] = entry
def lookup_from_inner(self, name):
entry = self.lookup_here(name)
if entry:
entry.in_closure = 1
return entry
else:
return (self.outer_scope and self.outer_scope.lookup_from_inner(name)) or None
def mangle_closure_cnames(self, scope_var):
for entry in self.entries.values():
if entry.in_closure:
if not hasattr(entry, 'orig_cname'):
entry.orig_cname = entry.cname
entry.cname = scope_var + "->" + entry.cname
class GeneratorLocalScope(LocalScope):
temp_prefix = Naming.cur_scope_cname + "->" + LocalScope.temp_prefix
def mangle_closure_cnames(self, scope_var):
for entry in self.entries.values() + self.temp_entries:
entry.in_closure = 1
LocalScope.mangle_closure_cnames(self, scope_var)
# def mangle(self, prefix, name):
# return "%s->%s" % (Naming.scope_obj_cname, name)
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
...@@ -1198,7 +1248,6 @@ class PyClassScope(ClassScope): ...@@ -1198,7 +1248,6 @@ class PyClassScope(ClassScope):
entry = Scope.declare_var(self, name, type, pos, entry = Scope.declare_var(self, name, type, pos,
cname, visibility, is_cdef) cname, visibility, is_cdef)
entry.is_pyglobal = 1 entry.is_pyglobal = 1
entry.namespace_cname = self.class_obj_cname
return entry return entry
def allocate_temp(self, type): def allocate_temp(self, type):
...@@ -1295,7 +1344,7 @@ class CClassScope(ClassScope): ...@@ -1295,7 +1344,7 @@ class CClassScope(ClassScope):
entry.is_pyglobal = 1 # xxx: is_pyglobal changes behaviour in so many places that entry.is_pyglobal = 1 # xxx: is_pyglobal changes behaviour in so many places that
# I keep it in for now. is_member should be enough # I keep it in for now. is_member should be enough
# later on # later on
entry.namespace_cname = "(PyObject *)%s" % self.parent_type.typeptr_cname self.namespace_cname = "(PyObject *)%s" % self.parent_type.typeptr_cname
entry.interned_cname = self.intern_identifier(name) entry.interned_cname = self.intern_identifier(name)
return entry return entry
......
from Cython.TestUtils import CythonTest
import Cython.Compiler.Errors as Errors
from Cython.Compiler.Nodes import *
from Cython.Compiler.ParseTreeTransforms import *
class TestBufferParsing(CythonTest):
# First, we only test the raw parser, i.e.
# the number and contents of arguments are NOT checked.
# However "dtype"/the first positional argument is special-cased
# to parse a type argument rather than an expression
def parse(self, s):
return self.should_not_fail(lambda: self.fragment(s)).root
def not_parseable(self, expected_error, s):
e = self.should_fail(lambda: self.fragment(s), Errors.CompileError)
self.assertEqual(expected_error, e.message_only)
def test_basic(self):
t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
bufnode = t.stats[0].base_type
self.assert_(isinstance(bufnode, CBufferAccessTypeNode))
self.assertEqual(2, len(bufnode.positional_args))
# print bufnode.dump()
# should put more here...
def test_type_fail(self):
self.not_parseable("Expected: type",
u"cdef object[2] x")
def test_type_pos(self):
self.parse(u"cdef object[short unsigned int, 3] x")
def test_type_keyword(self):
self.parse(u"cdef object[foo=foo, dtype=short unsigned int] x")
def test_notype_as_expr1(self):
self.not_parseable("Expected: expression",
u"cdef object[foo2=short unsigned int] x")
def test_notype_as_expr2(self):
self.not_parseable("Expected: expression",
u"cdef object[int, short unsigned int] x")
def test_pos_after_key(self):
self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x")
class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets
def parse_opts(self, opts):
s = u"cdef object[%s] x" % opts
root = self.fragment(s, pipeline=[PostParse(self)]).root
buftype = root.stats[0].base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, e.message_only)
def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim)
def test_dict(self):
buf = self.parse_opts(u"ndim=3, dtype=unsigned short int")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim)
def test_dtype(self):
self.non_parse(ERR_BUF_MISSING % 'dtype', u"")
def test_ndim(self):
self.parse_opts(u"int, 2")
self.non_parse(ERR_BUF_INT % 'ndim', u"int, 'a'")
self.non_parse(ERR_BUF_NONNEG % 'ndim', u"int, -34")
def test_use_DEF(self):
t = self.fragment(u"""
DEF ndim = 3
cdef object[int, ndim] x
cdef object[ndim=ndim, dtype=int] y
""", pipeline=[PostParse(self)]).root
self.assert_(t.stats[1].base_type.ndim == 3)
self.assert_(t.stats[2].base_type.ndim == 3)
# add exotic and impossible combinations as they come along
from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import *
class TestNormalizeTree(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root
self.assertLines(u"""
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_wrap_singlestat(self):
t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
self.assertLines(u"""
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_wrap_multistat(self):
t = self.run_pipeline([NormalizeTree(None)], u"""
if z:
x
y
""")
self.assertLines(u"""
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
stats[1]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_statinexpr(self):
t = self.run_pipeline([NormalizeTree(None)], u"""
a, b = x, y
""")
self.assertLines(u"""
(root): StatListNode
stats[0]: ParallelAssignmentNode
stats[0]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
stats[1]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
""", self.treetypes(t))
def test_wrap_offagain(self):
t = self.run_pipeline([NormalizeTree(None)], u"""
x
y
if z:
x
""")
self.assertLines(u"""
(root): StatListNode
stats[0]: ExprStatNode
expr: NameNode
stats[1]: ExprStatNode
expr: NameNode
stats[2]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_pass_eliminated(self):
t = self.run_pipeline([NormalizeTree(None)], u"pass")
self.assert_(len(t.stats) == 0)
class TestWithTransform(TransformTest):
def test_simplified(self):
t = self.run_pipeline([WithTransform(None)], u"""
with x:
y = z ** 3
""")
self.assertCode(u"""
$MGR = x
$EXIT = $MGR.__exit__
$MGR.__enter__()
$EXC = True
try:
try:
y = z ** 3
except:
$EXC = False
if (not $EXIT($EXCINFO)):
raise
finally:
if $EXC:
$EXIT(None, None, None)
""", t)
def test_basic(self):
t = self.run_pipeline([WithTransform(None)], u"""
with x as y:
y = z ** 3
""")
self.assertCode(u"""
$MGR = x
$EXIT = $MGR.__exit__
$VALUE = $MGR.__enter__()
$EXC = True
try:
try:
y = $VALUE
y = z ** 3
except:
$EXC = False
if (not $EXIT($EXCINFO)):
raise
finally:
if $EXC:
$EXIT(None, None, None)
""", t)
if __name__ == "__main__":
import unittest
unittest.main()
from Cython.TestUtils import CythonTest from Cython.TestUtils import CythonTest
from Cython.Compiler.TreeFragment import * from Cython.Compiler.TreeFragment import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
import Cython.Compiler.Naming as Naming
class TestTreeFragments(CythonTest): class TestTreeFragments(CythonTest):
def test_basic(self): def test_basic(self):
F = self.fragment(u"x = 4") F = self.fragment(u"x = 4")
T = F.copy() T = F.copy()
...@@ -12,15 +14,15 @@ class TestTreeFragments(CythonTest): ...@@ -12,15 +14,15 @@ class TestTreeFragments(CythonTest):
F = self.fragment(u"if True: x = 4") F = self.fragment(u"if True: x = 4")
T1 = F.root T1 = F.root
T2 = F.copy() T2 = F.copy()
self.assertEqual("x", T2.body.if_clauses[0].body.lhs.name) self.assertEqual("x", T2.stats[0].if_clauses[0].body.lhs.name)
T2.body.if_clauses[0].body.lhs.name = "other" T2.stats[0].if_clauses[0].body.lhs.name = "other"
self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name) self.assertEqual("x", T1.stats[0].if_clauses[0].body.lhs.name)
def test_substitutions_are_copied(self): def test_substitutions_are_copied(self):
T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")}) T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")})
self.assertEqual("x", T.body.expr.operand1.name) self.assertEqual("x", T.stats[0].expr.operand1.name)
self.assertEqual("x", T.body.expr.operand2.name) self.assertEqual("x", T.stats[0].expr.operand2.name)
self.assert_(T.body.expr.operand1 is not T.body.expr.operand2) self.assert_(T.stats[0].expr.operand1 is not T.stats[0].expr.operand2)
def test_substitution(self): def test_substitution(self):
F = self.fragment(u"x = 4") F = self.fragment(u"x = 4")
...@@ -32,7 +34,7 @@ class TestTreeFragments(CythonTest): ...@@ -32,7 +34,7 @@ class TestTreeFragments(CythonTest):
F = self.fragment(u"PASS") F = self.fragment(u"PASS")
pass_stat = PassStatNode(pos=None) pass_stat = PassStatNode(pos=None)
T = F.substitute({"PASS" : pass_stat}) T = F.substitute({"PASS" : pass_stat})
self.assert_(isinstance(T.body, PassStatNode), T.body) self.assert_(isinstance(T.stats[0], PassStatNode), T)
def test_pos_is_transferred(self): def test_pos_is_transferred(self):
F = self.fragment(u""" F = self.fragment(u"""
...@@ -40,11 +42,22 @@ class TestTreeFragments(CythonTest): ...@@ -40,11 +42,22 @@ class TestTreeFragments(CythonTest):
x = u * v ** w x = u * v ** w
""") """)
T = F.substitute({"v" : NameNode(pos=None, name="a")}) T = F.substitute({"v" : NameNode(pos=None, name="a")})
v = F.root.body.stats[1].rhs.operand2.operand1 v = F.root.stats[1].rhs.operand2.operand1
a = T.body.stats[1].rhs.operand2.operand1 a = T.stats[1].rhs.operand2.operand1
self.assertEquals(v.pos, a.pos) self.assertEquals(v.pos, a.pos)
def test_temps(self):
import Cython.Compiler.Visitor as v
v.tmpnamectr = 0
F = self.fragment(u"""
TMP
x = TMP
""")
T = F.substitute(temps=[u"TMP"])
s = T.stats
self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name)
self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP")
self.assert_(s[0].expr.name != u"TMP")
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
...@@ -6,8 +6,8 @@ import re ...@@ -6,8 +6,8 @@ import re
from cStringIO import StringIO from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
from Visitor import VisitorTransform from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node from Nodes import Node, StatListNode
from ExprNodes import NameNode from ExprNodes import NameNode
import Parsing import Parsing
import Main import Main
...@@ -53,7 +53,7 @@ def parse_from_strings(name, code, pxds={}): ...@@ -53,7 +53,7 @@ def parse_from_strings(name, code, pxds={}):
buf = StringIO(code.encode(encoding)) buf = StringIO(code.encode(encoding))
scanner = PyrexScanner(buf, code_source, source_encoding = encoding, scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
type_names = scope.type_names, context = context) scope = scope, context = context)
tree = Parsing.p_module(scanner, 0, module_name) tree = Parsing.p_module(scanner, 0, module_name)
return tree return tree
...@@ -92,27 +92,56 @@ class TemplateTransform(VisitorTransform): ...@@ -92,27 +92,56 @@ class TemplateTransform(VisitorTransform):
if its name is listed in the substitutions dictionary in the if its name is listed in the substitutions dictionary in the
same way. It is the responsibility of the caller to make sure same way. It is the responsibility of the caller to make sure
that the replacement nodes is a valid expression. that the replacement nodes is a valid expression.
Also a list "temps" should be passed. Any names listed will
be transformed into anonymous, temporary names.
Currently supported for tempnames is:
NameNode
(various function and class definition nodes etc. should be added to this)
Each replacement node gets the position of the substituted node Each replacement node gets the position of the substituted node
recursively applied to every member node. recursively applied to every member node.
""" """
def __call__(self, node, substitutions, temps, pos):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = temp_name_handle(key)
self.temps = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
def visit_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
return node return None
else: else:
c = node.clone_node() c = node.clone_node()
if self.pos is not None:
c.pos = self.pos
self.visitchildren(c) self.visitchildren(c)
return c return c
def try_substitution(self, node, key): def try_substitution(self, node, key):
sub = self.substitutions.get(key) sub = self.substitutions.get(key)
if sub is None: if sub is not None:
return self.visit_Node(node) # make copy as usual pos = self.pos
if pos is None: pos = node.pos
return ApplyPositionAndCopy(pos)(sub)
else: else:
return ApplyPositionAndCopy(node.pos)(sub) return self.visit_Node(node) # make copy as usual
def visit_NameNode(self, node): def visit_NameNode(self, node):
return self.try_substitution(node, node.name) tempname = self.temps.get(node.name)
if tempname is not None:
# Replace name with temporary
node.name = tempname
return self.visit_Node(node)
else:
return self.try_substitution(node, node.name)
def visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable # If an expression-as-statement consists of only a replaceable
...@@ -122,10 +151,6 @@ class TemplateTransform(VisitorTransform): ...@@ -122,10 +151,6 @@ class TemplateTransform(VisitorTransform):
else: else:
return self.visit_Node(node) return self.visit_Node(node)
def __call__(self, node, substitutions):
self.substitutions = substitutions
return super(TemplateTransform, self).__call__(node)
def copy_code_tree(node): def copy_code_tree(node):
return TreeCopier()(node) return TreeCopier()(node)
...@@ -133,12 +158,12 @@ INDENT_RE = re.compile(ur"^ *") ...@@ -133,12 +158,12 @@ INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines): def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines" "Strips empty lines and common indentation from the list of strings given in lines"
lines = [x for x in lines if x.strip() != u""] lines = [x for x in lines if x.strip() != u""]
minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines) minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
lines = [x[minindent:] for x in lines] lines = [x[minindent:] for x in lines]
return lines return lines
class TreeFragment(object): class TreeFragment(object):
def __init__(self, code, name, pxds={}): def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[]):
if isinstance(code, unicode): if isinstance(code, unicode):
def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
...@@ -147,18 +172,28 @@ class TreeFragment(object): ...@@ -147,18 +172,28 @@ class TreeFragment(object):
for key, value in pxds.iteritems(): for key, value in pxds.iteritems():
fmt_pxds[key] = fmt(value) fmt_pxds[key] = fmt(value)
self.root = parse_from_strings(name, fmt_code, fmt_pxds) t = parse_from_strings(name, fmt_code, fmt_pxds)
mod = t
t = t.body # Make sure a StatListNode is at the top
if not isinstance(t, StatListNode):
t = StatListNode(pos=mod.pos, stats=[t])
for transform in pipeline:
t = transform(t)
self.root = t
elif isinstance(code, Node): elif isinstance(code, Node):
if pxds != {}: raise NotImplementedError() if pxds != {}: raise NotImplementedError()
self.root = code self.root = code
else: else:
raise ValueError("Unrecognized code format (accepts unicode and Node)") raise ValueError("Unrecognized code format (accepts unicode and Node)")
self.temps = temps
def copy(self): def copy(self):
return copy_code_tree(self.root) return copy_code_tree(self.root)
def substitute(self, nodes={}): def substitute(self, nodes={}, temps=[], pos = None):
return TemplateTransform()(self.root, substitutions = nodes) return TemplateTransform()(self.root,
substitutions = nodes,
temps = self.temps + temps, pos = pos)
......
# #
# Tree visitor and transform framework # Tree visitor and transform framework
# #
import inspect
import Nodes import Nodes
import ExprNodes import ExprNodes
import inspect import Naming
from Cython.Utils import EncodedString
class BasicVisitor(object): class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object.""" """A generic visitor base class which can be used for visiting any kind of object."""
...@@ -129,7 +131,6 @@ class VisitorTransform(TreeVisitor): ...@@ -129,7 +131,6 @@ class VisitorTransform(TreeVisitor):
was not, an exception will be raised. (Typically you want to ensure that you was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = super(VisitorTransform, self).visitchildren(parent, attrs) result = super(VisitorTransform, self).visitchildren(parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
...@@ -150,6 +151,19 @@ class VisitorTransform(TreeVisitor): ...@@ -150,6 +151,19 @@ class VisitorTransform(TreeVisitor):
def __call__(self, root): def __call__(self, root):
return self.visit(root) return self.visit(root)
class CythonTransform(VisitorTransform):
"""
Certain common conventions and utilitues for Cython transforms.
"""
def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
def visit_Node(self, node):
self.visitchildren(node)
return node
# Utils # Utils
def ensure_statlist(node): def ensure_statlist(node):
if not isinstance(node, Nodes.StatListNode): if not isinstance(node, Nodes.StatListNode):
...@@ -166,6 +180,19 @@ def replace_node(ptr, value): ...@@ -166,6 +180,19 @@ def replace_node(ptr, value):
else: else:
getattr(parent, attrname)[listidx] = value getattr(parent, attrname)[listidx] = value
tmpnamectr = 0
def temp_name_handle(description):
global tmpnamectr
tmpnamectr += 1
return EncodedString(Naming.temp_prefix + u"%d_%s" % (tmpnamectr, description))
def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"):
return None
else:
idx = handle.find(u"_", 6)
return handle[idx+1:]
class PrintTree(TreeVisitor): class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output. """Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information Subclass and override repr_of to provide more information
......
...@@ -4,8 +4,49 @@ import unittest ...@@ -4,8 +4,49 @@ import unittest
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
from Cython.Compiler.Visitor import TreeVisitor
class NodeTypeWriter(TreeVisitor):
def __init__(self):
super(NodeTypeWriter, self).__init__()
self._indents = 0
self.result = []
def visit_Node(self, node):
if len(self.access_path) == 0:
name = u"(root)"
else:
tip = self.access_path[-1]
if tip[2] is not None:
name = u"%s[%d]" % tip[1:3]
else:
name = tip[1]
self.result.append(u" " * self._indents +
u"%s: %s" % (name, node.__class__.__name__))
self._indents += 1
self.visitchildren(node)
self._indents -= 1
def treetypes(root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class CythonTest(unittest.TestCase): class CythonTest(unittest.TestCase):
def assertLines(self, expected, result):
"Checks that the given strings or lists of strings are equal line by line"
if not isinstance(expected, list): expected = expected.split(u"\n")
if not isinstance(result, list): result = result.split(u"\n")
for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
self.assertEqual(len(expected), len(result),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
def assertCode(self, expected, result_tree): def assertCode(self, expected, result_tree):
writer = CodeWriter() writer = CodeWriter()
writer.write(result_tree) writer.write(result_tree)
...@@ -18,13 +59,35 @@ class CythonTest(unittest.TestCase): ...@@ -18,13 +59,35 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(result_lines), len(expected_lines), self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
def fragment(self, code, pxds={}): def fragment(self, code, pxds={}, pipeline=[]):
"Simply create a tree fragment using the name of the test-case in parse errors." "Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id() name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):] if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_") name = name.replace(".", "_")
return TreeFragment(code, name, pxds) return TreeFragment(code, name, pxds, pipeline=pipeline)
def treetypes(self, root):
return treetypes(root)
def should_fail(self, func, exc_type=Exception):
"""Calls "func" and fails if it doesn't raise the right exception
(any exception by default). Also returns the exception in question.
"""
try:
func()
self.fail("Expected an exception of type %r" % exc_type)
except exc_type, e:
self.assert_(isinstance(e, exc_type))
return e
def should_not_fail(self, func):
"""Calls func and succeeds if and only if no exception is raised
(i.e. converts exception raising into a failed testcase). Returns
the return value of func."""
try:
return func()
except:
self.fail()
class TransformTest(CythonTest): class TransformTest(CythonTest):
""" """
...@@ -37,8 +100,8 @@ class TransformTest(CythonTest): ...@@ -37,8 +100,8 @@ class TransformTest(CythonTest):
To create a test case: To create a test case:
- Call run_pipeline. The pipeline should at least contain the transform you - Call run_pipeline. The pipeline should at least contain the transform you
are testing; pyx should be either a string (passed to the parser to are testing; pyx should be either a string (passed to the parser to
create a post-parse tree) or a ModuleNode representing input to pipeline. create a post-parse tree) or a node representing input to pipeline.
The result will be a transformed result (usually a ModuleNode). The result will be a transformed result.
- Check that the tree is correct. If wanted, assertCode can be used, which - Check that the tree is correct. If wanted, assertCode can be used, which
takes a code string as expected, and a ModuleNode in result_tree takes a code string as expected, and a ModuleNode in result_tree
...@@ -53,7 +116,6 @@ class TransformTest(CythonTest): ...@@ -53,7 +116,6 @@ class TransformTest(CythonTest):
def run_pipeline(self, pipeline, pyx, pxds={}): def run_pipeline(self, pipeline, pyx, pxds={}):
tree = self.fragment(pyx, pxds).root tree = self.fragment(pyx, pxds).root
assert isinstance(tree, ModuleNode)
# Run pipeline # Run pipeline
for T in pipeline: for T in pipeline:
tree = T(tree) tree = T(tree)
......
...@@ -72,6 +72,9 @@ class TestCodeWriter(CythonTest): ...@@ -72,6 +72,9 @@ class TestCodeWriter(CythonTest):
def test_inplace_assignment(self): def test_inplace_assignment(self):
self.t(u"x += 43") self.t(u"x += 43")
def test_attribute(self):
self.t(u"a.x")
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
cdef extern from "Python.h":
ctypedef struct PyObject
ctypedef struct Py_buffer:
void *buf
Py_ssize_t len
int readonly
char *format
int ndim
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
Py_ssize_t itemsize
void *internal
int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
# int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
# int flags)
...@@ -238,6 +238,29 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -238,6 +238,29 @@ class CythonRunTestCase(CythonCompileTestCase):
except Exception: except Exception:
pass pass
def collect_unittests(path, suite, selectors):
def file_matches(filename):
return filename.startswith("Test") and filename.endswith(".py")
def package_matches(dirname):
return dirname == "Tests"
loader = unittest.TestLoader()
for dirpath, dirnames, filenames in os.walk(path):
parentname = os.path.split(dirpath)[-1]
if package_matches(parentname):
for f in filenames:
if file_matches(f):
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
for x in modulename.split('.')[1:]:
module = getattr(module, x)
suite.addTests(loader.loadTestsFromModule(module))
if __name__ == '__main__': if __name__ == '__main__':
from optparse import OptionParser from optparse import OptionParser
parser = OptionParser() parser = OptionParser()
...@@ -247,6 +270,12 @@ if __name__ == '__main__': ...@@ -247,6 +270,12 @@ if __name__ == '__main__':
parser.add_option("--no-cython", dest="with_cython", parser.add_option("--no-cython", dest="with_cython",
action="store_false", default=True, action="store_false", default=True,
help="do not run the Cython compiler, only the C compiler") help="do not run the Cython compiler, only the C compiler")
parser.add_option("--no-unit", dest="unittests",
action="store_false", default=True,
help="do not run the unit tests")
parser.add_option("--no-file", dest="filetests",
action="store_false", default=True,
help="do not run the file based tests")
parser.add_option("-C", "--coverage", dest="coverage", parser.add_option("-C", "--coverage", dest="coverage",
action="store_true", default=False, action="store_true", default=False,
help="collect source coverage data for the Compiler") help="collect source coverage data for the Compiler")
...@@ -296,9 +325,15 @@ if __name__ == '__main__': ...@@ -296,9 +325,15 @@ if __name__ == '__main__':
if not selectors: if not selectors:
selectors = [ lambda x:True ] selectors = [ lambda x:True ]
tests = TestBuilder(ROOTDIR, WORKDIR, selectors, test_suite = unittest.TestSuite()
options.annotate_source, options.cleanup_workdir)
test_suite = tests.build_suite() if options.unittests:
collect_unittests(os.getcwd(), test_suite, selectors)
if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source, options.cleanup_workdir)
test_suite.addTests(filetests.build_suite())
unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite) unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
......
from __future__ import with_statement
__doc__ = u"""
>>> no_as()
enter
hello
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> basic()
enter
value
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> with_exception(None)
enter
value
exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
outer except
>>> with_exception(True)
enter
value
exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
>>> multitarget()
enter
1 2 3 4 5
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> tupletarget()
enter
(1, 2, (3, (4, 5)))
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> typed()
enter
10
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
class MyException(Exception):
pass
class ContextManager:
def __init__(self, value, exit_ret = None):
self.value = value
self.exit_ret = exit_ret
def __exit__(self, a, b, tb):
print "exit", type(a), type(b), type(tb)
return self.exit_ret
def __enter__(self):
print "enter"
return self.value
def no_as():
with ContextManager("value"):
print "hello"
def basic():
with ContextManager("value") as x:
print x
def with_exception(exit_ret):
try:
with ContextManager("value", exit_ret=exit_ret) as value:
print value
raise MyException()
except:
print "outer except"
def multitarget():
with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
print a, b, c, d, e
def tupletarget():
with ContextManager((1, 2, (3, (4, 5)))) as t:
print t
def typed():
cdef unsigned char i
c = ContextManager(255)
with c as i:
i += 11
print i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment