Commit fae8606e authored by Mark Florisson's avatar Mark Florisson

branch merge

parents 5e9d7562 16055821
......@@ -57,6 +57,11 @@ class AutoTestDictTransform(ScopeTrackingTransform):
value = UnicodeNode(pos, value=doctest)
self.tests.append(DictItemNode(pos, key=key, value=value))
def visit_ExprNode(self, node):
# expressions cannot contain functions and lambda expressions
# do not have a docstring
return node
def visit_FuncDefNode(self, node):
if not node.doc:
return node
......
......@@ -7,6 +7,8 @@ from Errors import CompileError
from Code import UtilityCode
import Interpreter
import PyrexTypes
import Naming
import Symtab
try:
set
......
......@@ -1228,7 +1228,7 @@ class CCodeWriter(object):
def put_var_decref(self, entry):
if entry.type.is_pyobject:
if entry.init_to_none is False:
if entry.init_to_none is False: # FIXME: 0 and False are treated differently???
self.putln("__Pyx_XDECREF(%s);" % self.entry_as_pyobject(entry))
else:
self.putln("__Pyx_DECREF(%s);" % self.entry_as_pyobject(entry))
......
......@@ -205,3 +205,11 @@ def release_errors(ignore=False):
def held_errors():
return error_stack[-1]
# this module needs a redesign to support parallel cythonisation, but
# for now, the following works at least in sequential compiler runs
def reset():
_warn_once_seen.clear()
del error_stack[:]
......@@ -2,9 +2,19 @@
# Pyrex - Parse tree nodes for expressions
#
import cython
from cython import set
cython.declare(error=object, warning=object, warn_once=object, InternalError=object,
CompileError=object, UtilityCode=object, StringEncoding=object, operator=object,
Naming=object, Nodes=object, PyrexTypes=object, py_object_type=object,
list_type=object, tuple_type=object, set_type=object, dict_type=object, \
unicode_type=object, str_type=object, bytes_type=object, type_type=object,
Builtin=object, Symtab=object, Utils=object, find_coercion_error=object,
debug_disposal_code=object, debug_temp_alloc=object, debug_coercion=object)
import operator
from Errors import error, warning, warn_once, InternalError
from Errors import error, warning, warn_once, InternalError, CompileError
from Errors import hold_errors, release_errors, held_errors, report_error
from Code import UtilityCode
import StringEncoding
......@@ -21,16 +31,15 @@ import Symtab
import Options
from Cython import Utils
from Annotate import AnnotationItem
from Cython import Utils
from Cython.Debugging import print_call_chain
from DebugFlags import debug_disposal_code, debug_temp_alloc, \
debug_coercion
try:
set
except NameError:
from sets import Set as set
from __builtin__ import basestring
except ImportError:
basestring = str # Python 3
class NotConstant(object):
def __repr__(self):
......@@ -202,8 +211,9 @@ class ExprNode(Node):
_get_child_attrs = operator.attrgetter('subexprs')
except AttributeError:
# Python 2.3
def _get_child_attrs(self):
def __get_child_attrs(self):
return self.subexprs
_get_child_attrs = __get_child_attrs
child_attrs = property(fget=_get_child_attrs)
def not_implemented(self, method_name):
......@@ -1760,6 +1770,9 @@ class IteratorNode(ExprNode):
self.type = self.sequence.type
else:
self.sequence = self.sequence.coerce_to_pyobject(env)
if self.sequence.type is list_type or \
self.sequence.type is tuple_type:
self.sequence = self.sequence.as_none_safe_node("'NoneType' object is not iterable")
self.is_temp = 1
gil_message = "Iterating over Python object"
......@@ -1776,28 +1789,22 @@ class IteratorNode(ExprNode):
raise InternalError("for in carray slice not transformed")
is_builtin_sequence = self.sequence.type is list_type or \
self.sequence.type is tuple_type
may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type
if is_builtin_sequence:
code.putln(
"if (likely(%s != Py_None)) {" % self.sequence.py_result())
elif may_be_a_sequence:
may_be_a_sequence = not self.sequence.type.is_builtin_type
if may_be_a_sequence:
code.putln(
"if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % (
self.sequence.py_result(),
self.sequence.py_result()))
if may_be_a_sequence:
if is_builtin_sequence or may_be_a_sequence:
code.putln(
"%s = 0; %s = %s; __Pyx_INCREF(%s);" % (
self.counter_cname,
self.result(),
self.sequence.py_result(),
self.result()))
if not is_builtin_sequence:
if may_be_a_sequence:
code.putln("} else {")
if is_builtin_sequence:
code.putln(
'PyErr_SetString(PyExc_TypeError, "\'NoneType\' object is not iterable"); %s' %
code.error_goto(self.pos))
else:
code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % (
self.counter_cname,
self.result(),
......@@ -2981,7 +2988,7 @@ class SimpleCallNode(CallNode):
return "<error>"
formal_args = func_type.args
arg_list_code = []
args = zip(formal_args, self.args)
args = list(zip(formal_args, self.args))
max_nargs = len(func_type.args)
expected_nargs = max_nargs - func_type.optional_arg_count
actual_nargs = len(self.args)
......@@ -3026,7 +3033,7 @@ class SimpleCallNode(CallNode):
self.opt_arg_struct,
Naming.pyrex_prefix + "n",
len(self.args) - expected_nargs))
args = zip(func_type.args, self.args)
args = list(zip(func_type.args, self.args))
for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
code.putln("%s.%s = %s;" % (
self.opt_arg_struct,
......@@ -3153,7 +3160,7 @@ class GeneralCallNode(CallNode):
def explicit_args_kwds(self):
if self.starstar_arg or not isinstance(self.positional_args, TupleNode):
raise PostParseError(self.pos,
raise CompileError(self.pos,
'Compile-time keyword arguments must be explicit.')
return self.positional_args.args, self.keyword_args
......@@ -3613,7 +3620,7 @@ class AttributeNode(ExprNode):
interned_attr_cname = code.intern_identifier(self.attribute)
self.obj.generate_evaluation_code(code)
if self.is_py_attr or (isinstance(self.entry.scope, Symtab.PropertyScope)
and self.entry.scope.entries.has_key(u'__del__')):
and u'__del__' in self.entry.scope.entries):
code.put_error_if_neg(self.pos,
'PyObject_DelAttr(%s, %s)' % (
self.obj.py_result(),
......@@ -4125,46 +4132,49 @@ class ScopedExprNode(ExprNode):
subexprs = []
expr_scope = None
def analyse_types(self, env):
# nothing to do here, the children will be analysed separately
# does this node really have a local scope, e.g. does it leak loop
# variables or not? non-leaking Py3 behaviour is default, except
# for list comprehensions where the behaviour differs in Py2 and
# Py3 (set in Parsing.py based on parser context)
has_local_scope = True
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
elif self.has_local_scope:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
else:
self.expr_scope = None
def analyse_declarations(self, env):
self.init_scope(env)
def analyse_scoped_declarations(self, env):
# this is called with the expr_scope as env
pass
def analyse_expressions(self, env):
# nothing to do here, the children will be analysed separately
def analyse_types(self, env):
# no recursion here, the children will be analysed separately below
pass
def analyse_scoped_expressions(self, env):
# this is called with the expr_scope as env
pass
def init_scope(self, outer_scope, expr_scope=None):
self.expr_scope = expr_scope
class ComprehensionNode(ScopedExprNode):
subexprs = ["target"]
child_attrs = ["loop", "append"]
# leak loop variables or not? non-leaking Py3 behaviour is
# default, except for list comprehensions where the behaviour
# differs in Py2 and Py3 (see Parsing.py)
has_local_scope = True
def infer_type(self, env):
return self.target.infer_type(env)
def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop
self.init_scope(env)
self.loop.analyse_declarations(self.expr_scope or env)
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
elif self.has_local_scope:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
else:
self.expr_scope = None
def analyse_scoped_declarations(self, env):
self.loop.analyse_declarations(env)
def analyse_types(self, env):
self.target.analyse_expressions(env)
......@@ -4172,9 +4182,6 @@ class ComprehensionNode(ScopedExprNode):
if not self.has_local_scope:
self.loop.analyse_expressions(env)
def analyse_expressions(self, env):
self.analyse_types(env)
def analyse_scoped_expressions(self, env):
if self.has_local_scope:
self.loop.analyse_expressions(env)
......@@ -4276,20 +4283,16 @@ class GeneratorExpressionNode(ScopedExprNode):
type = py_object_type
def analyse_declarations(self, env):
self.init_scope(env)
self.loop.analyse_declarations(self.expr_scope)
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
else:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
def analyse_scoped_declarations(self, env):
self.loop.analyse_declarations(env)
def analyse_types(self, env):
if not self.has_local_scope:
self.loop.analyse_expressions(env)
self.is_temp = True
def analyse_scoped_expressions(self, env):
if self.has_local_scope:
self.loop.analyse_expressions(env)
def may_be_none(self):
......@@ -4310,14 +4313,29 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
# orig_func String the name of the builtin function this node replaces
child_attrs = ["loop"]
loop_analysed = False
def infer_type(self, env):
return self.result_node.infer_type(env)
def analyse_types(self, env):
if not self.has_local_scope:
self.loop_analysed = True
self.loop.analyse_expressions(env)
self.type = self.result_node.type
self.is_temp = True
def analyse_scoped_expressions(self, env):
self.loop_analysed = True
GeneratorExpressionNode.analyse_scoped_expressions(self, env)
def coerce_to(self, dst_type, env):
if self.orig_func == 'sum' and dst_type.is_numeric:
# we can optimise by dropping the aggregation variable into C
if self.orig_func == 'sum' and dst_type.is_numeric and not self.loop_analysed:
# We can optimise by dropping the aggregation variable and
# the add operations into C. This can only be done safely
# before analysing the loop body, after that, the result
# reference type will have infected expressions and
# assignments.
self.result_node.type = self.type = dst_type
return self
return GeneratorExpressionNode.coerce_to(self, dst_type, env)
......@@ -4838,10 +4856,14 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
class InnerFunctionNode(PyCFunctionNode):
# Special PyCFunctionNode that depends on a closure class
#
binding = True
needs_self_code = True
def self_result_code(self):
return "((PyObject*)%s)" % Naming.cur_scope_cname
if self.needs_self_code:
return "((PyObject*)%s)" % (Naming.cur_scope_cname)
return "NULL"
class LambdaNode(InnerFunctionNode):
# Lambda expression node (only used as a function reference)
......@@ -4859,7 +4881,6 @@ class LambdaNode(InnerFunctionNode):
name = StringEncoding.EncodedString('<lambda>')
def analyse_declarations(self, env):
#self.def_node.needs_closure = self.needs_closure
self.def_node.analyse_declarations(env)
self.pymethdef_cname = self.def_node.entry.pymethdef_cname
env.add_lambda_def(self.def_node)
......@@ -5651,7 +5672,12 @@ class NumBinopNode(BinopNode):
def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2):
return PyrexTypes.widest_numeric_type(type1, type2)
widest_type = PyrexTypes.widest_numeric_type(type1, type2)
if widest_type is PyrexTypes.c_bint_type:
if self.operator not in '|^&':
# False + False == 0 # not False!
widest_type = PyrexTypes.c_int_type
return widest_type
else:
return None
......
......@@ -10,5 +10,6 @@ unicode_literals = _get_feature("unicode_literals")
with_statement = _get_feature("with_statement")
division = _get_feature("division")
print_function = _get_feature("print_function")
nested_scopes = _get_feature("nested_scopes") # dummy
del _get_feature
......@@ -138,7 +138,6 @@ class Context(object):
WithTransform(self),
DecoratorTransform(self),
AnalyseDeclarationsTransform(self),
CreateClosureClasses(self),
AutoTestDictTransform(self),
EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), ## Necessary?
......@@ -148,6 +147,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self),
_check_c_declarations,
AnalyseExpressionsTransform(self),
CreateClosureClasses(self), ## After all lookups and type inference
ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(),
......@@ -523,6 +523,7 @@ class Context(object):
return ".".join(names)
def setup_errors(self, options, result):
Errors.reset() # clear any remaining error state
if options.use_listing_file:
result.listing_file = Utils.replace_suffix(source, ".lis")
path = result.listing_file
......
......@@ -2,15 +2,16 @@
# Pyrex - Module parse tree node
#
import cython
from cython import set
cython.declare(Naming=object, Options=object, PyrexTypes=object, TypeSlots=object,
error=object, warning=object, py_object_type=object, UtilityCode=object,
escape_byte_string=object, EncodedString=object)
import os, time
from PyrexTypes import CPtrType
import Future
try:
set
except NameError: # Python 2.3
from sets import Set as set
import Annotate
import Code
import Naming
......@@ -666,6 +667,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("#include <math.h>")
code.putln("#define %s" % Naming.api_guard_prefix + self.api_name(env))
self.generate_includes(env, cimported_modules, code)
code.putln("")
code.putln("#ifdef PYREX_WITHOUT_ASSERTIONS")
code.putln("#define CYTHON_WITHOUT_ASSERTIONS")
code.putln("#endif")
code.putln("")
if env.directives['ccomplex']:
code.putln("")
code.putln("#if !defined(CYTHON_CCOMPLEX)")
......@@ -1673,20 +1679,20 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("if (!(%s)) %s;" % (
entry.type.type_test_code("o"),
code.error_goto(entry.pos)))
code.put_var_decref(entry)
code.putln("Py_INCREF(o);")
code.put_decref(entry.cname, entry.type, nanny=False)
code.putln("%s = %s;" % (
entry.cname,
PyrexTypes.typecast(entry.type, py_object_type, "o")))
elif entry.type.from_py_function:
rhs = "%s(o)" % entry.type.from_py_function
if entry.type.is_enum:
rhs = typecast(entry.type, c_long_type, rhs)
rhs = PyrexTypes.typecast(entry.type, PyrexTypes.c_long_type, rhs)
code.putln("%s = %s; if (%s) %s;" % (
entry.cname,
rhs,
entry.type.error_condition(entry.cname),
code.error_goto(entry.pos)))
code.putln("Py_DECREF(o);")
else:
code.putln('PyErr_Format(PyExc_TypeError, "Cannot convert Python object %s to %s");' % (name, entry.type))
code.putln(code.error_goto(entry.pos))
......@@ -1695,12 +1701,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("if (PyObject_SetAttr(%s, py_name, o) < 0) goto bad;" % Naming.module_cname)
code.putln("}")
code.putln("return 0;")
if code.label_used(code.error_label):
code.put_label(code.error_label)
# This helps locate the offending name.
code.putln('__Pyx_AddTraceback("%s");' % self.full_module_name);
code.error_label = old_error_label
code.putln("bad:")
code.putln("Py_DECREF(o);")
code.putln("return -1;")
code.putln("}")
code.putln(import_star_utility_code)
......
......@@ -3,15 +3,17 @@
# Pyrex - Parse tree nodes
#
import sys, os, time, copy
import cython
from cython import set
cython.declare(sys=object, os=object, time=object, copy=object,
Builtin=object, error=object, warning=object, Naming=object, PyrexTypes=object,
py_object_type=object, ModuleScope=object, LocalScope=object, ClosureScope=object, \
StructOrUnionScope=object, PyClassScope=object, CClassScope=object,
CppClassScope=object, UtilityCode=object, EncodedString=object,
absolute_path_length=cython.Py_ssize_t)
try:
set
except NameError:
# Python 2.3
from sets import Set as set
import sys, os, time, copy
import Code
import Builtin
from Errors import error, warning, InternalError
import Naming
......@@ -241,7 +243,7 @@ class Node(object):
if encountered is None:
encountered = set()
if id(self) in encountered:
return "<%s (%d) -- already output>" % (self.__class__.__name__, id(self))
return "<%s (0x%x) -- already output>" % (self.__class__.__name__, id(self))
encountered.add(id(self))
def dump_child(x, level):
......@@ -253,12 +255,12 @@ class Node(object):
return repr(x)
attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out]
attrs = [(key, value) for key, value in self.__dict__.items() if key not in filter_out]
if len(attrs) == 0:
return "<%s (%d)>" % (self.__class__.__name__, id(self))
return "<%s (0x%x)>" % (self.__class__.__name__, id(self))
else:
indent = " " * level
res = "<%s (%d)\n" % (self.__class__.__name__, id(self))
res = "<%s (0x%x)\n" % (self.__class__.__name__, id(self))
for key, value in attrs:
res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1))
res += "%s>" % indent
......@@ -858,7 +860,7 @@ class TemplatedTypeNode(CBaseTypeNode):
if sys.version_info[0] < 3:
# Py 2.x enforces byte strings as keyword arguments ...
options = dict([ (name.encode('ASCII'), value)
for name, value in options.iteritems() ])
for name, value in options.items() ])
self.type = PyrexTypes.BufferType(base_type, **options)
......@@ -949,7 +951,7 @@ class CVarDefNode(StatNode):
entry.directive_locals = self.directive_locals
else:
if self.directive_locals:
s.error("Decorators can only be followed by functions")
error(self.pos, "Decorators can only be followed by functions")
if self.in_pxd and self.visibility != 'extern':
error(self.pos,
"Only 'extern' C variable declaration allowed in .pxd file")
......@@ -1146,11 +1148,13 @@ class FuncDefNode(StatNode, BlockNode):
# #filename string C name of filename string const
# entry Symtab.Entry
# needs_closure boolean Whether or not this function has inner functions/classes/yield
# needs_outer_scope boolean Whether or not this function requires outer scope
# directive_locals { string : NameNode } locals defined by cython.locals(...)
py_func = None
assmt = None
needs_closure = False
needs_outer_scope = False
modifiers = []
def analyse_default_values(self, env):
......@@ -1198,7 +1202,7 @@ class FuncDefNode(StatNode, BlockNode):
import Buffer
lenv = self.local_scope
if lenv.is_closure_scope:
if lenv.is_closure_scope and not lenv.is_passthrough:
outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname,
Naming.outer_scope_cname)
else:
......@@ -1259,10 +1263,13 @@ class FuncDefNode(StatNode, BlockNode):
cenv = env
while cenv.is_py_class_scope or cenv.is_c_class_scope:
cenv = cenv.outer_scope
if lenv.is_closure_scope:
if self.needs_closure:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
elif self.needs_outer_scope:
if lenv.is_passthrough:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
elif cenv.is_closure_scope:
code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname))
code.putln(";")
self.generate_argument_declarations(lenv, code)
......@@ -1314,12 +1321,14 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("}")
code.put_gotref(Naming.cur_scope_cname)
# Note that it is unsafe to decref the scope at this point.
if cenv.is_closure_scope:
if self.needs_outer_scope:
code.putln("%s = (%s)%s;" % (
outer_scope_cname,
cenv.scope_class.type.declaration_code(''),
Naming.self_cname))
if self.needs_closure:
if lenv.is_passthrough:
code.putln("%s = %s;" % (Naming.cur_scope_cname, outer_scope_cname));
elif self.needs_closure:
# inner closures own a reference to their outer parent
code.put_incref(outer_scope_cname, cenv.scope_class.type)
code.put_giveref(outer_scope_cname)
......@@ -2206,6 +2215,8 @@ class DefNode(FuncDefNode):
def needs_assignment_synthesis(self, env, code=None):
# Should enable for module level as well, that will require more testing...
if self.entry.is_lambda:
return True
if env.is_module_scope:
if code is None:
return env.directives['binding']
......@@ -3208,7 +3219,7 @@ class CClassDefNode(ClassDefNode):
api = self.api,
buffer_defaults = buffer_defaults)
if home_scope is not env and self.visibility == 'extern':
env.add_imported_entry(self.class_name, self.entry, pos)
env.add_imported_entry(self.class_name, self.entry, self.pos)
self.scope = scope = self.entry.type.scope
if scope is not None:
scope.directives = env.directives
......@@ -3376,7 +3387,7 @@ class SingleAssignmentNode(AssignmentNode):
if func_name in ['declare', 'typedef']:
if len(args) > 2 or kwds is not None:
error(rhs.pos, "Can only declare one type at a time.")
error(self.rhs.pos, "Can only declare one type at a time.")
return
type = args[0].analyse_as_type(env)
if type is None:
......@@ -3407,7 +3418,7 @@ class SingleAssignmentNode(AssignmentNode):
elif func_name in ['struct', 'union']:
self.declaration_only = True
if len(args) > 0 or kwds is None:
error(rhs.pos, "Struct or union members must be given by name.")
error(self.rhs.pos, "Struct or union members must be given by name.")
return
members = []
for member, type_node in kwds.key_value_pairs:
......@@ -3991,7 +4002,7 @@ class AssertStatNode(StatNode):
gil_message = "Raising exception"
def generate_execution_code(self, code):
code.putln("#ifndef PYREX_WITHOUT_ASSERTIONS")
code.putln("#ifndef CYTHON_WITHOUT_ASSERTIONS")
self.cond.generate_evaluation_code(code)
code.putln(
"if (unlikely(!%s)) {" %
......
import cython
from cython import set
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
UtilNodes=object, Naming=object)
import Nodes
import ExprNodes
import PyrexTypes
......@@ -17,14 +24,14 @@ from ParseTreeTransforms import SkipDeclarations
import codecs
try:
reduce
except NameError:
from __builtin__ import reduce
except ImportError:
from functools import reduce
try:
set
except NameError:
from sets import Set as set
from __builtin__ import basestring
except ImportError:
basestring = str # Python 3
class FakePythonEnv(object):
"A fake environment for creating type test nodes etc."
......@@ -749,7 +756,7 @@ class SwitchTransform(Visitor.VisitorTransform):
def extract_in_string_conditions(self, string_literal):
if isinstance(string_literal, ExprNodes.UnicodeNode):
charvals = map(ord, set(string_literal.value))
charvals = list(map(ord, set(string_literal.value)))
charvals.sort()
return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
constant_result=charval)
......@@ -1332,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
"""
if len(pos_args) not in (1,2):
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
ExprNodes.ComprehensionNode)):
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None:
return node
else: # ComprehensionNode
yield_stat_node = gen_expr_node.append
yield_expression = yield_stat_node.expr
try:
if not yield_expression.is_literal or not yield_expression.type.is_int:
return node
except AttributeError:
return node # in case we don't have a type yet
# special case: old Py2 backwards compatible "sum([int_const for ...])"
# can safely be unpacked into a genexpr
if len(pos_args) == 1:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
......@@ -1368,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
has_local_scope = gen_expr_node.has_local_scope)
def _handle_simple_function_min(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '<')
......@@ -1383,7 +1403,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# leave this to Python
return node
cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
last_result = args[0]
for arg_node in cascaded_nodes:
......@@ -1827,7 +1847,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
# Note: this requires the float() function to be typed as
# returning a C 'double'
if len(pos_args) == 0:
return ExprNode.FloatNode(
return ExprNodes.FloatNode(
node, value="0.0", constant_result=0.0
).coerce_to(Builtin.float_type, self.current_env())
elif len(pos_args) != 1:
......@@ -1860,8 +1880,12 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
return node
else:
return pos_args[0].coerce_to_boolean(
self.current_env()).coerce_to_pyobject(self.current_env())
# => !!<bint>(x) to make sure it's exactly 0 or 1
operand = pos_args[0].coerce_to_boolean(self.current_env())
operand = ExprNodes.NotNode(node.pos, operand = operand)
operand = ExprNodes.NotNode(node.pos, operand = operand)
# coerce back to Python object as that's the result we are expecting
return operand.coerce_to_pyobject(self.current_env())
### builtin functions
......@@ -2931,7 +2955,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
# check if all children are constant
children = self.visitchildren(node)
for child_result in children.itervalues():
for child_result in children.values():
if type(child_result) is list:
for child in child_result:
if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
......@@ -2966,12 +2990,23 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
self._calculate_const(node)
return node
def visit_UnaryMinusNode(self, node):
def visit_UnopNode(self, node):
self._calculate_const(node)
if node.constant_result is ExprNodes.not_a_constant:
return node
if not node.operand.is_literal:
return node
if isinstance(node.operand, ExprNodes.BoolNode):
return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
type = PyrexTypes.c_int_type,
constant_result = node.constant_result)
if node.operator == '+':
return self._handle_UnaryPlusNode(node)
elif node.operator == '-':
return self._handle_UnaryMinusNode(node)
return node
def _handle_UnaryMinusNode(self, node):
if isinstance(node.operand, ExprNodes.LongNode):
return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
constant_result = node.constant_result)
......@@ -2988,10 +3023,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
constant_result = node.constant_result)
return node
def visit_UnaryPlusNode(self, node):
self._calculate_const(node)
if node.constant_result is ExprNodes.not_a_constant:
return node
def _handle_UnaryPlusNode(self, node):
if node.constant_result == node.operand.constant_result:
return node.operand
return node
......@@ -3017,12 +3049,13 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return node
if isinstance(node.constant_result, float):
return node
if not node.operand1.is_literal or not node.operand2.is_literal:
operand1, operand2 = node.operand1, node.operand2
if not operand1.is_literal or not operand2.is_literal:
return node
# now inject a new constant node with the calculated value
try:
type1, type2 = node.operand1.type, node.operand2.type
type1, type2 = operand1.type, operand2.type
if type1 is None or type2 is None:
return node
except AttributeError:
......@@ -3032,14 +3065,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
widest_type = PyrexTypes.widest_numeric_type(type1, type2)
else:
widest_type = PyrexTypes.py_object_type
target_class = self._widest_node_class(node.operand1, node.operand2)
target_class = self._widest_node_class(operand1, operand2)
if target_class is None:
return node
elif target_class is ExprNodes.IntNode:
unsigned = getattr(node.operand1, 'unsigned', '') and \
getattr(node.operand2, 'unsigned', '')
longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')),
len(getattr(node.operand2, 'longness', '')))]
unsigned = getattr(operand1, 'unsigned', '') and \
getattr(operand2, 'unsigned', '')
longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
len(getattr(operand2, 'longness', '')))]
new_node = ExprNodes.IntNode(pos=node.pos,
unsigned = unsigned, longness = longness,
value = str(node.constant_result),
......
cimport cython
from Cython.Compiler.Visitor cimport (
CythonTransform, VisitorTransform, TreeVisitor,
ScopeTrackingTransform, EnvTransform)
cdef class NameNodeCollector(TreeVisitor):
cdef list name_nodes
cdef class SkipDeclarations: # (object):
pass
cdef class NormalizeTree(CythonTransform):
cdef bint is_in_statlist
cdef bint is_in_expr
cpdef visit_StatNode(self, node, is_listcontainer=*)
cdef class PostParse(ScopeTrackingTransform):
cdef dict specialattribute_handlers
cdef size_t lambda_counter
cdef _visit_assignment_node(self, node, list expr_list)
#def eliminate_rhs_duplicates(list expr_list_list, list ref_node_sequence)
#def sort_common_subsequences(list items)
@cython.locals(starred_targets=Py_ssize_t, lhs_size=Py_ssize_t, rhs_size=Py_ssize_t)
cdef flatten_parallel_assignments(list input, list output)
cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs_args, list rhs_args)
#class PxdPostParse(CythonTransform, SkipDeclarations):
#class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
#class WithTransform(CythonTransform, SkipDeclarations):
#class DecoratorTransform(CythonTransform, SkipDeclarations):
#class AnalyseDeclarationsTransform(CythonTransform):
cdef class AnalyseExpressionsTransform(CythonTransform):
pass
cdef class ExpandInplaceOperators(EnvTransform):
pass
cdef class AlignFunctionDefinitions(CythonTransform):
cdef dict directives
cdef scope
cdef class MarkClosureVisitor(CythonTransform):
cdef bint needs_closure
cdef class CreateClosureClasses(CythonTransform):
cdef list path
cdef bint in_lambda
cdef module_scope
cdef class GilCheck(VisitorTransform):
cdef list env_stack
cdef bint nogil
cdef class TransformBuiltinMethods(EnvTransform):
cdef visit_cython_attribute(self, node)
import cython
cython.declare(copy=object, ModuleNode=object, TreeFragment=object, TemplateTransform=object,
EncodedString=object, error=object, warning=object, PyrexTypes=object, Naming=object)
from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
from Cython.Compiler.ModuleNode import ModuleNode
......@@ -6,8 +11,9 @@ from Cython.Compiler.ExprNodes import *
from Cython.Compiler.UtilNodes import *
from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import error, CompileError
from Cython.Compiler import PyrexTypes
from Cython.Compiler.Errors import error, warning, CompileError
from Cython.Compiler import PyrexTypes, Naming
try:
set
......@@ -25,11 +31,12 @@ class NameNodeCollector(TreeVisitor):
super(NameNodeCollector, self).__init__()
self.name_nodes = []
visit_Node = TreeVisitor.visitchildren
def visit_NameNode(self, node):
self.name_nodes.append(node)
def visit_Node(self, node):
self._visitchildren(node, None)
class SkipDeclarations(object):
"""
......@@ -180,9 +187,6 @@ class PostParse(ScopeTrackingTransform):
def visit_LambdaNode(self, node):
# unpack a lambda expression into the corresponding DefNode
if self.scope_type != 'function':
error(node.pos,
"lambda functions are currently only supported in functions")
lambda_id = self.lambda_counter
self.lambda_counter += 1
node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
......@@ -244,8 +248,10 @@ class PostParse(ScopeTrackingTransform):
# Split parallel assignments (a,b = b,a) into separate partial
# assignments that are executed rhs-first using temps. This
# optimisation is best applied before type analysis so that known
# types on rhs and lhs can be matched directly.
# restructuring must be applied before type analysis so that known
# types on rhs and lhs can be matched directly. It is required in
# the case that the types cannot be coerced to a Python type in
# order to assign from a tuple.
def visit_SingleAssignmentNode(self, node):
self.visitchildren(node)
......@@ -300,7 +306,7 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
and appends them to ref_node_sequence. The input list is modified
in-place.
"""
seen_nodes = set()
seen_nodes = cython.set()
ref_nodes = {}
def find_duplicates(node):
if node.is_literal or node.is_name:
......@@ -328,13 +334,13 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
if node in ref_nodes:
return ref_nodes[node]
elif node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args)
node.args = list(map(substitute_nodes, node.args))
return node
# replace nodes inside of the common subexpressions
for node in ref_nodes:
if node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args)
node.args = list(map(substitute_nodes, node.args))
# replace common subexpressions on all rhs items
for expr_list in expr_list_list:
......@@ -342,10 +348,15 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
def sort_common_subsequences(items):
"""Sort items/subsequences so that all items and subsequences that
an item contains appear before the item itself. This implies a
partial order, and the sort must be stable to preserve the
original order as much as possible, so we use a simple insertion
sort.
an item contains appear before the item itself. This is needed
because each rhs item must only be evaluated once, so its value
must be evaluated first and then reused when packing sequences
that contain it.
This implies a partial order, and the sort must be stable to
preserve the original order as much as possible, so we use a
simple insertion sort (which is very fast for short sequences, the
normal case in practice).
"""
def contains(seq, x):
for item in seq:
......@@ -358,8 +369,8 @@ def sort_common_subsequences(items):
return b.is_sequence_constructor and contains(b.args, a)
for pos, item in enumerate(items):
key = item[1] # the ResultRefNode which has already been injected into the sequences
new_pos = pos
key = item[0]
for i in xrange(pos-1, -1, -1):
if lower_than(key, items[i][0]):
new_pos = i
......@@ -566,16 +577,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
'operator.comma' : c_binop_constructor(','),
}
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL']
+ unop_method_nodes.keys())
special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL'])
special_methods.update(unop_method_nodes.keys())
def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context)
self.compilation_directive_defaults = {}
for key, value in compilation_directive_defaults.iteritems():
for key, value in compilation_directive_defaults.items():
self.compilation_directive_defaults[unicode(key)] = value
self.cython_module_names = set()
self.cython_module_names = cython.set()
self.directive_names = {}
def check_directive_scope(self, pos, directive, scope):
......@@ -589,7 +600,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
# Set up processing and handle the cython: comments.
def visit_ModuleNode(self, node):
for key, value in node.directive_comments.iteritems():
for key, value in node.directive_comments.items():
if not self.check_directive_scope(node.pos, key, 'module'):
self.wrong_scope_error(node.pos, key, 'module')
del node.directive_comments[key]
......@@ -1017,7 +1028,7 @@ property NAME:
return node
def visit_ModuleNode(self, node):
self.seen_vars_stack.append(set())
self.seen_vars_stack.append(cython.set())
node.analyse_declarations(self.env_stack[-1])
self.visitchildren(node)
self.seen_vars_stack.pop()
......@@ -1049,7 +1060,7 @@ property NAME:
return node
def visit_FuncDefNode(self, node):
self.seen_vars_stack.append(set())
self.seen_vars_stack.append(cython.set())
lenv = node.local_scope
node.body.analyse_control_flow(lenv) # this will be totally refactored
node.declare_arguments(lenv)
......@@ -1068,15 +1079,18 @@ property NAME:
return node
def visit_ScopedExprNode(self, node):
node.analyse_declarations(self.env_stack[-1])
env = self.env_stack[-1]
node.analyse_declarations(env)
# the node may or may not have a local scope
if node.expr_scope:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
if node.has_local_scope:
self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
self.env_stack.append(node.expr_scope)
node.analyse_scoped_declarations(node.expr_scope)
self.visitchildren(node)
self.env_stack.pop()
self.seen_vars_stack.pop()
else:
node.analyse_scoped_declarations(env)
self.visitchildren(node)
return node
......@@ -1172,7 +1186,7 @@ class AnalyseExpressionsTransform(CythonTransform):
return node
def visit_ScopedExprNode(self, node):
if node.expr_scope is not None:
if node.has_local_scope:
node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node)
......@@ -1290,7 +1304,10 @@ class AlignFunctionDefinitions(CythonTransform):
class MarkClosureVisitor(CythonTransform):
needs_closure = False
def visit_ModuleNode(self, node):
self.needs_closure = False
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.needs_closure = False
......@@ -1320,16 +1337,58 @@ class MarkClosureVisitor(CythonTransform):
class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions
# that need it.
# that really need it.
def __init__(self, context):
super(CreateClosureClasses, self).__init__(context)
self.path = []
self.in_lambda = False
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.visitchildren(node)
return node
def create_class_from_scope(self, node, target_module_scope):
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
def get_scope_use(self, node):
from_closure = []
in_closure = []
for name, entry in node.local_scope.entries.items():
if entry.from_closure:
from_closure.append((name, entry))
elif entry.in_closure and not entry.from_closure:
in_closure.append((name, entry))
return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None):
from_closure, in_closure = self.get_scope_use(node)
in_closure.sort()
# Now from the begining
node.needs_closure = False
node.needs_outer_scope = False
func_scope = node.local_scope
cscope = node.entry.scope
while cscope.is_py_class_scope or cscope.is_c_class_scope:
cscope = cscope.outer_scope
if not from_closure and (self.path or inner_node):
if not inner_node:
if not node.assmt:
raise InternalError, "DefNode does not have assignment node"
inner_node = node.assmt.rhs
inner_node.needs_self_code = False
node.needs_outer_scope = False
# Simple cases
if not in_closure and not from_closure:
return
elif not in_closure:
func_scope.is_passthrough = True
func_scope.scope_class = cscope.scope_class
node.needs_outer_scope = True
return
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
......@@ -1338,34 +1397,41 @@ class CreateClosureClasses(CythonTransform):
class_scope.is_internal = True
class_scope.directives = {'final': True}
cscope = node.entry.scope
while cscope.is_py_class_scope or cscope.is_c_class_scope:
cscope = cscope.outer_scope
if cscope.is_closure_scope:
if from_closure:
assert cscope.is_closure_scope
class_scope.declare_var(pos=node.pos,
name=Naming.outer_scope_cname, # this could conflict?
name=Naming.outer_scope_cname,
cname=Naming.outer_scope_cname,
type=cscope.scope_class.type,
is_cdef=True)
entries = func_scope.entries.items()
entries.sort()
for name, entry in entries:
# This is wasteful--we should do this later when we know
# which vars are actually being used inside...
#
# Also, this happens before type inference and type
# analysis, so the entries created here may end up having
# incorrect or at least unspecified types.
node.needs_outer_scope = True
for name, entry in in_closure:
class_scope.declare_var(pos=entry.pos,
name=entry.name,
cname=entry.cname,
type=entry.type,
is_cdef=True)
node.needs_closure = True
# Do it here because other classes are already checked
target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node):
was_in_lambda = self.in_lambda
self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node)
self.visitchildren(node)
self.in_lambda = was_in_lambda
return node
def visit_FuncDefNode(self, node):
if node.needs_closure:
if self.in_lambda:
self.visitchildren(node)
return node
if node.needs_closure or self.path:
self.create_class_from_scope(node, self.module_scope)
self.path.append(node)
self.visitchildren(node)
self.path.pop()
return node
......
......@@ -75,6 +75,7 @@ class Entry(object):
# is_cfunction boolean Is a C function
# is_cmethod boolean Is a C method of an extension type
# is_unbound_cmethod boolean Is an unbound C method of an extension type
# is_lambda boolean Is a lambda function
# is_type boolean Is a type definition
# is_cclass boolean Is an extension class
# is_cpp_class boolean Is a C++ class
......@@ -137,6 +138,7 @@ class Entry(object):
is_cfunction = 0
is_cmethod = 0
is_unbound_cmethod = 0
is_lambda = 0
is_type = 0
is_cclass = 0
is_cpp_class = 0
......@@ -211,7 +213,8 @@ class Scope(object):
# return_type PyrexType or None Return type of function owning scope
# is_py_class_scope boolean Is a Python class scope
# is_c_class_scope boolean Is an extension type scope
# is_closure_scope boolean
# is_closure_scope boolean Is a closure scope
# is_passthrough boolean Outer scope is passed directly
# is_cpp_class_scope boolean Is a C++ class scope
# is_property_scope boolean Is a extension type property scope
# scope_prefix string Disambiguator for C names
......@@ -228,6 +231,7 @@ class Scope(object):
is_py_class_scope = 0
is_c_class_scope = 0
is_closure_scope = 0
is_passthrough = 0
is_cpp_class_scope = 0
is_property_scope = 0
is_module_scope = 0
......@@ -528,7 +532,7 @@ class Scope(object):
entry.name = EncodedString(func_cname)
entry.func_cname = func_cname
entry.signature = pyfunction_signature
self.pyfunc_entries.append(entry)
entry.is_lambda = True
return entry
def add_lambda_def(self, def_node):
......@@ -1122,6 +1126,29 @@ class ModuleScope(Scope):
if not entry.type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % entry.name)
def check_c_class(self, entry):
type = entry.type
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
def check_c_classes(self):
# Performs post-analysis checking and finishing up of extension types
# being implemented in this module. This is called only for the main
......@@ -1144,27 +1171,7 @@ class ModuleScope(Scope):
print("...entry %s %s" % (entry.name, entry))
print("......type = ", entry.type)
print("......visibility = ", entry.visibility)
type = entry.type
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
self.check_c_class(entry)
def check_c_functions(self):
# Performs post-analysis checking making sure all
......@@ -1253,6 +1260,8 @@ class LocalScope(Scope):
entry = Scope.lookup(self, name)
if entry is not None:
if entry.scope is not self and entry.scope.is_closure_scope:
if hasattr(entry.scope, "scope_class"):
raise InternalError, "lookup() after scope class created."
# The actual c fragment for the different scopes differs
# on the outside and inside, so we make a new entry
entry.in_closure = True
......@@ -1270,6 +1279,9 @@ class LocalScope(Scope):
for entry in self.entries.values():
if entry.from_closure:
cname = entry.outer_entry.cname
if self.is_passthrough:
entry.cname = cname
else:
if cname.startswith(Naming.cur_scope_cname):
cname = cname[len(Naming.cur_scope_cname)+2:]
entry.cname = "%s->%s" % (outer_scope_cname, cname)
......@@ -1277,20 +1289,19 @@ class LocalScope(Scope):
entry.original_cname = entry.cname
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class GeneratorExpressionScope(LocalScope):
class GeneratorExpressionScope(Scope):
"""Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all
we really need is a scope that holds the loop variable(s).
"""
def __init__(self, outer_scope):
name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref)
LocalScope.__init__(self, name, outer_scope)
Scope.__init__(self, name, outer_scope, outer_scope)
self.directives = outer_scope.directives
self.genexp_prefix = "%s%d%s" % (Naming.pyrex_prefix, len(name), name)
def mangle(self, prefix, name):
return '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(self, prefix, name))
return '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(self, prefix, name))
def declare_var(self, name, type, pos,
cname = None, visibility = 'private', is_cdef = True):
......@@ -1299,10 +1310,10 @@ class GeneratorExpressionScope(LocalScope):
outer_entry = self.outer_scope.lookup(name)
if outer_entry and outer_entry.is_variable:
type = outer_entry.type # may still be 'unspecified_type' !
# the outer scope needs to generate code for the variable, but
# the parent scope needs to generate code for the variable, but
# this scope must hold its name exclusively
cname = '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(Naming.var_prefix, name))
entry = self.outer_scope.declare_var(None, type, pos, cname, visibility, is_cdef = True)
cname = '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(Naming.var_prefix, name))
entry = self.parent_scope.declare_var(None, type, pos, cname, visibility, is_cdef = True)
self.entries[name] = entry
return entry
......
......@@ -225,8 +225,6 @@ class SimpleAssignmentTypeInferer(object):
for entry in scope.entries.values():
if entry.type is unspecified_type:
entry.type = py_object_type
if scope.is_closure_scope:
fix_closure_entries(scope)
return
dependancies_by_entry = {} # entry -> dependancies
......@@ -288,19 +286,6 @@ class SimpleAssignmentTypeInferer(object):
entry.type = py_object_type
if verbose:
message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
#if scope.is_closure_scope:
# fix_closure_entries(scope)
def fix_closure_entries(scope):
"""Temporary work-around to fix field types in the closure class
that were unknown at the time of creation and only determined
during type inference.
"""
closure_entries = scope.scope_class.type.scope.entries
for name, entry in scope.entries.iteritems():
if name in closure_entries:
closure_entry = closure_entries[name]
closure_entry.type = entry.type
def find_spanning_type(type1, type2):
if type1 is type2:
......
......@@ -141,6 +141,9 @@ class ResultRefNode(AtomicExprNode):
def infer_type(self, env):
if self.expression is not None:
return self.expression.infer_type(env)
if self.type is not None:
return self.type
assert False, "cannot infer type of ResultRefNode"
def may_be_none(self):
if not self.type.is_pyobject:
......
__version__ = "0.13"
__version__ = "0.13+"
# Void cython.* directives (for case insensitive operating systems).
from Cython.Shadow import *
......@@ -56,14 +56,13 @@ EXT_DEP_INCLUDES = [
VER_DEP_MODULES = {
# tests are excluded if 'CurrentPythonVersion OP VersionTuple', i.e.
# (2,4) : (operator.le, ...) excludes ... when PyVer <= 2.4.x
# (2,4) : (operator.lt, ...) excludes ... when PyVer < 2.4.x
(2,4) : (operator.lt, lambda x: x in ['run.extern_builtins_T258',
'run.builtin_sorted'
]),
(2,5) : (operator.lt, lambda x: x in ['run.any',
'run.all',
]),
(2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258'
]),
(2,4) : (operator.lt, lambda x: x in ['run.builtin_sorted'
]),
(2,6) : (operator.lt, lambda x: x in ['run.print_function',
'run.cython3',
]),
......
......@@ -84,7 +84,7 @@ else:
else:
scripts = ["cython.py", "cygdb.py"]
def compile_cython_modules(profile=False):
def compile_cython_modules(profile=False, compile_more=False, cython_with_refnanny=False):
source_root = os.path.abspath(os.path.dirname(__file__))
compiled_modules = ["Cython.Plex.Scanners",
"Cython.Plex.Actions",
......@@ -92,8 +92,20 @@ def compile_cython_modules(profile=False):
"Cython.Compiler.Parsing",
"Cython.Compiler.Visitor",
"Cython.Runtime.refnanny"]
extensions = []
if compile_more:
compiled_modules.extend([
"Cython.Compiler.ParseTreeTransforms",
"Cython.Compiler.Nodes",
"Cython.Compiler.ExprNodes",
"Cython.Compiler.ModuleNode",
"Cython.Compiler.Optimize",
])
defines = []
if cython_with_refnanny:
defines.append(('CYTHON_REFNANNY', '1'))
extensions = []
if sys.version_info[0] >= 3:
from Cython.Distutils import build_ext as build_ext_orig
for module in compiled_modules:
......@@ -105,8 +117,13 @@ def compile_cython_modules(profile=False):
dep_files = []
if os.path.exists(source_file + '.pxd'):
dep_files.append(source_file + '.pxd')
if '.refnanny' in module:
defines_for_module = []
else:
defines_for_module = defines
extensions.append(
Extension(module, sources = [pyx_source_file],
define_macros = defines_for_module,
depends = dep_files)
)
......@@ -181,8 +198,13 @@ def compile_cython_modules(profile=False):
if filename_encoding is None:
filename_encoding = sys.getdefaultencoding()
c_source_file = c_source_file.encode(filename_encoding)
if '.refnanny' in module:
defines_for_module = []
else:
defines_for_module = defines
extensions.append(
Extension(module, sources = [c_source_file])
Extension(module, sources = [c_source_file],
define_macros = defines_for_module)
)
else:
print("Compilation failed")
......@@ -204,10 +226,22 @@ cython_profile = '--cython-profile' in sys.argv
if cython_profile:
sys.argv.remove('--cython-profile')
try:
sys.argv.remove("--cython-compile-all")
cython_compile_more = True
except ValueError:
cython_compile_more = False
try:
sys.argv.remove("--cython-with-refnanny")
cython_with_refnanny = True
except ValueError:
cython_with_refnanny = False
try:
sys.argv.remove("--no-cython-compile")
except ValueError:
compile_cython_modules(cython_profile)
compile_cython_modules(cython_profile, cython_compile_more, cython_with_refnanny)
setup_args.update(setuptools_extra_args)
......
......@@ -7,7 +7,6 @@ numpy_ValueError_T172
unsignedbehaviour_T184
missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408
cascaded_list_unpacking_T467
compile.cpp_operators
cpp_templated_ctypedef
cpp_structs
......@@ -17,6 +16,8 @@ function_as_method_T494
closure_inside_cdef_T554
ipow_crash_T562
pure_mode_cmethod_inheritance_T583
genexpr_iterable_lookup_T600
for_from_pyvar_loop_T601
# CPython regression tests that don't current work:
pyregr.test_threadsignals
......
......@@ -10,10 +10,10 @@ all_tests_run() is executed which does final validation.
>>> items.sort()
>>> for key, value in items:
... print('%s ; %s' % (key, value))
MyCdefClass.cpdef_method (line 76) ; >>> add_log("cpdef class method")
MyCdefClass.method (line 73) ; >>> add_log("cdef class method")
MyClass.method (line 62) ; >>> add_log("class method")
mycpdeffunc (line 49) ; >>> add_log("cpdef")
MyCdefClass.cpdef_method (line 77) ; >>> add_log("cpdef class method")
MyCdefClass.method (line 74) ; >>> add_log("cdef class method")
MyClass.method (line 63) ; >>> add_log("class method")
mycpdeffunc (line 50) ; >>> add_log("cpdef")
myfunc (line 40) ; >>> add_log("def")
"""
......@@ -39,6 +39,7 @@ def add_log(s):
def myfunc():
""">>> add_log("def")"""
x = lambda a:1 # no docstring here ...
def doc_without_test():
"""Some docs"""
......
......@@ -5,4 +5,13 @@ def test():
True
"""
cdef int x = 5
print bool(x)
return bool(x)
def test_bool_and_int():
"""
>>> test_bool_and_int()
1
"""
cdef int x = 5
cdef int b = bool(x)
return b
......@@ -7,26 +7,65 @@ def simple_parallel_assignment_from_call():
cdef int ai, bi
cdef long al, bl
cdef object ao, bo
cdef int side_effect_count = call_count
reset()
ai, bi = al, bl = ao, bo = c = d = [intval(1), intval(2)]
side_effect_count = call_count - side_effect_count
return side_effect_count, ao, bo, ai, bi, al, bl, c, d
return call_count, ao, bo, ai, bi, al, bl, c, d
def recursive_parallel_assignment_from_call():
def recursive_parallel_assignment_from_call_left():
"""
>>> recursive_parallel_assignment_from_call()
>>> recursive_parallel_assignment_from_call_left()
(3, 1, 2, 3, 1, 2, 3, (1, 2), 3, [(1, 2), 3])
"""
cdef int ai, bi, ci
cdef object ao, bo, co
cdef int side_effect_count = call_count
reset()
(ai, bi), ci = (ao, bo), co = t,o = d = [(intval(1), intval(2)), intval(3)]
side_effect_count = call_count - side_effect_count
return side_effect_count, ao, bo, co, ai, bi, ci, t, o, d
return call_count, ao, bo, co, ai, bi, ci, t, o, d
def recursive_parallel_assignment_from_call_right():
"""
>>> recursive_parallel_assignment_from_call_right()
(3, 1, 2, 3, 1, 2, 3, 1, (2, 3), [1, (2, 3)])
"""
cdef int ai, bi, ci
cdef object ao, bo, co
reset()
ai, (bi, ci) = ao, (bo, co) = o,t = d = [intval(1), (intval(2), intval(3))]
return call_count, ao, bo, co, ai, bi, ci, o, t, d
def recursive_parallel_assignment_from_call_left_reversed():
"""
>>> recursive_parallel_assignment_from_call_left_reversed()
(3, 1, 2, 3, 1, 2, 3, (1, 2), 3, [(1, 2), 3])
"""
cdef int ai, bi, ci
cdef object ao, bo, co
reset()
d = t,o = (ao, bo), co = (ai, bi), ci = [(intval(1), intval(2)), intval(3)]
return call_count, ao, bo, co, ai, bi, ci, t, o, d
def recursive_parallel_assignment_from_call_right_reversed():
"""
>>> recursive_parallel_assignment_from_call_right_reversed()
(3, 1, 2, 3, 1, 2, 3, 1, (2, 3), [1, (2, 3)])
"""
cdef int ai, bi, ci
cdef object ao, bo, co
reset()
d = o,t = ao, (bo, co) = ai, (bi, ci) = [intval(1), (intval(2), intval(3))]
return call_count, ao, bo, co, ai, bi, ci, o, t, d
cdef int call_count = 0
cdef int next_expected_arg = 1
cdef reset():
global call_count, next_expected_arg
call_count = 0
next_expected_arg = 1
cdef int intval(int x):
global call_count
cdef int intval(int x) except -1:
global call_count, next_expected_arg
call_count += 1
assert next_expected_arg == x, "calls not in source code order: expected %d, found %d" % (next_expected_arg, x)
next_expected_arg += 1
return x
__doc__ = u"""
>>> f = add_n(3)
>>> f(2)
5
>>> f = add_n(1000000)
>>> f(1000000), f(-1000000)
(2000000, 0)
>>> a(5)()
8
>>> local_x(1)(2)(4)
4 2 1
15
# this currently crashes Cython due to redefinition
#>>> x(1)(2)(4)
#15
>>> x2(1)(2)(4)
4 2 1
15
>>> inner_override(2,4)()
5
>>> reassign(4)(2)
3
>>> reassign_int(4)(2)
3
>>> reassign_int_int(4)(2)
3
>>> def py_twofuncs(x):
... def f(a):
... return g(x) + a
... def g(b):
... return x + b
... return f
>>> py_twofuncs(1)(2) == cy_twofuncs(1)(2)
True
>>> py_twofuncs(3)(5) == cy_twofuncs(3)(5)
True
>>> inner_funcs = more_inner_funcs(1)(2,4,8)
>>> inner_funcs[0](16), inner_funcs[1](32), inner_funcs[2](64)
(19, 37, 73)
>>> switch_funcs([1,2,3], [4,5,6], 0)([10])
[1, 2, 3, 10]
>>> switch_funcs([1,2,3], [4,5,6], 1)([10])
[4, 5, 6, 10]
>>> switch_funcs([1,2,3], [4,5,6], 2) is None
True
>>> call_ignore_func()
"""
cimport cython
def add_n(int n):
"""
>>> f = add_n(3)
>>> f(2)
5
>>> f = add_n(1000000)
>>> f(1000000), f(-1000000)
(2000000, 0)
"""
def f(int x):
return x+n
return f
def a(int x):
"""
>>> a(5)()
8
"""
def b():
def c():
return 3+x
......@@ -74,6 +27,11 @@ def a(int x):
return b
def local_x(int arg_x):
"""
>>> local_x(1)(2)(4)
4 2 1
15
"""
cdef int local_x = arg_x
def y(arg_y):
y = arg_y
......@@ -84,15 +42,23 @@ def local_x(int arg_x):
return z
return y
# currently crashes Cython due to name redefinitions (see local_x())
## def x(int x):
## def y(y):
## def z(long z):
## return 8+z+y+x
## return z
## return y
def x(int x):
"""
>>> x(1)(2)(4)
15
"""
def y(y):
def z(long z):
return 8+z+y+x
return z
return y
def x2(int x2):
"""
>>> x2(1)(2)(4)
4 2 1
15
"""
def y2(y2):
def z2(long z2):
print z2, y2, x2
......@@ -102,6 +68,10 @@ def x2(int x2):
def inner_override(a,b):
"""
>>> inner_override(2,4)()
5
"""
def f():
a = 1
return a+b
......@@ -109,18 +79,30 @@ def inner_override(a,b):
def reassign(x):
"""
>>> reassign(4)(2)
3
"""
def f(a):
return a+x
x = 1
return f
def reassign_int(x):
"""
>>> reassign_int(4)(2)
3
"""
def f(int a):
return a+x
x = 1
return f
def reassign_int_int(int x):
"""
>>> reassign_int_int(4)(2)
3
"""
def f(int a):
return a+x
x = 1
......@@ -128,6 +110,19 @@ def reassign_int_int(int x):
def cy_twofuncs(x):
"""
>>> def py_twofuncs(x):
... def f(a):
... return g(x) + a
... def g(b):
... return x + b
... return f
>>> py_twofuncs(1)(2) == cy_twofuncs(1)(2)
True
>>> py_twofuncs(3)(5) == cy_twofuncs(3)(5)
True
"""
def f(a):
return g(x) + a
def g(b):
......@@ -135,6 +130,14 @@ def cy_twofuncs(x):
return f
def switch_funcs(a, b, int ix):
"""
>>> switch_funcs([1,2,3], [4,5,6], 0)([10])
[1, 2, 3, 10]
>>> switch_funcs([1,2,3], [4,5,6], 1)([10])
[4, 5, 6, 10]
>>> switch_funcs([1,2,3], [4,5,6], 2) is None
True
"""
def f(x):
return a + x
def g(x):
......@@ -152,9 +155,17 @@ def ignore_func(x):
return None
def call_ignore_func():
"""
>>> call_ignore_func()
"""
ignore_func((1,2,3))
def more_inner_funcs(x):
"""
>>> inner_funcs = more_inner_funcs(1)(2,4,8)
>>> inner_funcs[0](16), inner_funcs[1](32), inner_funcs[2](64)
(19, 37, 73)
"""
# called with x==1
def f(a):
def g(b):
......@@ -175,3 +186,45 @@ def more_inner_funcs(x):
# called with (2,4,8)
return f(a_f), g(b_g), h(b_h)
return resolve
@cython.test_assert_path_exists("//DefNode//DefNode//DefNode//DefNode",
"//DefNode[@needs_outer_scope = False]", # deep_inner()
"//DefNode//DefNode//DefNode//DefNode[@needs_closure = False]", # h()
)
@cython.test_fail_if_path_exists("//DefNode//DefNode[@needs_outer_scope = False]")
def deep_inner():
"""
>>> deep_inner()()
2
"""
cdef int x = 1
def f():
def g():
def h():
return x+1
return h
return g()
return f()
@cython.test_assert_path_exists("//DefNode//DefNode//DefNode",
"//DefNode//DefNode//DefNode[@needs_outer_scope = False]", # a()
"//DefNode//DefNode//DefNode[@needs_closure = False]", # a(), g(), h()
)
@cython.test_fail_if_path_exists("//DefNode//DefNode//DefNode[@needs_closure = True]") # a(), g(), h()
def deep_inner_sibling():
"""
>>> deep_inner_sibling()()
2
"""
cdef int x = 1
def f():
def a():
return 1
def g():
return x+a()
def h():
return g()
return h
return f()
cdef extern from "cpp_namespaces_helper.h" namespace "A":
ctypedef int A_t
A_t A_func(A_t first, A_t)
cdef void f(A_t)
cdef extern from "cpp_namespaces_helper.h" namespace "outer":
int outer_value
......@@ -26,3 +27,9 @@ def test_nested():
print outer_value
print inner_value
def test_typedef(A_t a):
"""
>>> test_typedef(3)
3
"""
return a
......@@ -76,6 +76,26 @@ def list_comp():
assert x == 'abc' # don't leak in Py3 code
return result
module_level_lc = [ module_level_loopvar*2 for module_level_loopvar in range(4) ]
def list_comp_module_level():
"""
>>> module_level_lc
[0, 2, 4, 6]
>>> module_level_loopvar
Traceback (most recent call last):
NameError: name 'module_level_loopvar' is not defined
"""
module_level_list_genexp = list(module_level_genexp_loopvar*2 for module_level_genexp_loopvar in range(4))
def genexpr_module_level():
"""
>>> module_level_list_genexp
[0, 2, 4, 6]
>>> module_level_genexp_loopvar
Traceback (most recent call last):
NameError: name 'module_level_genexp_loopvar' is not defined
"""
def list_comp_unknown_type(l):
"""
>>> list_comp_unknown_type(range(5))
......
cdef unsigned long size2():
return 3
def for_from_plain_ulong():
"""
>>> for_from_plain_ulong()
0
1
2
"""
cdef object j = 0
for j from 0 <= j < size2():
print j
def for_in_plain_ulong():
"""
>>> for_in_plain_ulong()
0
1
2
"""
cdef object j = 0
for j in range(size2()):
print j
cdef extern from "for_from_pyvar_loop_T601_extern_def.h":
ctypedef unsigned long Ulong
cdef Ulong size():
return 3
def for_from_ctypedef_ulong():
"""
>>> for_from_ctypedef_ulong()
0
1
2
"""
cdef object j = 0
for j from 0 <= j < size():
print j
def for_in_ctypedef_ulong():
"""
>>> for_in_ctypedef_ulong()
0
1
2
"""
cdef object j = 0
for j in range(size()):
print j
cimport cython
@cython.test_assert_path_exists('//ComprehensionNode')
@cython.test_fail_if_path_exists('//SimpleCallNode')
def list_genexpr_iterable_lookup():
"""
>>> x = (0,1,2,3,4,5)
>>> [ x*2 for x in x if x % 2 == 0 ] # leaks in Py2 but finds the right 'x'
[0, 4, 8]
>>> list_genexpr_iterable_lookup()
[0, 4, 8]
"""
x = (0,1,2,3,4,5)
result = list( x*2 for x in x if x % 2 == 0 )
assert x == (0,1,2,3,4,5)
return result
@cython.test_assert_path_exists('//ComprehensionNode')
@cython.test_fail_if_path_exists('//SingleAssignmentNode//SimpleCallNode')
def genexpr_iterable_in_closure():
"""
>>> genexpr_iterable_in_closure()
[0, 4, 8]
"""
x = 'abc'
def f():
return x
result = list( x*2 for x in x if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
cdef object executable, version_info
cdef long hexversion
from sys import *
def test_cdefed_objects():
"""
>>> ex, vi = test_cdefed_objects()
>>> assert ex is not None
>>> assert vi is not None
"""
return executable, version_info
def test_cdefed_cvalues():
"""
>>> hexver = test_cdefed_cvalues()
>>> assert hexver is not None
>>> assert hexver > 0x02020000
"""
return hexversion
def test_non_cdefed_names():
"""
>>> mod, pth = test_non_cdefed_names()
>>> assert mod is not None
>>> assert pth is not None
"""
return modules, path
......@@ -149,6 +149,46 @@ def return_typed_sum_squares_start(seq, int start):
return <int>sum((i*i for i in seq), start)
@cython.test_assert_path_exists('//ForInStatNode',
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode')
def return_sum_of_listcomp_consts_start(seq, int start):
"""
>>> sum([1 for i in range(10) if i > 3], -1)
5
>>> return_sum_of_listcomp_consts_start(range(10), -1)
5
>>> print(sum([1 for i in range(10000) if i > 3], 9))
10005
>>> print(return_sum_of_listcomp_consts_start(range(10000), 9))
10005
"""
return sum([1 for i in seq if i > 3], start)
@cython.test_assert_path_exists('//ForInStatNode',
"//InlinedGeneratorExpressionNode",
# the next test is for a deficiency
# (see InlinedGeneratorExpressionNode.coerce_to()),
# hope this breaks one day
"//CoerceFromPyTypeNode//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode')
def return_typed_sum_of_listcomp_consts_start(seq, int start):
"""
>>> sum([1 for i in range(10) if i > 3], -1)
5
>>> return_typed_sum_of_listcomp_consts_start(range(10), -1)
5
>>> print(sum([1 for i in range(10000) if i > 3], 9))
10005
>>> print(return_typed_sum_of_listcomp_consts_start(range(10000), 9))
10005
"""
return <int>sum([1 for i in seq if i > 3], start)
@cython.test_assert_path_exists(
'//ForInStatNode',
"//InlinedGeneratorExpressionNode")
......
# Module scope lambda functions
__doc__ = """
>>> pow2(16)
256
>>> with_closure(0)
0
>>> typed_lambda(1)(2)
3
>>> typed_lambda(1.5)(1.5)
2
>>> cdef_const_lambda()
123
>>> const_lambda()
321
"""
pow2 = lambda x: x * x
with_closure = lambda x:(lambda: x)()
typed_lambda = lambda int x : (lambda int y: x + y)
cdef int xxx = 123
cdef_const_lambda = lambda: xxx
yyy = 321
const_lambda = lambda: yyy
# cython: language_level=3
def list_comp_in_closure():
"""
>>> list_comp_in_closure()
[0, 4, 8]
"""
x = 'abc'
def f():
return x
result = [x*2 for x in range(5) if x % 2 == 0]
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_list_comp_in_closure():
"""
>>> pytyped_list_comp_in_closure()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
result = [x*2 for x in range(5) if x % 2 == 0]
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_list_comp_in_closure_repeated():
"""
>>> pytyped_list_comp_in_closure_repeated()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
for i in range(3):
result = [x*2 for x in range(5) if x % 2 == 0]
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def genexpr_in_closure():
"""
>>> genexpr_in_closure()
[0, 4, 8]
"""
x = 'abc'
def f():
return x
result = list( x*2 for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_genexpr_in_closure():
"""
>>> pytyped_genexpr_in_closure()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
result = list( x*2 for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_genexpr_in_closure_repeated():
"""
>>> pytyped_genexpr_in_closure_repeated()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
for i in range(3):
result = list( x*2 for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def genexpr_scope_in_closure():
"""
>>> genexpr_scope_in_closure()
[0, 4, 8]
"""
i = 2
x = 'abc'
def f():
return i, x
result = list( x*i for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == (2,'abc') # don't leak in Py3 code
return result
# tests copied from test/test_bool.py in Py2.7
cdef assertEqual(a,b):
assert a == b, '%r != %r' % (a,b)
cdef assertIs(a,b):
assert a is b, '%r is not %r' % (a,b)
cdef assertIsNot(a,b):
assert a is not b, '%r is %r' % (a,b)
cdef assertNotIsInstance(a,b):
assert not isinstance(a,b), 'isinstance(%r, %s)' % (a,b)
def test_int():
"""
>>> test_int()
"""
assertEqual(int(False), 0)
assertIsNot(int(False), False)
assertEqual(int(True), 1)
assertIsNot(int(True), True)
def test_float():
"""
>>> test_float()
"""
assertEqual(float(False), 0.0)
assertIsNot(float(False), False)
assertEqual(float(True), 1.0)
assertIsNot(float(True), True)
def test_repr():
"""
>>> test_repr()
"""
assertEqual(repr(False), 'False')
assertEqual(repr(True), 'True')
assertEqual(eval(repr(False)), False)
assertEqual(eval(repr(True)), True)
def test_str():
"""
>>> test_str()
"""
assertEqual(str(False), 'False')
assertEqual(str(True), 'True')
def test_math():
"""
>>> test_math()
"""
assertEqual(+False, 0)
assertIsNot(+False, False)
assertEqual(-False, 0)
assertIsNot(-False, False)
assertEqual(abs(False), 0)
assertIsNot(abs(False), False)
assertEqual(+True, 1)
assertIsNot(+True, True)
assertEqual(-True, -1)
assertEqual(abs(True), 1)
assertIsNot(abs(True), True)
assertEqual(~False, -1)
assertEqual(~True, -2)
assertEqual(False+2, 2)
assertEqual(True+2, 3)
assertEqual(2+False, 2)
assertEqual(2+True, 3)
assertEqual(False+False, 0)
assertIsNot(False+False, False)
assertEqual(False+True, 1)
assertIsNot(False+True, True)
assertEqual(True+False, 1)
assertIsNot(True+False, True)
assertEqual(True+True, 2)
assertEqual(True-True, 0)
assertIsNot(True-True, False)
assertEqual(False-False, 0)
assertIsNot(False-False, False)
assertEqual(True-False, 1)
assertIsNot(True-False, True)
assertEqual(False-True, -1)
assertEqual(True*1, 1)
assertEqual(False*1, 0)
assertIsNot(False*1, False)
assertEqual(True/1, 1)
assertIsNot(True/1, True)
assertEqual(False/1, 0)
assertIsNot(False/1, False)
for b in False, True:
for i in 0, 1, 2:
assertEqual(b**i, int(b)**i)
assertIsNot(b**i, bool(int(b)**i))
for a in False, True:
for b in False, True:
assertIs(a&b, bool(int(a)&int(b)))
assertIs(a|b, bool(int(a)|int(b)))
assertIs(a^b, bool(int(a)^int(b)))
assertEqual(a&int(b), int(a)&int(b))
assertIsNot(a&int(b), bool(int(a)&int(b)))
assertEqual(a|int(b), int(a)|int(b))
assertIsNot(a|int(b), bool(int(a)|int(b)))
assertEqual(a^int(b), int(a)^int(b))
assertIsNot(a^int(b), bool(int(a)^int(b)))
assertEqual(int(a)&b, int(a)&int(b))
assertIsNot(int(a)&b, bool(int(a)&int(b)))
assertEqual(int(a)|b, int(a)|int(b))
assertIsNot(int(a)|b, bool(int(a)|int(b)))
assertEqual(int(a)^b, int(a)^int(b))
assertIsNot(int(a)^b, bool(int(a)^int(b)))
assertIs(1==1, True)
assertIs(1==0, False)
assertIs(0<1, True)
assertIs(1<0, False)
assertIs(0<=0, True)
assertIs(1<=0, False)
assertIs(1>0, True)
assertIs(1>1, False)
assertIs(1>=1, True)
assertIs(0>=1, False)
assertIs(0!=1, True)
assertIs(0!=0, False)
x = [1]
assertIs(x is x, True)
assertIs(x is not x, False)
assertIs(1 in x, True)
assertIs(0 in x, False)
assertIs(1 not in x, False)
assertIs(0 not in x, True)
x = {1: 2}
assertIs(x is x, True)
assertIs(x is not x, False)
assertIs(1 in x, True)
assertIs(0 in x, False)
assertIs(1 not in x, False)
assertIs(0 not in x, True)
assertIs(not True, False)
assertIs(not False, True)
def test_convert():
"""
>>> test_convert()
"""
assertIs(bool(10), True)
assertIs(bool(1), True)
assertIs(bool(-1), True)
assertIs(bool(0), False)
assertIs(bool("hello"), True)
assertIs(bool(""), False)
assertIs(bool(), False)
def test_isinstance():
"""
>>> test_isinstance()
"""
assertIs(isinstance(True, bool), True)
assertIs(isinstance(False, bool), True)
assertIs(isinstance(True, int), True)
assertIs(isinstance(False, int), True)
assertIs(isinstance(1, bool), False)
assertIs(isinstance(0, bool), False)
def test_issubclass():
"""
>>> test_issubclass()
"""
assertIs(issubclass(bool, int), True)
assertIs(issubclass(int, bool), False)
def test_boolean():
"""
>>> test_boolean()
"""
assertEqual(True & 1, 1)
assertNotIsInstance(True & 1, bool)
assertIs(True & True, True)
assertEqual(True | 1, 1)
assertNotIsInstance(True | 1, bool)
assertIs(True | True, True)
assertEqual(True ^ 1, 0)
assertNotIsInstance(True ^ 1, bool)
assertIs(True ^ True, False)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment