Commit 23922c39 authored by Stefan Behnel's avatar Stefan Behnel

merge

parents 0510cac9 920a8f1a
from Cython.Compiler.Visitor import TreeVisitor from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
""" """
Serializes a Cython code tree to Cython code. This is primarily useful for Serializes a Cython code tree to Cython code. This is primarily useful for
...@@ -35,6 +36,7 @@ class CodeWriter(TreeVisitor): ...@@ -35,6 +36,7 @@ class CodeWriter(TreeVisitor):
result = LinesResult() result = LinesResult()
self.result = result self.result = result
self.numindents = 0 self.numindents = 0
self.tempnames = {}
def write(self, tree): def write(self, tree):
self.visit(tree) self.visit(tree)
...@@ -58,6 +60,12 @@ class CodeWriter(TreeVisitor): ...@@ -58,6 +60,12 @@ class CodeWriter(TreeVisitor):
self.startline(s) self.startline(s)
self.endline() self.endline()
def putname(self, name):
tmpdesc = get_temp_name_handle_desc(name)
if tmpdesc is not None:
name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
self.put(name)
def comma_seperated_list(self, items, output_rhs=False): def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0: if len(items) > 0:
for item in items[:-1]: for item in items[:-1]:
...@@ -116,7 +124,7 @@ class CodeWriter(TreeVisitor): ...@@ -116,7 +124,7 @@ class CodeWriter(TreeVisitor):
self.endline() self.endline()
def visit_NameNode(self, node): def visit_NameNode(self, node):
self.put(node.name) self.putname(node.name)
def visit_IntNode(self, node): def visit_IntNode(self, node):
self.put(node.value) self.put(node.value)
...@@ -185,10 +193,23 @@ class CodeWriter(TreeVisitor): ...@@ -185,10 +193,23 @@ class CodeWriter(TreeVisitor):
self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm... self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
def visit_SimpleCallNode(self, node): def visit_SimpleCallNode(self, node):
self.put(node.function.name + u"(") self.visit(node.function)
self.put(u"(")
self.comma_seperated_list(node.args) self.comma_seperated_list(node.args)
self.put(")") self.put(")")
def visit_GeneralCallNode(self, node):
self.visit(node.function)
self.put(u"(")
posarg = node.positional_args
if isinstance(posarg, AsTupleNode):
self.visit(posarg.arg)
else:
self.comma_seperated_list(posarg)
if node.keyword_args is not None or node.starstar_arg is not None:
raise Exception("Not implemented yet")
self.put(u")")
def visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
self.startline() self.startline()
self.visit(node.expr) self.visit(node.expr)
...@@ -197,9 +218,83 @@ class CodeWriter(TreeVisitor): ...@@ -197,9 +218,83 @@ class CodeWriter(TreeVisitor):
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.startline() self.startline()
self.visit(node.lhs) self.visit(node.lhs)
self.put(" %s= " % node.operator) self.put(u" %s= " % node.operator)
self.visit(node.rhs) self.visit(node.rhs)
self.endline() self.endline()
def visit_WithStatNode(self, node):
self.startline()
self.put(u"with ")
self.visit(node.manager)
if node.target is not None:
self.put(u" as ")
self.visit(node.target)
self.endline(u":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_AttributeNode(self, node):
self.visit(node.obj)
self.put(u".%s" % node.attribute)
def visit_BoolNode(self, node):
self.put(str(node.value))
def visit_TryFinallyStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
self.line(u"finally:")
self.indent()
self.visit(node.finally_clause)
self.dedent()
def visit_TryExceptStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
for x in node.except_clauses:
self.visit(x)
if node.else_clause is not None:
self.visit(node.else_clause)
def visit_ExceptClauseNode(self, node):
self.startline(u"except")
if node.pattern is not None:
self.put(u" ")
self.visit(node.pattern)
if node.target is not None:
self.put(u", ")
self.visit(node.target)
self.endline(":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_ReturnStatNode(self, node):
self.startline("return ")
self.visit(node.value)
self.endline()
def visit_DecoratorNode(self, node):
self.startline("@")
self.visit(node.decorator)
self.endline()
def visit_ReraiseStatNode(self, node):
self.line("raise")
def visit_NoneNode(self, node):
self.put(u"None")
def visit_ImportNode(self, node):
self.put(u"(import %s)" % node.module_name.value)
def visit_NotNode(self, node):
self.put(u"(not ")
self.visit(node.operand)
self.put(u")")
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
import PyrexTypes
from sets import Set as set
class PureCFuncNode(Node):
def __init__(self, pos, cname, type, c_code, visibility='private'):
self.pos = pos
self.cname = cname
self.type = type
self.c_code = c_code
self.visibility = visibility
def analyse_types(self, env):
self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname,
defining=True, visibility=self.visibility)
def generate_function_definitions(self, env, code, transforms):
assert self.type.optional_arg_count == 0
visibility = self.entry.visibility
if visibility != 'private':
storage_class = "%s " % Naming.extern_c_macro
else:
storage_class = "static "
arg_decls = [arg.declaration_code() for arg in self.type.args]
sig = self.type.return_type.declaration_code(
self.type.function_header_code(self.cname, ", ".join(arg_decls)))
code.putln("")
code.putln("%s%s {" % (storage_class, sig))
code.put(self.c_code)
code.putln("}")
def generate_execution_code(self, code):
pass
tschecker_functype = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type,
[PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
(0, 0, None), cname="ts")],
exception_value = "NULL"
)
tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
#
# Entry point
#
def __call__(self, node):
assert isinstance(node, ModuleNode)
try:
cymod = self.context.modules[u'__cython__']
except KeyError:
# No buffer fun for this module
return node
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result
#
# Basic operations for transforms
#
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
for name, entry in bufvars:
bufopts = entry.type.buffer_options
# Get or make a type string checker
tschecker = self.tschecker(bufopts.dtype)
# Declare auxiliary vars
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.bufstruct_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars, tschecker)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
# Notes: The cast to <char*> gets around Cython not supporting const types
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx)),
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
pos = node.pos
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them with the buffer pointer
expr = AttributeNode(pos,
obj=NameNode(pos, name=bufaux.buffer_info_var.name),
attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
else:
return node
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else:
return node
#
# Utils for creating type string checkers
#
def new_ts_func(self, name, code):
cname = "%s_%s" % (tsprefix, name)
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
self.ts_funcs.append(funcnode)
return funcnode
def mangle_dtype_name(self, dtype):
# Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int")
return dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(self, dtype):
# See if we can consume one (unnamed) dtype as next item
funcnode = self.ts_item_checkers.get(dtype)
if funcnode is None:
char = dtype.typestring
if char is not None and len(char) > 1:
# Can use direct comparison
funcnode = self.new_ts_func("natitem_%s" % self.mangle_dtype_name(dtype), """\
if (*ts != '%s') {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL;
} else return ts + 1;
""" % char)
else:
# Must deduce sign and length; rely on int vs. float to be correctly declared
ctype = dtype.declaration_code("")
code = """\
int ok;
switch (*ts) {"""
if dtype.is_int:
types = [
('b', 'char'), ('h', 'short'), ('i', 'int'),
('l', 'long'), ('q', 'long long')
]
code += "".join(["""\
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;
""" % (char, ctype, against, ctype, char.upper(), ctype, "unsigned " + against, ctype) for
char, against in types])
code += """\
default: ok = 0;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%s')", ts);
return NULL;
} else return ts + 1;
"""
funcnode = self.new_ts_func("tdefitem_%s" % self.mangle_dtype_name(dtype), code)
self.ts_item_checkers[dtype] = funcnode
return funcnode.entry.cname
ts_consume_whitespace_cname = None
ts_check_endian_cname = None
def ensure_ts_utils(self):
# Makes sure that the typechecker utils are in scope
# (and constructs them if not)
if self.ts_consume_whitespace_cname is None:
self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\
while (1) {
switch (*ts) {
case 10:
case 13:
case ' ':
++ts;
default:
return ts;
}
}
""").entry.cname
if self.ts_check_endian_cname is None:
self.ts_check_endian_cname = self.new_ts_func("check_endian", """\
int ok = 1;
switch (*ts) {
case '@':
case '=':
++ts; break;
case '<':
if (__BYTE_ORDER == __LITTLE_ENDIAN) ++ts;
else ok = 0;
break;
case '>':
case '!':
if (__BYTE_ORDER == __BIG_ENDIAN) ++ts;
else ok = 0;
break;
}
if (!ok) {
PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts);
return NULL;
}
return ts;
""").entry.cname
def create_ts_check_simple(self, dtype):
# Check whole string for single unnamed item
consume_whitespace = self.ts_consume_whitespace_cname
check_endian = self.ts_check_endian_cname
check_item = self.get_ts_check_item(dtype)
return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\
ts = %(consume_whitespace)s(ts);
ts = %(check_endian)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
ts = %(check_item)s(ts);
if (!ts) return NULL;
ts = %(consume_whitespace)s(ts);
if (*ts != 0) {
PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts);
return NULL;
}
return ts;
""" % locals())
def tschecker(self, dtype):
# Creates a type string checker function for the given type.
# Each checker is created as a function entry in the module scope
# and a PureCNode and put in the self.ts_checkers dict.
# Also the entry is returned.
#
# TODO: __eq__ and __hash__ for types
self.ensure_ts_utils()
funcnode = self.tscheckers.get(dtype)
if funcnode is None:
if dtype.is_struct_or_union:
assert False
elif dtype.is_int or dtype.is_float:
# This includes simple typedef-ed types
funcnode = self.create_ts_check_simple(dtype)
else:
assert False
self.tscheckers[dtype] = funcnode
return funcnode.entry
# TODO:
# - buf must be NULL before getting new buffer
...@@ -64,12 +64,16 @@ class CCodeWriter: ...@@ -64,12 +64,16 @@ class CCodeWriter:
dl = code.count("{") - code.count("}") dl = code.count("{") - code.count("}")
if dl < 0: if dl < 0:
self.level += dl self.level += dl
elif dl == 0 and code.startswith('}'):
self.level -= 1
if self.bol: if self.bol:
self.indent() self.indent()
self._write(code) self._write(code)
self.bol = 0 self.bol = 0
if dl > 0: if dl > 0:
self.level += dl self.level += dl
elif dl == 0 and code.startswith('}'):
self.level += 1
def increase_indent(self): def increase_indent(self):
self.level = self.level + 1 self.level = self.level + 1
...@@ -200,6 +204,8 @@ class CCodeWriter: ...@@ -200,6 +204,8 @@ class CCodeWriter:
def put_var_declaration(self, entry, static = 0, dll_linkage = None, def put_var_declaration(self, entry, static = 0, dll_linkage = None,
definition = True): definition = True):
#print "Code.put_var_declaration:", entry.name, "definition =", definition ### #print "Code.put_var_declaration:", entry.name, "definition =", definition ###
if entry.in_closure:
return
visibility = entry.visibility visibility = entry.visibility
if visibility == 'private' and not definition: if visibility == 'private' and not definition:
#print "...private and not definition, skipping" ### #print "...private and not definition, skipping" ###
......
...@@ -32,6 +32,7 @@ class CompileError(PyrexError): ...@@ -32,6 +32,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
self.position = position self.position = position
self.message_only = message
# Deprecated and withdrawn in 2.6: # Deprecated and withdrawn in 2.6:
# self.message = message # self.message = message
if position: if position:
...@@ -91,6 +92,7 @@ def error(position, message): ...@@ -91,6 +92,7 @@ def error(position, message):
#print "Errors.error:", repr(position), repr(message) ### #print "Errors.error:", repr(position), repr(message) ###
global num_errors global num_errors
err = CompileError(position, message) err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
line = "%s\n" % err line = "%s\n" % err
if listing_file: if listing_file:
listing_file.write(line) listing_file.write(line)
......
...@@ -972,7 +972,7 @@ class NameNode(AtomicExprNode): ...@@ -972,7 +972,7 @@ class NameNode(AtomicExprNode):
if entry.is_builtin: if entry.is_builtin:
namespace = Naming.builtins_cname namespace = Naming.builtins_cname
else: # entry.is_pyglobal else: # entry.is_pyglobal
namespace = entry.namespace_cname namespace = entry.scope.namespace_cname
code.putln( code.putln(
'%s = __Pyx_GetName(%s, %s); %s' % ( '%s = __Pyx_GetName(%s, %s); %s' % (
self.result_code, self.result_code,
...@@ -997,7 +997,7 @@ class NameNode(AtomicExprNode): ...@@ -997,7 +997,7 @@ class NameNode(AtomicExprNode):
# is_pyglobal seems to be True for module level-globals only. # is_pyglobal seems to be True for module level-globals only.
# We use this to access class->tp_dict if necessary. # We use this to access class->tp_dict if necessary.
if entry.is_pyglobal: if entry.is_pyglobal:
namespace = self.entry.namespace_cname namespace = self.entry.scope.namespace_cname
if entry.is_member: if entry.is_member:
# if the entry is a member we have to cheat: SetAttr does not work # if the entry is a member we have to cheat: SetAttr does not work
# on types, so we create a descriptor which is then added to tp_dict # on types, so we create a descriptor which is then added to tp_dict
...@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode): ...@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode):
else: else:
code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name))) code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name)))
class BackquoteNode(ExprNode): class BackquoteNode(ExprNode):
# `expr` # `expr`
# #
...@@ -1146,16 +1145,22 @@ class IteratorNode(ExprNode): ...@@ -1146,16 +1145,22 @@ class IteratorNode(ExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln( code.putln(
"if (PyList_CheckExact(%s)) { %s = 0; %s = %s; Py_INCREF(%s); }" % ( "if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % (
self.sequence.py_result(), self.sequence.py_result(),
self.sequence.py_result()))
code.putln(
"%s = 0; %s = %s; Py_INCREF(%s);" % (
self.counter.result_code, self.counter.result_code,
self.result_code, self.result_code,
self.sequence.py_result(), self.sequence.py_result(),
self.result_code)) self.result_code))
code.putln("else { %s = PyObject_GetIter(%s); %s }" % ( code.putln("} else {")
code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % (
self.counter.result_code,
self.result_code, self.result_code,
self.sequence.py_result(), self.sequence.py_result(),
code.error_goto_if_null(self.result_code, self.pos))) code.error_goto_if_null(self.result_code, self.pos)))
code.putln("}")
class NextNode(AtomicExprNode): class NextNode(AtomicExprNode):
...@@ -1174,15 +1179,19 @@ class NextNode(AtomicExprNode): ...@@ -1174,15 +1179,19 @@ class NextNode(AtomicExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln( code.putln(
"if (PyList_CheckExact(%s)) { if (%s >= PyList_GET_SIZE(%s)) break; %s = PyList_GET_ITEM(%s, %s++); Py_INCREF(%s); }" % ( "if (likely(%s != -1)) {" % self.iterator.counter.result_code)
self.iterator.py_result(), code.putln(
"if (%s >= PySequence_Fast_GET_SIZE(%s)) break;" % (
self.iterator.counter.result_code, self.iterator.counter.result_code,
self.iterator.py_result(), self.iterator.py_result()))
code.putln(
"%s = PySequence_Fast_GET_ITEM(%s, %s); Py_INCREF(%s); %s++;" % (
self.result_code, self.result_code,
self.iterator.py_result(), self.iterator.py_result(),
self.iterator.counter.result_code, self.iterator.counter.result_code,
self.result_code)) self.result_code,
code.putln("else {") self.iterator.counter.result_code))
code.putln("} else {")
code.putln( code.putln(
"%s = PyIter_Next(%s);" % ( "%s = PyIter_Next(%s);" % (
self.result_code, self.result_code,
...@@ -1212,6 +1221,9 @@ class ExcValueNode(AtomicExprNode): ...@@ -1212,6 +1221,9 @@ class ExcValueNode(AtomicExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
def analyse_types(self, env):
pass
class TempNode(AtomicExprNode): class TempNode(AtomicExprNode):
# Node created during analyse_types phase # Node created during analyse_types phase
...@@ -1249,8 +1261,19 @@ class IndexNode(ExprNode): ...@@ -1249,8 +1261,19 @@ class IndexNode(ExprNode):
# #
# base ExprNode # base ExprNode
# index ExprNode # index ExprNode
# indices [ExprNode]
# is_buffer_access boolean Whether this is a buffer access.
#
# indices is used on buffer access, index on non-buffer access.
# The former contains a clean list of index parameters, the
# latter whatever Python object is needed for index access.
subexprs = ['base', 'index', 'indices']
indices = None
subexprs = ['base', 'index'] def __init__(self, pos, index, *args, **kw):
ExprNode.__init__(self, pos, index=index, *args, **kw)
self._index = index
def compile_time_value(self, denv): def compile_time_value(self, denv):
base = self.base.compile_time_value(denv) base = self.base.compile_time_value(denv)
...@@ -1273,7 +1296,42 @@ class IndexNode(ExprNode): ...@@ -1273,7 +1296,42 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0): def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
skip_child_analysis = False
buffer_access = False
if self.base.type.buffer_options is not None:
if isinstance(self.index, TupleNode):
indices = self.index.args
else:
indices = [self.index]
if len(indices) == self.base.type.buffer_options.ndim:
buffer_access = True
skip_child_analysis = True
for x in indices:
x.analyse_types(env)
if not x.type.is_int:
buffer_access = False
if buffer_access:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.buffer_options.dtype
self.is_temp = 1
self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
if not buffer_access:
if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
self.index.analyse_types(env) self.index.analyse_types(env)
if self.base.type.is_pyobject: if self.base.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
...@@ -1314,6 +1372,9 @@ class IndexNode(ExprNode): ...@@ -1314,6 +1372,9 @@ class IndexNode(ExprNode):
return 1 return 1
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access:
return "<not needed>"
else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result_code, self.index.result_code) self.base.result_code, self.index.result_code)
...@@ -1328,11 +1389,17 @@ class IndexNode(ExprNode): ...@@ -1328,11 +1389,17 @@ class IndexNode(ExprNode):
def generate_subexpr_evaluation_code(self, code): def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code) self.base.generate_evaluation_code(code)
if self.index is not None:
self.index.generate_evaluation_code(code) self.index.generate_evaluation_code(code)
else:
for i in self.indices: i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
if self.index is not None:
self.index.generate_disposal_code(code) self.index.generate_disposal_code(code)
else:
for i in self.indices: i.generate_disposal_code(code)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.type.is_pyobject: if self.type.is_pyobject:
...@@ -1602,6 +1669,17 @@ class SimpleCallNode(CallNode): ...@@ -1602,6 +1669,17 @@ class SimpleCallNode(CallNode):
func_type = func_type.base_type func_type = func_type.base_type
return func_type return func_type
def exception_checks(self):
func_type = self.function.type
exc_val = func_type.exception_value
exc_check = func_type.exception_check
if exc_val is None and self.function.entry.visibility != 'extern':
return_type = func_type.return_type
if not return_type.is_struct_or_union and not return_type.is_void:
exc_val = return_type.cast_code(Naming.default_error)
exc_check = 1
return exc_val, exc_check
def analyse_c_function_call(self, env): def analyse_c_function_call(self, env):
func_type = self.function_type() func_type = self.function_type()
# Check function type # Check function type
...@@ -1643,12 +1721,13 @@ class SimpleCallNode(CallNode): ...@@ -1643,12 +1721,13 @@ class SimpleCallNode(CallNode):
"Python object cannot be passed as a varargs parameter") "Python object cannot be passed as a varargs parameter")
# Calc result type and code fragment # Calc result type and code fragment
self.type = func_type.return_type self.type = func_type.return_type
if self.type.is_pyobject \
or func_type.exception_value is not None \
or func_type.exception_check:
self.is_temp = 1
if self.type.is_pyobject: if self.type.is_pyobject:
self.is_temp = 1
self.result_ctype = py_object_type self.result_ctype = py_object_type
else:
exc_val, exc_check = self.exception_checks()
if self.type.is_pyobject or exc_val is not None or exc_check:
self.is_temp = 1
# C++ exception handler # C++ exception handler
if func_type.exception_check == '+': if func_type.exception_check == '+':
if func_type.exception_value is None: if func_type.exception_value is None:
...@@ -1713,8 +1792,7 @@ class SimpleCallNode(CallNode): ...@@ -1713,8 +1792,7 @@ class SimpleCallNode(CallNode):
if self.type.is_pyobject: if self.type.is_pyobject:
exc_checks.append("!%s" % self.result_code) exc_checks.append("!%s" % self.result_code)
else: else:
exc_val = func_type.exception_value exc_val, exc_check = self.exception_checks()
exc_check = func_type.exception_check
if exc_val is not None: if exc_val is not None:
exc_checks.append("%s == %s" % (self.result_code, exc_val)) exc_checks.append("%s == %s" % (self.result_code, exc_val))
if exc_check: if exc_check:
...@@ -2166,10 +2244,10 @@ class SequenceNode(ExprNode): ...@@ -2166,10 +2244,10 @@ class SequenceNode(ExprNode):
for arg in self.args: for arg in self.args:
arg.analyse_target_declaration(env) arg.analyse_target_declaration(env)
def analyse_types(self, env): def analyse_types(self, env, skip_children=False):
for i in range(len(self.args)): for i in range(len(self.args)):
arg = self.args[i] arg = self.args[i]
arg.analyse_types(env) if not skip_children: arg.analyse_types(env)
self.args[i] = arg.coerce_to_pyobject(env) self.args[i] = arg.coerce_to_pyobject(env)
self.type = py_object_type self.type = py_object_type
self.gil_check(env) self.gil_check(env)
...@@ -2274,12 +2352,12 @@ class TupleNode(SequenceNode): ...@@ -2274,12 +2352,12 @@ class TupleNode(SequenceNode):
gil_message = "Constructing Python tuple" gil_message = "Constructing Python tuple"
def analyse_types(self, env): def analyse_types(self, env, skip_children=False):
if len(self.args) == 0: if len(self.args) == 0:
self.is_temp = 0 self.is_temp = 0
self.is_literal = 1 self.is_literal = 1
else: else:
SequenceNode.analyse_types(self, env) SequenceNode.analyse_types(self, env, skip_children)
self.type = tuple_type self.type = tuple_type
def calculate_result_code(self): def calculate_result_code(self):
...@@ -2782,13 +2860,18 @@ def unop_node(pos, operator, operand): ...@@ -2782,13 +2860,18 @@ def unop_node(pos, operator, operand):
class TypecastNode(ExprNode): class TypecastNode(ExprNode):
# C type cast # C type cast
# #
# operand ExprNode
# base_type CBaseTypeNode # base_type CBaseTypeNode
# declarator CDeclaratorNode # declarator CDeclaratorNode
# operand ExprNode #
# If used from a transform, one can if wanted specify the attribute
# "type" directly and leave base_type and declarator to None
subexprs = ['operand'] subexprs = ['operand']
base_type = declarator = type = None
def analyse_types(self, env): def analyse_types(self, env):
if self.type is None:
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env) _, self.type = self.declarator.analyse(base_type, env)
if self.type.is_cfunction: if self.type.is_cfunction:
...@@ -3061,7 +3144,7 @@ class NumBinopNode(BinopNode): ...@@ -3061,7 +3144,7 @@ class NumBinopNode(BinopNode):
"+": "PyNumber_Add", "+": "PyNumber_Add",
"-": "PyNumber_Subtract", "-": "PyNumber_Subtract",
"*": "PyNumber_Multiply", "*": "PyNumber_Multiply",
"/": "PyNumber_Divide", "/": "__Pyx_PyNumber_Divide",
"//": "PyNumber_FloorDivide", "//": "PyNumber_FloorDivide",
"%": "PyNumber_Remainder", "%": "PyNumber_Remainder",
"**": "PyNumber_Power" "**": "PyNumber_Power"
...@@ -3811,6 +3894,10 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -3811,6 +3894,10 @@ class CoerceToPyTypeNode(CoercionNode):
gil_message = "Converting to Python object" gil_message = "Converting to Python object"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.arg.type.to_py_function function = self.arg.type.to_py_function
code.putln('%s = %s(%s); %s' % ( code.putln('%s = %s(%s); %s' % (
...@@ -3835,6 +3922,10 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -3835,6 +3922,10 @@ class CoerceFromPyTypeNode(CoercionNode):
error(arg.pos, error(arg.pos,
"Obtaining char * from temporary Python value") "Obtaining char * from temporary Python value")
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.type.from_py_function function = self.type.from_py_function
operand = self.arg.py_result() operand = self.arg.py_result()
...@@ -3893,6 +3984,10 @@ class CoerceToTempNode(CoercionNode): ...@@ -3893,6 +3984,10 @@ class CoerceToTempNode(CoercionNode):
gil_message = "Creating temporary Python reference" gil_message = "Creating temporary Python reference"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
#self.arg.generate_evaluation_code(code) # Already done #self.arg.generate_evaluation_code(code) # Already done
# by generic generate_subexpr_evaluation_code! # by generic generate_subexpr_evaluation_code!
...@@ -3945,6 +4040,65 @@ class CloneNode(CoercionNode): ...@@ -3945,6 +4040,65 @@ class CloneNode(CoercionNode):
def release_temp(self, env): def release_temp(self, env):
pass pass
class PersistentNode(ExprNode):
# A PersistentNode is like a CloneNode except it handles the temporary
# allocation itself by keeping track of the number of times it has been
# used.
subexprs = ["arg"]
temp_counter = 0
generate_counter = 0
analyse_counter = 0
result_code = None
def __init__(self, arg, uses):
self.pos = arg.pos
self.arg = arg
self.uses = uses
def analyse_types(self, env):
if self.analyse_counter == 0:
self.arg.analyse_types(env)
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
self.is_temp = 1
self.analyse_counter += 1
def calculate_result_code(self):
return self.result_code
def generate_evaluation_code(self, code):
if self.generate_counter == 0:
self.arg.generate_evaluation_code(code)
code.putln("%s = %s;" % (
self.result_code, self.arg.result_as(self.ctype())))
if self.type.is_pyobject:
code.put_incref(self.result_code, self.ctype())
self.arg.generate_disposal_code(code)
self.generate_counter += 1
def generate_disposal_code(self, code):
if self.generate_counter == self.uses:
if self.type.is_pyobject:
code.put_decref_clear(self.result_code, self.ctype())
def allocate_temps(self, env, result=None):
if self.temp_counter == 0:
self.arg.allocate_temps(env)
self.allocate_temp(env, result)
self.arg.release_temp(env)
self.temp_counter += 1
def allocate_temp(self, env, result=None):
if result is None:
self.result_code = env.allocate_temp(self.type)
else:
self.result_code = result
def release_temp(self, env):
if self.temp_counter == self.uses:
env.release_temp(self.result_code)
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
...@@ -7,5 +7,7 @@ def _get_feature(name): ...@@ -7,5 +7,7 @@ def _get_feature(name):
return object() return object()
unicode_literals = _get_feature("unicode_literals") unicode_literals = _get_feature("unicode_literals")
with_statement = _get_feature("with_statement")
division = _get_feature("division")
del _get_feature del _get_feature
...@@ -65,6 +65,7 @@ def make_lexicon(): ...@@ -65,6 +65,7 @@ def make_lexicon():
escapeseq = Str("\\") + (two_oct | three_oct | two_hex | escapeseq = Str("\\") + (two_oct | three_oct | two_hex |
Str('u') + four_hex | Str('x') + two_hex | AnyChar) Str('u') + four_hex | Str('x') + two_hex | AnyChar)
deco = Str("@")
bra = Any("([{") bra = Any("([{")
ket = Any(")]}") ket = Any(")]}")
punct = Any(":,;+-*/|&<>=.%`~^?") punct = Any(":,;+-*/|&<>=.%`~^?")
...@@ -82,6 +83,7 @@ def make_lexicon(): ...@@ -82,6 +83,7 @@ def make_lexicon():
(longconst, 'LONG'), (longconst, 'LONG'),
(fltconst, 'FLOAT'), (fltconst, 'FLOAT'),
(imagconst, 'IMAG'), (imagconst, 'IMAG'),
(deco, 'DECORATOR'),
(punct | diphthong, TEXT), (punct | diphthong, TEXT),
(bra, Method('open_bracket_action')), (bra, Method('open_bracket_action')),
......
...@@ -25,22 +25,6 @@ from Cython import Utils ...@@ -25,22 +25,6 @@ from Cython import Utils
module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$") module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
# Note: PHASES and TransformSet should be removed soon; but that's for
# another day and another commit.
PHASES = [
'before_analyse_function', # run in FuncDefNode.generate_function_definitions
'after_analyse_function' # run in FuncDefNode.generate_function_definitions
]
class TransformSet(dict):
def __init__(self):
for name in PHASES:
self[name] = []
def run(self, name, node, **options):
assert name in self, "Transform phase %s not defined" % name
for transform in self[name]:
transform(node, phase=name, **options)
verbose = 0 verbose = 0
class Context: class Context:
...@@ -58,6 +42,8 @@ class Context: ...@@ -58,6 +42,8 @@ class Context:
#self.modules = {"__builtin__" : BuiltinScope()} #self.modules = {"__builtin__" : BuiltinScope()}
import Builtin import Builtin
self.modules = {"__builtin__" : Builtin.builtin_scope} self.modules = {"__builtin__" : Builtin.builtin_scope}
self.pxds = {}
self.pyxs = {}
self.include_directories = include_directories self.include_directories = include_directories
self.future_directives = set() self.future_directives = set()
...@@ -305,55 +291,25 @@ class Context: ...@@ -305,55 +291,25 @@ class Context:
names.reverse() names.reverse()
return ".".join(names) return ".".join(names)
def compile(self, source, options = None, full_module_name = None): def setup_errors(self, options):
# Compile a Pyrex implementation file in this context
# and return a CompilationResult.
if not options:
options = default_options
result = CompilationResult()
cwd = os.getcwd()
source = os.path.join(cwd, source)
result.main_source_file = source
if options.use_listing_file: if options.use_listing_file:
result.listing_file = Utils.replace_suffix(source, ".lis") result.listing_file = Utils.replace_suffix(source, ".lis")
Errors.open_listing_file(result.listing_file, Errors.open_listing_file(result.listing_file,
echo_to_stderr = options.errors_to_stderr) echo_to_stderr = options.errors_to_stderr)
else: else:
Errors.open_listing_file(None) Errors.open_listing_file(None)
if options.output_file:
result.c_file = os.path.join(cwd, options.output_file) def teardown_errors(self, errors_occurred, options, result):
else: source_desc = result.compilation_source.source_desc
if options.cplus: if not isinstance(source_desc, FileSourceDescriptor):
c_suffix = ".cpp" raise RuntimeError("Only file sources for code supported")
else:
c_suffix = ".c"
result.c_file = Utils.replace_suffix(source, c_suffix)
c_stat = None
if result.c_file:
try:
c_stat = os.stat(result.c_file)
except EnvironmentError:
pass
full_module_name = full_module_name or self.extract_module_name(source, options)
source = FileSourceDescriptor(source)
initial_pos = (source, 1, 0)
scope = self.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
errors_occurred = False
try:
tree = self.parse(source, scope, pxd = 0,
full_module_name = full_module_name)
tree.process_implementation(scope, options, result)
except CompileError:
errors_occurred = True
Errors.close_listing_file() Errors.close_listing_file()
result.num_errors = Errors.num_errors result.num_errors = Errors.num_errors
if result.num_errors > 0: if result.num_errors > 0:
errors_occurred = True errors_occurred = True
if errors_occurred and result.c_file: if errors_occurred and result.c_file:
try: try:
Utils.castrate_file(result.c_file, os.stat(source.filename)) Utils.castrate_file(result.c_file, os.stat(source_desc.filename))
except EnvironmentError: except EnvironmentError:
pass pass
result.c_file = None result.c_file = None
...@@ -366,6 +322,103 @@ class Context: ...@@ -366,6 +322,103 @@ class Context:
extra_objects = options.objects, extra_objects = options.objects,
verbose_flag = options.show_version, verbose_flag = options.show_version,
cplus = options.cplus) cplus = options.cplus)
def run_pipeline(self, pipeline, source):
errors_occurred = False
data = source
try:
for phase in pipeline:
data = phase(data)
except CompileError:
errors_occurred = True
return (errors_occurred, data)
def create_parse(context):
def parse(compsrc):
source_desc = compsrc.source_desc
full_module_name = compsrc.full_module_name
initial_pos = (source_desc, 1, 0)
scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
tree.compilation_source = compsrc
tree.scope = scope
return tree
return parse
def create_generate_code(context, options, result):
def generate_code(module_node):
scope = module_node.scope
module_node.process_implementation(options, result)
result.compilation_source = module_node.compilation_source
return result
return generate_code
def create_default_pipeline(context, options, result):
from Visitor import PrintTree
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform
from Buffer import BufferTransform
from ModuleNode import check_c_classes
return [
create_parse(context),
NormalizeTree(context),
PostParse(context),
FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context),
check_c_classes,
AnalyseExpressionsTransform(context),
BufferTransform(context),
SwitchTransform(),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
def create_default_resultobj(compilation_source, options):
result = CompilationResult()
result.main_source_file = compilation_source.source_desc.filename
result.compilation_source = compilation_source
source_desc = compilation_source.source_desc
if options.output_file:
result.c_file = os.path.join(compilation_source.cwd, options.output_file)
else:
if options.cplus:
c_suffix = ".cpp"
else:
c_suffix = ".c"
result.c_file = Utils.replace_suffix(source_desc.filename, c_suffix)
# The below doesn't make any sense? Why is it there?
c_stat = None
if result.c_file:
try:
c_stat = os.stat(result.c_file)
except EnvironmentError:
pass
return result
def run_pipeline(source, options, full_module_name = None):
# Set up context
context = Context(options.include_path)
# Set up source object
cwd = os.getcwd()
source_desc = FileSourceDescriptor(os.path.join(cwd, source))
full_module_name = full_module_name or context.extract_module_name(source, options)
source = CompilationSource(source_desc, full_module_name, cwd)
# Set up result object
result = create_default_resultobj(source, options)
# Get pipeline
pipeline = create_default_pipeline(context, options, result)
context.setup_errors(options)
errors_occurred, enddata = context.run_pipeline(pipeline, source)
context.teardown_errors(errors_occurred, options, result)
return result return result
#------------------------------------------------------------------------ #------------------------------------------------------------------------
...@@ -374,6 +427,16 @@ class Context: ...@@ -374,6 +427,16 @@ class Context:
# #
#------------------------------------------------------------------------ #------------------------------------------------------------------------
class CompilationSource(object):
"""
Contains the data necesarry to start up a compilation pipeline for
a single compilation unit.
"""
def __init__(self, source_desc, full_module_name, cwd):
self.source_desc = source_desc
self.full_module_name = full_module_name
self.cwd = cwd
class CompilationOptions: class CompilationOptions:
""" """
Options to the Cython compiler: Options to the Cython compiler:
...@@ -389,7 +452,6 @@ class CompilationOptions: ...@@ -389,7 +452,6 @@ class CompilationOptions:
defaults to true when recursive is true. defaults to true when recursive is true.
verbose boolean Always print source names being compiled verbose boolean Always print source names being compiled
quiet boolean Don't print source names in recursive mode quiet boolean Don't print source names in recursive mode
transforms Transform.TransformSet Transforms to use on the parse tree
Following options are experimental and only used on MacOSX: Following options are experimental and only used on MacOSX:
...@@ -427,6 +489,7 @@ class CompilationResult: ...@@ -427,6 +489,7 @@ class CompilationResult:
object_file string or None Result of compiling the C file object_file string or None Result of compiling the C file
extension_file string or None Result of linking the object file extension_file string or None Result of linking the object file
num_errors integer Number of compilation errors num_errors integer Number of compilation errors
compilation_source CompilationSource
""" """
def __init__(self): def __init__(self):
...@@ -464,8 +527,10 @@ def compile_single(source, options, full_module_name = None): ...@@ -464,8 +527,10 @@ def compile_single(source, options, full_module_name = None):
Always compiles a single file; does not perform timestamp checking or Always compiles a single file; does not perform timestamp checking or
recursion. recursion.
""" """
context = Context(options.include_path) return run_pipeline(source, options, full_module_name)
return context.compile(source, options, full_module_name) # context = Context(options.include_path)
# return context.compile(source, options, full_module_name)
def compile_multiple(sources, options): def compile_multiple(sources, options):
""" """
...@@ -478,21 +543,21 @@ def compile_multiple(sources, options): ...@@ -478,21 +543,21 @@ def compile_multiple(sources, options):
sources = [os.path.abspath(source) for source in sources] sources = [os.path.abspath(source) for source in sources]
processed = set() processed = set()
results = CompilationResultSet() results = CompilationResultSet()
context = Context(options.include_path)
recursive = options.recursive recursive = options.recursive
timestamps = options.timestamps timestamps = options.timestamps
if timestamps is None: if timestamps is None:
timestamps = recursive timestamps = recursive
verbose = options.verbose or ((recursive or timestamps) and not options.quiet) verbose = options.verbose or ((recursive or timestamps) and not options.quiet)
for source in sources: for source in sources:
context = Context(options.include_path) # to be removed later
if source not in processed: if source not in processed:
if not timestamps or context.c_file_out_of_date(source): if not timestamps or context.c_file_out_of_date(source):
if verbose: if verbose:
sys.stderr.write("Compiling %s\n" % source) sys.stderr.write("Compiling %s\n" % source)
result = context.compile(source, options) result = context.compile(source, options)
# Compiling multiple sources in one context doesn't quite # Compiling multiple sources in one context doesn't quite
# work properly yet. # work properly yet.
context = Context(options.include_path) # to be removed later
results.add(source, result) results.add(source, result)
processed.add(source) processed.add(source)
if recursive: if recursive:
...@@ -522,7 +587,10 @@ def compile(source, options = None, c_compile = 0, c_link = 0, ...@@ -522,7 +587,10 @@ def compile(source, options = None, c_compile = 0, c_link = 0,
and not options.recursive: and not options.recursive:
return compile_single(source, options, full_module_name) return compile_single(source, options, full_module_name)
else: else:
return compile_multiple(source, options) # Hack it for wednesday dev1
assert len(source) == 1
return compile_single(source[0], options)
# return compile_multiple(source, options)
#------------------------------------------------------------------------ #------------------------------------------------------------------------
# #
...@@ -571,9 +639,9 @@ default_options = dict( ...@@ -571,9 +639,9 @@ default_options = dict(
output_file = None, output_file = None,
annotate = False, annotate = False,
generate_pxi = 0, generate_pxi = 0,
transforms = TransformSet(),
working_path = "", working_path = "",
recursive = 0, recursive = 0,
transforms = None, # deprecated
timestamps = None, timestamps = None,
verbose = 0, verbose = 0,
quiet = 0) quiet = 0)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import os, time import os, time
from cStringIO import StringIO from cStringIO import StringIO
from PyrexTypes import CPtrType from PyrexTypes import CPtrType
import Future
try: try:
set set
...@@ -25,6 +26,10 @@ from PyrexTypes import py_object_type ...@@ -25,6 +26,10 @@ from PyrexTypes import py_object_type
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
def check_c_classes(module_node):
module_node.scope.check_c_classes()
return module_node
class ModuleNode(Nodes.Node, Nodes.BlockNode): class ModuleNode(Nodes.Node, Nodes.BlockNode):
# doc string or None # doc string or None
# body StatListNode # body StatListNode
...@@ -32,6 +37,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -32,6 +37,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# referenced_modules [ModuleScope] # referenced_modules [ModuleScope]
# module_temp_cname string # module_temp_cname string
# full_module_name string # full_module_name string
#
# scope The module scope.
# compilation_source A CompilationSource (see Main)
child_attrs = ["body"] child_attrs = ["body"]
...@@ -44,10 +52,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -44,10 +52,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
env.doc = self.doc env.doc = self.doc
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def process_implementation(self, env, options, result): def process_implementation(self, options, result):
self.analyse_declarations(env) env = self.scope
env.check_c_classes()
self.body.analyse_expressions(env)
env.return_type = PyrexTypes.c_void_type env.return_type = PyrexTypes.c_void_type
self.referenced_modules = [] self.referenced_modules = []
self.find_referenced_modules(env, self.referenced_modules, {}) self.find_referenced_modules(env, self.referenced_modules, {})
...@@ -254,6 +260,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -254,6 +260,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_module_cleanup_func(env, code) self.generate_module_cleanup_func(env, code)
self.generate_filename_table(code) self.generate_filename_table(code)
self.generate_utility_functions(env, code) self.generate_utility_functions(env, code)
self.generate_buffer_compatability_functions(env, code)
self.generate_declarations_for_modules(env, modules, code.h) self.generate_declarations_for_modules(env, modules, code.h)
...@@ -433,6 +440,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -433,6 +440,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)") code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
code.putln("")
code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
code.putln("#endif") code.putln("#endif")
code.put(builtin_module_name_utility_code[0]) code.put(builtin_module_name_utility_code[0])
...@@ -458,8 +468,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -458,8 +468,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #define PyInt_AsSsize_t PyLong_AsSsize_t") code.putln(" #define PyInt_AsSsize_t PyLong_AsSsize_t")
code.putln(" #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask") code.putln(" #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask")
code.putln(" #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask") code.putln(" #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask")
code.putln(" #define PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y)") code.putln(" #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y)")
code.putln("#else") code.putln("#else")
if Future.division in env.context.future_directives:
code.putln(" #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y)")
else:
code.putln(" #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y)")
code.putln(" #define PyBytes_Type PyString_Type") code.putln(" #define PyBytes_Type PyString_Type")
code.putln("#endif") code.putln("#endif")
...@@ -473,6 +487,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -473,6 +487,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("#ifndef __cdecl") code.putln("#ifndef __cdecl")
code.putln(" #define __cdecl") code.putln(" #define __cdecl")
code.putln("#endif") code.putln("#endif")
code.putln('');
code.putln('#define %s 0xB0000000B000B0BBLL' % Naming.default_error);
code.putln('');
self.generate_extern_c_macro_definition(code) self.generate_extern_c_macro_definition(code)
code.putln("#include <math.h>") code.putln("#include <math.h>")
code.putln("#define %s" % Naming.api_guard_prefix + self.api_name(env)) code.putln("#define %s" % Naming.api_guard_prefix + self.api_name(env))
...@@ -1940,6 +1957,90 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1940,6 +1957,90 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.h.put(utility_code[0]) code.h.put(utility_code[0])
code.put(utility_code[1]) code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("")
def generate_buffer_compatability_functions(self, env, code):
# will be refactored
try:
env.entries[u'numpy']
code.put("""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
""")
except KeyError:
pass
# For now, hard-code numpy imported as "numpy"
types = []
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code.putln("#if PY_VERSION_HEX < 0x02060000")
code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
clause = "else if"
code.putln("else {")
code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
code.putln("return -1;")
if len(types) > 0: code.putln("}")
code.putln("}")
code.putln("")
code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
if len(types) > 0:
clause = "if"
for t, get, release in types:
code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
clause = "else if"
code.putln("}")
code.putln("")
code.putln("#endif")
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
pyrex_prefix = "__pyx_" pyrex_prefix = "__pyx_"
temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_" builtin_prefix = pyrex_prefix + "builtin_"
arg_prefix = pyrex_prefix + "arg_" arg_prefix = pyrex_prefix + "arg_"
funcdoc_prefix = pyrex_prefix + "doc_" funcdoc_prefix = pyrex_prefix + "doc_"
...@@ -70,6 +72,9 @@ optional_args_cname = pyrex_prefix + "optional_args" ...@@ -70,6 +72,9 @@ optional_args_cname = pyrex_prefix + "optional_args"
no_opt_args = pyrex_prefix + "no_opt_args" no_opt_args = pyrex_prefix + "no_opt_args"
import_star = pyrex_prefix + "import_star" import_star = pyrex_prefix + "import_star"
import_star_set = pyrex_prefix + "import_star_set" import_star_set = pyrex_prefix + "import_star_set"
cur_scope_cname = pyrex_prefix + "cur_scope"
enc_scope_cname = pyrex_prefix + "enc_scope"
default_error = pyrex_prefix + "ERROR"
line_c_macro = "__LINE__" line_c_macro = "__LINE__"
......
...@@ -10,7 +10,7 @@ import Naming ...@@ -10,7 +10,7 @@ import Naming
import PyrexTypes import PyrexTypes
import TypeSlots import TypeSlots
from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType
from Symtab import ModuleScope, LocalScope, \ from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \
StructOrUnionScope, PyClassScope, CClassScope StructOrUnionScope, PyClassScope, CClassScope
from Cython.Utils import open_new_file, replace_suffix, EncodedString from Cython.Utils import open_new_file, replace_suffix, EncodedString
import Options import Options
...@@ -172,6 +172,26 @@ class Node(object): ...@@ -172,6 +172,26 @@ class Node(object):
self._end_pos = pos self._end_pos = pos
return pos return pos
def dump(self, level=0, filter_out=("pos",)):
def dump_child(x, level):
if isinstance(x, Node):
return x.dump(level)
elif isinstance(x, list):
return "[%s]" % ", ".join([dump_child(item, level) for item in x])
else:
return repr(x)
attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out]
if len(attrs) == 0:
return "<%s (%d)>" % (self.__class__.__name__, id(self))
else:
indent = " " * level
res = "<%s (%d)\n" % (self.__class__.__name__, id(self))
for key, value in attrs:
res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1))
res += "%s>" % indent
return res
class BlockNode: class BlockNode:
# Mixin class for nodes representing a declaration block. # Mixin class for nodes representing a declaration block.
...@@ -334,9 +354,9 @@ class CDeclaratorNode(Node): ...@@ -334,9 +354,9 @@ class CDeclaratorNode(Node):
class CNameDeclaratorNode(CDeclaratorNode): class CNameDeclaratorNode(CDeclaratorNode):
# name string The Pyrex name being declared # name string The Pyrex name being declared
# cname string or None C name, if specified # cname string or None C name, if specified
# rhs ExprNode or None the value assigned on declaration # default ExprNode or None the value assigned on declaration
child_attrs = [] child_attrs = ['default']
def analyse(self, base_type, env, nonempty = 0): def analyse(self, base_type, env, nonempty = 0):
if nonempty and self.name == '': if nonempty and self.name == '':
...@@ -348,25 +368,25 @@ class CNameDeclaratorNode(CDeclaratorNode): ...@@ -348,25 +368,25 @@ class CNameDeclaratorNode(CDeclaratorNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.entry = env.lookup(self.name) self.entry = env.lookup(self.name)
if self.rhs is not None: if self.default is not None:
env.control_flow.set_state(self.rhs.end_pos(), (self.entry.name, 'initalized'), True) env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'initalized'), True)
env.control_flow.set_state(self.rhs.end_pos(), (self.entry.name, 'source'), 'assignment') env.control_flow.set_state(self.default.end_pos(), (self.entry.name, 'source'), 'assignment')
self.entry.used = 1 self.entry.used = 1
if self.type.is_pyobject: if self.type.is_pyobject:
self.entry.init_to_none = False self.entry.init_to_none = False
self.entry.init = 0 self.entry.init = 0
self.rhs.analyse_types(env) self.default.analyse_types(env)
self.rhs = self.rhs.coerce_to(self.type, env) self.default = self.default.coerce_to(self.type, env)
self.rhs.allocate_temps(env) self.default.allocate_temps(env)
self.rhs.release_temp(env) self.default.release_temp(env)
def generate_execution_code(self, code): def generate_execution_code(self, code):
if self.rhs is not None: if self.default is not None:
self.rhs.generate_evaluation_code(code) self.default.generate_evaluation_code(code)
if self.type.is_pyobject: if self.type.is_pyobject:
self.rhs.make_owned_reference(code) self.default.make_owned_reference(code)
code.putln('%s = %s;' % (self.entry.cname, self.rhs.result_as(self.entry.type))) code.putln('%s = %s;' % (self.entry.cname, self.default.result_as(self.entry.type)))
self.rhs.generate_post_assignment_code(code) self.default.generate_post_assignment_code(code)
code.putln() code.putln()
class CPtrDeclaratorNode(CDeclaratorNode): class CPtrDeclaratorNode(CDeclaratorNode):
...@@ -545,7 +565,6 @@ class CBaseTypeNode(Node): ...@@ -545,7 +565,6 @@ class CBaseTypeNode(Node):
pass pass
class CSimpleBaseTypeNode(CBaseTypeNode): class CSimpleBaseTypeNode(CBaseTypeNode):
# name string # name string
# module_path [string] Qualifying name components # module_path [string] Qualifying name components
...@@ -587,6 +606,30 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -587,6 +606,30 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
else: else:
return PyrexTypes.error_type return PyrexTypes.error_type
class CBufferAccessTypeNode(Node):
# After parsing:
# positional_args [ExprNode] List of positional arguments
# keyword_args DictNode Keyword arguments
# base_type_node CBaseTypeNode
# After PostParse:
# dtype_node CBaseTypeNode
# ndim int
# After analysis:
# type PyrexType.PyrexType
child_attrs = ["base_type_node", "positional_args", "keyword_args",
"dtype_node"]
dtype_node = None
def analyse(self, env):
base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env)
options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim)
self.type = PyrexTypes.create_buffer_type(base_type, options)
return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
# base_type CBaseTypeNode # base_type CBaseTypeNode
...@@ -773,9 +816,11 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -773,9 +816,11 @@ class FuncDefNode(StatNode, BlockNode):
# return_type PyrexType # return_type PyrexType
# #filename string C name of filename string const # #filename string C name of filename string const
# entry Symtab.Entry # entry Symtab.Entry
# needs_closure boolean Whether or not this function has inner functions/classes/yield
py_func = None py_func = None
assmt = None assmt = None
needs_closure = False
def analyse_default_values(self, env): def analyse_default_values(self, env):
genv = env.global_scope() genv = env.global_scope()
...@@ -803,25 +848,27 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -803,25 +848,27 @@ class FuncDefNode(StatNode, BlockNode):
def need_gil_acquisition(self, lenv): def need_gil_acquisition(self, lenv):
return 0 return 0
def generate_function_definitions(self, env, code, transforms): def create_local_scope(self, env):
code.mark_pos(self.pos) genv = env
# Generate C code for header and body of function while env.is_py_class_scope or env.is_c_class_scope:
genv = env.global_scope() env = env.outer_scope
if self.needs_closure:
lenv = GeneratorLocalScope(name = self.entry.name, outer_scope = genv)
else:
lenv = LocalScope(name = self.entry.name, outer_scope = genv) lenv = LocalScope(name = self.entry.name, outer_scope = genv)
lenv.return_type = self.return_type lenv.return_type = self.return_type
type = self.entry.type type = self.entry.type
if type.is_cfunction: if type.is_cfunction:
lenv.nogil = type.nogil and not type.with_gil lenv.nogil = type.nogil and not type.with_gil
self.local_scope = lenv
return lenv
def generate_function_definitions(self, env, code, transforms):
# Generate C code for header and body of function
code.init_labels() code.init_labels()
self.declare_arguments(lenv) lenv = self.local_scope
transforms.run('before_analyse_function', self, env=env, lenv=lenv, genv=genv)
self.body.analyse_control_flow(lenv)
self.body.analyse_declarations(lenv)
self.body.analyse_expressions(lenv)
transforms.run('after_analyse_function', self, env=env, lenv=lenv, genv=genv)
# Code for nested function definitions would go here
# if we supported them, which we probably won't.
# ----- Top-level constants used by this function # ----- Top-level constants used by this function
code.mark_pos(self.pos)
self.generate_interned_num_decls(lenv, code) self.generate_interned_num_decls(lenv, code)
self.generate_interned_string_decls(lenv, code) self.generate_interned_string_decls(lenv, code)
self.generate_py_string_decls(lenv, code) self.generate_py_string_decls(lenv, code)
...@@ -838,7 +885,10 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -838,7 +885,10 @@ class FuncDefNode(StatNode, BlockNode):
self.generate_function_header(code, self.generate_function_header(code,
with_pymethdef = env.is_py_class_scope) with_pymethdef = env.is_py_class_scope)
# ----- Local variable declarations # ----- Local variable declarations
lenv.mangle_closure_cnames(Naming.cur_scope_cname)
self.generate_argument_declarations(lenv, code) self.generate_argument_declarations(lenv, code)
if self.needs_closure:
code.putln("/* TODO: declare and create scope object */")
code.put_var_declarations(lenv.var_entries) code.put_var_declarations(lenv.var_entries)
init = "" init = ""
if not self.return_type.is_void: if not self.return_type.is_void:
...@@ -919,6 +969,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -919,6 +969,7 @@ class FuncDefNode(StatNode, BlockNode):
self.put_stararg_decrefs(code) self.put_stararg_decrefs(code)
if acquire_gil: if acquire_gil:
code.putln("PyGILState_Release(_save);") code.putln("PyGILState_Release(_save);")
code.putln("/* TODO: decref scope object */")
# ----- Return # ----- Return
if not self.return_type.is_void: if not self.return_type.is_void:
code.putln("return %s;" % Naming.retval_cname) code.putln("return %s;" % Naming.retval_cname)
...@@ -975,8 +1026,9 @@ class CFuncDefNode(FuncDefNode): ...@@ -975,8 +1026,9 @@ class CFuncDefNode(FuncDefNode):
# #
# with_gil boolean Acquire GIL around body # with_gil boolean Acquire GIL around body
# type CFuncType # type CFuncType
# py_func wrapper for calling from Python
child_attrs = ["base_type", "declarator", "body"] child_attrs = ["base_type", "declarator", "body", "py_func"]
def unqualified_name(self): def unqualified_name(self):
return self.entry.name return self.entry.name
...@@ -1155,10 +1207,17 @@ class CFuncDefNode(FuncDefNode): ...@@ -1155,10 +1207,17 @@ class CFuncDefNode(FuncDefNode):
if self.return_type.is_pyobject: if self.return_type.is_pyobject:
return "0" return "0"
else: else:
#return None if self.entry.type.exception_value is not None:
return self.entry.type.exception_value return self.entry.type.exception_value
elif self.return_type.is_struct_or_union or self.return_type.is_void:
return None
else:
return self.return_type.cast_code(Naming.default_error)
def caller_will_check_exceptions(self): def caller_will_check_exceptions(self):
if self.entry.type.exception_value is None:
return 1
else:
return self.entry.type.exception_check return self.entry.type.exception_check
def generate_optarg_wrapper_function(self, env, code): def generate_optarg_wrapper_function(self, env, code):
...@@ -1183,13 +1242,19 @@ class PyArgDeclNode(Node): ...@@ -1183,13 +1242,19 @@ class PyArgDeclNode(Node):
# entry Symtab.Entry # entry Symtab.Entry
child_attrs = [] child_attrs = []
pass
class DecoratorNode(Node):
# A decorator
#
# decorator NameNode or CallNode
child_attrs = ['decorator']
class DefNode(FuncDefNode): class DefNode(FuncDefNode):
# A Python function definition. # A Python function definition.
# #
# name string the Python name of the function # name string the Python name of the function
# decorators [DecoratorNode] list of decorators
# args [CArgDeclNode] formal arguments # args [CArgDeclNode] formal arguments
# star_arg PyArgDeclNode or None * argument # star_arg PyArgDeclNode or None * argument
# starstar_arg PyArgDeclNode or None ** argument # starstar_arg PyArgDeclNode or None ** argument
...@@ -1201,13 +1266,14 @@ class DefNode(FuncDefNode): ...@@ -1201,13 +1266,14 @@ class DefNode(FuncDefNode):
# #
# assmt AssignmentNode Function construction/assignment # assmt AssignmentNode Function construction/assignment
child_attrs = ["args", "star_arg", "starstar_arg", "body"] child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
assmt = None assmt = None
num_kwonly_args = 0 num_kwonly_args = 0
num_required_kw_args = 0 num_required_kw_args = 0
reqd_kw_flags_cname = "0" reqd_kw_flags_cname = "0"
is_wrapper = 0 is_wrapper = 0
decorators = None
def __init__(self, pos, **kwds): def __init__(self, pos, **kwds):
FuncDefNode.__init__(self, pos, **kwds) FuncDefNode.__init__(self, pos, **kwds)
...@@ -1839,6 +1905,8 @@ class OverrideCheckNode(StatNode): ...@@ -1839,6 +1905,8 @@ class OverrideCheckNode(StatNode):
child_attrs = ['body'] child_attrs = ['body']
body = None
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.args = env.arg_entries self.args = env.arg_entries
if self.py_func.is_module_scope: if self.py_func.is_module_scope:
...@@ -1852,9 +1920,7 @@ class OverrideCheckNode(StatNode): ...@@ -1852,9 +1920,7 @@ class OverrideCheckNode(StatNode):
function=self.func_node, function=self.func_node,
args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[first_arg:]]) args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[first_arg:]])
self.body = ReturnStatNode(self.pos, value=call_node) self.body = ReturnStatNode(self.pos, value=call_node)
# self.func_temp = env.allocate_temp_pyobject()
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
# env.release_temp(self.func_temp)
def generate_execution_code(self, code): def generate_execution_code(self, code):
# Check to see if we are an extension type # Check to see if we are an extension type
...@@ -1869,7 +1935,7 @@ class OverrideCheckNode(StatNode): ...@@ -1869,7 +1935,7 @@ class OverrideCheckNode(StatNode):
code.putln("else {") code.putln("else {")
else: else:
code.putln("else if (unlikely(Py_TYPE(%s)->tp_dictoffset != 0)) {" % self_arg) code.putln("else if (unlikely(Py_TYPE(%s)->tp_dictoffset != 0)) {" % self_arg)
err = code.error_goto_if_null(self_arg, self.pos) err = code.error_goto_if_null(self.func_node.result_code, self.pos)
# need to get attribute manually--scope would return cdef method # need to get attribute manually--scope would return cdef method
code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (self.func_node.result_code, self_arg, self.py_func.interned_attr_cname, err)) code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (self.func_node.result_code, self_arg, self.py_func.interned_attr_cname, err))
# It appears that this type is not anywhere exposed in the Python/C API # It appears that this type is not anywhere exposed in the Python/C API
...@@ -1878,12 +1944,13 @@ class OverrideCheckNode(StatNode): ...@@ -1878,12 +1944,13 @@ class OverrideCheckNode(StatNode):
code.putln('if (!%s || %s) {' % (is_builtin_function_or_method, is_overridden)) code.putln('if (!%s || %s) {' % (is_builtin_function_or_method, is_overridden))
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.putln('}') code.putln('}')
# code.put_decref(self.func_temp, PyrexTypes.py_object_type) code.put_decref_clear(self.func_node.result_code, PyrexTypes.py_object_type)
code.putln("}") code.putln("}")
class ClassDefNode(StatNode, BlockNode):
pass
class PyClassDefNode(ClassDefNode):
class PyClassDefNode(StatNode, BlockNode):
# A Python class definition. # A Python class definition.
# #
# name EncodedString Name of the class # name EncodedString Name of the class
...@@ -1916,18 +1983,26 @@ class PyClassDefNode(StatNode, BlockNode): ...@@ -1916,18 +1983,26 @@ class PyClassDefNode(StatNode, BlockNode):
bases = bases, dict = self.dict, doc = doc_node) bases = bases, dict = self.dict, doc = doc_node)
self.target = ExprNodes.NameNode(pos, name = name) self.target = ExprNodes.NameNode(pos, name = name)
def create_scope(self, env):
genv = env
while env.is_py_class_scope or env.is_c_class_scope:
env = env.outer_scope
cenv = self.scope = PyClassScope(name = self.name, outer_scope = genv)
return cenv
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
cenv = self.create_scope(env)
cenv.class_obj_cname = self.target.entry.cname
self.body.analyse_declarations(cenv)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.dict.analyse_expressions(env) self.dict.analyse_expressions(env)
self.classobj.analyse_expressions(env) self.classobj.analyse_expressions(env)
genv = env.global_scope() genv = env.global_scope()
cenv = PyClassScope(name = self.name, outer_scope = genv) cenv = self.scope
cenv.class_dict_cname = self.dict.result_code cenv.class_dict_cname = self.dict.result_code
cenv.class_obj_cname = self.classobj.result_code cenv.namespace_cname = cenv.class_obj_cname = self.classobj.result_code
self.scope = cenv
self.body.analyse_declarations(cenv)
self.body.analyse_expressions(cenv) self.body.analyse_expressions(cenv)
self.target.analyse_target_expression(env, self.classobj) self.target.analyse_target_expression(env, self.classobj)
self.dict.release_temp(env) self.dict.release_temp(env)
...@@ -1947,7 +2022,7 @@ class PyClassDefNode(StatNode, BlockNode): ...@@ -1947,7 +2022,7 @@ class PyClassDefNode(StatNode, BlockNode):
self.dict.generate_disposal_code(code) self.dict.generate_disposal_code(code)
class CClassDefNode(StatNode, BlockNode): class CClassDefNode(ClassDefNode):
# An extension type definition. # An extension type definition.
# #
# visibility 'private' or 'public' or 'extern' # visibility 'private' or 'public' or 'extern'
...@@ -2356,7 +2431,7 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -2356,7 +2431,7 @@ class InPlaceAssignmentNode(AssignmentNode):
# Fortunately, the type of the lhs node is fairly constrained # Fortunately, the type of the lhs node is fairly constrained
# (it must be a NameNode, AttributeNode, or IndexNode). # (it must be a NameNode, AttributeNode, or IndexNode).
child_attrs = ["lhs", "rhs", "dup"] child_attrs = ["lhs", "rhs"]
dup = None dup = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
...@@ -2991,7 +3066,7 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -2991,7 +3066,7 @@ class ForInStatNode(LoopNode, StatNode):
# else_clause StatNode # else_clause StatNode
# item NextNode used internally # item NextNode used internally
child_attrs = ["target", "iterator", "body", "else_clause", "item"] child_attrs = ["target", "iterator", "body", "else_clause"]
item = None item = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
...@@ -3225,6 +3300,18 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -3225,6 +3300,18 @@ class ForFromStatNode(LoopNode, StatNode):
self.else_clause.annotate(code) self.else_clause.annotate(code)
class WithStatNode(StatNode):
"""
Represents a Python with statement.
This is only used at parse tree level; and is not present in
analysis or generation phases.
"""
# manager The with statement manager object
# target Node (lhs expression)
# body StatNode
child_attrs = ["manager", "target", "body"]
class TryExceptStatNode(StatNode): class TryExceptStatNode(StatNode):
# try .. except statement # try .. except statement
# #
...@@ -3317,17 +3404,26 @@ class ExceptClauseNode(Node): ...@@ -3317,17 +3404,26 @@ class ExceptClauseNode(Node):
# pattern ExprNode # pattern ExprNode
# target ExprNode or None # target ExprNode or None
# body StatNode # body StatNode
# excinfo_target NameNode or None optional target for exception info
# match_flag string result of exception match # match_flag string result of exception match
# exc_value ExcValueNode used internally # exc_value ExcValueNode used internally
# function_name string qualified name of enclosing function # function_name string qualified name of enclosing function
# exc_vars (string * 3) local exception variables # exc_vars (string * 3) local exception variables
child_attrs = ["pattern", "target", "body", "exc_value"] # excinfo_target is never set by the parser, but can be set by a transform
# in order to extract more extensive information about the exception as a
# sys.exc_info()-style tuple into a target variable
child_attrs = ["pattern", "target", "body", "exc_value", "excinfo_target"]
exc_value = None exc_value = None
excinfo_target = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
if self.target: if self.target:
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
if self.excinfo_target is not None:
self.excinfo_target.analyse_target_declaration(env)
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
...@@ -3345,6 +3441,17 @@ class ExceptClauseNode(Node): ...@@ -3345,6 +3441,17 @@ class ExceptClauseNode(Node):
self.exc_value = ExprNodes.ExcValueNode(self.pos, env, self.exc_vars[1]) self.exc_value = ExprNodes.ExcValueNode(self.pos, env, self.exc_vars[1])
self.exc_value.allocate_temps(env) self.exc_value.allocate_temps(env)
self.target.analyse_target_expression(env, self.exc_value) self.target.analyse_target_expression(env, self.exc_value)
if self.excinfo_target is not None:
import ExprNodes
self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[0]),
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[1]),
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[2])
])
self.excinfo_tuple.analyse_expressions(env)
self.excinfo_tuple.allocate_temps(env)
self.excinfo_target.analyse_target_expression(env, self.excinfo_tuple)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
for var in self.exc_vars: for var in self.exc_vars:
env.release_temp(var) env.release_temp(var)
...@@ -3374,6 +3481,10 @@ class ExceptClauseNode(Node): ...@@ -3374,6 +3481,10 @@ class ExceptClauseNode(Node):
if self.target: if self.target:
self.exc_value.generate_evaluation_code(code) self.exc_value.generate_evaluation_code(code)
self.target.generate_assignment_code(self.exc_value, code) self.target.generate_assignment_code(self.exc_value, code)
if self.excinfo_target is not None:
self.excinfo_tuple.generate_evaluation_code(code)
self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code)
old_exc_vars = code.exc_vars old_exc_vars = code.exc_vars
code.exc_vars = self.exc_vars code.exc_vars = self.exc_vars
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
...@@ -4484,6 +4595,7 @@ bad: ...@@ -4484,6 +4595,7 @@ bad:
Py_XDECREF(*tb); Py_XDECREF(*tb);
return -1; return -1;
} }
"""] """]
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
import Nodes import Nodes
import ExprNodes import ExprNodes
import PyrexTypes
import Visitor import Visitor
def unwrap_node(node):
while isinstance(node, ExprNodes.PersistentNode):
node = node.arg
return node
def is_common_value(a, b): def is_common_value(a, b):
a = unwrap_node(a)
b = unwrap_node(b)
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
...@@ -11,13 +18,20 @@ def is_common_value(a, b): ...@@ -11,13 +18,20 @@ def is_common_value(a, b):
return False return False
class SwitchTransformVisitor(Visitor.VisitorTransform): class SwitchTransform(Visitor.VisitorTransform):
"""
This transformation tries to turn long if statements into C switch statements.
The requirement is that every clause be an (or of) var == value, where the var
is common among all clauses and both var and value are not Python objects.
"""
def extract_conditions(self, cond): def extract_conditions(self, cond):
if isinstance(cond, ExprNodes.CoerceToTempNode): if isinstance(cond, ExprNodes.CoerceToTempNode):
cond = cond.arg cond = cond.arg
if isinstance(cond, ExprNodes.TypecastNode):
cond = cond.operand
if (isinstance(cond, ExprNodes.PrimaryCmpNode) if (isinstance(cond, ExprNodes.PrimaryCmpNode)
and cond.cascade is None and cond.cascade is None
and cond.operator == '==' and cond.operator == '=='
...@@ -40,14 +54,8 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): ...@@ -40,14 +54,8 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
return t1, c1+c2 return t1, c1+c2
return None, None return None, None
def is_common_value(self, a, b):
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
return not a.is_py_attr and is_common_value(a.obj, b.obj)
return False
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
self.visitchildren(node)
if len(node.if_clauses) < 3: if len(node.if_clauses) < 3:
return node return node
common_var = None common_var = None
...@@ -56,19 +64,73 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): ...@@ -56,19 +64,73 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
var, conditions = self.extract_conditions(if_clause.condition) var, conditions = self.extract_conditions(if_clause.condition)
if var is None: if var is None:
return node return node
elif common_var is not None and not self.is_common_value(var, common_var): elif common_var is not None and not is_common_value(var, common_var):
return node return node
else: else:
common_var = var common_var = var
cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
conditions = conditions, conditions = conditions,
body = if_clause.body)) body = if_clause.body))
common_var = unwrap_node(common_var)
return Nodes.SwitchStatNode(pos = node.pos, return Nodes.SwitchStatNode(pos = node.pos,
test = common_var, test = common_var,
cases = cases, cases = cases,
else_clause = node.else_clause) else_clause = node.else_clause)
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
return node return node
class FlattenInListTransform(Visitor.VisitorTransform):
"""
This transformation flattens "x in [val1, ..., valn]" into a sequential list
of comparisons.
"""
def visit_PrimaryCmpNode(self, node):
self.visitchildren(node)
if node.cascade is not None:
return node
elif node.operator == 'in':
conjunction = 'or'
eq_or_neq = '=='
elif node.operator == 'not_in':
conjunction = 'and'
eq_or_neq = '!='
else:
return node
if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
args = node.operand2.args
if len(args) == 0:
return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
else:
lhs = ExprNodes.PersistentNode(node.operand1, len(args))
conds = []
for arg in args:
cond = ExprNodes.PrimaryCmpNode(
pos = node.pos,
operand1 = lhs,
operator = eq_or_neq,
operand2 = arg,
cascade = None)
conds.append(ExprNodes.TypecastNode(
pos = node.pos,
operand = cond,
type = PyrexTypes.c_bint_type))
def concat(left, right):
return ExprNodes.BoolBinopNode(
pos = node.pos,
operator = conjunction,
operand1 = left,
operand2 = right)
return reduce(concat, conds)
else:
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
from sets import Set as set
class NormalizeTree(CythonTransform):
"""
This transform fixes up a few things after parsing
in order to make the parse tree more suitable for
transforms.
a) After parsing, blocks with only one statement will
be represented by that statement, not by a StatListNode.
When doing transforms this is annoying and inconsistent,
as one cannot in general remove a statement in a consistent
way and so on. This transform wraps any single statements
in a StatListNode containing a single statement.
b) The PassStatNode is a noop and serves no purpose beyond
plugging such one-statement blocks; i.e., once parsed a
` "pass" can just as well be represented using an empty
StatListNode. This means less special cases to worry about
in subsequent transforms (one always checks to see if a
StatListNode has no children to see if the block is empty).
"""
def __init__(self, context):
super(NormalizeTree, self).__init__(context)
self.is_in_statlist = False
self.is_in_expr = False
def visit_ExprNode(self, node):
stacktmp = self.is_in_expr
self.is_in_expr = True
self.visitchildren(node)
self.is_in_expr = stacktmp
return node
def visit_StatNode(self, node, is_listcontainer=False):
stacktmp = self.is_in_statlist
self.is_in_statlist = is_listcontainer
self.visitchildren(node)
self.is_in_statlist = stacktmp
if not self.is_in_statlist and not self.is_in_expr:
return StatListNode(pos=node.pos, stats=[node])
else:
return node
def visit_StatListNode(self, node):
self.is_in_statlist = True
self.visitchildren(node)
self.is_in_statlist = False
return node
def visit_ParallelAssignmentNode(self, node):
return self.visit_StatNode(node, True)
def visit_CEnumDefNode(self, node):
return self.visit_StatNode(node, True)
def visit_CStructOrUnionDefNode(self, node):
return self.visit_StatNode(node, True)
# Eliminate PassStatNode
def visit_PassStatNode(self, node):
if not self.is_in_statlist:
return StatListNode(pos=node.pos, stats=[])
else:
return []
# Eliminate CascadedAssignmentNode
def visit_CascadedAssignmentNode(self, node):
tmpname = temp_name_handle()
class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
class PostParse(CythonTransform):
"""
Basic interpretation of the parse tree, as well as validity
checking that can be done on a very basic level on the parse
tree (while still not being a problem with the basic syntax,
as such).
Specifically:
- CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the option combination is valid.
Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform
if a more pure Abstract Syntax Tree is wanted.
"""
buffer_options = ("dtype", "ndim") # ordered!
def visit_CBufferAccessTypeNode(self, node):
options = {}
# Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name)
options[name] = arg
# Fetch named arguments
for item in node.keyword_args.key_value_pairs:
name = str(item.key.value)
if not name in self.buffer_options:
raise PostParseError(item.key.pos,
ERR_BUF_UNKNOWN % name)
if name in options.keys():
raise PostParseError(item.key.pos,
ERR_BUF_DUP % key)
options[name] = item.value
provided = options.keys()
# get dtype
dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype_node = dtype
# get ndim
if "ndim" in provided:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
# so nothing more to do here
raise PostParseError(ndimnode.pos, ERR_BUF_INT % 'ndim')
ndim_value = int(ndimnode.value)
if ndim_value < 0:
raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
node.ndim = int(ndimnode.value)
else:
node.ndim = 1
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
return node
class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains
# the exc_info() tuple that can be generated by the enclosing except
# statement.
template_without_target = TreeFragment(u"""
MGR = EXPR
EXIT = MGR.__exit__
MGR.__enter__()
EXC = True
try:
try:
BODY
except:
EXC = False
if not EXIT(*EXCINFO):
raise
finally:
if EXC:
EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u"""
MGR = EXPR
EXIT = MGR.__exit__
VALUE = MGR.__enter__()
EXC = True
try:
try:
TARGET = VALUE
BODY
except:
EXC = False
if not EXIT(*EXCINFO):
raise
finally:
if EXC:
EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
pipeline=[NormalizeTree(None)])
def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO')
excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
if node.target is not None:
result = self.template_with_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
u'TARGET' : node.target,
u'EXCINFO' : excinfo_namenode
}, pos = node.pos)
# Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
else:
result = self.template_without_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
u'EXCINFO' : excinfo_namenode
}, pos = node.pos)
# Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
return result.stats
class DecoratorTransform(CythonTransform):
def visit_DefNode(self, func_node):
if not func_node.decorators:
return func_node
decorator_result = NameNode(func_node.pos, name = func_node.name)
for decorator in func_node.decorators[::-1]:
decorator_result = SimpleCallNode(
decorator.pos,
function = decorator.decorator,
args = [decorator_result])
func_name_node = NameNode(func_node.pos, name = func_node.name)
reassignment = SingleAssignmentNode(
func_node.pos,
lhs = func_name_node,
rhs = decorator_result)
return [func_node, reassignment]
class AnalyseDeclarationsTransform(CythonTransform):
def __call__(self, root):
self.env_stack = [root.scope]
return super(AnalyseDeclarationsTransform, self).__call__(root)
def visit_ModuleNode(self, node):
node.analyse_declarations(self.env_stack[-1])
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
lenv = node.create_local_scope(self.env_stack[-1])
node.body.analyse_control_flow(lenv) # this will be totally refactored
node.declare_arguments(lenv)
node.body.analyse_declarations(lenv)
self.env_stack.append(lenv)
self.visitchildren(node)
self.env_stack.pop()
return node
class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
node.body.analyse_expressions(node.local_scope)
self.visitchildren(node)
return node
class MarkClosureVisitor(CythonTransform):
needs_closure = False
def visit_FuncDefNode(self, node):
self.needs_closure = False
self.visitchildren(node)
node.needs_closure = self.needs_closure
self.needs_closure = True
return node
def visit_ClassDefNode(self, node):
self.visitchildren(node)
self.needs_closure = True
return node
def visit_YieldNode(self, node):
self.needs_closure = True
class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions
# that need it.
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.visitchildren(node)
return node
def create_class_from_scope(self, node, target_module_scope):
as_name = temp_name_handle("closure")
func_scope = node.local_scope
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
class_scope = entry.type.scope
for entry in func_scope.entries.values():
class_scope.declare_var(pos=node.pos,
name=entry.name,
cname=entry.cname,
type=entry.type,
is_cdef=True)
def visit_FuncDefNode(self, node):
self.create_class_from_scope(node, self.module_scope)
return node
...@@ -312,6 +312,7 @@ def p_call(s, function): ...@@ -312,6 +312,7 @@ def p_call(s, function):
if s.sy != ',': if s.sy != ',':
break break
s.next() s.next()
if s.sy == '*': if s.sy == '*':
s.next() s.next()
star_arg = p_simple_expr(s) star_arg = p_simple_expr(s)
...@@ -1159,13 +1160,13 @@ def p_for_from_step(s): ...@@ -1159,13 +1160,13 @@ def p_for_from_step(s):
inequality_relations = ('<', '<=', '>', '>=') inequality_relations = ('<', '<=', '>', '>=')
def p_for_target(s): def p_target(s, terminator):
pos = s.position() pos = s.position()
expr = p_bit_expr(s) expr = p_bit_expr(s)
if s.sy == ',': if s.sy == ',':
s.next() s.next()
exprs = [expr] exprs = [expr]
while s.sy != 'in': while s.sy != terminator:
exprs.append(p_bit_expr(s)) exprs.append(p_bit_expr(s))
if s.sy != ',': if s.sy != ',':
break break
...@@ -1174,6 +1175,9 @@ def p_for_target(s): ...@@ -1174,6 +1175,9 @@ def p_for_target(s):
else: else:
return expr return expr
def p_for_target(s):
return p_target(s, 'in')
def p_for_iterator(s): def p_for_iterator(s):
pos = s.position() pos = s.position()
expr = p_testlist(s) expr = p_testlist(s)
...@@ -1253,8 +1257,17 @@ def p_with_statement(s): ...@@ -1253,8 +1257,17 @@ def p_with_statement(s):
body = p_suite(s) body = p_suite(s)
return Nodes.GILStatNode(pos, state = state, body = body) return Nodes.GILStatNode(pos, state = state, body = body)
else: else:
s.error("Only 'with gil' and 'with nogil' implemented", manager = p_expr(s)
pos = pos) target = None
if s.sy == 'IDENT' and s.systring == 'as':
s.next()
allow_multi = (s.sy == '(')
target = p_target(s, ':')
if not allow_multi and isinstance(target, ExprNodes.TupleNode):
s.error("Multiple with statement target values not allowed without paranthesis")
body = p_suite(s)
return Nodes.WithStatNode(pos, manager = manager,
target = target, body = body)
def p_simple_statement(s, first_statement = 0): def p_simple_statement(s, first_statement = 0):
#print "p_simple_statement:", s.sy, s.systring ### #print "p_simple_statement:", s.sy, s.systring ###
...@@ -1359,6 +1372,14 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1359,6 +1372,14 @@ def p_statement(s, ctx, first_statement = 0):
return p_DEF_statement(s) return p_DEF_statement(s)
elif s.sy == 'IF': elif s.sy == 'IF':
return p_IF_statement(s, ctx) return p_IF_statement(s, ctx)
elif s.sy == 'DECORATOR':
if ctx.level not in ('module', 'class', 'c_class', 'property'):
s.error('decorator not allowed here')
s.level = ctx.level
decorators = p_decorators(s)
if s.sy != 'def':
s.error("Decorators can only be followed by functions ")
return p_def_statement(s, decorators)
else: else:
overridable = 0 overridable = 0
if s.sy == 'cdef': if s.sy == 'cdef':
...@@ -1447,6 +1468,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0): ...@@ -1447,6 +1468,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0):
else: else:
return body return body
def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keywords=()):
"""
Parses positional and keyword arguments. end_sy_set
should contain any s.sy that terminate the argument list.
Argument expansion (* and **) are not allowed.
type_positions and type_keywords specifies which argument
positions and/or names which should be interpreted as
types. Other arguments will be treated as expressions.
Returns: (positional_args, keyword_args)
"""
positional_args = []
keyword_args = []
pos_idx = 0
while s.sy not in end_sy_set:
if s.sy == '*' or s.sy == '**':
s.error('Argument expansion not allowed here.')
was_keyword = False
parsed_type = False
if s.sy == 'IDENT':
# Since we can have either types or expressions as positional args,
# we use a strategy of looking an extra step forward for a '=' and
# if it is a positional arg we backtrack.
ident = s.systring
s.next()
if s.sy == '=':
s.next()
# Is keyword arg
if ident in type_keywords:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
value = Utils.EncodedString(ident))
keyword_args.append((keyword_node, arg))
was_keyword = True
else:
s.put_back('IDENT', ident)
if not was_keyword:
if pos_idx in type_positions:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
positional_args.append(arg)
pos_idx += 1
if len(keyword_args) > 0:
s.error("Non-keyword arg following keyword arg",
pos = arg.pos)
if s.sy != ',':
if s.sy not in end_sy_set:
if parsed_type:
s.error("Expected: type")
else:
s.error("Expected: expression")
break
s.next()
return positional_args, keyword_args
def p_c_base_type(s, self_flag = 0, nonempty = 0): def p_c_base_type(s, self_flag = 0, nonempty = 0):
# If self_flag is true, this is the base type for the # If self_flag is true, this is the base type for the
# self argument of a C method of an extension type. # self argument of a C method of an extension type.
...@@ -1519,11 +1605,43 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1519,11 +1605,43 @@ def p_c_simple_base_type(s, self_flag, nonempty):
else: else:
#print "p_c_simple_base_type: not looking at type at", s.position() #print "p_c_simple_base_type: not looking at type at", s.position()
name = None name = None
return Nodes.CSimpleBaseTypeNode(pos,
type_node = Nodes.CSimpleBaseTypeNode(pos,
name = name, module_path = module_path, name = name, module_path = module_path,
is_basic_c_type = is_basic, signed = signed, is_basic_c_type = is_basic, signed = signed,
longness = longness, is_self_arg = self_flag) longness = longness, is_self_arg = self_flag)
# Treat trailing [] on type as buffer access
if 0: # s.sy == '[':
if is_basic:
s.error("Basic C types do not support buffer access")
return p_buffer_access(s, type_node)
else:
return type_node
def p_buffer_access(s, type_node):
# s.sy == '['
pos = s.position()
s.next()
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = type_node)
return result
def looking_at_type(s): def looking_at_type(s):
return looking_at_base_type(s) or s.looking_at_type_name() return looking_at_base_type(s) or s.looking_at_type_name()
...@@ -1668,7 +1786,7 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, ...@@ -1668,7 +1786,7 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
name = "" name = ""
cname = None cname = None
result = Nodes.CNameDeclaratorNode(pos, result = Nodes.CNameDeclaratorNode(pos,
name = name, cname = cname, rhs = rhs) name = name, cname = cname, default = rhs)
result.calling_convention = calling_convention result.calling_convention = calling_convention
return result return result
...@@ -1993,7 +2111,21 @@ def p_ctypedef_statement(s, ctx): ...@@ -1993,7 +2111,21 @@ def p_ctypedef_statement(s, ctx):
declarator = declarator, visibility = visibility, declarator = declarator, visibility = visibility,
in_pxd = ctx.level == 'module_pxd') in_pxd = ctx.level == 'module_pxd')
def p_def_statement(s): def p_decorators(s):
decorators = []
while s.sy == 'DECORATOR':
pos = s.position()
s.next()
decorator = ExprNodes.NameNode(
pos, name = Utils.EncodedString(
p_dotted_name(s, as_allowed=0)[2] ))
if s.sy == '(':
decorator = p_call(s, decorator)
decorators.append(Nodes.DecoratorNode(pos, decorator=decorator))
s.expect_newline("Expected a newline after decorator")
return decorators
def p_def_statement(s, decorators=None):
# s.sy == 'def' # s.sy == 'def'
pos = s.position() pos = s.position()
s.next() s.next()
...@@ -2022,7 +2154,7 @@ def p_def_statement(s): ...@@ -2022,7 +2154,7 @@ def p_def_statement(s):
doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1) doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1)
return Nodes.DefNode(pos, name = name, args = args, return Nodes.DefNode(pos, name = name, args = args,
star_arg = star_arg, starstar_arg = starstar_arg, star_arg = star_arg, starstar_arg = starstar_arg,
doc = doc, body = body) doc = doc, body = body, decorators = decorators)
def p_py_arg_decl(s): def p_py_arg_decl(s):
pos = s.position() pos = s.position()
......
...@@ -4,6 +4,23 @@ ...@@ -4,6 +4,23 @@
from Cython import Utils from Cython import Utils
import Naming import Naming
import copy
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class BaseType: class BaseType:
# #
...@@ -44,6 +61,7 @@ class PyrexType(BaseType): ...@@ -44,6 +61,7 @@ class PyrexType(BaseType):
# default_value string Initial value # default_value string Initial value
# parsetuple_format string Format char for PyArg_ParseTuple # parsetuple_format string Format char for PyArg_ParseTuple
# pymemberdef_typecode string Type code for PyMemberDef struct # pymemberdef_typecode string Type code for PyMemberDef struct
# typestring string String char defining the type (see Python struct module)
# #
# declaration_code(entity_code, # declaration_code(entity_code,
# for_display = 0, dll_linkage = None, pyrex = 0) # for_display = 0, dll_linkage = None, pyrex = 0)
...@@ -92,6 +110,7 @@ class PyrexType(BaseType): ...@@ -92,6 +110,7 @@ class PyrexType(BaseType):
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
buffer_options = None # can contain a BufferOptions instance
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -183,7 +202,6 @@ class CTypedefType(BaseType): ...@@ -183,7 +202,6 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
...@@ -399,9 +417,10 @@ class CNumericType(CType): ...@@ -399,9 +417,10 @@ class CNumericType(CType):
sign_words = ("unsigned ", "", "signed ") sign_words = ("unsigned ", "", "signed ")
def __init__(self, rank, signed = 1, pymemberdef_typecode = None): def __init__(self, rank, signed = 1, pymemberdef_typecode = None, typestring = None):
self.rank = rank self.rank = rank
self.signed = signed self.signed = signed
self.typestring = typestring
ptf = self.parsetuple_formats[signed][rank] ptf = self.parsetuple_formats[signed][rank]
if ptf == '?': if ptf == '?':
ptf = None ptf = None
...@@ -434,8 +453,9 @@ class CIntType(CNumericType): ...@@ -434,8 +453,9 @@ class CIntType(CNumericType):
from_py_function = "__pyx_PyInt_AsLong" from_py_function = "__pyx_PyInt_AsLong"
exception_value = -1 exception_value = -1
def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0): def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0,
CNumericType.__init__(self, rank, signed, pymemberdef_typecode) typestring=None):
CNumericType.__init__(self, rank, signed, pymemberdef_typecode, typestring=typestring)
self.is_returncode = is_returncode self.is_returncode = is_returncode
if self.from_py_function == '__pyx_PyInt_AsLong': if self.from_py_function == '__pyx_PyInt_AsLong':
self.from_py_function = self.get_type_conversion() self.from_py_function = self.get_type_conversion()
...@@ -526,8 +546,8 @@ class CFloatType(CNumericType): ...@@ -526,8 +546,8 @@ class CFloatType(CNumericType):
to_py_function = "PyFloat_FromDouble" to_py_function = "PyFloat_FromDouble"
from_py_function = "__pyx_PyFloat_AsDouble" from_py_function = "__pyx_PyFloat_AsDouble"
def __init__(self, rank, pymemberdef_typecode = None): def __init__(self, rank, pymemberdef_typecode = None, typestring=None):
CNumericType.__init__(self, rank, 1, pymemberdef_typecode) CNumericType.__init__(self, rank, 1, pymemberdef_typecode, typestring = typestring)
def assignable_from_resolved_type(self, src_type): def assignable_from_resolved_type(self, src_type):
return src_type.is_numeric or src_type is error_type return src_type.is_numeric or src_type is error_type
...@@ -835,8 +855,11 @@ class CFuncTypeArg: ...@@ -835,8 +855,11 @@ class CFuncTypeArg:
# type PyrexType # type PyrexType
# pos source file position # pos source file position
def __init__(self, name, type, pos): def __init__(self, name, type, pos, cname=None):
self.name = name self.name = name
if cname is not None:
self.cname = cname
else:
self.cname = Naming.var_prefix + name self.cname = Naming.var_prefix + name
self.type = type self.type = type
self.pos = pos self.pos = pos
...@@ -1033,29 +1056,29 @@ c_void_type = CVoidType() ...@@ -1033,29 +1056,29 @@ c_void_type = CVoidType()
c_void_ptr_type = CPtrType(c_void_type) c_void_ptr_type = CPtrType(c_void_type)
c_void_ptr_ptr_type = CPtrType(c_void_ptr_type) c_void_ptr_ptr_type = CPtrType(c_void_ptr_type)
c_uchar_type = CIntType(0, 0, "T_UBYTE") c_uchar_type = CIntType(0, 0, "T_UBYTE", typestring="B")
c_ushort_type = CIntType(1, 0, "T_USHORT") c_ushort_type = CIntType(1, 0, "T_USHORT", typestring="H")
c_uint_type = CUIntType(2, 0, "T_UINT") c_uint_type = CUIntType(2, 0, "T_UINT", typestring="I")
c_ulong_type = CULongType(3, 0, "T_ULONG") c_ulong_type = CULongType(3, 0, "T_ULONG", typestring="L")
c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG") c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG", typestring="Q")
c_char_type = CIntType(0, 1, "T_CHAR") c_char_type = CIntType(0, 1, "T_CHAR", typestring="b")
c_short_type = CIntType(1, 1, "T_SHORT") c_short_type = CIntType(1, 1, "T_SHORT", typestring="h")
c_int_type = CIntType(2, 1, "T_INT") c_int_type = CIntType(2, 1, "T_INT", typestring="i")
c_long_type = CIntType(3, 1, "T_LONG") c_long_type = CIntType(3, 1, "T_LONG", typestring="l")
c_longlong_type = CLongLongType(4, 1, "T_LONGLONG") c_longlong_type = CLongLongType(4, 1, "T_LONGLONG", typestring="q")
c_py_ssize_t_type = CPySSizeTType(5, 1) c_py_ssize_t_type = CPySSizeTType(5, 1)
c_bint_type = CBIntType(2, 1, "T_INT") c_bint_type = CBIntType(2, 1, "T_INT", typestring="i")
c_schar_type = CIntType(0, 2, "T_CHAR") c_schar_type = CIntType(0, 2, "T_CHAR", typestring="b")
c_sshort_type = CIntType(1, 2, "T_SHORT") c_sshort_type = CIntType(1, 2, "T_SHORT", typestring="h")
c_sint_type = CIntType(2, 2, "T_INT") c_sint_type = CIntType(2, 2, "T_INT", typestring="i")
c_slong_type = CIntType(3, 2, "T_LONG") c_slong_type = CIntType(3, 2, "T_LONG", typestring="l")
c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG") c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG", typestring="q")
c_float_type = CFloatType(6, "T_FLOAT") c_float_type = CFloatType(6, "T_FLOAT", typestring="f")
c_double_type = CFloatType(7, "T_DOUBLE") c_double_type = CFloatType(7, "T_DOUBLE", typestring="d")
c_longdouble_type = CFloatType(8) c_longdouble_type = CFloatType(8, typestring="g")
c_null_ptr_type = CNullPtrType(c_void_type) c_null_ptr_type = CNullPtrType(c_void_type)
c_char_array_type = CCharArrayType(None) c_char_array_type = CCharArrayType(None)
...@@ -1070,7 +1093,8 @@ c_returncode_type = CIntType(2, 1, "T_INT", is_returncode = 1) ...@@ -1070,7 +1093,8 @@ c_returncode_type = CIntType(2, 1, "T_INT", is_returncode = 1)
c_anon_enum_type = CAnonEnumType(-1, 1) c_anon_enum_type = CAnonEnumType(-1, 1)
# the Py_buffer type is defined in Builtin.py # the Py_buffer type is defined in Builtin.py
c_py_buffer_ptr_type = CPtrType(CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")) c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
error_type = ErrorType() error_type = ErrorType()
......
# #
# Pyrex - Symbol Table # Symbol Table
# #
import re import re
...@@ -19,6 +19,16 @@ import __builtin__ ...@@ -19,6 +19,16 @@ import __builtin__
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux:
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
self.buffer_info_var = buffer_info_var
self.stridevars = stridevars
self.shapevars = shapevars
self.tschecker = tschecker
def __repr__(self):
return "<BufferAux %r>" % self.__dict__
class Entry: class Entry:
# A symbol table entry in a Scope or ModuleNamespace. # A symbol table entry in a Scope or ModuleNamespace.
# #
...@@ -47,6 +57,7 @@ class Entry: ...@@ -47,6 +57,7 @@ class Entry:
# is_self_arg boolean Is the "self" arg of an exttype method # is_self_arg boolean Is the "self" arg of an exttype method
# is_arg boolean Is the arg of a method # is_arg boolean Is the arg of a method
# is_local boolean Is a local variable # is_local boolean Is a local variable
# in_closure boolean Is referenced in an inner scope
# is_readonly boolean Can't be assigned to # is_readonly boolean Can't be assigned to
# func_cname string C func implementing Python func # func_cname string C func implementing Python func
# pos position Source position where declared # pos position Source position where declared
...@@ -75,6 +86,8 @@ class Entry: ...@@ -75,6 +86,8 @@ class Entry:
# defined_in_pxd boolean Is defined in a .pxd file (not just declared) # defined_in_pxd boolean Is defined in a .pxd file (not just declared)
# api boolean Generate C API for C class or function # api boolean Generate C API for C class or function
# utility_code string Utility code needed when this entry is used # utility_code string Utility code needed when this entry is used
#
# buffer_aux BufferAux or None Extra information needed for buffer variables
borrowed = 0 borrowed = 0
init = "" init = ""
...@@ -96,6 +109,7 @@ class Entry: ...@@ -96,6 +109,7 @@ class Entry:
is_self_arg = 0 is_self_arg = 0
is_arg = 0 is_arg = 0
is_local = 0 is_local = 0
in_closure = 0
is_declared_generic = 0 is_declared_generic = 0
is_readonly = 0 is_readonly = 0
func_cname = None func_cname = None
...@@ -115,6 +129,7 @@ class Entry: ...@@ -115,6 +129,7 @@ class Entry:
api = 0 api = 0
utility_code = None utility_code = None
is_overridable = 0 is_overridable = 0
buffer_aux = None
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -163,6 +178,8 @@ class Scope: ...@@ -163,6 +178,8 @@ class Scope:
in_cinclude = 0 in_cinclude = 0
nogil = 0 nogil = 0
temp_prefix = Naming.pyrex_prefix
def __init__(self, name, outer_scope, parent_scope): def __init__(self, name, outer_scope, parent_scope):
# The outer_scope is the next scope in the lookup chain. # The outer_scope is the next scope in the lookup chain.
# The parent_scope is used to derive the qualified name of this scope. # The parent_scope is used to derive the qualified name of this scope.
...@@ -447,7 +464,14 @@ class Scope: ...@@ -447,7 +464,14 @@ class Scope:
# Look up name in this scope or an enclosing one. # Look up name in this scope or an enclosing one.
# Return None if not found. # Return None if not found.
return (self.lookup_here(name) return (self.lookup_here(name)
or (self.outer_scope and self.outer_scope.lookup(name)) or (self.outer_scope and self.outer_scope.lookup_from_inner(name))
or None)
def lookup_from_inner(self, name):
# Look up name in this scope or an enclosing one.
# This is only called from enclosing scopes.
return (self.lookup_here(name)
or (self.outer_scope and self.outer_scope.lookup_from_inner(name))
or None) or None)
def lookup_here(self, name): def lookup_here(self, name):
...@@ -562,7 +586,7 @@ class Scope: ...@@ -562,7 +586,7 @@ class Scope:
return entry.cname return entry.cname
n = self.temp_counter n = self.temp_counter
self.temp_counter = n + 1 self.temp_counter = n + 1
cname = "%s%d" % (Naming.pyrex_prefix, n) cname = "%s%d" % (self.temp_prefix, n)
entry = Entry("", cname, type) entry = Entry("", cname, type)
entry.used = 1 entry.used = 1
if type.is_pyobject or type == PyrexTypes.c_py_ssize_t_type: if type.is_pyobject or type == PyrexTypes.c_py_ssize_t_type:
...@@ -608,6 +632,9 @@ class Scope: ...@@ -608,6 +632,9 @@ class Scope:
return 0 return 0
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname
def __init__(self): def __init__(self):
Scope.__init__(self, Options.pre_import, None, None) Scope.__init__(self, Options.pre_import, None, None)
...@@ -615,7 +642,6 @@ class PreImportScope(Scope): ...@@ -615,7 +642,6 @@ class PreImportScope(Scope):
entry = self.declare(name, name, py_object_type, pos) entry = self.declare(name, name, py_object_type, pos)
entry.is_variable = True entry.is_variable = True
entry.is_pyglobal = True entry.is_pyglobal = True
entry.namespace_cname = Naming.preimport_cname
return entry return entry
...@@ -761,6 +787,7 @@ class ModuleScope(Scope): ...@@ -761,6 +787,7 @@ class ModuleScope(Scope):
self.has_extern_class = 0 self.has_extern_class = 0
self.cached_builtins = [] self.cached_builtins = []
self.undeclared_cached_builtins = [] self.undeclared_cached_builtins = []
self.namespace_cname = self.module_cname
def qualifying_scope(self): def qualifying_scope(self):
return self.parent_module return self.parent_module
...@@ -876,7 +903,6 @@ class ModuleScope(Scope): ...@@ -876,7 +903,6 @@ class ModuleScope(Scope):
raise InternalError( raise InternalError(
"Non-cdef global variable is not a generic Python object") "Non-cdef global variable is not a generic Python object")
entry.is_pyglobal = 1 entry.is_pyglobal = 1
entry.namespace_cname = self.module_cname
else: else:
entry.is_cglobal = 1 entry.is_cglobal = 1
self.var_entries.append(entry) self.var_entries.append(entry)
...@@ -1075,7 +1101,6 @@ class ModuleScope(Scope): ...@@ -1075,7 +1101,6 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
class LocalScope(Scope): class LocalScope(Scope):
def __init__(self, name, outer_scope): def __init__(self, name, outer_scope):
...@@ -1119,6 +1144,33 @@ class LocalScope(Scope): ...@@ -1119,6 +1144,33 @@ class LocalScope(Scope):
entry = self.global_scope().lookup_target(name) entry = self.global_scope().lookup_target(name)
self.entries[name] = entry self.entries[name] = entry
def lookup_from_inner(self, name):
entry = self.lookup_here(name)
if entry:
entry.in_closure = 1
return entry
else:
return (self.outer_scope and self.outer_scope.lookup_from_inner(name)) or None
def mangle_closure_cnames(self, scope_var):
for entry in self.entries.values():
if entry.in_closure:
if not hasattr(entry, 'orig_cname'):
entry.orig_cname = entry.cname
entry.cname = scope_var + "->" + entry.cname
class GeneratorLocalScope(LocalScope):
temp_prefix = Naming.cur_scope_cname + "->" + LocalScope.temp_prefix
def mangle_closure_cnames(self, scope_var):
for entry in self.entries.values() + self.temp_entries:
entry.in_closure = 1
LocalScope.mangle_closure_cnames(self, scope_var)
# def mangle(self, prefix, name):
# return "%s->%s" % (Naming.scope_obj_cname, name)
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
...@@ -1198,7 +1250,6 @@ class PyClassScope(ClassScope): ...@@ -1198,7 +1250,6 @@ class PyClassScope(ClassScope):
entry = Scope.declare_var(self, name, type, pos, entry = Scope.declare_var(self, name, type, pos,
cname, visibility, is_cdef) cname, visibility, is_cdef)
entry.is_pyglobal = 1 entry.is_pyglobal = 1
entry.namespace_cname = self.class_obj_cname
return entry return entry
def allocate_temp(self, type): def allocate_temp(self, type):
...@@ -1295,7 +1346,7 @@ class CClassScope(ClassScope): ...@@ -1295,7 +1346,7 @@ class CClassScope(ClassScope):
entry.is_pyglobal = 1 # xxx: is_pyglobal changes behaviour in so many places that entry.is_pyglobal = 1 # xxx: is_pyglobal changes behaviour in so many places that
# I keep it in for now. is_member should be enough # I keep it in for now. is_member should be enough
# later on # later on
entry.namespace_cname = "(PyObject *)%s" % self.parent_type.typeptr_cname self.namespace_cname = "(PyObject *)%s" % self.parent_type.typeptr_cname
entry.interned_cname = self.intern_identifier(name) entry.interned_cname = self.intern_identifier(name)
return entry return entry
......
from Cython.TestUtils import CythonTest
import Cython.Compiler.Errors as Errors
from Cython.Compiler.Nodes import *
from Cython.Compiler.ParseTreeTransforms import *
class TestBufferParsing(CythonTest):
# First, we only test the raw parser, i.e.
# the number and contents of arguments are NOT checked.
# However "dtype"/the first positional argument is special-cased
# to parse a type argument rather than an expression
def parse(self, s):
return self.should_not_fail(lambda: self.fragment(s)).root
def not_parseable(self, expected_error, s):
e = self.should_fail(lambda: self.fragment(s), Errors.CompileError)
self.assertEqual(expected_error, e.message_only)
def test_basic(self):
t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
bufnode = t.stats[0].base_type
self.assert_(isinstance(bufnode, CBufferAccessTypeNode))
self.assertEqual(2, len(bufnode.positional_args))
# print bufnode.dump()
# should put more here...
def test_type_fail(self):
self.not_parseable("Expected: type",
u"cdef object[2] x")
def test_type_pos(self):
self.parse(u"cdef object[short unsigned int, 3] x")
def test_type_keyword(self):
self.parse(u"cdef object[foo=foo, dtype=short unsigned int] x")
def test_notype_as_expr1(self):
self.not_parseable("Expected: expression",
u"cdef object[foo2=short unsigned int] x")
def test_notype_as_expr2(self):
self.not_parseable("Expected: expression",
u"cdef object[int, short unsigned int] x")
def test_pos_after_key(self):
self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x")
class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets
def parse_opts(self, opts):
s = u"cdef object[%s] x" % opts
root = self.fragment(s, pipeline=[PostParse(self)]).root
buftype = root.stats[0].base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, e.message_only)
def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim)
def test_dict(self):
buf = self.parse_opts(u"ndim=3, dtype=unsigned short int")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim)
def test_dtype(self):
self.non_parse(ERR_BUF_MISSING % 'dtype', u"")
def test_ndim(self):
self.parse_opts(u"int, 2")
self.non_parse(ERR_BUF_INT % 'ndim', u"int, 'a'")
self.non_parse(ERR_BUF_NONNEG % 'ndim', u"int, -34")
def test_use_DEF(self):
t = self.fragment(u"""
DEF ndim = 3
cdef object[int, ndim] x
cdef object[ndim=ndim, dtype=int] y
""", pipeline=[PostParse(self)]).root
self.assert_(t.stats[1].base_type.ndim == 3)
self.assert_(t.stats[2].base_type.ndim == 3)
# add exotic and impossible combinations as they come along
import unittest
from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import DecoratorTransform
class TestDecorator(TransformTest):
def test_decorator(self):
t = self.run_pipeline([DecoratorTransform(None)], u"""
def decorator(fun):
return fun
@decorator
def decorated():
pass
""")
self.assertCode(u"""
def decorator(fun):
return fun
def decorated():
pass
decorated = decorator(decorated)
""", t)
if __name__ == '__main__':
unittest.main()
from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import *
class TestNormalizeTree(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root
self.assertLines(u"""
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_wrap_singlestat(self):
t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
self.assertLines(u"""
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_wrap_multistat(self):
t = self.run_pipeline([NormalizeTree(None)], u"""
if z:
x
y
""")
self.assertLines(u"""
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
stats[1]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_statinexpr(self):
t = self.run_pipeline([NormalizeTree(None)], u"""
a, b = x, y
""")
self.assertLines(u"""
(root): StatListNode
stats[0]: ParallelAssignmentNode
stats[0]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
stats[1]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
""", self.treetypes(t))
def test_wrap_offagain(self):
t = self.run_pipeline([NormalizeTree(None)], u"""
x
y
if z:
x
""")
self.assertLines(u"""
(root): StatListNode
stats[0]: ExprStatNode
expr: NameNode
stats[1]: ExprStatNode
expr: NameNode
stats[2]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_pass_eliminated(self):
t = self.run_pipeline([NormalizeTree(None)], u"pass")
self.assert_(len(t.stats) == 0)
class TestWithTransform(TransformTest):
def test_simplified(self):
t = self.run_pipeline([WithTransform(None)], u"""
with x:
y = z ** 3
""")
self.assertCode(u"""
$MGR = x
$EXIT = $MGR.__exit__
$MGR.__enter__()
$EXC = True
try:
try:
y = z ** 3
except:
$EXC = False
if (not $EXIT($EXCINFO)):
raise
finally:
if $EXC:
$EXIT(None, None, None)
""", t)
def test_basic(self):
t = self.run_pipeline([WithTransform(None)], u"""
with x as y:
y = z ** 3
""")
self.assertCode(u"""
$MGR = x
$EXIT = $MGR.__exit__
$VALUE = $MGR.__enter__()
$EXC = True
try:
try:
y = $VALUE
y = z ** 3
except:
$EXC = False
if (not $EXIT($EXCINFO)):
raise
finally:
if $EXC:
$EXIT(None, None, None)
""", t)
if __name__ == "__main__":
import unittest
unittest.main()
from Cython.TestUtils import CythonTest from Cython.TestUtils import CythonTest
from Cython.Compiler.TreeFragment import * from Cython.Compiler.TreeFragment import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
import Cython.Compiler.Naming as Naming
class TestTreeFragments(CythonTest): class TestTreeFragments(CythonTest):
def test_basic(self): def test_basic(self):
F = self.fragment(u"x = 4") F = self.fragment(u"x = 4")
T = F.copy() T = F.copy()
...@@ -12,15 +14,15 @@ class TestTreeFragments(CythonTest): ...@@ -12,15 +14,15 @@ class TestTreeFragments(CythonTest):
F = self.fragment(u"if True: x = 4") F = self.fragment(u"if True: x = 4")
T1 = F.root T1 = F.root
T2 = F.copy() T2 = F.copy()
self.assertEqual("x", T2.body.if_clauses[0].body.lhs.name) self.assertEqual("x", T2.stats[0].if_clauses[0].body.lhs.name)
T2.body.if_clauses[0].body.lhs.name = "other" T2.stats[0].if_clauses[0].body.lhs.name = "other"
self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name) self.assertEqual("x", T1.stats[0].if_clauses[0].body.lhs.name)
def test_substitutions_are_copied(self): def test_substitutions_are_copied(self):
T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")}) T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")})
self.assertEqual("x", T.body.expr.operand1.name) self.assertEqual("x", T.stats[0].expr.operand1.name)
self.assertEqual("x", T.body.expr.operand2.name) self.assertEqual("x", T.stats[0].expr.operand2.name)
self.assert_(T.body.expr.operand1 is not T.body.expr.operand2) self.assert_(T.stats[0].expr.operand1 is not T.stats[0].expr.operand2)
def test_substitution(self): def test_substitution(self):
F = self.fragment(u"x = 4") F = self.fragment(u"x = 4")
...@@ -32,7 +34,7 @@ class TestTreeFragments(CythonTest): ...@@ -32,7 +34,7 @@ class TestTreeFragments(CythonTest):
F = self.fragment(u"PASS") F = self.fragment(u"PASS")
pass_stat = PassStatNode(pos=None) pass_stat = PassStatNode(pos=None)
T = F.substitute({"PASS" : pass_stat}) T = F.substitute({"PASS" : pass_stat})
self.assert_(isinstance(T.body, PassStatNode), T.body) self.assert_(isinstance(T.stats[0], PassStatNode), T)
def test_pos_is_transferred(self): def test_pos_is_transferred(self):
F = self.fragment(u""" F = self.fragment(u"""
...@@ -40,11 +42,22 @@ class TestTreeFragments(CythonTest): ...@@ -40,11 +42,22 @@ class TestTreeFragments(CythonTest):
x = u * v ** w x = u * v ** w
""") """)
T = F.substitute({"v" : NameNode(pos=None, name="a")}) T = F.substitute({"v" : NameNode(pos=None, name="a")})
v = F.root.body.stats[1].rhs.operand2.operand1 v = F.root.stats[1].rhs.operand2.operand1
a = T.body.stats[1].rhs.operand2.operand1 a = T.stats[1].rhs.operand2.operand1
self.assertEquals(v.pos, a.pos) self.assertEquals(v.pos, a.pos)
def test_temps(self):
import Cython.Compiler.Visitor as v
v.tmpnamectr = 0
F = self.fragment(u"""
TMP
x = TMP
""")
T = F.substitute(temps=[u"TMP"])
s = T.stats
self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name)
self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP")
self.assert_(s[0].expr.name != u"TMP")
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
...@@ -6,8 +6,8 @@ import re ...@@ -6,8 +6,8 @@ import re
from cStringIO import StringIO from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
from Visitor import VisitorTransform from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node from Nodes import Node, StatListNode
from ExprNodes import NameNode from ExprNodes import NameNode
import Parsing import Parsing
import Main import Main
...@@ -53,7 +53,7 @@ def parse_from_strings(name, code, pxds={}): ...@@ -53,7 +53,7 @@ def parse_from_strings(name, code, pxds={}):
buf = StringIO(code.encode(encoding)) buf = StringIO(code.encode(encoding))
scanner = PyrexScanner(buf, code_source, source_encoding = encoding, scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
type_names = scope.type_names, context = context) scope = scope, context = context)
tree = Parsing.p_module(scanner, 0, module_name) tree = Parsing.p_module(scanner, 0, module_name)
return tree return tree
...@@ -93,25 +93,54 @@ class TemplateTransform(VisitorTransform): ...@@ -93,25 +93,54 @@ class TemplateTransform(VisitorTransform):
same way. It is the responsibility of the caller to make sure same way. It is the responsibility of the caller to make sure
that the replacement nodes is a valid expression. that the replacement nodes is a valid expression.
Also a list "temps" should be passed. Any names listed will
be transformed into anonymous, temporary names.
Currently supported for tempnames is:
NameNode
(various function and class definition nodes etc. should be added to this)
Each replacement node gets the position of the substituted node Each replacement node gets the position of the substituted node
recursively applied to every member node. recursively applied to every member node.
""" """
def __call__(self, node, substitutions, temps, pos):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = temp_name_handle(key)
self.temps = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
def visit_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
return node return None
else: else:
c = node.clone_node() c = node.clone_node()
if self.pos is not None:
c.pos = self.pos
self.visitchildren(c) self.visitchildren(c)
return c return c
def try_substitution(self, node, key): def try_substitution(self, node, key):
sub = self.substitutions.get(key) sub = self.substitutions.get(key)
if sub is None: if sub is not None:
return self.visit_Node(node) # make copy as usual pos = self.pos
if pos is None: pos = node.pos
return ApplyPositionAndCopy(pos)(sub)
else: else:
return ApplyPositionAndCopy(node.pos)(sub) return self.visit_Node(node) # make copy as usual
def visit_NameNode(self, node): def visit_NameNode(self, node):
tempname = self.temps.get(node.name)
if tempname is not None:
# Replace name with temporary
node.name = tempname
return self.visit_Node(node)
else:
return self.try_substitution(node, node.name) return self.try_substitution(node, node.name)
def visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
...@@ -122,10 +151,6 @@ class TemplateTransform(VisitorTransform): ...@@ -122,10 +151,6 @@ class TemplateTransform(VisitorTransform):
else: else:
return self.visit_Node(node) return self.visit_Node(node)
def __call__(self, node, substitutions):
self.substitutions = substitutions
return super(TemplateTransform, self).__call__(node)
def copy_code_tree(node): def copy_code_tree(node):
return TreeCopier()(node) return TreeCopier()(node)
...@@ -133,12 +158,12 @@ INDENT_RE = re.compile(ur"^ *") ...@@ -133,12 +158,12 @@ INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines): def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines" "Strips empty lines and common indentation from the list of strings given in lines"
lines = [x for x in lines if x.strip() != u""] lines = [x for x in lines if x.strip() != u""]
minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines) minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
lines = [x[minindent:] for x in lines] lines = [x[minindent:] for x in lines]
return lines return lines
class TreeFragment(object): class TreeFragment(object):
def __init__(self, code, name, pxds={}): def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[]):
if isinstance(code, unicode): if isinstance(code, unicode):
def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
...@@ -147,18 +172,28 @@ class TreeFragment(object): ...@@ -147,18 +172,28 @@ class TreeFragment(object):
for key, value in pxds.iteritems(): for key, value in pxds.iteritems():
fmt_pxds[key] = fmt(value) fmt_pxds[key] = fmt(value)
self.root = parse_from_strings(name, fmt_code, fmt_pxds) t = parse_from_strings(name, fmt_code, fmt_pxds)
mod = t
t = t.body # Make sure a StatListNode is at the top
if not isinstance(t, StatListNode):
t = StatListNode(pos=mod.pos, stats=[t])
for transform in pipeline:
t = transform(t)
self.root = t
elif isinstance(code, Node): elif isinstance(code, Node):
if pxds != {}: raise NotImplementedError() if pxds != {}: raise NotImplementedError()
self.root = code self.root = code
else: else:
raise ValueError("Unrecognized code format (accepts unicode and Node)") raise ValueError("Unrecognized code format (accepts unicode and Node)")
self.temps = temps
def copy(self): def copy(self):
return copy_code_tree(self.root) return copy_code_tree(self.root)
def substitute(self, nodes={}): def substitute(self, nodes={}, temps=[], pos = None):
return TemplateTransform()(self.root, substitutions = nodes) return TemplateTransform()(self.root,
substitutions = nodes,
temps = self.temps + temps, pos = pos)
......
# #
# Tree visitor and transform framework # Tree visitor and transform framework
# #
import inspect
import Nodes import Nodes
import ExprNodes import ExprNodes
import inspect import Naming
from Cython.Utils import EncodedString
class BasicVisitor(object): class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object.""" """A generic visitor base class which can be used for visiting any kind of object."""
...@@ -129,7 +131,6 @@ class VisitorTransform(TreeVisitor): ...@@ -129,7 +131,6 @@ class VisitorTransform(TreeVisitor):
was not, an exception will be raised. (Typically you want to ensure that you was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = super(VisitorTransform, self).visitchildren(parent, attrs) result = super(VisitorTransform, self).visitchildren(parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
...@@ -150,6 +151,19 @@ class VisitorTransform(TreeVisitor): ...@@ -150,6 +151,19 @@ class VisitorTransform(TreeVisitor):
def __call__(self, root): def __call__(self, root):
return self.visit(root) return self.visit(root)
class CythonTransform(VisitorTransform):
"""
Certain common conventions and utilitues for Cython transforms.
"""
def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
def visit_Node(self, node):
self.visitchildren(node)
return node
# Utils # Utils
def ensure_statlist(node): def ensure_statlist(node):
if not isinstance(node, Nodes.StatListNode): if not isinstance(node, Nodes.StatListNode):
...@@ -166,12 +180,29 @@ def replace_node(ptr, value): ...@@ -166,12 +180,29 @@ def replace_node(ptr, value):
else: else:
getattr(parent, attrname)[listidx] = value getattr(parent, attrname)[listidx] = value
tmpnamectr = 0
def temp_name_handle(description=None):
global tmpnamectr
tmpnamectr += 1
if description is not None:
name = u"%d_%s" % (tmpnamectr, description)
else:
name = u"%d" % tmpnamectr
return EncodedString(Naming.temp_prefix + name)
def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"):
return None
else:
idx = handle.find(u"_", 6)
return handle[idx+1:]
class PrintTree(TreeVisitor): class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output. """Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information Subclass and override repr_of to provide more information
about nodes. """ about nodes. """
def __init__(self): def __init__(self):
Transform.__init__(self) TreeVisitor.__init__(self)
self._indent = "" self._indent = ""
def indent(self): def indent(self):
...@@ -181,6 +212,8 @@ class PrintTree(TreeVisitor): ...@@ -181,6 +212,8 @@ class PrintTree(TreeVisitor):
def __call__(self, tree, phase=None): def __call__(self, tree, phase=None):
print("Parse tree dump at phase '%s'" % phase) print("Parse tree dump at phase '%s'" % phase)
self.visit(tree)
return tree
# Don't do anything about process_list, the defaults gives # Don't do anything about process_list, the defaults gives
# nice-looking name[idx] nodes which will visually appear # nice-looking name[idx] nodes which will visually appear
......
...@@ -4,8 +4,49 @@ import unittest ...@@ -4,8 +4,49 @@ import unittest
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
from Cython.Compiler.Visitor import TreeVisitor
class NodeTypeWriter(TreeVisitor):
def __init__(self):
super(NodeTypeWriter, self).__init__()
self._indents = 0
self.result = []
def visit_Node(self, node):
if len(self.access_path) == 0:
name = u"(root)"
else:
tip = self.access_path[-1]
if tip[2] is not None:
name = u"%s[%d]" % tip[1:3]
else:
name = tip[1]
self.result.append(u" " * self._indents +
u"%s: %s" % (name, node.__class__.__name__))
self._indents += 1
self.visitchildren(node)
self._indents -= 1
def treetypes(root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class CythonTest(unittest.TestCase): class CythonTest(unittest.TestCase):
def assertLines(self, expected, result):
"Checks that the given strings or lists of strings are equal line by line"
if not isinstance(expected, list): expected = expected.split(u"\n")
if not isinstance(result, list): result = result.split(u"\n")
for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
self.assertEqual(len(expected), len(result),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
def assertCode(self, expected, result_tree): def assertCode(self, expected, result_tree):
writer = CodeWriter() writer = CodeWriter()
writer.write(result_tree) writer.write(result_tree)
...@@ -18,13 +59,35 @@ class CythonTest(unittest.TestCase): ...@@ -18,13 +59,35 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(result_lines), len(expected_lines), self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
def fragment(self, code, pxds={}): def fragment(self, code, pxds={}, pipeline=[]):
"Simply create a tree fragment using the name of the test-case in parse errors." "Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id() name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):] if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_") name = name.replace(".", "_")
return TreeFragment(code, name, pxds) return TreeFragment(code, name, pxds, pipeline=pipeline)
def treetypes(self, root):
return treetypes(root)
def should_fail(self, func, exc_type=Exception):
"""Calls "func" and fails if it doesn't raise the right exception
(any exception by default). Also returns the exception in question.
"""
try:
func()
self.fail("Expected an exception of type %r" % exc_type)
except exc_type, e:
self.assert_(isinstance(e, exc_type))
return e
def should_not_fail(self, func):
"""Calls func and succeeds if and only if no exception is raised
(i.e. converts exception raising into a failed testcase). Returns
the return value of func."""
try:
return func()
except:
self.fail()
class TransformTest(CythonTest): class TransformTest(CythonTest):
""" """
...@@ -37,8 +100,8 @@ class TransformTest(CythonTest): ...@@ -37,8 +100,8 @@ class TransformTest(CythonTest):
To create a test case: To create a test case:
- Call run_pipeline. The pipeline should at least contain the transform you - Call run_pipeline. The pipeline should at least contain the transform you
are testing; pyx should be either a string (passed to the parser to are testing; pyx should be either a string (passed to the parser to
create a post-parse tree) or a ModuleNode representing input to pipeline. create a post-parse tree) or a node representing input to pipeline.
The result will be a transformed result (usually a ModuleNode). The result will be a transformed result.
- Check that the tree is correct. If wanted, assertCode can be used, which - Check that the tree is correct. If wanted, assertCode can be used, which
takes a code string as expected, and a ModuleNode in result_tree takes a code string as expected, and a ModuleNode in result_tree
...@@ -53,7 +116,6 @@ class TransformTest(CythonTest): ...@@ -53,7 +116,6 @@ class TransformTest(CythonTest):
def run_pipeline(self, pipeline, pyx, pxds={}): def run_pipeline(self, pipeline, pyx, pxds={}):
tree = self.fragment(pyx, pxds).root tree = self.fragment(pyx, pxds).root
assert isinstance(tree, ModuleNode)
# Run pipeline # Run pipeline
for T in pipeline: for T in pipeline:
tree = T(tree) tree = T(tree)
......
...@@ -73,6 +73,9 @@ class TestCodeWriter(CythonTest): ...@@ -73,6 +73,9 @@ class TestCodeWriter(CythonTest):
def test_inplace_assignment(self): def test_inplace_assignment(self):
self.t(u"x += 43") self.t(u"x += 43")
def test_attribute(self):
self.t(u"a.x")
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
unittest.main() unittest.main()
......
...@@ -7,6 +7,7 @@ from Cython.Distutils import build_ext ...@@ -7,6 +7,7 @@ from Cython.Distutils import build_ext
ext_modules=[ ext_modules=[
Extension("primes", ["primes.pyx"]), Extension("primes", ["primes.pyx"]),
Extension("spam", ["spam.pyx"]), Extension("spam", ["spam.pyx"]),
# Extension("optargs", ["optargs.pyx"], language = "c++"),
] ]
for file in glob.glob("*.pyx"): for file in glob.glob("*.pyx"):
......
cdef extern from "Python.h":
ctypedef struct PyObject
ctypedef struct Py_buffer:
void *buf
Py_ssize_t len
int readonly
char *format
int ndim
Py_ssize_t *shape
Py_ssize_t *strides
Py_ssize_t *suboffsets
Py_ssize_t itemsize
void *internal
int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
void PyErr_Format(int, char*, ...)
enum:
PyExc_TypeError
# int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
# int flags)
cdef extern from "Python.h":
ctypedef int Py_intptr_t
cdef extern from "numpy/arrayobject.h":
ctypedef class numpy.ndarray [object PyArrayObject]:
cdef char *data
cdef int nd
cdef Py_intptr_t *dimensions
cdef Py_intptr_t *strides
cdef object base
# descr not implemented yet here...
cdef int flags
cdef int itemsize
cdef object weakreflist
ctypedef unsigned int npy_uint8
ctypedef unsigned int npy_uint16
ctypedef unsigned int npy_uint32
ctypedef unsigned int npy_uint64
ctypedef unsigned int npy_uint96
ctypedef unsigned int npy_uint128
ctypedef signed int npy_int64
ctypedef float npy_float32
ctypedef float npy_float64
ctypedef float npy_float80
ctypedef float npy_float96
ctypedef float npy_float128
ctypedef npy_int64 Tint64
...@@ -9,8 +9,8 @@ from distutils.core import Extension ...@@ -9,8 +9,8 @@ from distutils.core import Extension
from distutils.command.build_ext import build_ext from distutils.command.build_ext import build_ext
distutils_distro = Distribution() distutils_distro = Distribution()
TEST_DIRS = ['compile', 'errors', 'run'] TEST_DIRS = ['compile', 'errors', 'run', 'pyregr']
TEST_RUN_DIRS = ['run'] TEST_RUN_DIRS = ['run', 'pyregr']
INCLUDE_DIRS = [ d for d in os.getenv('INCLUDE', '').split(os.pathsep) if d ] INCLUDE_DIRS = [ d for d in os.getenv('INCLUDE', '').split(os.pathsep) if d ]
CFLAGS = os.getenv('CFLAGS', '').split() CFLAGS = os.getenv('CFLAGS', '').split()
...@@ -78,15 +78,21 @@ class TestBuilder(object): ...@@ -78,15 +78,21 @@ class TestBuilder(object):
filenames = os.listdir(path) filenames = os.listdir(path)
filenames.sort() filenames.sort()
for filename in filenames: for filename in filenames:
if not filename.endswith(".pyx"): if not (filename.endswith(".pyx") or filename.endswith(".py")):
continue continue
module = filename[:-4] if context == 'pyregr' and not filename.startswith('test_'):
continue
module = os.path.splitext(filename)[0]
fqmodule = "%s.%s" % (context, module) fqmodule = "%s.%s" % (context, module)
if not [ 1 for match in self.selectors if not [ 1 for match in self.selectors
if match(fqmodule) ]: if match(fqmodule) ]:
continue continue
if context in TEST_RUN_DIRS: if context in TEST_RUN_DIRS:
test = CythonRunTestCase( if module.startswith("test_"):
build_test = CythonUnitTestCase
else:
build_test = CythonRunTestCase
test = build_test(
path, workdir, module, path, workdir, module,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir) cleanup_workdir=self.cleanup_workdir)
...@@ -133,11 +139,21 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -133,11 +139,21 @@ class CythonCompileTestCase(unittest.TestCase):
os.makedirs(self.workdir) os.makedirs(self.workdir)
def runTest(self): def runTest(self):
self.runCompileTest()
def runCompileTest(self):
self.compile(self.directory, self.module, self.workdir, self.compile(self.directory, self.module, self.workdir,
self.directory, self.expect_errors, self.annotate) self.directory, self.expect_errors, self.annotate)
def find_module_source_file(self, source_file):
if not os.path.exists(source_file):
source_file = source_file[:-1]
return source_file
def split_source_and_output(self, directory, module, workdir): def split_source_and_output(self, directory, module, workdir):
source_and_output = open(os.path.join(directory, module + '.pyx'), 'rU') source_file = os.path.join(directory, module) + '.pyx'
source_and_output = open(
self.find_module_source_file(source_file), 'rU')
out = open(os.path.join(workdir, module + '.pyx'), 'w') out = open(os.path.join(workdir, module + '.pyx'), 'w')
for line in source_and_output: for line in source_and_output:
last_line = line last_line = line
...@@ -157,7 +173,8 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -157,7 +173,8 @@ class CythonCompileTestCase(unittest.TestCase):
include_dirs = INCLUDE_DIRS[:] include_dirs = INCLUDE_DIRS[:]
if incdir: if incdir:
include_dirs.append(incdir) include_dirs.append(incdir)
source = os.path.join(directory, module + '.pyx') source = self.find_module_source_file(
os.path.join(directory, module + '.pyx'))
target = os.path.join(targetdir, module + '.c') target = os.path.join(targetdir, module + '.c')
options = CompilationOptions( options = CompilationOptions(
pyrex_default_options, pyrex_default_options,
...@@ -228,7 +245,7 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -228,7 +245,7 @@ class CythonRunTestCase(CythonCompileTestCase):
result = self.defaultTestResult() result = self.defaultTestResult()
result.startTest(self) result.startTest(self)
try: try:
self.runTest() self.runCompileTest()
doctest.DocTestSuite(self.module).run(result) doctest.DocTestSuite(self.module).run(result)
except Exception: except Exception:
result.addError(self, sys.exc_info()) result.addError(self, sys.exc_info())
...@@ -238,6 +255,48 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -238,6 +255,48 @@ class CythonRunTestCase(CythonCompileTestCase):
except Exception: except Exception:
pass pass
class CythonUnitTestCase(CythonCompileTestCase):
def shortDescription(self):
return "compiling and running unit tests in " + self.module
def run(self, result=None):
if result is None:
result = self.defaultTestResult()
result.startTest(self)
try:
self.runCompileTest()
unittest.defaultTestLoader.loadTestsFromName(self.module).run(result)
except Exception:
result.addError(self, sys.exc_info())
result.stopTest(self)
try:
self.tearDown()
except Exception:
pass
def collect_unittests(path, suite, selectors):
def file_matches(filename):
return filename.startswith("Test") and filename.endswith(".py")
def package_matches(dirname):
return dirname == "Tests"
loader = unittest.TestLoader()
for dirpath, dirnames, filenames in os.walk(path):
parentname = os.path.split(dirpath)[-1]
if package_matches(parentname):
for f in filenames:
if file_matches(f):
filepath = os.path.join(dirpath, f)[:-len(".py")]
modulename = filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]:
continue
module = __import__(modulename)
for x in modulename.split('.')[1:]:
module = getattr(module, x)
suite.addTests([loader.loadTestsFromModule(module)])
if __name__ == '__main__': if __name__ == '__main__':
from optparse import OptionParser from optparse import OptionParser
parser = OptionParser() parser = OptionParser()
...@@ -247,6 +306,12 @@ if __name__ == '__main__': ...@@ -247,6 +306,12 @@ if __name__ == '__main__':
parser.add_option("--no-cython", dest="with_cython", parser.add_option("--no-cython", dest="with_cython",
action="store_false", default=True, action="store_false", default=True,
help="do not run the Cython compiler, only the C compiler") help="do not run the Cython compiler, only the C compiler")
parser.add_option("--no-unit", dest="unittests",
action="store_false", default=True,
help="do not run the unit tests")
parser.add_option("--no-file", dest="filetests",
action="store_false", default=True,
help="do not run the file based tests")
parser.add_option("-C", "--coverage", dest="coverage", parser.add_option("-C", "--coverage", dest="coverage",
action="store_true", default=False, action="store_true", default=False,
help="collect source coverage data for the Compiler") help="collect source coverage data for the Compiler")
...@@ -296,9 +361,15 @@ if __name__ == '__main__': ...@@ -296,9 +361,15 @@ if __name__ == '__main__':
if not selectors: if not selectors:
selectors = [ lambda x:True ] selectors = [ lambda x:True ]
tests = TestBuilder(ROOTDIR, WORKDIR, selectors, test_suite = unittest.TestSuite()
if options.unittests:
collect_unittests(os.getcwd(), test_suite, selectors)
if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source, options.cleanup_workdir) options.annotate_source, options.cleanup_workdir)
test_suite = tests.build_suite() test_suite.addTests([filetests.build_suite()])
unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite) unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
......
def f(a, b):
assert a, a+b
cdef class Parrot:
cdef describe(self):
print "This is a parrot."
cdef action(self):
print "Polly wants a cracker!"
cdef extern from "foo.h":
ctypedef long long big_t
cdef void spam(big_t b)
spam(grail)
cdef int f() except -1:
cdef type t
cdef object x
t = buffer
t = enumerate
t = file
t = float
t = int
t = long
t = open
t = property
t = str
t = tuple
t = xrange
x = True
x = False
x = Ellipsis
x = Exception
x = StopIteration
x = StandardError
x = ArithmeticError
x = LookupError
x = AssertionError
x = AssertionError
x = EOFError
x = FloatingPointError
x = EnvironmentError
x = IOError
x = OSError
x = ImportError
x = IndexError
x = KeyError
x = KeyboardInterrupt
x = MemoryError
x = NameError
x = OverflowError
x = RuntimeError
x = NotImplementedError
x = SyntaxError
x = IndentationError
x = TabError
x = ReferenceError
x = SystemError
x = SystemExit
x = TypeError
x = UnboundLocalError
x = UnicodeError
x = UnicodeEncodeError
x = UnicodeDecodeError
x = UnicodeTranslateError
x = ValueError
x = ZeroDivisionError
x = MemoryErrorInst
x = Warning
x = UserWarning
x = DeprecationWarning
x = PendingDeprecationWarning
x = SyntaxWarning
#x = OverflowWarning # Does not seem to exist in 2.5
x = RuntimeWarning
x = FutureWarning
typecheck(x, Exception)
try:
pass
except ValueError:
pass
cdef int f() except -1:
cdef dict d
cdef object x, z
cdef int i
z = dict
d = dict(x)
d = dict(*x)
d.clear()
z = d.copy()
z = d.items()
z = d.keys()
z = d.values()
d.merge(x, i)
d.update(x)
d.merge_pairs(x, i)
cdef int f() except -1:
cdef list l
cdef object x, y, z
z = list
l = list(x)
l = list(*y)
z = l.insert
l.insert(17, 42)
l.append(88)
l.sort()
l.reverse()
z = l.as_tuple()
cdef int f() except -1:
cdef slice s
cdef object z
cdef int i
z = slice
s = slice(1, 2, 3)
z = slice.indices()
i = s.start
i = s.stop
i = s.step
cdef int f() except -1:
cdef type t1, t2
cdef object x
cdef int b
b = typecheck(x, t1)
b = issubtype(t1, t2)
cdef void foo():
cdef int i, j, k
i = j = k
a = b = c
i = j = c
a = b = k
(a, b), c = (d, e), f = (x, y), z
# a, b = p, q = x, y
\ No newline at end of file
cdef extern from "cdefemptysue.h":
cdef struct spam:
pass
ctypedef union eggs:
pass
cdef enum ham:
pass
cdef extern spam s
cdef extern eggs e
cdef extern ham h
cdef extern from "cheese.h":
ctypedef int camembert
struct roquefort:
int x
char *swiss
void cheddar()
class external.runny [object runny_obj]:
cdef int a
def __init__(self):
pass
cdef runny r
r = x
r.a = 42
cdef extern from "cheese.h":
pass
cdef int f():
pass
cdef char *g(int k, float z):
pass
cimport spam
cimport pkg.eggs
cdef spam.Spam yummy
cdef pkg.eggs.Eggs fried
spam.eat(yummy)
spam.tons = 3.14
ova = pkg.eggs
fried = pkg.eggs.Eggs()
from spam cimport Spam
from pkg.eggs cimport Eggs as ova
cdef extern Spam yummy
cdef ova fried
fried = None
from package.inpackage cimport Spam
cdef Spam s2
from cexportfunc cimport f, g
cdef extern from "ctypedefextern.h":
ctypedef int some_int
ctypedef some_int *some_ptr
cdef void spam():
cdef some_int i
cdef some_ptr p
p[0] = i
cdef class Spam:
answer = 42
cdef object f(object x) nogil:
pass
cdef void g(int x) nogil:
cdef object z
z = None
cdef void h(int x) nogil:
p()
cdef object p() nogil:
pass
cdef void r() nogil:
q()
cdef void (*fp)()
cdef void (*fq)() nogil
cdef extern void u()
cdef object m():
global fp, fq
cdef object x, y, obj
cdef int i, j, k
global fred
q()
with nogil:
r()
q()
i = 42
obj = None
17L
7j
asdf
`"Hello"`
import fred
from fred import obj
for x in obj:
pass
obj[i]
obj[i:j]
obj[i:j:k]
obj.fred
(x, y)
[x, y]
{x: y}
obj and x
t(obj)
f(42)
x + obj
-obj
x = y = obj
x, y = y, x
obj[i] = x
obj.fred = x
print obj
del fred
return obj
raise obj
if obj:
pass
while obj:
pass
for x <= obj <= y:
pass
try:
pass
except:
pass
try:
pass
finally:
pass
fq = u
fq = fp
cdef void q():
pass
cdef class C:
pass
cdef void t(C c) nogil:
pass
cdef extern from "foo.h":
int fred()
cdef extern from "externsue.h":
enum Eggs:
runny, firm, hard
struct Spam:
int i
union Soviet:
char c
cdef extern Eggs e
cdef extern Spam s
cdef extern Soviet u
cdef void tomato():
global e
e = runny
e = firm
e = hard
cdef class Widget:
pass
cdef class Container:
pass
cdef Widget w
cdef Container c
w.parent = c
cdef class Spam:
cdef public object eggs
def __getattr__(self, name):
print "Spam getattr:", name
cdef int f() except -1:
g = getattr3
cdef public int grail
cdef public spam(int servings):
pass
cdef public class sandwich [object sandwich, type sandwich_Type]:
cdef int tomato
cdef float lettuce
include "i_public.pxi"
cdef struct S:
int q
cdef int f() except -1:
cdef int i, j, k
cdef float x, y, z
cdef object a, b, c, d, e
cdef int m[3]
cdef S s
global g
i += j + k
x += y + z
x += i
a += b + c
g += a
m[i] += j
a[i] += b + c
a[b + c] += d
(a + b)[c] += d
a[i : j] += b
(a + b)[i : j] += c
a.b += c + d
(a + b).c += d
s.q += i
cdef int f() except -1:
cdef object a, b
cdef char *p
a += b
a -= b
a *= b
a /= b
a %= b
a **= b
a <<= b
a >>= b
a &= b
a ^= b
a |= b
p += 42
p -= 42
p += a
cdef int f() except -1:
cdef object x, y, z
cdef int i
cdef unsigned int ui
z = x[y]
z = x[i]
x[y] = z
x[i] = z
z = x[ui]
x[ui] = z
cdef api float f(float x):
return 0.5 * x * x
cdef int f(int x):
return x * x
cdef int g(int x):
return 5 * x
def f(a, *p, **n):
pass
cdef int f() except -1:
cdef int i, x, y
for x < i < y:
pass
cdef int f() except -1:
cdef int i, x, y
for i from x < i < y:
pass
cimport spam, eggs
cdef extern spam.Spam yummy
cdef eggs.Eggs fried
fried = None
from spam cimport Spam
from eggs cimport Eggs
cdef extern Spam yummy
cdef Eggs fried
fried = None
def f():
cdef list l
l = list()
l.append("second")
l.insert(0, "first")
return l
cdef extern from "l_capi_api.h":
float f(float)
int import_l_capi() except -1
def test():
print f(3.1415)
import_l_capi()
cimport l_cfuncexport
from l_cfuncexport cimport g
print l_cfuncexport.f(42)
print g(42)
class Spam:
"""Spam, glorious spam!"""
cdef int tomato() except -1:
print "Entering tomato"
raise Exception("Eject! Eject! Eject!")
print "Leaving tomato"
cdef void sandwich():
print "Entering sandwich"
tomato()
print "Leaving sandwich"
def snack():
print "Entering snack"
tomato()
print "Leaving snack"
def lunch():
print "Entering lunch"
sandwich()
print "Leaving lunch"
cdef class SpamDish:
cdef int spam
cdef void describe(self):
print "This dish contains", self.spam, "tons of spam."
cdef class FancySpamDish(SpamDish):
cdef int lettuce
cdef void describe(self):
print "This dish contains", self.spam, "tons of spam",
print "and", self.lettuce, "milligrams of lettuce."
cdef void describe_dish(SpamDish d):
d.describe()
def test():
cdef SpamDish s
cdef FancySpamDish ss
s = SpamDish()
s.spam = 42
ss = FancySpamDish()
ss.spam = 88
ss.lettuce = 5
describe_dish(s)
describe_dish(ss)
from b_extimpinherit cimport Parrot
cdef class Norwegian(Parrot):
cdef action(self):
print "This parrot is resting."
cdef plumage(self):
print "Lovely plumage!"
def main():
cdef Parrot p
cdef Norwegian n
p = Parrot()
n = Norwegian()
print "Parrot:"
p.describe()
p.action()
print "Norwegian:"
n.describe()
n.action()
n.plumage()
cdef class Parrot:
cdef object plumage
def __init__(self):
self.plumage = "yellow"
def describe(self):
print "This bird has lovely", self.plumage, "plumage."
cdef class Norwegian(Parrot):
def __init__(self):
self.plumage = "blue"
cdef class Spam:
cdef public int tons
cdef readonly float tastiness
cdef int temperature
def __init__(self, tons, tastiness, temperature):
self.tons = tons
self.tastiness = tastiness
self.temperature = temperature
def get_temperature(self):
return self.temperature
cdef extern from "numeric.h":
struct PyArray_Descr:
int type_num, elsize
char type
ctypedef class Numeric.ArrayType [object PyArrayObject]:
cdef char *data
cdef int nd
cdef int *dimensions, *strides
cdef object base
cdef PyArray_Descr *descr
cdef int flags
def ogle(ArrayType a):
print "No. of dimensions:", a.nd
print " Dim Value"
for i in range(a.nd):
print "%5d %5d" % (i, a.dimensions[i])
print "flags:", a.flags
print "Type no.", a.descr.type_num
print "Element size:", a.descr.elsize
cdef class CheeseShop:
cdef object cheeses
def __cinit__(self):
self.cheeses = []
property cheese:
"A senseless waste of a property."
def __get__(self):
return "We don't have: %s" % self.cheeses
def __set__(self, value):
self.cheeses.append(value)
def __del__(self):
del self.cheeses[:]
cdef class Animal:
cdef object __weakref__
cdef public object name
def test(obj, attr, dflt):
return getattr3(obj, attr, dflt)
import spam
print "Imported spam"
print dir(spam)
import sys
print "Imported sys"
print sys
cdef class Parrot:
cdef void describe(self):
print "This parrot is resting."
cdef class Norwegian(Parrot):
cdef void describe(self):
Parrot.describe(self)
print "Lovely plumage!"
cdef Parrot p1, p2
p1 = Parrot()
p2 = Norwegian()
p1.describe()
p2.describe()
def pd(d):
l = []
i = d.items()
i.sort()
for kv in i:
l.append("%r: %r" % kv)
return "{%s}" % ", ".join(l)
def c(a, b, c):
print "a =", a, "b =", b, "c =", c
def d(a, b, *, c = 88):
print "a =", a, "b =", b, "c =", c
def e(a, b, c = 88, **kwds):
print "a =", a, "b =", b, "c =", c, "kwds =", pd(kwds)
def f(a, b, *, c, d = 42):
print "a =", a, "b =", b, "c =", c, "d =", d
def g(a, b, *, c, d = 42, e = 17, f, **kwds):
print "a =", a, "b =", b, "c =", c, "d =", d, "e =", e, "f =", f, "kwds =", pd(kwds)
def h(a, b, *args, c, d = 42, e = 17, f, **kwds):
print "a =", a, "b =", b, "args =", args, "c =", c, "d =", d, "e =", e, "f =", f, "kwds =", pd(kwds)
class Inquisition(object):
"""Something that nobody expects."""
def __repr__(self):
return "Surprise!"
def f():
print "Spam!"
f()
def foo():
raise Exception
cdef int spam() except -1:
raise Exception("Spam error")
cdef int grail() except -1:
spam()
def tomato():
grail()
seq = [1, [2, 3]]
def f():
a, (b, c) = [1, [2, 3]]
print a
print b
print c
def g():
a, b, c = seq
def h():
a, = seq
def f():
return 42
cdef int g():
cdef object x
return x
DEF nan = float('nan')
DEF inf = float('inf')
DEF minf = -float('inf')
cdef int f() except -1:
cdef float x, y, z
x = nan
y = inf
z = minf
import sys
from Pyrex.Compiler.Main import main
sys.argv[1:] = "-I spam -Ieggs --include-dir ham".split()
main(command_line = 1)
cdef class Spam:
pass
def probe():
pass
cdef class Spam:
pass
cdef int f() except -1:
cdef type t
t = Spam
cdef class C:
cdef int i
cdef int f() except -1:
cdef object x
cdef void *p
cdef int i
x = <object>p
p = <void *>x
x = (<object>p).foo
i = (<C>p).i
(<C>p).i = i
...@@ -2,6 +2,7 @@ def f(): ...@@ -2,6 +2,7 @@ def f():
cdef int int1, int3 cdef int int1, int3
cdef int *ptr1, *ptr2, *ptr3 cdef int *ptr1, *ptr2, *ptr3
ptr1 = ptr2 + ptr3 # error ptr1 = ptr2 + ptr3 # error
_ERRORS = u""" _ERRORS = u"""
/Local/Projects/D/Pyrex/Source/Tests/Errors2/e_addop.pyx:4:13: Invalid operand types for '+' (int *; int *) 4:13: Invalid operand types for '+' (int *; int *)
""" """
__doc__ = """
>>> test()
2
"""
def test():
cdef int x[2][2]
x[0][0] = 1
x[0][1] = 2
x[1][0] = 3
x[1][1] = 4
return f(x)[1]
cdef int* f(int x[2][2]):
return x[0]
__doc__ = u"""
>>> f(1,2)
4
>>> f.HERE
1
>>> g(1,2)
5
>>> g.HERE
5
>>> h(1,2)
6
>>> h.HERE
1
"""
class wrap:
def __init__(self, func):
self.func = func
self.HERE = 1
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def decorate(func):
try:
func.HERE += 1
except AttributeError:
func = wrap(func)
return func
def decorate2(a,b):
return decorate
@decorate
def f(a,b):
return a+b+1
@decorate
@decorate
@decorate
@decorate
@decorate
def g(a,b):
return a+b+2
@decorate2(1,2)
def h(a,b):
return a+b+3
__doc__ = u"""
>>> test_in('ABC')
1
>>> test_in('abc')
2
>>> test_in('X')
3
>>> test_in('XYZ')
4
>>> test_in('ABCXYZ')
5
>>> test_in('')
5
>>> test_not_in('abc')
1
>>> test_not_in('CDE')
2
>>> test_not_in('CDEF')
3
>>> test_not_in('BCD')
4
"""
def test_in(s):
if s in ('ABC', 'BCD'):
return 1
elif s.upper() in ('ABC', 'BCD'):
return 2
elif len(s) in (1,2):
return 3
elif len(s) in (3,4):
return 4
else:
return 5
def test_not_in(s):
if s not in ('ABC', 'BCD', 'CDE', 'CDEF'):
return 1
elif s.upper() not in ('ABC', 'BCD', 'CDEF'):
return 2
elif len(s) not in [3]:
return 3
elif len(s) not in [1,2]:
return 4
else:
return 5
from __future__ import division
__doc__ = """
>>> from future_division import doit
>>> doit(1,2)
(0.5, 0)
>>> doit(4,3)
(1.3333333333333333, 1)
>>> doit(4,3.0)
(1.3333333333333333, 1.0)
>>> doit(4,2)
(2.0, 2)
"""
def doit(x,y):
return x/y, x//y
__doc__ = u""" __doc__ = u"""
>>> go() >>> go_py()
Spam!
Spam!
Spam!
Spam!
Spam!
>>> go_c()
Spam! Spam!
Spam! Spam!
Spam! Spam!
...@@ -7,7 +14,11 @@ __doc__ = u""" ...@@ -7,7 +14,11 @@ __doc__ = u"""
Spam! Spam!
""" """
def go(): def go_py():
for i in range(5): for i in range(5):
print u"Spam!" print u"Spam!"
def go_c():
cdef int i
for i in range(5):
print u"Spam!"
__doc__ = u"""
>>> test()
5
"""
def test():
a = b = c = 5
return a
__doc__ = u"""
>>> switch_simple_py(1)
1
>>> switch_simple_py(2)
2
>>> switch_simple_py(3)
3
>>> switch_simple_py(4)
8
>>> switch_simple_py(5)
0
>>> switch_py(1)
1
>>> switch_py(2)
2
>>> switch_py(3)
3
>>> switch_py(4)
4
>>> switch_py(5)
4
>>> switch_py(6)
0
>>> switch_py(8)
4
>>> switch_py(10)
10
>>> switch_py(12)
12
>>> switch_py(13)
0
>>> switch_simple_c(1)
1
>>> switch_simple_c(2)
2
>>> switch_simple_c(3)
3
>>> switch_simple_c(4)
8
>>> switch_simple_c(5)
0
>>> switch_c(1)
1
>>> switch_c(2)
2
>>> switch_c(3)
3
>>> switch_c(4)
4
>>> switch_c(5)
4
>>> switch_c(6)
0
>>> switch_c(8)
4
>>> switch_c(10)
10
>>> switch_c(12)
12
>>> switch_c(13)
0
"""
def switch_simple_py(x):
if x == 1:
return 1
elif 2 == x:
return 2
elif x in [3]:
return 3
elif x in (4,):
return 8
else:
return 0
return -1
def switch_py(x):
if x == 1:
return 1
elif 2 == x:
return 2
elif x in [3]:
return 3
elif x in [4,5,7,8]:
return 4
elif x in (10,11):
return 10
elif x in (12,):
return 12
else:
return 0
return -1
def switch_simple_c(int x):
if x == 1:
return 1
elif 2 == x:
return 2
elif x in [3]:
return 3
elif x in (4,):
return 8
else:
return 0
return -1
def switch_c(int x):
if x == 1:
return 1
elif 2 == x:
return 2
elif x in [3]:
return 3
elif x in [4,5,7,8]:
return 4
elif x in (10,11):
return 10
elif x in (12,):
return 12
else:
return 0
return -1
from __future__ import with_statement
__doc__ = u"""
>>> no_as()
enter
hello
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> basic()
enter
value
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> with_exception(None)
enter
value
exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
outer except
>>> with_exception(True)
enter
value
exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
>>> multitarget()
enter
1 2 3 4 5
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> tupletarget()
enter
(1, 2, (3, (4, 5)))
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> typed()
enter
10
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
class MyException(Exception):
pass
class ContextManager:
def __init__(self, value, exit_ret = None):
self.value = value
self.exit_ret = exit_ret
def __exit__(self, a, b, tb):
print "exit", type(a), type(b), type(tb)
return self.exit_ret
def __enter__(self):
print "enter"
return self.value
def no_as():
with ContextManager("value"):
print "hello"
def basic():
with ContextManager("value") as x:
print x
def with_exception(exit_ret):
try:
with ContextManager("value", exit_ret=exit_ret) as value:
print value
raise MyException()
except:
print "outer except"
def multitarget():
with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
print a, b, c, d, e
def tupletarget():
with ContextManager((1, 2, (3, (4, 5)))) as t:
print t
def typed():
cdef unsigned char i
c = ContextManager(255)
with c as i:
i += 11
print i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment