Commit b978c7e0 authored by Robert Bradshaw's avatar Robert Bradshaw

merge

parents 0a0e72a7 fa55b747
...@@ -20,7 +20,7 @@ class EmbedSignature(CythonTransform): ...@@ -20,7 +20,7 @@ class EmbedSignature(CythonTransform):
try: try:
denv = self.denv # XXX denv = self.denv # XXX
ctval = default_val.compile_time_value(self.denv) ctval = default_val.compile_time_value(self.denv)
repr_val = '%r' % ctval repr_val = repr(ctval)
if isinstance(default_val, ExprNodes.UnicodeNode): if isinstance(default_val, ExprNodes.UnicodeNode):
if repr_val[:1] != 'u': if repr_val[:1] != 'u':
return u'u%s' % repr_val return u'u%s' % repr_val
...@@ -28,8 +28,8 @@ class EmbedSignature(CythonTransform): ...@@ -28,8 +28,8 @@ class EmbedSignature(CythonTransform):
if repr_val[:1] != 'b': if repr_val[:1] != 'b':
return u'b%s' % repr_val return u'b%s' % repr_val
elif isinstance(default_val, ExprNodes.StringNode): elif isinstance(default_val, ExprNodes.StringNode):
if repr_val[:1] in ('u', 'b'): if repr_val[:1] in 'ub':
repr_val[1:] return repr_val[1:]
return repr_val return repr_val
except Exception: except Exception:
try: try:
......
...@@ -637,6 +637,8 @@ class GlobalState(object): ...@@ -637,6 +637,8 @@ class GlobalState(object):
def put_cached_builtin_init(self, pos, name, cname): def put_cached_builtin_init(self, pos, name, cname):
w = self.parts['cached_builtins'] w = self.parts['cached_builtins']
interned_cname = self.get_interned_identifier(name).cname interned_cname = self.get_interned_identifier(name).cname
from ExprNodes import get_name_interned_utility_code
self.use_utility_code(get_name_interned_utility_code)
w.putln('%s = __Pyx_GetName(%s, %s); if (!%s) %s' % ( w.putln('%s = __Pyx_GetName(%s, %s); if (!%s) %s' % (
cname, cname,
Naming.builtins_cname, Naming.builtins_cname,
...@@ -667,7 +669,7 @@ class GlobalState(object): ...@@ -667,7 +669,7 @@ class GlobalState(object):
decls_writer = self.parts['decls'] decls_writer = self.parts['decls']
for _, cname, c in c_consts: for _, cname, c in c_consts:
decls_writer.putln('static char %s[] = "%s";' % ( decls_writer.putln('static char %s[] = "%s";' % (
cname, StringEncoding.split_docstring(c.escaped_value))) cname, StringEncoding.split_string_literal(c.escaped_value)))
if c.py_strings is not None: if c.py_strings is not None:
for py_string in c.py_strings.itervalues(): for py_string in c.py_strings.itervalues():
py_strings.append((c.cname, len(py_string.cname), py_string)) py_strings.append((c.cname, len(py_string.cname), py_string))
......
...@@ -1246,8 +1246,8 @@ class NameNode(AtomicExprNode): ...@@ -1246,8 +1246,8 @@ class NameNode(AtomicExprNode):
self.is_temp = 0 self.is_temp = 0
else: else:
self.is_temp = 1 self.is_temp = 1
env.use_utility_code(get_name_interned_utility_code)
self.is_used_as_rvalue = 1 self.is_used_as_rvalue = 1
env.use_utility_code(get_name_interned_utility_code)
def nogil_check(self, env): def nogil_check(self, env):
if self.is_used_as_rvalue: if self.is_used_as_rvalue:
...@@ -1334,6 +1334,7 @@ class NameNode(AtomicExprNode): ...@@ -1334,6 +1334,7 @@ class NameNode(AtomicExprNode):
namespace = Naming.builtins_cname namespace = Naming.builtins_cname
else: # entry.is_pyglobal else: # entry.is_pyglobal
namespace = entry.scope.namespace_cname namespace = entry.scope.namespace_cname
code.globalstate.use_utility_code(get_name_interned_utility_code)
code.putln( code.putln(
'%s = __Pyx_GetName(%s, %s); %s' % ( '%s = __Pyx_GetName(%s, %s); %s' % (
self.result(), self.result(),
...@@ -1554,7 +1555,13 @@ class IteratorNode(ExprNode): ...@@ -1554,7 +1555,13 @@ class IteratorNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.sequence.analyse_types(env) self.sequence.analyse_types(env)
self.sequence = self.sequence.coerce_to_pyobject(env) if isinstance(self.sequence, SliceIndexNode) and \
(self.sequence.base.type.is_array or self.sequence.base.type.is_ptr) \
or self.sequence.type.is_array and self.sequence.type.size is not None:
# C array iteration will be transformed later on
pass
else:
self.sequence = self.sequence.coerce_to_pyobject(env)
self.is_temp = 1 self.is_temp = 1
gil_message = "Iterating over Python object" gil_message = "Iterating over Python object"
...@@ -2444,6 +2451,15 @@ class CallNode(ExprNode): ...@@ -2444,6 +2451,15 @@ class CallNode(ExprNode):
self.analyse_types(env) self.analyse_types(env)
self.coerce_to(type, env) self.coerce_to(type, env)
return True return True
elif type and type.is_cpp_class:
for arg in self.args:
arg.analyse_types(env)
constructor = type.scope.lookup("<init>")
self.function = RawCNameExprNode(self.function.pos, constructor.type)
self.function.entry = constructor
self.function.set_cname(type.declaration_code(""))
self.analyse_c_function_call(env)
return True
def nogil_check(self, env): def nogil_check(self, env):
func_type = self.function_type() func_type = self.function_type()
...@@ -2490,6 +2506,8 @@ class SimpleCallNode(CallNode): ...@@ -2490,6 +2506,8 @@ class SimpleCallNode(CallNode):
def infer_type(self, env): def infer_type(self, env):
function = self.function function = self.function
func_type = function.infer_type(env) func_type = function.infer_type(env)
if isinstance(self.function, NewExprNode):
return PyrexTypes.CPtrType(self.function.class_type)
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
if func_type.is_cfunction: if func_type.is_cfunction:
...@@ -4708,7 +4726,9 @@ class BinopNode(ExprNode): ...@@ -4708,7 +4726,9 @@ class BinopNode(ExprNode):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.is_py_operation(): if self.is_py_operation():
self.coerce_operands_to_pyobjects(env) self.coerce_operands_to_pyobjects(env)
self.type = py_object_type self.type = self.result_type(self.operand1.type,
self.operand2.type)
assert self.type.is_pyobject
self.is_temp = 1 self.is_temp = 1
elif self.is_cpp_operation(): elif self.is_cpp_operation():
self.analyse_cpp_operation(env) self.analyse_cpp_operation(env)
...@@ -4750,6 +4770,30 @@ class BinopNode(ExprNode): ...@@ -4750,6 +4770,30 @@ class BinopNode(ExprNode):
def result_type(self, type1, type2): def result_type(self, type1, type2):
if self.is_py_operation_types(type1, type2): if self.is_py_operation_types(type1, type2):
if type2.is_string:
type2 = Builtin.bytes_type
if type1.is_string:
type1 = Builtin.bytes_type
elif self.operator == '%' \
and type1 in (Builtin.str_type, Builtin.unicode_type):
# note that b'%s' % b'abc' doesn't work in Py3
return type1
if type1.is_builtin_type:
if type1 is type2:
if self.operator in '**%+|&^':
# FIXME: at least these operators should be safe - others?
return type1
elif self.operator == '*':
if type1 in (Builtin.bytes_type, Builtin.str_type, Builtin.unicode_type):
return type1
# multiplication of containers/numbers with an
# integer value always (?) returns the same type
if type2.is_int:
return type1
elif type2.is_builtin_type and type1.is_int and self.operator == '*':
# multiplication of containers/numbers with an
# integer value always (?) returns the same type
return type2
return py_object_type return py_object_type
else: else:
return self.compute_c_result_type(type1, type2) return self.compute_c_result_type(type1, type2)
......
...@@ -88,7 +88,7 @@ class Context(object): ...@@ -88,7 +88,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from TypeInference import MarkAssignments from TypeInference import MarkAssignments, MarkOverflowingArithmatic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
from AutoDocTransforms import EmbedSignature from AutoDocTransforms import EmbedSignature
...@@ -135,6 +135,7 @@ class Context(object): ...@@ -135,6 +135,7 @@ class Context(object):
EmbedSignature(self), EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), EarlyReplaceBuiltinCalls(self),
MarkAssignments(self), MarkAssignments(self),
MarkOverflowingArithmatic(self),
TransformBuiltinMethods(self), TransformBuiltinMethods(self),
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
...@@ -218,8 +219,11 @@ class Context(object): ...@@ -218,8 +219,11 @@ class Context(object):
for phase in pipeline: for phase in pipeline:
if phase is not None: if phase is not None:
if DebugFlags.debug_verbose_pipeline: if DebugFlags.debug_verbose_pipeline:
t = time()
print "Entering pipeline phase %r" % phase print "Entering pipeline phase %r" % phase
data = phase(data) data = phase(data)
if DebugFlags.debug_verbose_pipeline:
print " %.3f seconds" % (time() - t)
except CompileError, err: except CompileError, err:
# err is set # err is set
Errors.report_error(err) Errors.report_error(err)
......
...@@ -22,7 +22,7 @@ from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \ ...@@ -22,7 +22,7 @@ from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \
StructOrUnionScope, PyClassScope, CClassScope, CppClassScope StructOrUnionScope, PyClassScope, CClassScope, CppClassScope
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
from Code import UtilityCode from Code import UtilityCode
from StringEncoding import EncodedString, escape_byte_string, split_docstring from StringEncoding import EncodedString, escape_byte_string, split_string_literal
import Options import Options
import ControlFlow import ControlFlow
import DebugFlags import DebugFlags
...@@ -661,6 +661,17 @@ class CArgDeclNode(Node): ...@@ -661,6 +661,17 @@ class CArgDeclNode(Node):
base_type = self.base_type.analyse(env, could_be_name = could_be_name) base_type = self.base_type.analyse(env, could_be_name = could_be_name)
if hasattr(self.base_type, 'arg_name') and self.base_type.arg_name: if hasattr(self.base_type, 'arg_name') and self.base_type.arg_name:
self.declarator.name = self.base_type.arg_name self.declarator.name = self.base_type.arg_name
# The parser is unable to resolve the ambiguity of [] as part of the
# type (e.g. in buffers) or empty declarator (as with arrays).
# This is only arises for empty multi-dimensional arrays.
if (base_type.is_array
and isinstance(self.base_type, TemplatedTypeNode)
and isinstance(self.declarator, CArrayDeclaratorNode)):
declarator = self.declarator
while isinstance(declarator.base, CArrayDeclaratorNode):
declarator = declarator.base
declarator.base = self.base_type.array_declarator
base_type = base_type.base_type
return self.declarator.analyse(base_type, env, nonempty = nonempty) return self.declarator.analyse(base_type, env, nonempty = nonempty)
else: else:
return self.name_declarator, self.type return self.name_declarator, self.type
...@@ -821,7 +832,7 @@ class TemplatedTypeNode(CBaseTypeNode): ...@@ -821,7 +832,7 @@ class TemplatedTypeNode(CBaseTypeNode):
else: else:
# Array # Array
empty_declarator = CNameDeclaratorNode(self.pos, name="") empty_declarator = CNameDeclaratorNode(self.pos, name="", cname=None)
if len(self.positional_args) > 1 or self.keyword_args.key_value_pairs: if len(self.positional_args) > 1 or self.keyword_args.key_value_pairs:
error(self.pos, "invalid array declaration") error(self.pos, "invalid array declaration")
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
...@@ -832,9 +843,10 @@ class TemplatedTypeNode(CBaseTypeNode): ...@@ -832,9 +843,10 @@ class TemplatedTypeNode(CBaseTypeNode):
dimension = None dimension = None
else: else:
dimension = self.positional_args[0] dimension = self.positional_args[0]
self.type = CArrayDeclaratorNode(self.pos, self.array_declarator = CArrayDeclaratorNode(self.pos,
base = empty_declarator, base = empty_declarator,
dimension = dimension).analyse(base_type, env)[1] dimension = dimension)
self.type = self.array_declarator.analyse(base_type, env)[1]
return self.type return self.type
...@@ -2052,7 +2064,7 @@ class DefNode(FuncDefNode): ...@@ -2052,7 +2064,7 @@ class DefNode(FuncDefNode):
code.putln( code.putln(
'static char %s[] = "%s";' % ( 'static char %s[] = "%s";' % (
self.entry.doc_cname, self.entry.doc_cname,
split_docstring(escape_byte_string(docstr)))) split_string_literal(escape_byte_string(docstr))))
if with_pymethdef: if with_pymethdef:
code.put( code.put(
"static PyMethodDef %s = " % "static PyMethodDef %s = " %
...@@ -4166,7 +4178,7 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -4166,7 +4178,7 @@ class ForFromStatNode(LoopNode, StatNode):
target_node = ExprNodes.PyTempNode(self.target.pos, None) target_node = ExprNodes.PyTempNode(self.target.pos, None)
target_node.allocate(code) target_node.allocate(code)
interned_cname = code.intern_identifier(self.target.entry.name) interned_cname = code.intern_identifier(self.target.entry.name)
code.putln("/*here*/") code.globalstate.use_utility_code(ExprNodes.get_name_interned_utility_code)
code.putln("%s = __Pyx_GetName(%s, %s); %s" % ( code.putln("%s = __Pyx_GetName(%s, %s); %s" % (
target_node.result(), target_node.result(),
Naming.module_cname, Naming.module_cname,
......
...@@ -89,10 +89,12 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -89,10 +89,12 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_dict_iteration( return self._transform_dict_iteration(
node, dict_obj=iterator, keys=True, values=False) node, dict_obj=iterator, keys=True, values=False)
# C array slice iteration? # C array (slice) iteration?
if isinstance(iterator, ExprNodes.SliceIndexNode) and \ if isinstance(iterator, ExprNodes.SliceIndexNode) and \
(iterator.base.type.is_array or iterator.base.type.is_ptr): (iterator.base.type.is_array or iterator.base.type.is_ptr):
return self._transform_carray_iteration(node, iterator) return self._transform_carray_iteration(node, iterator)
elif iterator.type.is_array:
return self._transform_carray_iteration(node, iterator)
elif not isinstance(iterator, ExprNodes.SimpleCallNode): elif not isinstance(iterator, ExprNodes.SimpleCallNode):
return node return node
...@@ -131,13 +133,26 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -131,13 +133,26 @@ class IterationTransform(Visitor.VisitorTransform):
return node return node
def _transform_carray_iteration(self, node, slice_node): def _transform_carray_iteration(self, node, slice_node):
start = slice_node.start if isinstance(slice_node, ExprNodes.SliceIndexNode):
stop = slice_node.stop slice_base = slice_node.base
step = None start = slice_node.start
if not stop: stop = slice_node.stop
step = None
if not stop:
return node
elif slice_node.type.is_array and slice_node.type.size is not None:
slice_base = slice_node
start = None
stop = ExprNodes.IntNode(
slice_node.pos, value=str(slice_node.type.size))
step = None
else:
return node return node
carray_ptr = slice_node.base.coerce_to_simple(self.current_scope) ptr_type = slice_base.type
if ptr_type.is_array:
ptr_type = ptr_type.element_ptr_type()
carray_ptr = slice_base.coerce_to_simple(self.current_scope)
if start and start.constant_result != 0: if start and start.constant_result != 0:
start_ptr_node = ExprNodes.AddNode( start_ptr_node = ExprNodes.AddNode(
...@@ -145,7 +160,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -145,7 +160,7 @@ class IterationTransform(Visitor.VisitorTransform):
operand1=carray_ptr, operand1=carray_ptr,
operator='+', operator='+',
operand2=start, operand2=start,
type=carray_ptr.type) type=ptr_type)
else: else:
start_ptr_node = carray_ptr start_ptr_node = carray_ptr
...@@ -154,13 +169,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -154,13 +169,13 @@ class IterationTransform(Visitor.VisitorTransform):
operand1=carray_ptr, operand1=carray_ptr,
operator='+', operator='+',
operand2=stop, operand2=stop,
type=carray_ptr.type type=ptr_type
).coerce_to_simple(self.current_scope) ).coerce_to_simple(self.current_scope)
counter = UtilNodes.TempHandle(carray_ptr.type) counter = UtilNodes.TempHandle(ptr_type)
counter_temp = counter.ref(node.target.pos) counter_temp = counter.ref(node.target.pos)
if slice_node.base.type.is_string and node.target.type.is_pyobject: if slice_base.type.is_string and node.target.type.is_pyobject:
# special case: char* -> bytes # special case: char* -> bytes
target_value = ExprNodes.SliceIndexNode( target_value = ExprNodes.SliceIndexNode(
node.target.pos, node.target.pos,
...@@ -181,7 +196,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -181,7 +196,7 @@ class IterationTransform(Visitor.VisitorTransform):
type=PyrexTypes.c_int_type), type=PyrexTypes.c_int_type),
base=counter_temp, base=counter_temp,
is_buffer_access=False, is_buffer_access=False,
type=carray_ptr.type.base_type) type=ptr_type.base_type)
if target_value.type != node.target.type: if target_value.type != node.target.type:
target_value = target_value.coerce_to(node.target.type, target_value = target_value.coerce_to(node.target.type,
...@@ -1606,20 +1621,20 @@ impl = "" ...@@ -1606,20 +1621,20 @@ impl = ""
pop_utility_code = UtilityCode( pop_utility_code = UtilityCode(
proto = """ proto = """
static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) { static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
#if PY_VERSION_HEX >= 0x02040000
if (likely(PyList_CheckExact(L)) if (likely(PyList_CheckExact(L))
/* Check that both the size is positive and no reallocation shrinking needs to be done. */ /* Check that both the size is positive and no reallocation shrinking needs to be done. */
&& likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) { && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
Py_SIZE(L) -= 1; Py_SIZE(L) -= 1;
return PyList_GET_ITEM(L, PyList_GET_SIZE(L)); return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
} }
else { #endif
PyObject *r, *m; PyObject *r, *m;
m = __Pyx_GetAttrString(L, "pop"); m = __Pyx_GetAttrString(L, "pop");
if (!m) return NULL; if (!m) return NULL;
r = PyObject_CallObject(m, NULL); r = PyObject_CallObject(m, NULL);
Py_DECREF(m); Py_DECREF(m);
return r; return r;
}
} }
""", """,
impl = "" impl = ""
...@@ -1632,6 +1647,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix); ...@@ -1632,6 +1647,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
impl = """ impl = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) { static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
PyObject *r, *m, *t, *py_ix; PyObject *r, *m, *t, *py_ix;
#if PY_VERSION_HEX >= 0x02040000
if (likely(PyList_CheckExact(L))) { if (likely(PyList_CheckExact(L))) {
Py_ssize_t size = PyList_GET_SIZE(L); Py_ssize_t size = PyList_GET_SIZE(L);
if (likely(size > (((PyListObject*)L)->allocated >> 1))) { if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
...@@ -1650,6 +1666,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) { ...@@ -1650,6 +1666,7 @@ static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
} }
} }
} }
#endif
py_ix = t = NULL; py_ix = t = NULL;
m = __Pyx_GetAttrString(L, "pop"); m = __Pyx_GetAttrString(L, "pop");
if (!m) goto bad; if (!m) goto bad;
......
...@@ -10,6 +10,15 @@ cython.declare(Nodes=object, ExprNodes=object, EncodedString=object) ...@@ -10,6 +10,15 @@ cython.declare(Nodes=object, ExprNodes=object, EncodedString=object)
import os import os
import re import re
import sys import sys
try:
from __builtin__ import set
except ImportError:
try:
from builtins import set
except ImportError:
from sets import Set as set
from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor
import Nodes import Nodes
import ExprNodes import ExprNodes
......
...@@ -1893,6 +1893,8 @@ class CppClassType(CType): ...@@ -1893,6 +1893,8 @@ class CppClassType(CType):
def assignable_from_resolved_type(self, other_type): def assignable_from_resolved_type(self, other_type):
# TODO: handle operator=(...) here? # TODO: handle operator=(...) here?
if other_type is error_type:
return True
return other_type.is_cpp_class and other_type.is_subclass(self) return other_type.is_cpp_class and other_type.is_subclass(self)
def attributes_known(self): def attributes_known(self):
...@@ -2369,6 +2371,17 @@ def spanning_type(type1, type2): ...@@ -2369,6 +2371,17 @@ def spanning_type(type1, type2):
return widest_numeric_type(type1, c_double_type) return widest_numeric_type(type1, c_double_type)
elif type1.is_pyobject ^ type2.is_pyobject: elif type1.is_pyobject ^ type2.is_pyobject:
return py_object_type return py_object_type
elif type1.is_extension_type and type2.is_extension_type:
if type1.typeobj_is_imported() or type2.typeobj_is_imported():
return py_object_type
while True:
if type1.subtype_of(type2):
return type2
elif type2.subtype_of(type1):
return type1
type1, type2 = type1.base_type, type2.base_type
if type1 is None or type2 is None:
return py_object_type
elif type1.assignable_from(type2): elif type1.assignable_from(type2):
if type1.is_extension_type and type1.typeobj_is_imported(): if type1.is_extension_type and type1.typeobj_is_imported():
# external types are unsafe, so we use PyObject instead # external types are unsafe, so we use PyObject instead
......
...@@ -97,7 +97,10 @@ def initial_compile_time_env(): ...@@ -97,7 +97,10 @@ def initial_compile_time_env():
'UNAME_VERSION', 'UNAME_MACHINE') 'UNAME_VERSION', 'UNAME_MACHINE')
for name, value in zip(names, platform.uname()): for name, value in zip(names, platform.uname()):
benv.declare(name, value) benv.declare(name, value)
import __builtin__ as builtins try:
import __builtin__ as builtins
except ImportError:
import builtins
names = ('False', 'True', names = ('False', 'True',
'abs', 'bool', 'chr', 'cmp', 'complex', 'dict', 'divmod', 'enumerate', 'abs', 'bool', 'chr', 'cmp', 'complex', 'dict', 'divmod', 'enumerate',
'float', 'hash', 'hex', 'int', 'len', 'list', 'long', 'map', 'max', 'min', 'float', 'hash', 'hex', 'int', 'len', 'list', 'long', 'map', 'max', 'min',
......
...@@ -185,9 +185,9 @@ def escape_byte_string(s): ...@@ -185,9 +185,9 @@ def escape_byte_string(s):
append(c) append(c)
return join_bytes(l).decode('ISO-8859-1') return join_bytes(l).decode('ISO-8859-1')
def split_docstring(s): def split_string_literal(s):
# MSVC can't handle long string literals. # MSVC can't handle long string literals.
if len(s) < 2047: if len(s) < 2047:
return s return s
else: else:
return '""'.join([s[i:i+2000] for i in range(0, len(s), 2000)]) return '""'.join([s[i:i+2000] for i in range(0, len(s), 2000)]).replace(r'\""', '""\\')
...@@ -119,6 +119,8 @@ class Entry(object): ...@@ -119,6 +119,8 @@ class Entry(object):
# inline_func_in_pxd boolean Hacky special case for inline function in pxd file. # inline_func_in_pxd boolean Hacky special case for inline function in pxd file.
# Ideally this should not be necesarry. # Ideally this should not be necesarry.
# assignments [ExprNode] List of expressions that get assigned to this entry. # assignments [ExprNode] List of expressions that get assigned to this entry.
# might_overflow boolean In an arithmatic expression that could cause
# overflow (used for type inference).
inline_func_in_pxd = False inline_func_in_pxd = False
borrowed = 0 borrowed = 0
...@@ -167,6 +169,7 @@ class Entry(object): ...@@ -167,6 +169,7 @@ class Entry(object):
is_overridable = 0 is_overridable = 0
buffer_aux = None buffer_aux = None
prev_entry = None prev_entry = None
might_overflow = 0
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -434,7 +437,7 @@ class Scope(object): ...@@ -434,7 +437,7 @@ class Scope(object):
if type.is_cpp_class and visibility != 'extern': if type.is_cpp_class and visibility != 'extern':
constructor = type.scope.lookup(u'<init>') constructor = type.scope.lookup(u'<init>')
if constructor is not None and PyrexTypes.best_match([], constructor.all_alternatives()) is None: if constructor is not None and PyrexTypes.best_match([], constructor.all_alternatives()) is None:
error(pos, "C++ class must have an empty constructor to be stack allocated") error(pos, "C++ class must have a default constructor to be stack allocated")
entry = self.declare(name, cname, type, pos, visibility) entry = self.declare(name, cname, type, pos, visibility)
entry.is_variable = 1 entry.is_variable = 1
self.control_flow.set_state((), (name, 'initalized'), False) self.control_flow.set_state((), (name, 'initalized'), False)
...@@ -1564,6 +1567,7 @@ class CppClassScope(Scope): ...@@ -1564,6 +1567,7 @@ class CppClassScope(Scope):
if name == self.name.split('::')[-1] and cname is None: if name == self.name.split('::')[-1] and cname is None:
self.check_base_default_constructor(pos) self.check_base_default_constructor(pos)
name = '<init>' name = '<init>'
type.return_type = self.lookup(self.name).type
prev_entry = self.lookup_here(name) prev_entry = self.lookup_here(name)
entry = self.declare_var(name, type, pos, cname, visibility) entry = self.declare_var(name, type, pos, cname, visibility)
if prev_entry: if prev_entry:
......
...@@ -112,6 +112,62 @@ class MarkAssignments(CythonTransform): ...@@ -112,6 +112,62 @@ class MarkAssignments(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class MarkOverflowingArithmatic(CythonTransform):
# It may be possible to integrate this with the above for
# performance improvements (though likely not worth it).
might_overflow = False
def __call__(self, root):
self.env_stack = []
self.env = root.scope
return super(MarkOverflowingArithmatic, self).__call__(root)
def visit_safe_node(self, node):
self.might_overflow, saved = False, self.might_overflow
self.visitchildren(node)
self.might_overflow = saved
return node
def visit_neutral_node(self, node):
self.visitchildren(node)
return node
def visit_dangerous_node(self, node):
self.might_overflow, saved = True, self.might_overflow
self.visitchildren(node)
self.might_overflow = saved
return node
def visit_FuncDefNode(self, node):
self.env_stack.append(self.env)
self.env = node.local_scope
self.visit_safe_node(node)
self.env = self.env_stack.pop()
return node
def visit_NameNode(self, node):
if self.might_overflow:
entry = node.entry or self.env.lookup(node.name)
if entry:
entry.might_overflow = True
return node
def visit_BinopNode(self, node):
if node.operator in '&|^':
return self.visit_neutral_node(node)
else:
return self.visit_dangerous_node(node)
visit_UnopNode = visit_neutral_node
visit_UnaryMinusNode = visit_dangerous_node
visit_InPlaceAssignmentNode = visit_dangerous_node
visit_Node = visit_safe_node
class PyObjectTypeInferer: class PyObjectTypeInferer:
""" """
...@@ -175,7 +231,7 @@ class SimpleAssignmentTypeInferer: ...@@ -175,7 +231,7 @@ class SimpleAssignmentTypeInferer:
entry = ready_to_infer.pop() entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
if types: if types:
entry.type = spanning_type(types) entry.type = spanning_type(types, entry.might_overflow)
else: else:
# FIXME: raise a warning? # FIXME: raise a warning?
# print "No assignments", entry.pos, entry # print "No assignments", entry.pos, entry
...@@ -188,9 +244,9 @@ class SimpleAssignmentTypeInferer: ...@@ -188,9 +244,9 @@ class SimpleAssignmentTypeInferer:
if len(deps) == 1 and deps == set([entry]): if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()] types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types: if types:
entry.type = spanning_type(types) entry.type = spanning_type(types, entry.might_overflow)
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = spanning_type(types) # might be wider... entry.type = spanning_type(types, entry.might_overflow) # might be wider...
resolve_dependancy(entry) resolve_dependancy(entry)
del dependancies_by_entry[entry] del dependancies_by_entry[entry]
if ready_to_infer: if ready_to_infer:
...@@ -218,11 +274,11 @@ def find_spanning_type(type1, type2): ...@@ -218,11 +274,11 @@ def find_spanning_type(type1, type2):
return PyrexTypes.c_double_type return PyrexTypes.c_double_type
return result_type return result_type
def aggressive_spanning_type(types): def aggressive_spanning_type(types, might_overflow):
result_type = reduce(find_spanning_type, types) result_type = reduce(find_spanning_type, types)
return result_type return result_type
def safe_spanning_type(types): def safe_spanning_type(types, might_overflow):
result_type = reduce(find_spanning_type, types) result_type = reduce(find_spanning_type, types)
if result_type.is_pyobject: if result_type.is_pyobject:
# any specific Python type is always safe to infer # any specific Python type is always safe to infer
...@@ -235,6 +291,22 @@ def safe_spanning_type(types): ...@@ -235,6 +291,22 @@ def safe_spanning_type(types):
# find_spanning_type() only returns 'bint' for clean boolean # find_spanning_type() only returns 'bint' for clean boolean
# operations without other int types, so this is safe, too # operations without other int types, so this is safe, too
return result_type return result_type
elif result_type.is_ptr and not (result_type.is_int and result_type.rank == 0):
# Any pointer except (signed|unsigned|) char* can't implicitly
# become a PyObject.
return result_type
elif result_type.is_cpp_class:
# These can't implicitly become Python objects either.
return result_type
elif result_type.is_struct:
# Though we have struct -> object for some structs, this is uncommonly
# used, won't arise in pure Python, and there shouldn't be side
# effects, so I'm declaring this safe.
return result_type
# TODO: double complex should be OK as well, but we need
# to make sure everything is supported.
elif result_type.is_int and not might_overflow:
return result_type
return py_object_type return py_object_type
......
...@@ -7,5 +7,6 @@ ...@@ -7,5 +7,6 @@
# and keep the old one under the module name _build_ext, # and keep the old one under the module name _build_ext,
# so that *our* build_ext can make use of it. # so that *our* build_ext can make use of it.
from build_ext import build_ext from Cython.Distutils.build_ext import build_ext
# from extension import Extension # from extension import Extension
...@@ -15,16 +15,6 @@ from distutils.sysconfig import customize_compiler, get_python_version ...@@ -15,16 +15,6 @@ from distutils.sysconfig import customize_compiler, get_python_version
from distutils.dep_util import newer, newer_group from distutils.dep_util import newer, newer_group
from distutils import log from distutils import log
from distutils.dir_util import mkpath from distutils.dir_util import mkpath
try:
from Cython.Compiler.Main \
import CompilationOptions, \
default_options as pyrex_default_options, \
compile as cython_compile
from Cython.Compiler.Errors import PyrexError
except ImportError, e:
print "failed to import Cython: %s" % e
PyrexError = None
from distutils.command import build_ext as _build_ext from distutils.command import build_ext as _build_ext
extension_name_re = _build_ext.extension_name_re extension_name_re = _build_ext.extension_name_re
...@@ -83,18 +73,22 @@ class build_ext(_build_ext.build_ext): ...@@ -83,18 +73,22 @@ class build_ext(_build_ext.build_ext):
self.build_extension(ext) self.build_extension(ext)
def cython_sources(self, sources, extension): def cython_sources(self, sources, extension):
""" """
Walk the list of source files in 'sources', looking for Cython Walk the list of source files in 'sources', looking for Cython
source files (.pyx and .py). Run Cython on all that are source files (.pyx and .py). Run Cython on all that are
found, and return a modified 'sources' list with Cython source found, and return a modified 'sources' list with Cython source
files replaced by the generated C (or C++) files. files replaced by the generated C (or C++) files.
""" """
try:
if PyrexError == None: from Cython.Compiler.Main \
raise DistutilsPlatformError, \ import CompilationOptions, \
("Cython does not appear to be installed " default_options as pyrex_default_options, \
"on platform '%s'") % os.name compile as cython_compile
from Cython.Compiler.Errors import PyrexError
except ImportError:
e = sys.exc_info()[1]
print("failed to import Cython: %s" % e)
raise DistutilsPlatformError("Cython does not appear to be installed")
new_sources = [] new_sources = []
pyrex_sources = [] pyrex_sources = []
......
;;;; `Cython' mode. (add-to-list 'auto-mode-alist '("\\.pyx\\'" . cython-mode)) (define-derived-mode cython-mode python-mode "Cython" (font-lock-add-keywords nil `((,(concat "\\<\\(NULL" "\\|c\\(def\\|har\\|typedef\\)" "\\|e\\(num\\|xtern\\)" "\\|float" "\\|in\\(clude\\|t\\)" "\\|object\\|public\\|struct\\|type\\|union\\|void" "\\)\\>") 1 font-lock-keyword-face t)))) ;; Cython mode
\ No newline at end of file
(require 'python-mode)
(add-to-list 'auto-mode-alist '("\\.pyx\\'" . cython-mode))
(add-to-list 'auto-mode-alist '("\\.pxd\\'" . cython-mode))
(add-to-list 'auto-mode-alist '("\\.pxi\\'" . cython-mode))
(defun cython-compile ()
"Compile the file via Cython."
(interactive)
(let ((cy-buffer (current-buffer)))
(with-current-buffer
(compile compile-command)
(set (make-local-variable 'cython-buffer) cy-buffer)
(add-to-list (make-local-variable 'compilation-finish-functions)
'cython-compilation-finish)))
)
(defun cython-compilation-finish (buffer how)
"Called when Cython compilation finishes."
;; XXX could annotate source here
)
(defvar cython-mode-map
(let ((map (make-sparse-keymap)))
;; Will inherit from `python-mode-map' thanks to define-derived-mode.
(define-key map "\C-c\C-c" 'cython-compile)
map)
"Keymap used in `cython-mode'.")
(defvar cython-font-lock-keywords
`(;; new keywords in Cython language
(,(regexp-opt '("by" "cdef" "cimport" "cpdef" "ctypedef" "enum" "except?"
"extern" "gil" "include" "nogil" "property" "public"
"readonly" "struct" "union" "DEF" "IF" "ELIF" "ELSE") 'words)
1 font-lock-keyword-face)
;; C and Python types (highlight as builtins)
(,(regexp-opt '("NULL" "bint" "char" "dict" "double" "float" "int" "list"
"long" "object" "Py_ssize_t" "short" "size_t" "void") 'words)
1 font-lock-builtin-face)
;; cdef is used for more than functions, so simply highlighting the next
;; word is problematic. struct, enum and property work though.
("\\<\\(?:struct\\|enum\\)[ \t]+\\([a-zA-Z_]+[a-zA-Z0-9_]*\\)"
1 py-class-name-face)
("\\<property[ \t]+\\([a-zA-Z_]+[a-zA-Z0-9_]*\\)"
1 font-lock-function-name-face))
"Additional font lock keywords for Cython mode.")
(define-derived-mode cython-mode python-mode "Cython"
"Major mode for Cython development, derived from Python mode.
\\{cython-mode-map}"
(setcar font-lock-defaults
(append python-font-lock-keywords cython-font-lock-keywords))
(set (make-local-variable 'compile-command)
(concat "cython -a " buffer-file-name))
(add-to-list (make-local-variable 'compilation-finish-functions)
'cython-compilation-finish)
)
(provide 'cython-mode)
...@@ -48,8 +48,8 @@ EXT_DEP_INCLUDES = [ ...@@ -48,8 +48,8 @@ EXT_DEP_INCLUDES = [
] ]
VER_DEP_MODULES = { VER_DEP_MODULES = {
# such as: (2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258'
# (2,4) : (operator.le, lambda x: x in ['run.set']), ]),
(3,): (operator.ge, lambda x: x in ['run.non_future_division', (3,): (operator.ge, lambda x: x in ['run.non_future_division',
'compile.extsetslice', 'compile.extsetslice',
'compile.extdelslice']), 'compile.extdelslice']),
......
...@@ -25,6 +25,11 @@ if sys.platform == "darwin": ...@@ -25,6 +25,11 @@ if sys.platform == "darwin":
setup_args = {} setup_args = {}
def add_command_class(name, cls):
cmdclasses = setup_args.get('cmdclass', {})
cmdclasses[name] = cls
setup_args['cmdclass'] = cmdclasses
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
import lib2to3.refactor import lib2to3.refactor
from distutils.command.build_py \ from distutils.command.build_py \
...@@ -34,7 +39,7 @@ if sys.version_info[0] >= 3: ...@@ -34,7 +39,7 @@ if sys.version_info[0] >= 3:
if fix.split('fix_')[-1] not in ('next',) if fix.split('fix_')[-1] not in ('next',)
] ]
build_py.fixer_names = fixers build_py.fixer_names = fixers
setup_args['cmdclass'] = {"build_py" : build_py} add_command_class("build_py", build_py)
if sys.version_info < (2,4): if sys.version_info < (2,4):
...@@ -72,54 +77,84 @@ else: ...@@ -72,54 +77,84 @@ else:
else: else:
scripts = ["cython.py"] scripts = ["cython.py"]
def compile_cython_modules():
source_root = os.path.abspath(os.path.dirname(__file__))
compiled_modules = ["Cython.Plex.Scanners",
"Cython.Compiler.Scanning",
"Cython.Compiler.Parsing",
"Cython.Compiler.Visitor",
"Cython.Runtime.refnanny"]
extensions = []
try:
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
raise ValueError from Cython.Distutils import build_ext as build_ext_orig
sys.argv.remove("--no-cython-compile")
except ValueError:
try:
from distutils.command.build_ext import build_ext as build_ext_orig
class build_ext(build_ext_orig):
def build_extension(self, ext, *args, **kargs):
try:
build_ext_orig.build_extension(self, ext, *args, **kargs)
except StandardError:
print("Compilation of '%s' failed" % ext.sources[0])
from Cython.Compiler.Main import compile
from Cython import Utils
source_root = os.path.dirname(__file__)
compiled_modules = ["Cython.Plex.Scanners",
"Cython.Compiler.Scanning",
"Cython.Compiler.Parsing",
"Cython.Compiler.Visitor",
"Cython.Runtime.refnanny"]
extensions = []
for module in compiled_modules: for module in compiled_modules:
source_file = os.path.join(source_root, *module.split('.')) source_file = os.path.join(source_root, *module.split('.'))
if os.path.exists(source_file + ".py"): if os.path.exists(source_file + ".py"):
pyx_source_file = source_file + ".py" pyx_source_file = source_file + ".py"
else: else:
pyx_source_file = source_file + ".pyx" pyx_source_file = source_file + ".pyx"
c_source_file = source_file + ".c" extensions.append(
if not os.path.exists(c_source_file) or \ Extension(module, sources = [pyx_source_file])
Utils.file_newer_than(pyx_source_file, )
Utils.modification_time(c_source_file)):
print("Compiling module %s ..." % module) class build_ext(build_ext_orig):
result = compile(pyx_source_file) def build_extensions(self):
c_source_file = result.c_file # add path where 2to3 installed the transformed sources
if c_source_file: # and make sure Python (re-)imports them from there
extensions.append( already_imported = [ module for module in sys.modules
Extension(module, sources = [c_source_file]) if module == 'Cython' or module.startswith('Cython.') ]
) for module in already_imported:
else: del sys.modules[module]
print("Compilation failed") sys.path.insert(0, os.path.join(source_root, self.build_lib))
if extensions:
setup_args['ext_modules'] = extensions build_ext_orig.build_extensions(self)
setup_args['cmdclass'] = {"build_ext" : build_ext}
except Exception: setup_args['ext_modules'] = extensions
print("ERROR: %s" % sys.exc_info()[1]) add_command_class("build_ext", build_ext)
print("Extension module compilation failed, using plain Python implementation")
else: # Python 2.x
from distutils.command.build_ext import build_ext as build_ext_orig
try:
class build_ext(build_ext_orig):
def build_extension(self, ext, *args, **kargs):
try:
build_ext_orig.build_extension(self, ext, *args, **kargs)
except StandardError:
print("Compilation of '%s' failed" % ext.sources[0])
from Cython.Compiler.Main import compile
from Cython import Utils
source_root = os.path.dirname(__file__)
for module in compiled_modules:
source_file = os.path.join(source_root, *module.split('.'))
if os.path.exists(source_file + ".py"):
pyx_source_file = source_file + ".py"
else:
pyx_source_file = source_file + ".pyx"
c_source_file = source_file + ".c"
if not os.path.exists(c_source_file) or \
Utils.file_newer_than(pyx_source_file,
Utils.modification_time(c_source_file)):
print("Compiling module %s ..." % module)
result = compile(pyx_source_file)
c_source_file = result.c_file
if c_source_file:
extensions.append(
Extension(module, sources = [c_source_file])
)
else:
print("Compilation failed")
if extensions:
setup_args['ext_modules'] = extensions
add_command_class("build_ext", build_ext)
except Exception:
print("ERROR: %s" % sys.exc_info()[1])
print("Extension module compilation failed, using plain Python implementation")
try:
sys.argv.remove("--no-cython-compile")
except ValueError:
compile_cython_modules()
setup_args.update(setuptools_extra_args) setup_args.update(setuptools_extra_args)
......
...@@ -11,6 +11,10 @@ cdef extern int (*iapfn())[5] ...@@ -11,6 +11,10 @@ cdef extern int (*iapfn())[5]
cdef extern char *(*cpapfn())[5] cdef extern char *(*cpapfn())[5]
cdef extern int fnargfn(int ()) cdef extern int fnargfn(int ())
cdef extern int ia[]
cdef extern int iaa[][3]
cdef extern int a(int[][3], int[][3][5])
cdef void f(): cdef void f():
cdef void *p=NULL cdef void *p=NULL
global ifnp, cpa global ifnp, cpa
......
This diff is collapsed.
...@@ -111,21 +111,84 @@ def slice_charptr_for_loop_c_enumerate(): ...@@ -111,21 +111,84 @@ def slice_charptr_for_loop_c_enumerate():
############################################################ ############################################################
# tests for int* slicing # tests for int* slicing
## cdef int cints[6] cdef int cints[6]
## for i in range(6): for i in range(6):
## cints[i] = i cints[i] = i
## @cython.test_assert_path_exists("//ForFromStatNode", @cython.test_assert_path_exists("//ForFromStatNode",
## "//ForFromStatNode//IndexNode") "//ForFromStatNode//IndexNode")
## @cython.test_fail_if_path_exists("//ForInStatNode") @cython.test_fail_if_path_exists("//ForInStatNode")
## def slice_intptr_for_loop_c(): def slice_intarray_for_loop_c():
## """ """
## >>> slice_intptr_for_loop_c() >>> slice_intarray_for_loop_c()
## [0, 1, 2] [0, 1, 2]
## [1, 2, 3, 4] [1, 2, 3, 4]
## [4, 5] [4, 5]
## """ """
## cdef int i cdef int i
## print [ i for i in cints[:3] ] print [ i for i in cints[:3] ]
## print [ i for i in cints[1:5] ] print [ i for i in cints[1:5] ]
## print [ i for i in cints[4:6] ] print [ i for i in cints[4:6] ]
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def iter_intarray_for_loop_c():
"""
>>> iter_intarray_for_loop_c()
[0, 1, 2, 3, 4, 5]
"""
cdef int i
print [ i for i in cints ]
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def slice_intptr_for_loop_c():
"""
>>> slice_intptr_for_loop_c()
[0, 1, 2]
[1, 2, 3, 4]
[4, 5]
"""
cdef int* nums = cints
cdef int i
print [ i for i in nums[:3] ]
print [ i for i in nums[1:5] ]
print [ i for i in nums[4:6] ]
############################################################
# tests for slicing other arrays
cdef double cdoubles[6]
for i in range(6):
cdoubles[i] = i + 0.5
cdef double* cdoubles_ptr = cdoubles
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def slice_doublptr_for_loop_c():
"""
>>> slice_doublptr_for_loop_c()
[0.5, 1.5, 2.5]
[1.5, 2.5, 3.5, 4.5]
[4.5, 5.5]
"""
cdef double d
print [ d for d in cdoubles_ptr[:3] ]
print [ d for d in cdoubles_ptr[1:5] ]
print [ d for d in cdoubles_ptr[4:6] ]
@cython.test_assert_path_exists("//ForFromStatNode",
"//ForFromStatNode//IndexNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def iter_doublearray_for_loop_c():
"""
>>> iter_doublearray_for_loop_c()
[0.5, 1.5, 2.5, 3.5, 4.5, 5.5]
"""
cdef double d
print [ d for d in cdoubles ]
...@@ -21,7 +21,7 @@ def test_arithmetic(double complex z, double complex w): ...@@ -21,7 +21,7 @@ def test_arithmetic(double complex z, double complex w):
>>> test_arithmetic(5-10j, 3+4j) >>> test_arithmetic(5-10j, 3+4j)
((5-10j), (-5+10j), (8-6j), (2-14j), (55-10j), (-1-2j)) ((5-10j), (-5+10j), (8-6j), (2-14j), (55-10j), (-1-2j))
""" """
return +z, -z, z+w, z-w, z*w, z/w return +z, -z+0, z+w, z-w, z*w, z/w
@cython.cdivision(False) @cython.cdivision(False)
def test_div_by_zero(double complex z): def test_div_by_zero(double complex z):
......
...@@ -21,11 +21,8 @@ def test_wrap_pair(int i, double x): ...@@ -21,11 +21,8 @@ def test_wrap_pair(int i, double x):
>>> test_wrap_pair(2, 2.25) >>> test_wrap_pair(2, 2.25)
(2, 2.25, True) (2, 2.25, True)
""" """
cdef Pair[int, double] *pair
cdef Wrap[Pair[int, double]] *wrap
try: try:
pair = new Pair[int, double](i, x) wrap = new Wrap[Pair[int, double]](Pair[int, double](i, x))
wrap = new Wrap[Pair[int, double]](deref(pair))
return wrap.get().first(), wrap.get().second(), deref(wrap) == deref(wrap) return wrap.get().first(), wrap.get().second(), deref(wrap) == deref(wrap)
finally: finally:
del pair, wrap del wrap
...@@ -21,7 +21,6 @@ def test_int(int x, int y): ...@@ -21,7 +21,6 @@ def test_int(int x, int y):
>>> test_int(100, 100) >>> test_int(100, 100)
(100, 100, True) (100, 100, True)
""" """
cdef Wrap[int] *a, *b
try: try:
a = new Wrap[int](x) a = new Wrap[int](x)
b = new Wrap[int](0) b = new Wrap[int](0)
...@@ -38,7 +37,6 @@ def test_double(double x, double y): ...@@ -38,7 +37,6 @@ def test_double(double x, double y):
>>> test_double(100, 100) >>> test_double(100, 100)
(100.0, 100.0, True) (100.0, 100.0, True)
""" """
cdef Wrap[double] *a, *b
try: try:
a = new Wrap[double](x) a = new Wrap[double](x)
b = new Wrap[double](-1) b = new Wrap[double](-1)
...@@ -54,7 +52,6 @@ def test_pair(int i, double x): ...@@ -54,7 +52,6 @@ def test_pair(int i, double x):
>>> test_pair(2, 2.25) >>> test_pair(2, 2.25)
(2, 2.25, True, False) (2, 2.25, True, False)
""" """
cdef Pair[int, double] *pair
try: try:
pair = new Pair[int, double](i, x) pair = new Pair[int, double](i, x)
return pair.first(), pair.second(), deref(pair) == deref(pair), deref(pair) != deref(pair) return pair.first(), pair.second(), deref(pair) == deref(pair), deref(pair) != deref(pair)
...@@ -68,7 +65,6 @@ def test_ptr(int i): ...@@ -68,7 +65,6 @@ def test_ptr(int i):
>>> test_ptr(5) >>> test_ptr(5)
5 5
""" """
cdef Wrap[int*] *w
try: try:
w = new Wrap[int*](&i) w = new Wrap[int*](&i)
return deref(w.get()) return deref(w.get())
...@@ -85,7 +81,6 @@ def test_func_ptr(double x): ...@@ -85,7 +81,6 @@ def test_func_ptr(double x):
>>> test_func_ptr(-1.5) >>> test_func_ptr(-1.5)
2.25 2.25
""" """
cdef Wrap[double (*)(double)] *w
try: try:
w = new Wrap[double (*)(double)](&f) w = new Wrap[double (*)(double)](&f)
return w.get()(x) return w.get()(x)
......
cdef extern from "Python.h": cdef extern from "Python.h":
ctypedef class __builtin__.str [object PyStringObject]:
cdef long ob_shash
ctypedef class __builtin__.list [object PyListObject]: ctypedef class __builtin__.list [object PyListObject]:
cdef Py_ssize_t ob_size
cdef Py_ssize_t allocated cdef Py_ssize_t allocated
ctypedef class __builtin__.dict [object PyDictObject]: ctypedef class __builtin__.dict [object PyDictObject]:
pass pass
cdef str s = "abc" cdef Py_ssize_t Py_SIZE(object o)
cdef list L = [1,2,4] cdef list L = [1,2,4]
cdef dict d = {'A': 'a'} cdef dict d = {'A': 'a'}
def test_list(list L): def test_list(list L):
""" """
>>> test_list(list(range(10))) >>> test_list(list(range(10)))
...@@ -23,18 +20,7 @@ def test_list(list L): ...@@ -23,18 +20,7 @@ def test_list(list L):
>>> test_list(list_subclass([1,2,3])) >>> test_list(list_subclass([1,2,3]))
True True
""" """
return L.ob_size <= L.allocated return Py_SIZE(L) <= L.allocated
def test_str(str s):
"""
>>> test_str("abc")
True
>>> class str_subclass(str): pass
>>> test_str(str_subclass("xyz"))
True
"""
cdef char* ss = s
return hash(s) == s.ob_shash
def test_tuple(tuple t): def test_tuple(tuple t):
""" """
......
...@@ -93,13 +93,49 @@ def arithmetic(): ...@@ -93,13 +93,49 @@ def arithmetic():
>>> arithmetic() >>> arithmetic()
""" """
a = 1 + 2 a = 1 + 2
assert typeof(a) == "long" assert typeof(a) == "long", typeof(a)
b = 1 + 1.5 b = 1 + 1.5
assert typeof(b) == "double" assert typeof(b) == "double", typeof(b)
c = 1 + <object>2 c = 1 + <object>2
assert typeof(c) == "Python object" assert typeof(c) == "Python object", typeof(c)
d = "abc %s" % "x" d = 1 * 1.5 ** 2
assert typeof(d) == "Python object" assert typeof(d) == "double", typeof(d)
def builtin_type_operations():
"""
>>> builtin_type_operations()
"""
b1 = b'a' * 10
b1 = 10 * b'a'
b1 = 10 * b'a' * 10
assert typeof(b1) == "bytes object", typeof(b1)
b2 = b'a' + b'b'
assert typeof(b2) == "bytes object", typeof(b2)
u1 = u'a' * 10
u1 = 10 * u'a'
assert typeof(u1) == "unicode object", typeof(u1)
u2 = u'a' + u'b'
assert typeof(u2) == "unicode object", typeof(u2)
u3 = u'a%s' % u'b'
u3 = u'a%s' % 10
assert typeof(u3) == "unicode object", typeof(u3)
s1 = "abc %s" % "x"
s1 = "abc %s" % 10
assert typeof(s1) == "str object", typeof(s1)
s2 = "abc %s" + "x"
assert typeof(s2) == "str object", typeof(s2)
s3 = "abc %s" * 10
s3 = "abc %s" * 10 * 10
s3 = 10 * "abc %s" * 10
assert typeof(s3) == "str object", typeof(s3)
L1 = [] + []
assert typeof(L1) == "list object", typeof(L1)
L2 = [] * 2
assert typeof(L2) == "list object", typeof(L2)
T1 = () + ()
assert typeof(T1) == "tuple object", typeof(T1)
T2 = () * 2
assert typeof(T2) == "tuple object", typeof(T2)
def cascade(): def cascade():
""" """
...@@ -215,10 +251,29 @@ def safe_only(): ...@@ -215,10 +251,29 @@ def safe_only():
""" """
a = 1.0 a = 1.0
assert typeof(a) == "double", typeof(c) assert typeof(a) == "double", typeof(c)
b = 1 b = 1;
assert typeof(b) == "Python object", typeof(b) assert typeof(b) == "long", typeof(b)
c = MyType() c = MyType()
assert typeof(c) == "MyType", typeof(c) assert typeof(c) == "MyType", typeof(c)
for i in range(10): pass
assert typeof(i) == "long", typeof(i)
d = 1
res = ~d
assert typeof(d) == "long", typeof(d)
# potentially overflowing arithmatic
e = 1
e += 1
assert typeof(e) == "Python object", typeof(e)
f = 1
res = f * 10
assert typeof(f) == "Python object", typeof(f)
g = 1
res = 10*(~g)
assert typeof(g) == "Python object", typeof(g)
for j in range(10):
res = -j
assert typeof(j) == "Python object", typeof(j)
@infer_types(None) @infer_types(None)
def args_tuple_keywords(*args, **kwargs): def args_tuple_keywords(*args, **kwargs):
...@@ -249,3 +304,36 @@ def args_tuple_keywords_reassign_pyobjects(*args, **kwargs): ...@@ -249,3 +304,36 @@ def args_tuple_keywords_reassign_pyobjects(*args, **kwargs):
args = [] args = []
kwargs = "test" kwargs = "test"
# / A -> AA -> AAA
# Base0 -> Base -
# \ B -> BB
# C -> CC
cdef class Base0: pass
cdef class Base(Base0): pass
cdef class A(Base): pass
cdef class AA(A): pass
cdef class AAA(AA): pass
cdef class B(Base): pass
cdef class BB(B): pass
cdef class C: pass
cdef class CC(C): pass
@infer_types(None)
def common_extension_type_base():
"""
>>> common_extension_type_base()
"""
x = A()
x = AA()
assert typeof(x) == "A", typeof(x)
y = A()
y = B()
assert typeof(y) == "Base", typeof(y)
z = AAA()
z = BB()
assert typeof(z) == "Base", typeof(z)
w = A()
w = CC()
assert typeof(w) == "Python object", typeof(w)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment