Commit f9d5e812 authored by Robert Bradshaw's avatar Robert Bradshaw

merge latest cython-devel into cpp branch

parents bcf2668c baaee17d
......@@ -61,7 +61,7 @@ class CodeWriter(TreeVisitor):
self.startline(s)
self.endline()
def comma_seperated_list(self, items, output_rhs=False):
def comma_separated_list(self, items, output_rhs=False):
if len(items) > 0:
for item in items[:-1]:
self.visit(item)
......@@ -82,7 +82,7 @@ class CodeWriter(TreeVisitor):
def visit_FuncDefNode(self, node):
self.startline(u"def %s(" % node.name)
self.comma_seperated_list(node.args)
self.comma_separated_list(node.args)
self.endline(u"):")
self.indent()
self.visit(node.body)
......@@ -167,7 +167,7 @@ class CodeWriter(TreeVisitor):
def visit_PrintStatNode(self, node):
self.startline(u"print ")
self.comma_seperated_list(node.arg_tuple.args)
self.comma_separated_list(node.arg_tuple.args)
if not node.append_newline:
self.put(u",")
self.endline()
......@@ -181,7 +181,7 @@ class CodeWriter(TreeVisitor):
self.startline(u"cdef ")
self.visit(node.base_type)
self.put(u" ")
self.comma_seperated_list(node.declarators, output_rhs=True)
self.comma_separated_list(node.declarators, output_rhs=True)
self.endline()
def visit_ForInStatNode(self, node):
......@@ -200,12 +200,12 @@ class CodeWriter(TreeVisitor):
self.dedent()
def visit_SequenceNode(self, node):
self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
self.comma_separated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
def visit_SimpleCallNode(self, node):
self.visit(node.function)
self.put(u"(")
self.comma_seperated_list(node.args)
self.comma_separated_list(node.args)
self.put(")")
def visit_GeneralCallNode(self, node):
......@@ -215,7 +215,7 @@ class CodeWriter(TreeVisitor):
if isinstance(posarg, AsTupleNode):
self.visit(posarg.arg)
else:
self.comma_seperated_list(posarg)
self.comma_separated_list(posarg)
if node.keyword_args is not None or node.starstar_arg is not None:
raise Exception("Not implemented yet")
self.put(u")")
......
......@@ -6,6 +6,7 @@ from PyrexTypes import py_object_type
from Builtin import dict_type
from StringEncoding import EncodedString
import Naming
import Symtab
class AutoTestDictTransform(ScopeTrackingTransform):
# Handles autotestdict directive
......@@ -82,6 +83,16 @@ class AutoTestDictTransform(ScopeTrackingTransform):
type=py_object_type,
is_py_attr=True,
is_temp=True)
if isinstance(node.entry.scope, Symtab.PropertyScope):
new_node = AttributeNode(pos, obj=parent,
attribute=node.entry.scope.name,
type=py_object_type,
is_py_attr=True,
is_temp=True)
parent = new_node
name = "%s.%s.%s" % (clsname, node.entry.scope.name,
node.entry.name)
else:
name = "%s.%s" % (clsname, node.entry.name)
else:
assert False
......
......@@ -723,8 +723,8 @@ typedef struct {
} __Pyx_BufFmt_StackElem;
static CYTHON_INLINE int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack);
static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
static int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack);
""", impl="""
static CYTHON_INLINE int __Pyx_IsLittleEndian(void) {
unsigned int n = 1;
......@@ -1131,7 +1131,7 @@ static CYTHON_INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
buf->suboffsets = __Pyx_minusones;
}
static int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack) {
static CYTHON_INLINE int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack) {
if (obj == Py_None) {
__Pyx_ZeroBuffer(buf);
return 0;
......
......@@ -31,7 +31,7 @@ builtin_function_table = [
('intern', "O", "O", "__Pyx_Intern"),
('isinstance', "OO", "b", "PyObject_IsInstance"),
('issubclass', "OO", "b", "PyObject_IsSubclass"),
('iter', "O", "O", "PyObject_GetIter"),
#('iter', "O", "O", "PyObject_GetIter"), # optimised later on
('len', "O", "Z", "PyObject_Length"),
('locals', "", "O", "__pyx_locals"),
#('map', "", "", ""),
......
......@@ -2034,6 +2034,10 @@ class IndexNode(ExprNode):
else:
function = "__Pyx_GetItemInt"
code.globalstate.use_utility_code(getitem_int_utility_code)
else:
if self.base.type is dict_type:
function = "__Pyx_PyDict_GetItem"
code.globalstate.use_utility_code(getitem_dict_utility_code)
else:
function = "PyObject_GetItem"
index_code = self.index.py_result()
......@@ -2274,7 +2278,7 @@ class SliceIndexNode(ExprNode):
self.base.py_result(),
self.start_code(),
self.stop_code(),
rhs.result()))
rhs.py_result()))
else:
start_offset = ''
if self.start:
......@@ -3438,8 +3442,6 @@ class SequenceNode(ExprNode):
# allocates the temps in a rather hacky way -- the assignment
# is evaluated twice, within each if-block.
code.globalstate.use_utility_code(unpacking_utility_code)
if rhs.type is tuple_type:
tuple_check = "likely(%s != Py_None)"
else:
......@@ -3475,6 +3477,8 @@ class SequenceNode(ExprNode):
rhs.py_result(), len(self.args)))
code.putln(code.error_goto(self.pos))
else:
code.globalstate.use_utility_code(unpacking_utility_code)
self.iterator.allocate(code)
code.putln(
"%s = PyObject_GetIter(%s); %s" % (
......@@ -6445,8 +6449,64 @@ impl = ""
#------------------------------------------------------------------------------------
# If the is_unsigned flag is set, we need to do some extra work to make
# sure the index doesn't become negative.
raise_noneattr_error_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_RaiseNoneAttributeError(const char* attrname);
""",
impl = '''
static CYTHON_INLINE void __Pyx_RaiseNoneAttributeError(const char* attrname) {
PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%s'", attrname);
}
''')
raise_noneindex_error_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_RaiseNoneIndexingError(void);
""",
impl = '''
static CYTHON_INLINE void __Pyx_RaiseNoneIndexingError(void) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is unsubscriptable");
}
''')
raise_none_iter_error_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void);
""",
impl = '''
static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable");
}
''')
#------------------------------------------------------------------------------------
getitem_dict_utility_code = UtilityCode(
proto = """
#if PY_MAJOR_VERSION >= 3
static PyObject *__Pyx_PyDict_GetItem(PyObject *d, PyObject* key) {
PyObject *value;
if (unlikely(d == Py_None)) {
__Pyx_RaiseNoneIndexingError();
return NULL;
}
value = PyDict_GetItemWithError(d, key);
if (unlikely(!value)) {
if (!PyErr_Occurred())
PyErr_SetObject(PyExc_KeyError, key);
return NULL;
}
Py_INCREF(value);
return value;
}
#else
#define __Pyx_PyDict_GetItem(d, key) PyObject_GetItem(d, key)
#endif
""",
requires = [raise_noneindex_error_utility_code])
#------------------------------------------------------------------------------------
getitem_int_utility_code = UtilityCode(
proto = """
......@@ -6575,36 +6635,6 @@ impl = """
#------------------------------------------------------------------------------------
raise_noneattr_error_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_RaiseNoneAttributeError(const char* attrname);
""",
impl = '''
static CYTHON_INLINE void __Pyx_RaiseNoneAttributeError(const char* attrname) {
PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%s'", attrname);
}
''')
raise_noneindex_error_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_RaiseNoneIndexingError(void);
""",
impl = '''
static CYTHON_INLINE void __Pyx_RaiseNoneIndexingError(void) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is unsubscriptable");
}
''')
raise_none_iter_error_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void);
""",
impl = '''
static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable");
}
''')
raise_too_many_values_to_unpack = UtilityCode(
proto = """
static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(void);
......
......@@ -9,5 +9,6 @@ def _get_feature(name):
unicode_literals = _get_feature("unicode_literals")
with_statement = _get_feature("with_statement")
division = _get_feature("division")
print_function = _get_feature("print_function")
del _get_feature
......@@ -88,7 +88,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from TypeInference import MarkAssignments, MarkOverflowingArithmatic
from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform
from AutoDocTransforms import EmbedSignature
......@@ -135,7 +135,7 @@ class Context(object):
EmbedSignature(self),
EarlyReplaceBuiltinCalls(self),
MarkAssignments(self),
MarkOverflowingArithmatic(self),
MarkOverflowingArithmetic(self),
TransformBuiltinMethods(self),
IntroduceBufferAuxiliaryVars(self),
_check_c_declarations,
......
......@@ -1726,7 +1726,7 @@ class PyArgDeclNode(Node):
class DecoratorNode(Node):
# A decorator
#
# decorator NameNode or CallNode
# decorator NameNode or CallNode or AttributeNode
child_attrs = ['decorator']
......@@ -2025,7 +2025,7 @@ class DefNode(FuncDefNode):
def declare_python_arg(self, env, arg):
if arg:
if env.directives['infer_types'] != 'none':
if env.directives['infer_types'] != False:
type = PyrexTypes.unspecified_type
else:
type = py_object_type
......@@ -2441,7 +2441,7 @@ class DefNode(FuncDefNode):
# it looks funny to separate the init-to-0 from setting the
# default value, but C89 needs this
code.putln("PyObject* values[%d] = {%s};" % (
max_args, ','.join(['0']*max_args)))
max_args, ','.join('0'*max_args)))
for i, default_value in default_args:
code.putln('values[%d] = %s;' % (i, default_value))
......@@ -3292,7 +3292,7 @@ class ParallelAssignmentNode(AssignmentNode):
class InPlaceAssignmentNode(AssignmentNode):
# An in place arithmatic operand:
# An in place arithmetic operand:
#
# a += b
# a -= b
......@@ -3447,11 +3447,15 @@ class PrintStatNode(StatNode):
# print statement
#
# arg_tuple TupleNode
# stream ExprNode or None (stdout)
# append_newline boolean
child_attrs = ["arg_tuple"]
child_attrs = ["arg_tuple", "stream"]
def analyse_expressions(self, env):
if self.stream:
self.stream.analyse_expressions(env)
self.stream = self.stream.coerce_to_pyobject(env)
self.arg_tuple.analyse_expressions(env)
self.arg_tuple = self.arg_tuple.coerce_to_pyobject(env)
env.use_utility_code(printing_utility_code)
......@@ -3462,12 +3466,18 @@ class PrintStatNode(StatNode):
gil_message = "Python print statement"
def generate_execution_code(self, code):
if self.stream:
self.stream.generate_evaluation_code(code)
stream_result = self.stream.py_result()
else:
stream_result = '0'
if len(self.arg_tuple.args) == 1 and self.append_newline:
arg = self.arg_tuple.args[0]
arg.generate_evaluation_code(code)
code.putln(
"if (__Pyx_PrintOne(%s) < 0) %s" % (
"if (__Pyx_PrintOne(%s, %s) < 0) %s" % (
stream_result,
arg.py_result(),
code.error_goto(self.pos)))
arg.generate_disposal_code(code)
......@@ -3475,14 +3485,21 @@ class PrintStatNode(StatNode):
else:
self.arg_tuple.generate_evaluation_code(code)
code.putln(
"if (__Pyx_Print(%s, %d) < 0) %s" % (
"if (__Pyx_Print(%s, %s, %d) < 0) %s" % (
stream_result,
self.arg_tuple.py_result(),
self.append_newline,
code.error_goto(self.pos)))
self.arg_tuple.generate_disposal_code(code)
self.arg_tuple.free_temps(code)
if self.stream:
self.stream.generate_disposal_code(code)
self.stream.free_temps(code)
def annotate(self, code):
if self.stream:
self.stream.annotate(code)
self.arg_tuple.annotate(code)
......@@ -5028,12 +5045,18 @@ else:
printing_utility_code = UtilityCode(
proto = """
static int __Pyx_Print(PyObject *, int); /*proto*/
static int __Pyx_Print(PyObject*, PyObject *, int); /*proto*/
#if PY_MAJOR_VERSION >= 3
static PyObject* %s = 0;
static PyObject* %s = 0;
#endif
""" % (Naming.print_function, Naming.print_function_kwargs),
cleanup = """
#if PY_MAJOR_VERSION >= 3
Py_CLEAR(%s);
Py_CLEAR(%s);
#endif
""" % (Naming.print_function, Naming.print_function_kwargs),
impl = r"""
#if PY_MAJOR_VERSION < 3
static PyObject *__Pyx_GetStdout(void) {
......@@ -5044,13 +5067,14 @@ static PyObject *__Pyx_GetStdout(void) {
return f;
}
static int __Pyx_Print(PyObject *arg_tuple, int newline) {
PyObject *f;
static int __Pyx_Print(PyObject* f, PyObject *arg_tuple, int newline) {
PyObject* v;
int i;
if (!f) {
if (!(f = __Pyx_GetStdout()))
return -1;
}
for (i=0; i < PyTuple_GET_SIZE(arg_tuple); i++) {
if (PyFile_SoftSpace(f, 1)) {
if (PyFile_WriteString(" ", f) < 0)
......@@ -5078,22 +5102,38 @@ static int __Pyx_Print(PyObject *arg_tuple, int newline) {
#else /* Python 3 has a print function */
static int __Pyx_Print(PyObject *arg_tuple, int newline) {
static int __Pyx_Print(PyObject* stream, PyObject *arg_tuple, int newline) {
PyObject* kwargs = 0;
PyObject* result = 0;
PyObject* end_string;
if (!%(PRINT_FUNCTION)s) {
if (unlikely(!%(PRINT_FUNCTION)s)) {
%(PRINT_FUNCTION)s = __Pyx_GetAttrString(%(BUILTINS)s, "print");
if (!%(PRINT_FUNCTION)s)
return -1;
}
if (stream) {
kwargs = PyDict_New();
if (unlikely(!kwargs))
return -1;
if (unlikely(PyDict_SetItemString(kwargs, "file", stream) < 0))
goto bad;
if (!newline) {
if (!%(PRINT_KWARGS)s) {
end_string = PyUnicode_FromStringAndSize(" ", 1);
if (unlikely(!end_string))
goto bad;
if (PyDict_SetItemString(kwargs, "end", end_string) < 0) {
Py_DECREF(end_string);
goto bad;
}
Py_DECREF(end_string);
}
} else if (!newline) {
if (unlikely(!%(PRINT_KWARGS)s)) {
%(PRINT_KWARGS)s = PyDict_New();
if (!%(PRINT_KWARGS)s)
if (unlikely(!%(PRINT_KWARGS)s))
return -1;
end_string = PyUnicode_FromStringAndSize(" ", 1);
if (!end_string)
if (unlikely(!end_string))
return -1;
if (PyDict_SetItemString(%(PRINT_KWARGS)s, "end", end_string) < 0) {
Py_DECREF(end_string);
......@@ -5104,10 +5144,16 @@ static int __Pyx_Print(PyObject *arg_tuple, int newline) {
kwargs = %(PRINT_KWARGS)s;
}
result = PyObject_Call(%(PRINT_FUNCTION)s, arg_tuple, kwargs);
if (unlikely(kwargs) && (kwargs != %(PRINT_KWARGS)s))
Py_DECREF(kwargs);
if (!result)
return -1;
Py_DECREF(result);
return 0;
bad:
if (kwargs != %(PRINT_KWARGS)s)
Py_XDECREF(kwargs);
return -1;
}
#endif
......@@ -5119,15 +5165,16 @@ static int __Pyx_Print(PyObject *arg_tuple, int newline) {
printing_one_utility_code = UtilityCode(
proto = """
static int __Pyx_PrintOne(PyObject *o); /*proto*/
static int __Pyx_PrintOne(PyObject* stream, PyObject *o); /*proto*/
""",
impl = r"""
#if PY_MAJOR_VERSION < 3
static int __Pyx_PrintOne(PyObject *o) {
PyObject *f;
static int __Pyx_PrintOne(PyObject* f, PyObject *o) {
if (!f) {
if (!(f = __Pyx_GetStdout()))
return -1;
}
if (PyFile_SoftSpace(f, 0)) {
if (PyFile_WriteString(" ", f) < 0)
return -1;
......@@ -5139,19 +5186,19 @@ static int __Pyx_PrintOne(PyObject *o) {
return 0;
/* the line below is just to avoid compiler
* compiler warnings about unused functions */
return __Pyx_Print(NULL, 0);
return __Pyx_Print(f, NULL, 0);
}
#else /* Python 3 has a print function */
static int __Pyx_PrintOne(PyObject *o) {
static int __Pyx_PrintOne(PyObject* stream, PyObject *o) {
int res;
PyObject* arg_tuple = PyTuple_New(1);
if (unlikely(!arg_tuple))
return -1;
Py_INCREF(o);
PyTuple_SET_ITEM(arg_tuple, 0, o);
res = __Pyx_Print(arg_tuple, 1);
res = __Pyx_Print(stream, arg_tuple, 1);
Py_DECREF(arg_tuple);
return res;
}
......
......@@ -1121,6 +1121,9 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
exception_check = True)
def _handle_simple_function_float(self, node, pos_args):
"""Transform float() into either a C type cast or a faster C
function call.
"""
# Note: this requires the float() function to be typed as
# returning a C 'double'
if len(pos_args) != 1:
......@@ -1158,6 +1161,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
])
def _handle_simple_function_getattr(self, node, pos_args):
"""Replace 2/3 argument forms of getattr() by C-API calls.
"""
if len(pos_args) == 2:
return ExprNodes.PythonCapiCallNode(
node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
......@@ -1173,16 +1178,42 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
return node
PyObject_GetIter_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
])
PyCallIter_New_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("sentinel", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_iter(self, node, pos_args):
"""Replace 1/2 argument forms of iter() by C-API calls.
"""
if len(pos_args) == 1:
return ExprNodes.PythonCapiCallNode(
node.pos, "PyObject_GetIter", self.PyObject_GetIter_func_type,
args = pos_args,
is_temp = node.is_temp)
elif len(pos_args) == 2:
return ExprNodes.PythonCapiCallNode(
node.pos, "PyCallIter_New", self.PyCallIter_New_func_type,
args = pos_args,
is_temp = node.is_temp)
else:
self._error_wrong_arg_count('iter', node, pos_args, '1 or 2')
return node
Pyx_strlen_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_size_t_type, [
PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
])
def _handle_simple_function_len(self, node, pos_args):
# note: this only works because we already replaced len() by
# PyObject_Length() which returns a Py_ssize_t instead of a
# Python object, so we can return a plain size_t instead
# without caring about Python object conversion etc.
"""Replace len(char*) by the equivalent call to strlen().
"""
if len(pos_args) != 1:
self._error_wrong_arg_count('len', node, pos_args, 1)
return node
......@@ -1191,6 +1222,13 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
arg = arg.arg
if not arg.type.is_string:
return node
if not node.type.is_numeric:
# this optimisation only works when we already replaced
# len() by PyObject_Length() which returns a Py_ssize_t
# instead of a Python object, so we can return a plain
# size_t instead without caring about Python object
# conversion etc.
return node
node = ExprNodes.PythonCapiCallNode(
node.pos, "strlen", self.Pyx_strlen_func_type,
args = [arg],
......@@ -1205,6 +1243,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
])
def _handle_simple_function_type(self, node, pos_args):
"""Replace type(o) by a macro call to Py_TYPE(o).
"""
if len(pos_args) != 1:
return node
node = ExprNodes.PythonCapiCallNode(
......@@ -1269,7 +1309,9 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
])
def _handle_simple_method_object_append(self, node, args, is_unbound_method):
# X.append() is almost always referring to a list
"""Optimistic optimisation as X.append() is almost always
referring to a list.
"""
if len(args) != 2:
return node
......@@ -1292,7 +1334,9 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
])
def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
# X.pop([n]) is almost always referring to a list
"""Optimistic optimisation as X.pop([n]) is almost always
referring to a list.
"""
if len(args) == 1:
return ExprNodes.PythonCapiCallNode(
node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
......@@ -1322,6 +1366,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
exception_value = "-1")
def _handle_simple_method_list_append(self, node, args, is_unbound_method):
"""Call PyList_Append() instead of l.append().
"""
if len(args) != 2:
self._error_wrong_arg_count('list.append', node, args, 2)
return node
......@@ -1336,6 +1382,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
exception_value = "-1")
def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
"""Call PyList_Sort() instead of the 0-argument l.sort().
"""
if len(args) != 1:
return node
return self._substitute_method_call(
......@@ -1343,6 +1391,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
'sort', is_unbound_method, args)
def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
"""Call PyList_Reverse() instead of l.reverse().
"""
if len(args) != 1:
self._error_wrong_arg_count('list.reverse', node, args, 1)
return node
......@@ -1350,6 +1400,28 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
node, "PyList_Reverse", self.single_param_func_type,
'reverse', is_unbound_method, args)
Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
],
exception_value = "NULL")
def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
"""Replace dict.get() by a call to PyDict_GetItem().
"""
if len(args) == 2:
args.append(ExprNodes.NoneNode(node.pos))
elif len(args) != 3:
self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
return node
return self._substitute_method_call(
node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
'get', is_unbound_method, args,
utility_code = dict_getitem_default_utility_code)
PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
Builtin.bytes_type, [
PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
......@@ -1371,6 +1443,9 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
for name in _special_encodings ]
def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
"""Replace unicode.encode(...) by a direct C-API call to the
corresponding codec.
"""
if len(args) < 1 or len(args) > 3:
self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
return node
......@@ -1436,6 +1511,9 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
exception_value = "NULL")
def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
"""Replace char*.decode() by a direct C-API call to the
corresponding codec, possibly resoving a slice on the char*.
"""
if len(args) < 1 or len(args) > 3:
self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
return node
......@@ -1575,7 +1653,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
return (encoding, encoding_node, error_handling, error_handling_node)
def _substitute_method_call(self, node, name, func_type,
attr_name, is_unbound_method, args=()):
attr_name, is_unbound_method, args=(),
utility_code=None):
args = list(args)
if args:
self_arg = args[0]
......@@ -1592,10 +1671,46 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
return ExprNodes.PythonCapiCallNode(
node.pos, name, func_type,
args = args,
is_temp = node.is_temp
is_temp = node.is_temp,
utility_code = utility_code
)
dict_getitem_default_utility_code = UtilityCode(
proto = '''
static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
PyObject* value;
#if PY_MAJOR_VERSION >= 3
value = PyDict_GetItemWithError(d, key);
if (unlikely(!value)) {
if (unlikely(PyErr_Occurred()))
return NULL;
value = default_value;
}
Py_INCREF(value);
#else
if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
/* these presumably have safe hash functions */
value = PyDict_GetItem(d, key);
if (unlikely(!value)) {
value = default_value;
}
Py_INCREF(value);
} else {
PyObject *m;
m = __Pyx_GetAttrString(d, "get");
if (!m) return NULL;
value = PyObject_CallFunctionObjArgs(m, key,
(default_value == Py_None) ? NULL : default_value, NULL);
Py_DECREF(m);
}
#endif
return value;
}
''',
impl = ""
)
append_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
......
......@@ -128,7 +128,7 @@ def parse_directive_value(name, value, relaxed_bool=False):
def parse_directive_list(s, relaxed_bool=False, ignore_unknown=False):
"""
Parses a comma-seperated list of pragma options. Whitespace
Parses a comma-separated list of pragma options. Whitespace
is not considered.
>>> parse_directive_list(' ')
......
......@@ -242,6 +242,231 @@ class PostParse(CythonTransform):
self.context.nonfatal_error(e)
return None
# Split parallel assignments (a,b = b,a) into separate partial
# assignments that are executed rhs-first using temps. This
# optimisation is best applied before type analysis so that known
# types on rhs and lhs can be matched directly.
def visit_SingleAssignmentNode(self, node):
self.visitchildren(node)
return self._visit_assignment_node(node, [node.lhs, node.rhs])
def visit_CascadedAssignmentNode(self, node):
self.visitchildren(node)
return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
def _visit_assignment_node(self, node, expr_list):
"""Flatten parallel assignments into separate single
assignments or cascaded assignments.
"""
if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
# no parallel assignments => nothing to do
return node
expr_list_list = []
flatten_parallel_assignments(expr_list, expr_list_list)
temp_refs = []
eliminate_rhs_duplicates(expr_list_list, temp_refs)
nodes = []
for expr_list in expr_list_list:
lhs_list = expr_list[:-1]
rhs = expr_list[-1]
if len(lhs_list) == 1:
node = Nodes.SingleAssignmentNode(rhs.pos,
lhs = lhs_list[0], rhs = rhs)
else:
node = Nodes.CascadedAssignmentNode(rhs.pos,
lhs_list = lhs_list, rhs = rhs)
nodes.append(node)
if len(nodes) == 1:
assign_node = nodes[0]
else:
assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
if temp_refs:
duplicates_and_temps = [ (temp.expression, temp)
for temp in temp_refs ]
sort_common_subsequences(duplicates_and_temps)
for _, temp_ref in duplicates_and_temps[::-1]:
assign_node = LetNode(temp_ref, assign_node)
return assign_node
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
"""Replace rhs items by LetRefNodes if they appear more than once.
Creates a sequence of LetRefNodes that set up the required temps
and appends them to ref_node_sequence. The input list is modified
in-place.
"""
seen_nodes = set()
ref_nodes = {}
def find_duplicates(node):
if node.is_literal or node.is_name:
# no need to replace those; can't include attributes here
# as their access is not necessarily side-effect free
return
if node in seen_nodes:
if node not in ref_nodes:
ref_node = LetRefNode(node)
ref_nodes[node] = ref_node
ref_node_sequence.append(ref_node)
else:
seen_nodes.add(node)
if node.is_sequence_constructor:
for item in node.args:
find_duplicates(item)
for expr_list in expr_list_list:
rhs = expr_list[-1]
find_duplicates(rhs)
if not ref_nodes:
return
def substitute_nodes(node):
if node in ref_nodes:
return ref_nodes[node]
elif node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args)
return node
# replace nodes inside of the common subexpressions
for node in ref_nodes:
if node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args)
# replace common subexpressions on all rhs items
for expr_list in expr_list_list:
expr_list[-1] = substitute_nodes(expr_list[-1])
def sort_common_subsequences(items):
"""Sort items/subsequences so that all items and subsequences that
an item contains appear before the item itself. This implies a
partial order, and the sort must be stable to preserve the
original order as much as possible, so we use a simple insertion
sort.
"""
def contains(seq, x):
for item in seq:
if item is x:
return True
elif item.is_sequence_constructor and contains(item.args, x):
return True
return False
def lower_than(a,b):
return b.is_sequence_constructor and contains(b.args, a)
for pos, item in enumerate(items):
new_pos = pos
key = item[0]
for i in xrange(pos-1, -1, -1):
if lower_than(key, items[i][0]):
new_pos = i
if new_pos != pos:
for i in xrange(pos, new_pos, -1):
items[i] = items[i-1]
items[new_pos] = item
def flatten_parallel_assignments(input, output):
# The input is a list of expression nodes, representing the LHSs
# and RHS of one (possibly cascaded) assignment statement. For
# sequence constructors, rearranges the matching parts of both
# sides into a list of equivalent assignments between the
# individual elements. This transformation is applied
# recursively, so that nested structures get matched as well.
rhs = input[-1]
if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
output.append(input)
return
complete_assignments = []
rhs_size = len(rhs.args)
lhs_targets = [ [] for _ in xrange(rhs_size) ]
starred_assignments = []
for lhs in input[:-1]:
if not lhs.is_sequence_constructor:
if lhs.is_starred:
error(lhs.pos, "starred assignment target must be in a list or tuple")
complete_assignments.append(lhs)
continue
lhs_size = len(lhs.args)
starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
if starred_targets > 1:
error(lhs.pos, "more than 1 starred expression in assignment")
output.append([lhs,rhs])
continue
elif lhs_size - starred_targets > rhs_size:
error(lhs.pos, "need more than %d value%s to unpack"
% (rhs_size, (rhs_size != 1) and 's' or ''))
output.append([lhs,rhs])
continue
elif starred_targets:
map_starred_assignment(lhs_targets, starred_assignments,
lhs.args, rhs.args)
elif lhs_size < rhs_size:
error(lhs.pos, "too many values to unpack (expected %d, got %d)"
% (lhs_size, rhs_size))
output.append([lhs,rhs])
continue
else:
for targets, expr in zip(lhs_targets, lhs.args):
targets.append(expr)
if complete_assignments:
complete_assignments.append(rhs)
output.append(complete_assignments)
# recursively flatten partial assignments
for cascade, rhs in zip(lhs_targets, rhs.args):
if cascade:
cascade.append(rhs)
flatten_parallel_assignments(cascade, output)
# recursively flatten starred assignments
for cascade in starred_assignments:
if cascade[0].is_sequence_constructor:
flatten_parallel_assignments(cascade, output)
else:
output.append(cascade)
def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
# Appends the fixed-position LHS targets to the target list that
# appear left and right of the starred argument.
#
# The starred_assignments list receives a new tuple
# (lhs_target, rhs_values_list) that maps the remaining arguments
# (those that match the starred target) to a list.
# left side of the starred target
for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
if expr.is_starred:
starred = i
lhs_remaining = len(lhs_args) - i - 1
break
targets.append(expr)
else:
raise InternalError("no starred arg found when splitting starred assignment")
# right side of the starred target
for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
lhs_args[-lhs_remaining:])):
targets.append(expr)
# the starred target itself, must be assigned a (potentially empty) list
target = lhs_args[starred].target # unpack starred node
starred_rhs = rhs_args[starred:]
if lhs_remaining:
starred_rhs = starred_rhs[:-lhs_remaining]
if starred_rhs:
pos = starred_rhs[0].pos
else:
pos = target.pos
starred_assignments.append([
target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
class PxdPostParse(CythonTransform, SkipDeclarations):
"""
Basic interpretation/validity checking that should only be
......
......@@ -59,8 +59,6 @@ cpdef p_testlist(PyrexScanner s)
#
#-------------------------------------------------------
cpdef flatten_parallel_assignments(input, output)
cpdef p_global_statement(PyrexScanner s)
cpdef p_expression_or_assignment(PyrexScanner s)
cpdef p_print_statement(PyrexScanner s)
......
......@@ -917,138 +917,29 @@ def p_expression_or_assignment(s):
return Nodes.PassStatNode(expr.pos)
else:
return Nodes.ExprStatNode(expr.pos, expr = expr)
else:
expr_list_list = []
flatten_parallel_assignments(expr_list, expr_list_list)
nodes = []
for expr_list in expr_list_list:
lhs_list = expr_list[:-1]
rhs = expr_list[-1]
if len(lhs_list) == 1:
node = Nodes.SingleAssignmentNode(rhs.pos,
lhs = lhs_list[0], rhs = rhs)
else:
node = Nodes.CascadedAssignmentNode(rhs.pos,
lhs_list = lhs_list, rhs = rhs)
nodes.append(node)
if len(nodes) == 1:
return nodes[0]
else:
return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
def flatten_parallel_assignments(input, output):
# The input is a list of expression nodes, representing the LHSs
# and RHS of one (possibly cascaded) assignment statement. For
# sequence constructors, rearranges the matching parts of both
# sides into a list of equivalent assignments between the
# individual elements. This transformation is applied
# recursively, so that nested structures get matched as well.
rhs = input[-1]
if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
output.append(input)
return
complete_assignments = []
rhs_size = len(rhs.args)
lhs_targets = [ [] for _ in range(rhs_size) ]
starred_assignments = []
for lhs in input[:-1]:
if not lhs.is_sequence_constructor:
if lhs.is_starred:
error(lhs.pos, "starred assignment target must be in a list or tuple")
complete_assignments.append(lhs)
continue
lhs_size = len(lhs.args)
starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
if starred_targets:
if starred_targets > 1:
error(lhs.pos, "more than 1 starred expression in assignment")
output.append([lhs,rhs])
continue
elif lhs_size - starred_targets > rhs_size:
error(lhs.pos, "need more than %d value%s to unpack"
% (rhs_size, (rhs_size != 1) and 's' or ''))
output.append([lhs,rhs])
continue
map_starred_assignment(lhs_targets, starred_assignments,
lhs.args, rhs.args)
else:
if lhs_size > rhs_size:
error(lhs.pos, "need more than %d value%s to unpack"
% (rhs_size, (rhs_size != 1) and 's' or ''))
output.append([lhs,rhs])
continue
elif lhs_size < rhs_size:
error(lhs.pos, "too many values to unpack (expected %d, got %d)"
% (lhs_size, rhs_size))
output.append([lhs,rhs])
continue
else:
for targets, expr in zip(lhs_targets, lhs.args):
targets.append(expr)
if complete_assignments:
complete_assignments.append(rhs)
output.append(complete_assignments)
# recursively flatten partial assignments
for cascade, rhs in zip(lhs_targets, rhs.args):
if cascade:
cascade.append(rhs)
flatten_parallel_assignments(cascade, output)
# recursively flatten starred assignments
for cascade in starred_assignments:
if cascade[0].is_sequence_constructor:
flatten_parallel_assignments(cascade, output)
else:
output.append(cascade)
def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
# Appends the fixed-position LHS targets to the target list that
# appear left and right of the starred argument.
#
# The starred_assignments list receives a new tuple
# (lhs_target, rhs_values_list) that maps the remaining arguments
# (those that match the starred target) to a list.
# left side of the starred target
for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
if expr.is_starred:
starred = i
lhs_remaining = len(lhs_args) - i - 1
break
targets.append(expr)
else:
raise InternalError("no starred arg found when splitting starred assignment")
# right side of the starred target
for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
lhs_args[-lhs_remaining:])):
targets.append(expr)
# the starred target itself, must be assigned a (potentially empty) list
target = lhs_args[starred].target # unpack starred node
starred_rhs = rhs_args[starred:]
if lhs_remaining:
starred_rhs = starred_rhs[:-lhs_remaining]
if starred_rhs:
pos = starred_rhs[0].pos
rhs = expr_list[-1]
if len(expr_list) == 2:
return Nodes.SingleAssignmentNode(rhs.pos,
lhs = expr_list[0], rhs = rhs)
else:
pos = target.pos
starred_assignments.append([
target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
return Nodes.CascadedAssignmentNode(rhs.pos,
lhs_list = expr_list[:-1], rhs = rhs)
def p_print_statement(s):
# s.sy == 'print'
pos = s.position()
ends_with_comma = 0
s.next()
if s.sy == '>>':
s.error("'print >>' not yet implemented")
s.next()
stream = p_simple_expr(s)
if s.sy == ',':
s.next()
ends_with_comma = s.sy in ('NEWLINE', 'EOF')
else:
stream = None
args = []
ends_with_comma = 0
if s.sy not in ('NEWLINE', 'EOF'):
args.append(p_simple_expr(s))
while s.sy == ',':
......@@ -1059,7 +950,8 @@ def p_print_statement(s):
args.append(p_simple_expr(s))
arg_tuple = ExprNodes.TupleNode(pos, args = args)
return Nodes.PrintStatNode(pos,
arg_tuple = arg_tuple, append_newline = not ends_with_comma)
arg_tuple = arg_tuple, stream = stream,
append_newline = not ends_with_comma)
def p_exec_statement(s):
# s.sy == 'exec'
......
......@@ -164,13 +164,13 @@ class PyrexType(BaseType):
return 1
def create_typedef_type(cname, base_type, is_external=0):
def create_typedef_type(name, base_type, cname, is_external=0):
if base_type.is_complex:
if is_external:
raise ValueError("Complex external typedefs not supported")
return base_type
else:
return CTypedefType(cname, base_type, is_external)
return CTypedefType(name, base_type, cname, is_external)
class CTypedefType(BaseType):
#
......@@ -180,6 +180,7 @@ class CTypedefType(BaseType):
# HERE IS DELEGATED!
#
# qualified_name string
# typedef_name string
# typedef_cname string
# typedef_base_type PyrexType
# typedef_is_external bool
......@@ -191,8 +192,9 @@ class CTypedefType(BaseType):
from_py_utility_code = None
def __init__(self, cname, base_type, is_external=0):
def __init__(self, name, base_type, cname, is_external=0):
assert not base_type.is_complex
self.typedef_name = name
self.typedef_cname = cname
self.typedef_base_type = base_type
self.typedef_is_external = is_external
......@@ -214,19 +216,12 @@ class CTypedefType(BaseType):
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
name = self.declaration_name(for_display, pyrex)
if pyrex or for_display:
base_code = name
base_code = self.typedef_name
else:
base_code = public_decl(name, dll_linkage)
base_code = public_decl(self.typedef_cname, dll_linkage)
return self.base_declaration_code(base_code, entity_code)
def declaration_name(self, for_display = 0, pyrex = 0):
if pyrex or for_display:
return self.qualified_name
else:
return self.typedef_cname
def as_argument_type(self):
return self
......@@ -242,7 +237,7 @@ class CTypedefType(BaseType):
return "<CTypedefType %s>" % self.typedef_cname
def __str__(self):
return self.declaration_name(for_display = 1)
return self.typedef_name
def _create_utility_code(self, template_utility_code,
template_function_name):
......@@ -999,7 +994,7 @@ class CComplexType(CNumericType):
env.use_utility_code(complex_real_imag_utility_code)
for utility_code in (complex_type_utility_code,
complex_from_parts_utility_code,
complex_arithmatic_utility_code):
complex_arithmetic_utility_code):
env.use_utility_code(
utility_code.specialize(
self,
......@@ -1168,7 +1163,7 @@ static %(type)s __Pyx_PyComplex_As_%(type_name)s(PyObject* o) {
}
""")
complex_arithmatic_utility_code = UtilityCode(
complex_arithmetic_utility_code = UtilityCode(
proto="""
#if CYTHON_CCOMPLEX
#define __Pyx_c_eq%(m)s(a, b) ((a)==(b))
......
......@@ -10,13 +10,15 @@ import codecs
from time import time
import cython
cython.declare(EncodedString=object, string_prefixes=object, raw_prefixes=object, IDENT=object)
cython.declare(EncodedString=object, string_prefixes=object, raw_prefixes=object, IDENT=object,
print_function=object)
from Cython import Plex, Utils
from Cython.Plex.Scanners import Scanner
from Cython.Plex.Errors import UnrecognizedInput
from Errors import CompileError, error
from Lexicon import string_prefixes, raw_prefixes, make_lexicon, IDENT
from Future import print_function
from StringEncoding import EncodedString
......@@ -61,7 +63,7 @@ def build_resword_dict():
d[word] = 1
return d
cython.declare(resword_dict=object)
cython.declare(resword_dict=dict)
resword_dict = build_resword_dict()
#------------------------------------------------------------------
......@@ -345,6 +347,10 @@ class PyrexScanner(Scanner):
self.error("Unrecognized character")
if sy == IDENT:
if systring in resword_dict:
if systring == 'print' and \
print_function in self.context.future_directives:
systring = EncodedString(systring)
else:
sy = systring
else:
systring = EncodedString(systring)
......
......@@ -119,7 +119,7 @@ class Entry(object):
# inline_func_in_pxd boolean Hacky special case for inline function in pxd file.
# Ideally this should not be necesarry.
# assignments [ExprNode] List of expressions that get assigned to this entry.
# might_overflow boolean In an arithmatic expression that could cause
# might_overflow boolean In an arithmetic expression that could cause
# overflow (used for type inference).
inline_func_in_pxd = False
......@@ -359,7 +359,8 @@ class Scope(object):
else:
cname = self.mangle(Naming.type_prefix, name)
try:
type = PyrexTypes.create_typedef_type(cname, base_type, (visibility == 'extern'))
type = PyrexTypes.create_typedef_type(name, base_type, cname,
(visibility == 'extern'))
except ValueError, e:
error(pos, e.message)
type = PyrexTypes.error_type
......
......@@ -50,13 +50,13 @@ class TestNormalizeTree(TransformTest):
""")
self.assertLines(u"""
(root): StatListNode
stats[0]: ParallelAssignmentNode
stats[0]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
stats[1]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
lhs: TupleNode
args[0]: NameNode
args[1]: NameNode
rhs: TupleNode
args[0]: NameNode
args[1]: NameNode
""", self.treetypes(t))
def test_wrap_offagain(self):
......
......@@ -112,7 +112,7 @@ class MarkAssignments(CythonTransform):
self.visitchildren(node)
return node
class MarkOverflowingArithmatic(CythonTransform):
class MarkOverflowingArithmetic(CythonTransform):
# It may be possible to integrate this with the above for
# performance improvements (though likely not worth it).
......@@ -122,7 +122,7 @@ class MarkOverflowingArithmatic(CythonTransform):
def __call__(self, root):
self.env_stack = []
self.env = root.scope
return super(MarkOverflowingArithmatic, self).__call__(root)
return super(MarkOverflowingArithmetic, self).__call__(root)
def visit_safe_node(self, node):
self.might_overflow, saved = False, self.might_overflow
......@@ -262,11 +262,12 @@ class SimpleAssignmentTypeInferer:
def find_spanning_type(type1, type2):
if type1 is type2:
return type1
result_type = type1
elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
# type inference can break the coercion back to a Python bool
# if it returns an arbitrary int type here
return py_object_type
else:
result_type = PyrexTypes.spanning_type(type1, type2)
if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, Builtin.float_type):
# Python's float type is just a C double, so it's safe to
......
......@@ -130,6 +130,9 @@ class ResultRefNode(AtomicExprNode):
def infer_type(self, env):
return self.expression.infer_type(env)
def is_simple(self):
return True
def result(self):
return self.result_code
......@@ -222,7 +225,8 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
# BLOCK (can modify temp)
# if temp is an object, decref
#
# To be used after analysis phase, does no analysis.
# Usually used after analysis phase, but forwards analysis methods
# to its children
child_attrs = ['temp_expression', 'body']
......@@ -231,6 +235,17 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
self.pos = body.pos
self.body = body
def analyse_control_flow(self, env):
self.body.analyse_control_flow(env)
def analyse_declarations(self, env):
self.temp_expression.analyse_declarations(env)
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.temp_expression.analyse_expressions(env)
self.body.analyse_expressions(env)
def generate_execution_code(self, code):
self.setup_temp_expr(code)
self.body.generate_execution_code(code)
......
# -*- coding: utf-8 -*-
"""unittest-xml-reporting is a PyUnit-based TestRunner that can export test
results to XML files that can be consumed by a wide range of tools, such as
build systems, IDEs and Continuous Integration servers.
This module provides the XMLTestRunner class, which is heavily based on the
default TextTestRunner. This makes the XMLTestRunner very simple to use.
The script below, adapted from the unittest documentation, shows how to use
XMLTestRunner in a very simple way. In fact, the only difference between this
script and the original one is the last line:
import random
import unittest
import xmlrunner
class TestSequenceFunctions(unittest.TestCase):
def setUp(self):
self.seq = range(10)
def test_shuffle(self):
# make sure the shuffled sequence does not lose any elements
random.shuffle(self.seq)
self.seq.sort()
self.assertEqual(self.seq, range(10))
def test_choice(self):
element = random.choice(self.seq)
self.assert_(element in self.seq)
def test_sample(self):
self.assertRaises(ValueError, random.sample, self.seq, 20)
for element in random.sample(self.seq, 5):
self.assert_(element in self.seq)
if __name__ == '__main__':
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='test-reports'))
"""
import os
import sys
import time
from unittest import TestResult, _TextTestResult, TextTestRunner
from cStringIO import StringIO
class _TestInfo(object):
"""This class is used to keep useful information about the execution of a
test method.
"""
# Possible test outcomes
(SUCCESS, FAILURE, ERROR) = range(3)
def __init__(self, test_result, test_method, outcome=SUCCESS, err=None):
"Create a new instance of _TestInfo."
self.test_result = test_result
self.test_method = test_method
self.outcome = outcome
self.err = err
def get_elapsed_time(self):
"""Return the time that shows how long the test method took to
execute.
"""
return self.test_result.stop_time - self.test_result.start_time
def get_description(self):
"Return a text representation of the test method."
return self.test_result.getDescription(self.test_method)
def get_error_info(self):
"""Return a text representation of an exception thrown by a test
method.
"""
if not self.err:
return ''
return self.test_result._exc_info_to_string(self.err, \
self.test_method)
class _XMLTestResult(_TextTestResult):
"""A test result class that can express test results in a XML report.
Used by XMLTestRunner.
"""
def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1, \
elapsed_times=True):
"Create a new instance of _XMLTestResult."
_TextTestResult.__init__(self, stream, descriptions, verbosity)
self.successes = []
self.callback = None
self.elapsed_times = elapsed_times
def _prepare_callback(self, test_info, target_list, verbose_str,
short_str):
"""Append a _TestInfo to the given target list and sets a callback
method to be called by stopTest method.
"""
target_list.append(test_info)
def callback():
"""This callback prints the test method outcome to the stream,
as well as the elapsed time.
"""
# Ignore the elapsed times for a more reliable unit testing
if not self.elapsed_times:
self.start_time = self.stop_time = 0
if self.showAll:
self.stream.writeln('%s (%.3fs)' % \
(verbose_str, test_info.get_elapsed_time()))
elif self.dots:
self.stream.write(short_str)
self.callback = callback
def startTest(self, test):
"Called before execute each test method."
self.start_time = time.time()
TestResult.startTest(self, test)
if self.showAll:
self.stream.write(' ' + self.getDescription(test))
self.stream.write(" ... ")
def stopTest(self, test):
"Called after execute each test method."
_TextTestResult.stopTest(self, test)
self.stop_time = time.time()
if self.callback and callable(self.callback):
self.callback()
self.callback = None
def addSuccess(self, test):
"Called when a test executes successfully."
self._prepare_callback(_TestInfo(self, test), \
self.successes, 'OK', '.')
def addFailure(self, test, err):
"Called when a test method fails."
self._prepare_callback(_TestInfo(self, test, _TestInfo.FAILURE, err), \
self.failures, 'FAIL', 'F')
def addError(self, test, err):
"Called when a test method raises an error."
self._prepare_callback(_TestInfo(self, test, _TestInfo.ERROR, err), \
self.errors, 'ERROR', 'E')
def printErrorList(self, flavour, errors):
"Write some information about the FAIL or ERROR to the stream."
for test_info in errors:
self.stream.writeln(self.separator1)
self.stream.writeln('%s [%.3fs]: %s' % \
(flavour, test_info.get_elapsed_time(), \
test_info.get_description()))
self.stream.writeln(self.separator2)
self.stream.writeln('%s' % test_info.get_error_info())
def _get_info_by_testcase(self):
"""This method organizes test results by TestCase module. This
information is used during the report generation, where a XML report
will be generated for each TestCase.
"""
tests_by_testcase = {}
for tests in (self.successes, self.failures, self.errors):
for test_info in tests:
testcase = type(test_info.test_method)
# Ignore module name if it is '__main__'
module = testcase.__module__ + '.'
if module == '__main__.':
module = ''
testcase_name = module + testcase.__name__
if not tests_by_testcase.has_key(testcase_name):
tests_by_testcase[testcase_name] = []
tests_by_testcase[testcase_name].append(test_info)
return tests_by_testcase
def _report_testsuite(suite_name, tests, xml_document):
"Appends the testsuite section to the XML document."
testsuite = xml_document.createElement('testsuite')
xml_document.appendChild(testsuite)
testsuite.setAttribute('name', str(suite_name))
testsuite.setAttribute('tests', str(len(tests)))
testsuite.setAttribute('time', '%.3f' % \
sum(map(lambda e: e.get_elapsed_time(), tests)))
failures = filter(lambda e: e.outcome==_TestInfo.FAILURE, tests)
testsuite.setAttribute('failures', str(len(failures)))
errors = filter(lambda e: e.outcome==_TestInfo.ERROR, tests)
testsuite.setAttribute('errors', str(len(errors)))
return testsuite
_report_testsuite = staticmethod(_report_testsuite)
def _report_testcase(suite_name, test_result, xml_testsuite, xml_document):
"Appends a testcase section to the XML document."
testcase = xml_document.createElement('testcase')
xml_testsuite.appendChild(testcase)
testcase.setAttribute('classname', str(suite_name))
testcase.setAttribute('name', str(test_result.test_method.shortDescription() or test_result.test_method._testMethodName))
testcase.setAttribute('time', '%.3f' % test_result.get_elapsed_time())
if (test_result.outcome != _TestInfo.SUCCESS):
elem_name = ('failure', 'error')[test_result.outcome-1]
failure = xml_document.createElement(elem_name)
testcase.appendChild(failure)
failure.setAttribute('type', str(test_result.err[0].__name__))
failure.setAttribute('message', str(test_result.err[1].message))
error_info = test_result.get_error_info()
failureText = xml_document.createCDATASection(error_info)
failure.appendChild(failureText)
_report_testcase = staticmethod(_report_testcase)
def _report_output(test_runner, xml_testsuite, xml_document):
"Appends the system-out and system-err sections to the XML document."
systemout = xml_document.createElement('system-out')
xml_testsuite.appendChild(systemout)
stdout = test_runner.stdout.getvalue()
systemout_text = xml_document.createCDATASection(stdout)
systemout.appendChild(systemout_text)
systemerr = xml_document.createElement('system-err')
xml_testsuite.appendChild(systemerr)
stderr = test_runner.stderr.getvalue()
systemerr_text = xml_document.createCDATASection(stderr)
systemerr.appendChild(systemerr_text)
_report_output = staticmethod(_report_output)
def generate_reports(self, test_runner):
"Generates the XML reports to a given XMLTestRunner object."
from xml.dom.minidom import Document
all_results = self._get_info_by_testcase()
if type(test_runner.output) == str and not \
os.path.exists(test_runner.output):
os.makedirs(test_runner.output)
for suite, tests in all_results.items():
doc = Document()
# Build the XML file
testsuite = _XMLTestResult._report_testsuite(suite, tests, doc)
for test in tests:
_XMLTestResult._report_testcase(suite, test, testsuite, doc)
_XMLTestResult._report_output(test_runner, testsuite, doc)
xml_content = doc.toprettyxml(indent='\t')
if type(test_runner.output) is str:
report_file = file('%s%sTEST-%s.xml' % \
(test_runner.output, os.sep, suite), 'w')
try:
report_file.write(xml_content)
finally:
report_file.close()
else:
# Assume that test_runner.output is a stream
test_runner.output.write(xml_content)
class XMLTestRunner(TextTestRunner):
"""A test runner class that outputs the results in JUnit like XML files.
"""
def __init__(self, output='.', stream=sys.stderr, descriptions=True, \
verbose=False, elapsed_times=True):
"Create a new instance of XMLTestRunner."
verbosity = (1, 2)[verbose]
TextTestRunner.__init__(self, stream, descriptions, verbosity)
self.output = output
self.elapsed_times = elapsed_times
def _make_result(self):
"""Create the TestResult object which will be used to store
information about the executed tests.
"""
return _XMLTestResult(self.stream, self.descriptions, \
self.verbosity, self.elapsed_times)
def _patch_standard_output(self):
"""Replace the stdout and stderr streams with string-based streams
in order to capture the tests' output.
"""
(self.old_stdout, self.old_stderr) = (sys.stdout, sys.stderr)
(sys.stdout, sys.stderr) = (self.stdout, self.stderr) = \
(StringIO(), StringIO())
def _restore_standard_output(self):
"Restore the stdout and stderr streams."
(sys.stdout, sys.stderr) = (self.old_stdout, self.old_stderr)
def run(self, test):
"Run the given test case or test suite."
try:
# Prepare the test execution
self._patch_standard_output()
result = self._make_result()
# Print a nice header
self.stream.writeln()
self.stream.writeln('Running tests...')
self.stream.writeln(result.separator2)
# Execute tests
start_time = time.time()
test(result)
stop_time = time.time()
time_taken = stop_time - start_time
# Print results
result.printErrors()
self.stream.writeln(result.separator2)
run = result.testsRun
self.stream.writeln("Ran %d test%s in %.3fs" %
(run, run != 1 and "s" or "", time_taken))
self.stream.writeln()
# Error traces
if not result.wasSuccessful():
self.stream.write("FAILED (")
failed, errored = (len(result.failures), len(result.errors))
if failed:
self.stream.write("failures=%d" % failed)
if errored:
if failed:
self.stream.write(", ")
self.stream.write("errors=%d" % errored)
self.stream.writeln(")")
else:
self.stream.writeln("OK")
# Generate reports
self.stream.writeln()
self.stream.writeln('Generating XML reports...')
result.generate_reports(self)
finally:
self._restore_standard_output()
return result
Welcome to Cython!
=================
Cython (http://www.cython.org) is based on Pyrex, but supports more
cutting edge functionality and optimizations.
Cython (http://cython.org) is a language that makes writing C extensions for
the Python language as easy as Python itself. Cython is based on the
well-known Pyrex, but supports more cutting edge functionality and
optimizations.
The Cython language is very close to the Python language, but Cython
additionally supports calling C functions and declaring C types on variables
and class attributes. This allows the compiler to generate very efficient C
code from Cython code.
This makes Cython the ideal language for wrapping external C libraries, and
for fast C modules that speed up the execution of Python code.
LICENSE:
The original Pyrex program was licensed "free of restrictions" (see
below). Cython itself is licensed under the
below). Cython itself is licensed under the permissive
PYTHON SOFTWARE FOUNDATION LICENSE
http://www.python.org/psf/license/
Apache License
See LICENSE.txt.
--------------------------
There are TWO mercurial (hg) repositories included with Cython:
* Various project files, documentation, etc. (in the top level directory)
* The main codebase itself (in Cython/)
We keep these separate for easier merging with the Pyrex project.
To see the change history for Cython code itself, go to the Cython
directory and type
$ hg log
Note that Cython used to ship the Mercurial (hg) repository in its source
distribution, but no longer does so due to space constraints. To get the
full source history, make sure you have hg installed, then step into the
base directory of the Cython source distribution and type
This requires that you have installed Mercurial.
make repo
Alternatively, check out the latest developer repository from
-- William Stein (wstein@gmail.com)
http://hg.cython.org/cython-devel
xxxx
The following is from Pyrex:
......
......@@ -48,8 +48,12 @@ EXT_DEP_INCLUDES = [
]
VER_DEP_MODULES = {
# tests are excluded if 'CurrentPythonVersion OP VersionTuple', i.e.
# (2,4) : (operator.le, ...) excludes ... when PyVer <= 2.4.x
(2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258'
]),
(2,6) : (operator.lt, lambda x: x in ['run.print_function'
]),
(3,): (operator.ge, lambda x: x in ['run.non_future_division',
'compile.extsetslice',
'compile.extdelslice']),
......@@ -269,12 +273,18 @@ class CythonCompileTestCase(unittest.TestCase):
target = '%s.%s' % (module_name, self.language)
return target
def find_source_files(self, test_directory, module_name):
def copy_related_files(self, test_directory, target_directory, module_name):
is_related = re.compile('%s_.*[.].*' % module_name).match
for filename in os.listdir(test_directory):
if is_related(filename):
shutil.copy(os.path.join(test_directory, filename),
target_directory)
def find_source_files(self, workdir, module_name):
is_related = re.compile('%s_.*[.]%s' % (module_name, self.language)).match
return [self.build_target_filename(module_name)] + [
os.path.join(test_directory, filename)
for filename in os.listdir(test_directory)
if is_related(filename) and os.path.isfile(os.path.join(test_directory, filename)) ]
filename for filename in os.listdir(workdir)
if is_related(filename) and os.path.isfile(os.path.join(workdir, filename)) ]
def split_source_and_output(self, test_directory, module, workdir):
source_file = os.path.join(test_directory, module) + '.pyx'
......@@ -329,9 +339,10 @@ class CythonCompileTestCase(unittest.TestCase):
for match, get_additional_include_dirs in EXT_DEP_INCLUDES:
if match(module):
ext_include_dirs += get_additional_include_dirs()
self.copy_related_files(test_directory, workdir, module)
extension = Extension(
module,
sources = self.find_source_files(test_directory, module),
sources = self.find_source_files(workdir, module),
include_dirs = ext_include_dirs,
extra_compile_args = CFLAGS,
)
......@@ -427,9 +438,9 @@ class CythonRunTestCase(CythonCompileTestCase):
if tests is None:
# importing failed, try to fake a test class
tests = _FakeClass(
failureException=None,
shortDescription = self.shortDescription,
**{module_name: None})
failureException=sys.exc_info()[1],
_shortDescription=self.shortDescription(),
module_name=None)
partial_result.addError(tests, sys.exc_info())
result_code = 1
output = open(result_file, 'wb')
......@@ -444,6 +455,13 @@ class CythonRunTestCase(CythonCompileTestCase):
try:
cid, result_code = os.waitpid(child_id, 0)
# os.waitpid returns the child's result code in the
# upper byte of result_code, and the signal it was
# killed by in the lower byte
if result_code & 255:
raise Exception("Tests in module '%s' were unexpectedly killed by signal %d"%
(module_name, result_code & 255))
result_code = result_code >> 8
if result_code in (0,1):
input = open(result_file, 'rb')
try:
......@@ -452,7 +470,7 @@ class CythonRunTestCase(CythonCompileTestCase):
input.close()
if result_code:
raise Exception("Tests in module '%s' exited with status %d" %
(module_name, result_code >> 8))
(module_name, result_code))
finally:
try: os.unlink(result_file)
except: pass
......@@ -484,7 +502,7 @@ class PartialTestResult(_TextTestResult):
if attr_name == '_dt_test':
test_case._dt_test = _FakeClass(
name=test_case._dt_test.name)
else:
elif attr_name != '_shortDescription':
setattr(test_case, attr_name, None)
def data(self):
......@@ -497,7 +515,7 @@ class PartialTestResult(_TextTestResult):
"""Static method for merging the result back into the main
result object.
"""
errors, failures, tests_run, output = data
failures, errors, tests_run, output = data
if output:
result.stream.write(output)
result.errors.extend(errors)
......@@ -717,7 +735,12 @@ if __name__ == '__main__':
help="display test progress, pass twice to print test names")
parser.add_option("-T", "--ticket", dest="tickets",
action="append",
help="a bug ticket number to run the respective test in 'tests/bugs'")
help="a bug ticket number to run the respective test in 'tests/*'")
parser.add_option("--xml-output", dest="xml_output_dir", metavar="DIR",
help="write test results in XML to directory DIR")
parser.add_option("--exit-ok", dest="exit_ok", default=False,
action="store_true",
help="exit without error code even on test failures")
options, cmd_args = parser.parse_args()
......@@ -871,7 +894,14 @@ if __name__ == '__main__':
os.path.join(sys.prefix, 'lib', 'python'+sys.version[:3], 'test'),
'pyregr'))
result = unittest.TextTestRunner(verbosity=options.verbosity).run(test_suite)
if options.xml_output_dir:
from Cython.Tests.xmlrunner import XMLTestRunner
test_runner = XMLTestRunner(output=options.xml_output_dir,
verbose=options.verbosity > 0)
else:
test_runner = unittest.TextTestRunner(verbosity=options.verbosity)
result = test_runner.run(test_suite)
if options.coverage:
coverage.stop()
......@@ -891,4 +921,7 @@ if __name__ == '__main__':
import refnanny
sys.stderr.write("\n".join([repr(x) for x in refnanny.reflog]))
if options.exit_ok:
sys.exit(0)
else:
sys.exit(not result.wasSuccessful())
......@@ -3,13 +3,13 @@ from distutils.sysconfig import get_python_lib
import os, os.path
import sys
if 'sdist' in sys.argv:
if 'sdist' in sys.argv and sys.platform != "win32":
# Record the current revision in .hgrev
import subprocess # os.popen is cleaner but depricated
changset = subprocess.Popen("hg log --rev tip | grep changeset",
shell=True,
stdout=subprocess.PIPE).stdout.read()
rev = changset.split(':')[-1].strip()
rev = changset.decode('ISO-8859-1').split(':')[-1].strip()
hgrev = open('.hgrev', 'w')
hgrev.write(rev)
hgrev.close()
......
......@@ -9,5 +9,7 @@ missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408
cascaded_list_unpacking_T467
compile.cpp_operators
cppwrap
cpp_overload_wrapper
# Pyrex regression tests that don't current work:
pyregr.test_threadsignals
pyregr.test_module
cdef int raiseit():
raise IndexError
if False: raiseit()
_ERRORS = u"""
FIXME: provide a good error message here.
......
......@@ -7,6 +7,6 @@ cdef spamfunc spam
grail = spam # type mismatch
spam = grail # type mismatch
_ERRORS = u"""
7:28: Cannot assign type 'e_excvalfunctype.spamfunc' to 'e_excvalfunctype.grailfunc'
8:28: Cannot assign type 'e_excvalfunctype.grailfunc' to 'e_excvalfunctype.spamfunc'
7:28: Cannot assign type 'spamfunc' to 'grailfunc'
8:28: Cannot assign type 'grailfunc' to 'spamfunc'
"""
# invalid syntax (as handled by the parser)
def syntax():
*a, *b = 1,2,3,4,5
# wrong size RHS (as handled by the parser)
def length1():
......@@ -27,12 +22,11 @@ def length_recursive():
_ERRORS = u"""
5:4: more than 1 starred expression in assignment
10:4: too many values to unpack (expected 2, got 3)
13:4: need more than 1 value to unpack
16:4: need more than 0 values to unpack
19:4: need more than 0 values to unpack
22:4: need more than 0 values to unpack
23:4: need more than 1 value to unpack
26:6: need more than 1 value to unpack
5:4: too many values to unpack (expected 2, got 3)
8:4: need more than 1 value to unpack
11:4: need more than 0 values to unpack
14:4: need more than 0 values to unpack
17:4: need more than 0 values to unpack
18:4: need more than 1 value to unpack
21:6: need more than 1 value to unpack
"""
# invalid syntax (as handled by the parser)
def syntax():
*a, *b = 1,2,3,4,5
_ERRORS = u"""
5:4: more than 1 starred expression in assignment
5:8: more than 1 starred expression in assignment
"""
......@@ -109,3 +109,5 @@ cdef class MyCdefClass:
>>> True
False
"""
cdeffunc()
......@@ -53,6 +53,8 @@ def printbuf():
"""
cdef object[int, ndim=2] buf
print buf
return
buf[0,0] = 0
@testcase
def acquire_release(o1, o2):
......@@ -798,7 +800,7 @@ def printbuf_td_cy_int(object[td_cy_int] buf, shape):
>>> printbuf_td_cy_int(ShortMockBuffer(None, range(3)), (3,))
Traceback (most recent call last):
...
ValueError: Buffer dtype mismatch, expected 'bufaccess.td_cy_int' but got 'short'
ValueError: Buffer dtype mismatch, expected 'td_cy_int' but got 'short'
"""
cdef int i
for i in range(shape[0]):
......@@ -813,7 +815,7 @@ def printbuf_td_h_short(object[td_h_short] buf, shape):
>>> printbuf_td_h_short(IntMockBuffer(None, range(3)), (3,))
Traceback (most recent call last):
...
ValueError: Buffer dtype mismatch, expected 'bufaccess.td_h_short' but got 'int'
ValueError: Buffer dtype mismatch, expected 'td_h_short' but got 'int'
"""
cdef int i
for i in range(shape[0]):
......@@ -828,7 +830,7 @@ def printbuf_td_h_cy_short(object[td_h_cy_short] buf, shape):
>>> printbuf_td_h_cy_short(IntMockBuffer(None, range(3)), (3,))
Traceback (most recent call last):
...
ValueError: Buffer dtype mismatch, expected 'bufaccess.td_h_cy_short' but got 'int'
ValueError: Buffer dtype mismatch, expected 'td_h_cy_short' but got 'int'
"""
cdef int i
for i in range(shape[0]):
......@@ -843,7 +845,7 @@ def printbuf_td_h_ushort(object[td_h_ushort] buf, shape):
>>> printbuf_td_h_ushort(ShortMockBuffer(None, range(3)), (3,))
Traceback (most recent call last):
...
ValueError: Buffer dtype mismatch, expected 'bufaccess.td_h_ushort' but got 'short'
ValueError: Buffer dtype mismatch, expected 'td_h_ushort' but got 'short'
"""
cdef int i
for i in range(shape[0]):
......@@ -858,7 +860,7 @@ def printbuf_td_h_double(object[td_h_double] buf, shape):
>>> printbuf_td_h_double(FloatMockBuffer(None, [0.25, 1, 3.125]), (3,))
Traceback (most recent call last):
...
ValueError: Buffer dtype mismatch, expected 'bufaccess.td_h_double' but got 'float'
ValueError: Buffer dtype mismatch, expected 'td_h_double' but got 'float'
"""
cdef int i
for i in range(shape[0]):
......
......@@ -3,6 +3,10 @@ __doc__ = """
7
>>> lentest_char_c()
7
>>> lentest_char_c_short()
7
>>> lentest_char_c_float()
7.0
>>> lentest_uchar()
7
......@@ -36,6 +40,20 @@ def lentest_char_c():
cdef Py_ssize_t l = len(s)
return l
@cython.test_assert_path_exists(
"//PythonCapiCallNode",
)
def lentest_char_c_short():
cdef short l = len(s)
return l
@cython.test_assert_path_exists(
"//PythonCapiCallNode",
)
def lentest_char_c_float():
cdef float l = len(s)
return l
@cython.test_assert_path_exists(
"//PythonCapiCallNode",
......
......@@ -77,9 +77,16 @@ def test_attr_int(TestExtInt e):
else:
return False
ctypedef union _aux:
int i
void *p
cdef class TestExtPtr:
cdef void* p
def __init__(self, int i): self.p = <void*>i
def __init__(self, int i):
cdef _aux aux
aux.i = i
self.p = aux.p
def test_attr_ptr(TestExtPtr e):
"""
......
def get(dict d, key):
"""
>>> d = { 1: 10 }
>>> d.get(1)
10
>>> get(d, 1)
10
>>> d.get(2) is None
True
>>> get(d, 2) is None
True
>>> d.get((1,2)) is None
True
>>> get(d, (1,2)) is None
True
>>> class Unhashable:
... def __hash__(self):
... raise ValueError
>>> d.get(Unhashable())
Traceback (most recent call last):
ValueError
>>> get(d, Unhashable())
Traceback (most recent call last):
ValueError
>>> None.get(1)
Traceback (most recent call last):
...
AttributeError: 'NoneType' object has no attribute 'get'
>>> get(None, 1)
Traceback (most recent call last):
...
AttributeError: 'NoneType' object has no attribute 'get'
"""
return d.get(key)
def get_default(dict d, key, default):
"""
>>> d = { 1: 10 }
>>> d.get(1, 2)
10
>>> get_default(d, 1, 2)
10
>>> d.get(2, 2)
2
>>> get_default(d, 2, 2)
2
>>> d.get((1,2), 2)
2
>>> get_default(d, (1,2), 2)
2
>>> class Unhashable:
... def __hash__(self):
... raise ValueError
>>> d.get(Unhashable(), 2)
Traceback (most recent call last):
ValueError
>>> get_default(d, Unhashable(), 2)
Traceback (most recent call last):
ValueError
"""
return d.get(key, default)
def test(dict d, index):
"""
>>> d = { 1: 10 }
>>> test(d, 1)
10
>>> test(d, 2)
Traceback (most recent call last):
...
KeyError: 2
>>> test(d, (1,2))
Traceback (most recent call last):
...
KeyError: (1, 2)
>>> class Unhashable:
... def __hash__(self):
... raise ValueError
>>> test(d, Unhashable())
Traceback (most recent call last):
...
ValueError
>>> test(None, 1)
Traceback (most recent call last):
...
TypeError: 'NoneType' object is unsubscriptable
"""
return d[index]
cdef class Subscriptable:
def __getitem__(self, key):
return key
\ No newline at end of file
......@@ -135,6 +135,12 @@ __doc__ = ur"""
>>> print (f_D.__doc__)
f_D(long double D) -> long double
>>> print (f_my_i.__doc__)
f_my_i(MyInt i) -> MyInt
>>> print (f_my_f.__doc__)
f_my_f(MyFloat f) -> MyFloat
"""
cdef class Ext:
......@@ -279,3 +285,11 @@ cpdef double f_d(double d):
cpdef long double f_D(long double D):
return D
ctypedef int MyInt
cpdef MyInt f_my_i(MyInt i):
return i
ctypedef float MyFloat
cpdef MyFloat f_my_f(MyFloat f):
return f
......@@ -3,12 +3,17 @@ cdef class Spam:
property eggs:
def __get__(self):
"""
This is the docstring for Spam.eggs.__get__
"""
return 42
def tomato():
"""
>>> tomato()
42
>>> sorted(__test__.keys())
[u'Spam.eggs.__get__ (line 5)', u'tomato (line 11)']
"""
cdef Spam spam
cdef object lettuce
......
......@@ -19,7 +19,7 @@ def range_loop_indices():
Optimized integer for loops using range() should follow Python behavior,
and leave the index variable with the last value of the range.
"""
cdef int i, j, k=0, l, m
cdef int i, j, k=0, l=10, m=10
for i in range(10): pass
for j in range(2,10): pass
for k in range(0,10,get_step()): pass
......
def call_iter1(x):
"""
>>> [ i for i in iter([1,2,3]) ]
[1, 2, 3]
>>> [ i for i in call_iter1([1,2,3]) ]
[1, 2, 3]
"""
return iter(x)
class Ints(object):
def __init__(self):
self.i = 0
def __call__(self):
self.i += 1
if self.i > 10:
raise ValueError
return self.i
def call_iter2(x, sentinel):
"""
>>> [ i for i in iter(Ints(), 3) ]
[1, 2]
>>> [ i for i in call_iter2(Ints(), 3) ]
[1, 2]
"""
return iter(x, sentinel)
......@@ -13,3 +13,5 @@ def myfunc():
for i from 0 <= i < A.shape[0]:
A[i, :] /= 2
return A[0,0]
include "numpy_common.pxi"
cdef extern from *:
void import_array()
void import_umath()
if 0:
import_array()
import_umath()
......@@ -428,3 +428,5 @@ def test_point_record():
test[i].x = i
test[i].y = -i
print repr(test).replace('<', '!').replace('>', '!')
include "numpy_common.pxi"
......@@ -107,6 +107,60 @@ def swap_attr_values(A a, A b):
a.x, a.y, b.x, b.y = b.y, b.x, a.y, a.x # reverse
cdef class B:
cdef readonly A a1
cdef readonly A a2
def __init__(self, x1, y1, x2, y2):
self.a1, self.a2 = A(x1, y1), A(x2, y2)
@cython.test_assert_path_exists(
"//ParallelAssignmentNode",
"//ParallelAssignmentNode/SingleAssignmentNode",
"//ParallelAssignmentNode/SingleAssignmentNode/CoerceToTempNode",
"//ParallelAssignmentNode/SingleAssignmentNode/CoerceToTempNode[@use_managed_ref=False]",
"//ParallelAssignmentNode/SingleAssignmentNode//AttributeNode/NameNode",
"//ParallelAssignmentNode/SingleAssignmentNode//AttributeNode[@use_managed_ref=False]/NameNode",
)
@cython.test_fail_if_path_exists(
"//ParallelAssignmentNode/SingleAssignmentNode/CoerceToTempNode[@use_managed_ref=True]",
"//ParallelAssignmentNode/SingleAssignmentNode/AttributeNode[@use_managed_ref=True]",
)
def swap_recursive_attr_values(B a, B b):
"""
>>> a, b = B(1,2,3,4), B(5,6,7,8)
>>> a.a1.x, a.a1.y, a.a2.x, a.a2.y
(1, 2, 3, 4)
>>> b.a1.x, b.a1.y, b.a2.x, b.a2.y
(5, 6, 7, 8)
>>> swap_recursive_attr_values(a,b)
>>> a.a1.x, a.a1.y, a.a2.x, a.a2.y
(2, 1, 4, 4)
>>> b.a1.x, b.a1.y, b.a2.x, b.a2.y
(6, 5, 8, 8)
# compatibility test
>>> class A:
... def __init__(self, x, y):
... self.x, self.y = x, y
>>> class B:
... def __init__(self, x1, y1, x2, y2):
... self.a1, self.a2 = A(x1, y1), A(x2, y2)
>>> a, b = B(1,2,3,4), B(5,6,7,8)
>>> a.a1, a.a2 = a.a2, a.a1
>>> b.a1, b.a2 = b.a2, b.a1
>>> a.a1, a.a1.x, a.a2.y, a.a2, a.a1.y, a.a2.x = a.a2, a.a2.y, a.a1.x, a.a1, a.a2.x, a.a1.y
>>> b.a1, b.a1.x, b.a2.y, b.a2, b.a1.y, b.a2.x = b.a2, b.a2.y, b.a1.x, b.a1, b.a2.x, b.a1.y
>>> a.a1.x, a.a1.y, a.a2.x, a.a2.y
(2, 1, 4, 4)
>>> b.a1.x, b.a1.y, b.a2.x, b.a2.y
(6, 5, 8, 8)
"""
a.a1, a.a2 = a.a2, a.a1
b.a1, b.a2 = b.a2, b.a1
a.a1, a.a1.x, a.a2.y, a.a2, a.a1.y, a.a2.x = a.a2, a.a2.y, a.a1.x, a.a1, a.a2.x, a.a1.y
b.a1, b.a1.x, b.a2.y, b.a2, b.a1.y, b.a2.x = b.a2, b.a2.y, b.a1.x, b.a1, b.a2.x, b.a1.y
@cython.test_assert_path_exists(
# "//ParallelAssignmentNode",
# "//ParallelAssignmentNode/SingleAssignmentNode",
......
def f(a, b):
def print_to_stdout(a, b):
"""
>>> f(1, 'test')
>>> print_to_stdout(1, 'test')
<BLANKLINE>
1
1 test
......@@ -14,3 +14,29 @@ def f(a, b):
print a, b
print a, b,
print 42, u"spam"
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def print_to_stringio(stream, a, b):
"""
>>> stream = StringIO()
>>> print_to_stringio(stream, 1, 'test')
>>> print(stream.getvalue())
<BLANKLINE>
1
1 test
1 test
1 test 42 spam
<BLANKLINE>
"""
print >> stream
print >> stream, a
print >> stream, a,
print >> stream, b
print >> stream, a, b
print >> stream, a, b,
print >> stream, 42, u"spam"
# Py2.6 and later only!
from __future__ import print_function
def print_to_stdout(a, b):
"""
>>> print_to_stdout(1, 'test')
<BLANKLINE>
1
1 test
1 test
1 test 42 spam
"""
print()
print(a)
print(a, end=' ')
print(b)
print(a, b)
print(a, b, end=' ')
print(42, u"spam")
def print_assign(a, b):
"""
>>> print_assign(1, 'test')
<BLANKLINE>
1
1 test
1 test
1 test 42 spam
"""
x = print
x()
x(a)
x(a, end=' ')
x(b)
x(a, b)
x(a, b, end=' ')
x(42, u"spam")
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def print_to_stringio(stream, a, b):
"""
>>> stream = StringIO()
>>> print_to_stringio(stream, 1, 'test')
>>> print(stream.getvalue())
<BLANKLINE>
1
1 test
1 test
1 test 42 spam
<BLANKLINE>
"""
print(file=stream)
print(a, file=stream)
print(a, end=' ', file=stream)
print(b, file=stream)
print(a, b, file=stream)
print(a, b, end=' ', file=stream)
print(42, u"spam", file=stream)
cdef extern from *:
ctypedef class __builtin__.list [ object PyListObject ]:
pass
def slice_of_typed_value():
"""
>>> slice_of_typed_value()
[1, 2, 3]
"""
cdef object a = []
cdef list L = [1, 2, 3]
a[:] = L
return a
......@@ -261,7 +261,7 @@ def safe_only():
res = ~d
assert typeof(d) == "long", typeof(d)
# potentially overflowing arithmatic
# potentially overflowing arithmetic
e = 1
e += 1
assert typeof(e) == "Python object", typeof(e)
......
cimport cppwrap_lib
cimport cpp_overload_wrapper_lib as cppwrap_lib
cdef class DoubleKeeper:
"""
>>> d = DoubleKeeper()
>>> d.get_number()
1.0
>>> d.set_number(5.5)
>>> d.get_number()
5.5
>>> d.set_number(0)
>>> d.get_number()
0.0
"""
cdef cppwrap_lib.DoubleKeeper* keeper
def __cinit__(self, number=None):
......@@ -23,14 +34,33 @@ cdef class DoubleKeeper:
return self.keeper.get_number()
def transmogrify(self, double value):
"""
>>> d = DoubleKeeper(5.5)
>>> d.transmogrify(1.0)
5.5
>>> d.transmogrify(2.0)
11.0
"""
return self.keeper.transmogrify(value)
def voidfunc():
"""
>>> voidfunc()
"""
cppwrap_lib.voidfunc()
def doublefunc(double x, double y, double z):
"""
>>> doublefunc(1.0, 2.0, 3.0) == 1.0 + 2.0 + 3.0
True
"""
return cppwrap_lib.doublefunc(x, y, z)
def transmogrify_from_cpp(DoubleKeeper obj not None, double value):
"""
>>> d = DoubleKeeper(2.0)
>>> d.transmogrify(3.0) == 6.0
True
"""
return cppwrap_lib.transmogrify_from_cpp(obj.keeper, value)
#include "cppwrap_lib.h"
#include "cpp_overload_wrapper_lib.h"
void voidfunc (void)
{
......
cdef extern from "testapi.h":
cdef extern from "cpp_overload_wrapper_lib.h":
void voidfunc()
double doublefunc(double a, double b, double c)
......
......@@ -2,6 +2,18 @@
cimport cppwrap_lib
cdef class DoubleKeeper:
"""
>>> d = DoubleKeeper(1.0)
>>> d.get_number() == 1.0
True
>>> d.get_number() == 2.0
False
>>> d.set_number(2.0)
>>> d.get_number() == 2.0
True
>>> d.transmogrify(3.0) == 6.0
True
"""
cdef cppwrap_lib.DoubleKeeper* keeper
def __cinit__(self, double number):
......@@ -21,10 +33,22 @@ cdef class DoubleKeeper:
def voidfunc():
"""
>>> voidfunc()
"""
cppwrap_lib.voidfunc()
def doublefunc(double x, double y, double z):
"""
>>> doublefunc(1.0, 2.0, 3.0) == 1.0 + 2.0 + 3.0
True
"""
return cppwrap_lib.doublefunc(x, y, z)
def transmogrify_from_cpp(DoubleKeeper obj not None, double value):
"""
>>> d = DoubleKeeper(2.0)
>>> d.transmogrify(3.0) == 6.0
True
"""
return cppwrap_lib.transmogrify_from_cpp(obj.keeper, value)
......@@ -10,12 +10,6 @@ double doublefunc (double a, double b, double c)
return a + b + c;
}
DoubleKeeper::DoubleKeeper ()
: number (1.0)
{
}
DoubleKeeper::DoubleKeeper (double factor)
: number (factor)
{
......@@ -35,11 +29,6 @@ void DoubleKeeper::set_number (double f)
number = f;
}
void DoubleKeeper::set_number ()
{
number = 1.0;
}
double
DoubleKeeper::transmogrify (double value) const
{
......
cdef extern from "testapi.h":
cdef extern from "cppwrap_lib.h":
void voidfunc()
double doublefunc(double a, double b, double c)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment