Commit fae8606e authored by Mark Florisson's avatar Mark Florisson

branch merge

parents 5e9d7562 16055821
...@@ -57,6 +57,11 @@ class AutoTestDictTransform(ScopeTrackingTransform): ...@@ -57,6 +57,11 @@ class AutoTestDictTransform(ScopeTrackingTransform):
value = UnicodeNode(pos, value=doctest) value = UnicodeNode(pos, value=doctest)
self.tests.append(DictItemNode(pos, key=key, value=value)) self.tests.append(DictItemNode(pos, key=key, value=value))
def visit_ExprNode(self, node):
# expressions cannot contain functions and lambda expressions
# do not have a docstring
return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
if not node.doc: if not node.doc:
return node return node
......
...@@ -7,6 +7,8 @@ from Errors import CompileError ...@@ -7,6 +7,8 @@ from Errors import CompileError
from Code import UtilityCode from Code import UtilityCode
import Interpreter import Interpreter
import PyrexTypes import PyrexTypes
import Naming
import Symtab
try: try:
set set
......
...@@ -1228,7 +1228,7 @@ class CCodeWriter(object): ...@@ -1228,7 +1228,7 @@ class CCodeWriter(object):
def put_var_decref(self, entry): def put_var_decref(self, entry):
if entry.type.is_pyobject: if entry.type.is_pyobject:
if entry.init_to_none is False: if entry.init_to_none is False: # FIXME: 0 and False are treated differently???
self.putln("__Pyx_XDECREF(%s);" % self.entry_as_pyobject(entry)) self.putln("__Pyx_XDECREF(%s);" % self.entry_as_pyobject(entry))
else: else:
self.putln("__Pyx_DECREF(%s);" % self.entry_as_pyobject(entry)) self.putln("__Pyx_DECREF(%s);" % self.entry_as_pyobject(entry))
......
...@@ -205,3 +205,11 @@ def release_errors(ignore=False): ...@@ -205,3 +205,11 @@ def release_errors(ignore=False):
def held_errors(): def held_errors():
return error_stack[-1] return error_stack[-1]
# this module needs a redesign to support parallel cythonisation, but
# for now, the following works at least in sequential compiler runs
def reset():
_warn_once_seen.clear()
del error_stack[:]
...@@ -2,9 +2,19 @@ ...@@ -2,9 +2,19 @@
# Pyrex - Parse tree nodes for expressions # Pyrex - Parse tree nodes for expressions
# #
import cython
from cython import set
cython.declare(error=object, warning=object, warn_once=object, InternalError=object,
CompileError=object, UtilityCode=object, StringEncoding=object, operator=object,
Naming=object, Nodes=object, PyrexTypes=object, py_object_type=object,
list_type=object, tuple_type=object, set_type=object, dict_type=object, \
unicode_type=object, str_type=object, bytes_type=object, type_type=object,
Builtin=object, Symtab=object, Utils=object, find_coercion_error=object,
debug_disposal_code=object, debug_temp_alloc=object, debug_coercion=object)
import operator import operator
from Errors import error, warning, warn_once, InternalError from Errors import error, warning, warn_once, InternalError, CompileError
from Errors import hold_errors, release_errors, held_errors, report_error from Errors import hold_errors, release_errors, held_errors, report_error
from Code import UtilityCode from Code import UtilityCode
import StringEncoding import StringEncoding
...@@ -21,16 +31,15 @@ import Symtab ...@@ -21,16 +31,15 @@ import Symtab
import Options import Options
from Cython import Utils from Cython import Utils
from Annotate import AnnotationItem from Annotate import AnnotationItem
from Cython import Utils
from Cython.Debugging import print_call_chain from Cython.Debugging import print_call_chain
from DebugFlags import debug_disposal_code, debug_temp_alloc, \ from DebugFlags import debug_disposal_code, debug_temp_alloc, \
debug_coercion debug_coercion
try: try:
set from __builtin__ import basestring
except NameError: except ImportError:
from sets import Set as set basestring = str # Python 3
class NotConstant(object): class NotConstant(object):
def __repr__(self): def __repr__(self):
...@@ -202,8 +211,9 @@ class ExprNode(Node): ...@@ -202,8 +211,9 @@ class ExprNode(Node):
_get_child_attrs = operator.attrgetter('subexprs') _get_child_attrs = operator.attrgetter('subexprs')
except AttributeError: except AttributeError:
# Python 2.3 # Python 2.3
def _get_child_attrs(self): def __get_child_attrs(self):
return self.subexprs return self.subexprs
_get_child_attrs = __get_child_attrs
child_attrs = property(fget=_get_child_attrs) child_attrs = property(fget=_get_child_attrs)
def not_implemented(self, method_name): def not_implemented(self, method_name):
...@@ -1760,6 +1770,9 @@ class IteratorNode(ExprNode): ...@@ -1760,6 +1770,9 @@ class IteratorNode(ExprNode):
self.type = self.sequence.type self.type = self.sequence.type
else: else:
self.sequence = self.sequence.coerce_to_pyobject(env) self.sequence = self.sequence.coerce_to_pyobject(env)
if self.sequence.type is list_type or \
self.sequence.type is tuple_type:
self.sequence = self.sequence.as_none_safe_node("'NoneType' object is not iterable")
self.is_temp = 1 self.is_temp = 1
gil_message = "Iterating over Python object" gil_message = "Iterating over Python object"
...@@ -1776,36 +1789,30 @@ class IteratorNode(ExprNode): ...@@ -1776,36 +1789,30 @@ class IteratorNode(ExprNode):
raise InternalError("for in carray slice not transformed") raise InternalError("for in carray slice not transformed")
is_builtin_sequence = self.sequence.type is list_type or \ is_builtin_sequence = self.sequence.type is list_type or \
self.sequence.type is tuple_type self.sequence.type is tuple_type
may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type may_be_a_sequence = not self.sequence.type.is_builtin_type
if is_builtin_sequence: if may_be_a_sequence:
code.putln(
"if (likely(%s != Py_None)) {" % self.sequence.py_result())
elif may_be_a_sequence:
code.putln( code.putln(
"if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % ( "if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % (
self.sequence.py_result(), self.sequence.py_result(),
self.sequence.py_result())) self.sequence.py_result()))
if may_be_a_sequence: if is_builtin_sequence or may_be_a_sequence:
code.putln( code.putln(
"%s = 0; %s = %s; __Pyx_INCREF(%s);" % ( "%s = 0; %s = %s; __Pyx_INCREF(%s);" % (
self.counter_cname, self.counter_cname,
self.result(), self.result(),
self.sequence.py_result(), self.sequence.py_result(),
self.result())) self.result()))
code.putln("} else {") if not is_builtin_sequence:
if is_builtin_sequence: if may_be_a_sequence:
code.putln( code.putln("} else {")
'PyErr_SetString(PyExc_TypeError, "\'NoneType\' object is not iterable"); %s' %
code.error_goto(self.pos))
else:
code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % ( code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % (
self.counter_cname, self.counter_cname,
self.result(), self.result(),
self.sequence.py_result(), self.sequence.py_result(),
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
if may_be_a_sequence: if may_be_a_sequence:
code.putln("}") code.putln("}")
class NextNode(AtomicExprNode): class NextNode(AtomicExprNode):
...@@ -2981,7 +2988,7 @@ class SimpleCallNode(CallNode): ...@@ -2981,7 +2988,7 @@ class SimpleCallNode(CallNode):
return "<error>" return "<error>"
formal_args = func_type.args formal_args = func_type.args
arg_list_code = [] arg_list_code = []
args = zip(formal_args, self.args) args = list(zip(formal_args, self.args))
max_nargs = len(func_type.args) max_nargs = len(func_type.args)
expected_nargs = max_nargs - func_type.optional_arg_count expected_nargs = max_nargs - func_type.optional_arg_count
actual_nargs = len(self.args) actual_nargs = len(self.args)
...@@ -3026,7 +3033,7 @@ class SimpleCallNode(CallNode): ...@@ -3026,7 +3033,7 @@ class SimpleCallNode(CallNode):
self.opt_arg_struct, self.opt_arg_struct,
Naming.pyrex_prefix + "n", Naming.pyrex_prefix + "n",
len(self.args) - expected_nargs)) len(self.args) - expected_nargs))
args = zip(func_type.args, self.args) args = list(zip(func_type.args, self.args))
for formal_arg, actual_arg in args[expected_nargs:actual_nargs]: for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
code.putln("%s.%s = %s;" % ( code.putln("%s.%s = %s;" % (
self.opt_arg_struct, self.opt_arg_struct,
...@@ -3153,7 +3160,7 @@ class GeneralCallNode(CallNode): ...@@ -3153,7 +3160,7 @@ class GeneralCallNode(CallNode):
def explicit_args_kwds(self): def explicit_args_kwds(self):
if self.starstar_arg or not isinstance(self.positional_args, TupleNode): if self.starstar_arg or not isinstance(self.positional_args, TupleNode):
raise PostParseError(self.pos, raise CompileError(self.pos,
'Compile-time keyword arguments must be explicit.') 'Compile-time keyword arguments must be explicit.')
return self.positional_args.args, self.keyword_args return self.positional_args.args, self.keyword_args
...@@ -3613,7 +3620,7 @@ class AttributeNode(ExprNode): ...@@ -3613,7 +3620,7 @@ class AttributeNode(ExprNode):
interned_attr_cname = code.intern_identifier(self.attribute) interned_attr_cname = code.intern_identifier(self.attribute)
self.obj.generate_evaluation_code(code) self.obj.generate_evaluation_code(code)
if self.is_py_attr or (isinstance(self.entry.scope, Symtab.PropertyScope) if self.is_py_attr or (isinstance(self.entry.scope, Symtab.PropertyScope)
and self.entry.scope.entries.has_key(u'__del__')): and u'__del__' in self.entry.scope.entries):
code.put_error_if_neg(self.pos, code.put_error_if_neg(self.pos,
'PyObject_DelAttr(%s, %s)' % ( 'PyObject_DelAttr(%s, %s)' % (
self.obj.py_result(), self.obj.py_result(),
...@@ -4125,46 +4132,49 @@ class ScopedExprNode(ExprNode): ...@@ -4125,46 +4132,49 @@ class ScopedExprNode(ExprNode):
subexprs = [] subexprs = []
expr_scope = None expr_scope = None
def analyse_types(self, env): # does this node really have a local scope, e.g. does it leak loop
# nothing to do here, the children will be analysed separately # variables or not? non-leaking Py3 behaviour is default, except
# for list comprehensions where the behaviour differs in Py2 and
# Py3 (set in Parsing.py based on parser context)
has_local_scope = True
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
elif self.has_local_scope:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
else:
self.expr_scope = None
def analyse_declarations(self, env):
self.init_scope(env)
def analyse_scoped_declarations(self, env):
# this is called with the expr_scope as env
pass pass
def analyse_expressions(self, env): def analyse_types(self, env):
# nothing to do here, the children will be analysed separately # no recursion here, the children will be analysed separately below
pass pass
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
# this is called with the expr_scope as env # this is called with the expr_scope as env
pass pass
def init_scope(self, outer_scope, expr_scope=None):
self.expr_scope = expr_scope
class ComprehensionNode(ScopedExprNode): class ComprehensionNode(ScopedExprNode):
subexprs = ["target"] subexprs = ["target"]
child_attrs = ["loop", "append"] child_attrs = ["loop", "append"]
# leak loop variables or not? non-leaking Py3 behaviour is
# default, except for list comprehensions where the behaviour
# differs in Py2 and Py3 (see Parsing.py)
has_local_scope = True
def infer_type(self, env): def infer_type(self, env):
return self.target.infer_type(env) return self.target.infer_type(env)
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop self.append.target = self # this is used in the PyList_Append of the inner loop
self.init_scope(env) self.init_scope(env)
self.loop.analyse_declarations(self.expr_scope or env)
def init_scope(self, outer_scope, expr_scope=None): def analyse_scoped_declarations(self, env):
if expr_scope is not None: self.loop.analyse_declarations(env)
self.expr_scope = expr_scope
elif self.has_local_scope:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
else:
self.expr_scope = None
def analyse_types(self, env): def analyse_types(self, env):
self.target.analyse_expressions(env) self.target.analyse_expressions(env)
...@@ -4172,9 +4182,6 @@ class ComprehensionNode(ScopedExprNode): ...@@ -4172,9 +4182,6 @@ class ComprehensionNode(ScopedExprNode):
if not self.has_local_scope: if not self.has_local_scope:
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
def analyse_expressions(self, env):
self.analyse_types(env)
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
if self.has_local_scope: if self.has_local_scope:
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
...@@ -4276,21 +4283,17 @@ class GeneratorExpressionNode(ScopedExprNode): ...@@ -4276,21 +4283,17 @@ class GeneratorExpressionNode(ScopedExprNode):
type = py_object_type type = py_object_type
def analyse_declarations(self, env): def analyse_scoped_declarations(self, env):
self.init_scope(env) self.loop.analyse_declarations(env)
self.loop.analyse_declarations(self.expr_scope)
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
else:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
def analyse_types(self, env): def analyse_types(self, env):
if not self.has_local_scope:
self.loop.analyse_expressions(env)
self.is_temp = True self.is_temp = True
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
self.loop.analyse_expressions(env) if self.has_local_scope:
self.loop.analyse_expressions(env)
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -4310,14 +4313,29 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode): ...@@ -4310,14 +4313,29 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
# orig_func String the name of the builtin function this node replaces # orig_func String the name of the builtin function this node replaces
child_attrs = ["loop"] child_attrs = ["loop"]
loop_analysed = False
def infer_type(self, env):
return self.result_node.infer_type(env)
def analyse_types(self, env): def analyse_types(self, env):
if not self.has_local_scope:
self.loop_analysed = True
self.loop.analyse_expressions(env)
self.type = self.result_node.type self.type = self.result_node.type
self.is_temp = True self.is_temp = True
def analyse_scoped_expressions(self, env):
self.loop_analysed = True
GeneratorExpressionNode.analyse_scoped_expressions(self, env)
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if self.orig_func == 'sum' and dst_type.is_numeric: if self.orig_func == 'sum' and dst_type.is_numeric and not self.loop_analysed:
# we can optimise by dropping the aggregation variable into C # We can optimise by dropping the aggregation variable and
# the add operations into C. This can only be done safely
# before analysing the loop body, after that, the result
# reference type will have infected expressions and
# assignments.
self.result_node.type = self.type = dst_type self.result_node.type = self.type = dst_type
return self return self
return GeneratorExpressionNode.coerce_to(self, dst_type, env) return GeneratorExpressionNode.coerce_to(self, dst_type, env)
...@@ -4838,10 +4856,14 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -4838,10 +4856,14 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
class InnerFunctionNode(PyCFunctionNode): class InnerFunctionNode(PyCFunctionNode):
# Special PyCFunctionNode that depends on a closure class # Special PyCFunctionNode that depends on a closure class
# #
binding = True binding = True
needs_self_code = True
def self_result_code(self): def self_result_code(self):
return "((PyObject*)%s)" % Naming.cur_scope_cname if self.needs_self_code:
return "((PyObject*)%s)" % (Naming.cur_scope_cname)
return "NULL"
class LambdaNode(InnerFunctionNode): class LambdaNode(InnerFunctionNode):
# Lambda expression node (only used as a function reference) # Lambda expression node (only used as a function reference)
...@@ -4859,7 +4881,6 @@ class LambdaNode(InnerFunctionNode): ...@@ -4859,7 +4881,6 @@ class LambdaNode(InnerFunctionNode):
name = StringEncoding.EncodedString('<lambda>') name = StringEncoding.EncodedString('<lambda>')
def analyse_declarations(self, env): def analyse_declarations(self, env):
#self.def_node.needs_closure = self.needs_closure
self.def_node.analyse_declarations(env) self.def_node.analyse_declarations(env)
self.pymethdef_cname = self.def_node.entry.pymethdef_cname self.pymethdef_cname = self.def_node.entry.pymethdef_cname
env.add_lambda_def(self.def_node) env.add_lambda_def(self.def_node)
...@@ -5651,7 +5672,12 @@ class NumBinopNode(BinopNode): ...@@ -5651,7 +5672,12 @@ class NumBinopNode(BinopNode):
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2): if self.c_types_okay(type1, type2):
return PyrexTypes.widest_numeric_type(type1, type2) widest_type = PyrexTypes.widest_numeric_type(type1, type2)
if widest_type is PyrexTypes.c_bint_type:
if self.operator not in '|^&':
# False + False == 0 # not False!
widest_type = PyrexTypes.c_int_type
return widest_type
else: else:
return None return None
......
...@@ -10,5 +10,6 @@ unicode_literals = _get_feature("unicode_literals") ...@@ -10,5 +10,6 @@ unicode_literals = _get_feature("unicode_literals")
with_statement = _get_feature("with_statement") with_statement = _get_feature("with_statement")
division = _get_feature("division") division = _get_feature("division")
print_function = _get_feature("print_function") print_function = _get_feature("print_function")
nested_scopes = _get_feature("nested_scopes") # dummy
del _get_feature del _get_feature
...@@ -138,7 +138,6 @@ class Context(object): ...@@ -138,7 +138,6 @@ class Context(object):
WithTransform(self), WithTransform(self),
DecoratorTransform(self), DecoratorTransform(self),
AnalyseDeclarationsTransform(self), AnalyseDeclarationsTransform(self),
CreateClosureClasses(self),
AutoTestDictTransform(self), AutoTestDictTransform(self),
EmbedSignature(self), EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), ## Necessary? EarlyReplaceBuiltinCalls(self), ## Necessary?
...@@ -148,6 +147,7 @@ class Context(object): ...@@ -148,6 +147,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
CreateClosureClasses(self), ## After all lookups and type inference
ExpandInplaceOperators(self), ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary? OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(), IterationTransform(),
...@@ -523,6 +523,7 @@ class Context(object): ...@@ -523,6 +523,7 @@ class Context(object):
return ".".join(names) return ".".join(names)
def setup_errors(self, options, result): def setup_errors(self, options, result):
Errors.reset() # clear any remaining error state
if options.use_listing_file: if options.use_listing_file:
result.listing_file = Utils.replace_suffix(source, ".lis") result.listing_file = Utils.replace_suffix(source, ".lis")
path = result.listing_file path = result.listing_file
......
...@@ -2,15 +2,16 @@ ...@@ -2,15 +2,16 @@
# Pyrex - Module parse tree node # Pyrex - Module parse tree node
# #
import cython
from cython import set
cython.declare(Naming=object, Options=object, PyrexTypes=object, TypeSlots=object,
error=object, warning=object, py_object_type=object, UtilityCode=object,
escape_byte_string=object, EncodedString=object)
import os, time import os, time
from PyrexTypes import CPtrType from PyrexTypes import CPtrType
import Future import Future
try:
set
except NameError: # Python 2.3
from sets import Set as set
import Annotate import Annotate
import Code import Code
import Naming import Naming
...@@ -666,6 +667,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -666,6 +667,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("#include <math.h>") code.putln("#include <math.h>")
code.putln("#define %s" % Naming.api_guard_prefix + self.api_name(env)) code.putln("#define %s" % Naming.api_guard_prefix + self.api_name(env))
self.generate_includes(env, cimported_modules, code) self.generate_includes(env, cimported_modules, code)
code.putln("")
code.putln("#ifdef PYREX_WITHOUT_ASSERTIONS")
code.putln("#define CYTHON_WITHOUT_ASSERTIONS")
code.putln("#endif")
code.putln("")
if env.directives['ccomplex']: if env.directives['ccomplex']:
code.putln("") code.putln("")
code.putln("#if !defined(CYTHON_CCOMPLEX)") code.putln("#if !defined(CYTHON_CCOMPLEX)")
...@@ -1673,20 +1679,20 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1673,20 +1679,20 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("if (!(%s)) %s;" % ( code.putln("if (!(%s)) %s;" % (
entry.type.type_test_code("o"), entry.type.type_test_code("o"),
code.error_goto(entry.pos))) code.error_goto(entry.pos)))
code.put_var_decref(entry) code.putln("Py_INCREF(o);")
code.put_decref(entry.cname, entry.type, nanny=False)
code.putln("%s = %s;" % ( code.putln("%s = %s;" % (
entry.cname, entry.cname,
PyrexTypes.typecast(entry.type, py_object_type, "o"))) PyrexTypes.typecast(entry.type, py_object_type, "o")))
elif entry.type.from_py_function: elif entry.type.from_py_function:
rhs = "%s(o)" % entry.type.from_py_function rhs = "%s(o)" % entry.type.from_py_function
if entry.type.is_enum: if entry.type.is_enum:
rhs = typecast(entry.type, c_long_type, rhs) rhs = PyrexTypes.typecast(entry.type, PyrexTypes.c_long_type, rhs)
code.putln("%s = %s; if (%s) %s;" % ( code.putln("%s = %s; if (%s) %s;" % (
entry.cname, entry.cname,
rhs, rhs,
entry.type.error_condition(entry.cname), entry.type.error_condition(entry.cname),
code.error_goto(entry.pos))) code.error_goto(entry.pos)))
code.putln("Py_DECREF(o);")
else: else:
code.putln('PyErr_Format(PyExc_TypeError, "Cannot convert Python object %s to %s");' % (name, entry.type)) code.putln('PyErr_Format(PyExc_TypeError, "Cannot convert Python object %s to %s");' % (name, entry.type))
code.putln(code.error_goto(entry.pos)) code.putln(code.error_goto(entry.pos))
...@@ -1695,12 +1701,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1695,12 +1701,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("if (PyObject_SetAttr(%s, py_name, o) < 0) goto bad;" % Naming.module_cname) code.putln("if (PyObject_SetAttr(%s, py_name, o) < 0) goto bad;" % Naming.module_cname)
code.putln("}") code.putln("}")
code.putln("return 0;") code.putln("return 0;")
code.put_label(code.error_label) if code.label_used(code.error_label):
# This helps locate the offending name. code.put_label(code.error_label)
code.putln('__Pyx_AddTraceback("%s");' % self.full_module_name); # This helps locate the offending name.
code.putln('__Pyx_AddTraceback("%s");' % self.full_module_name);
code.error_label = old_error_label code.error_label = old_error_label
code.putln("bad:") code.putln("bad:")
code.putln("Py_DECREF(o);")
code.putln("return -1;") code.putln("return -1;")
code.putln("}") code.putln("}")
code.putln(import_star_utility_code) code.putln(import_star_utility_code)
......
...@@ -3,15 +3,17 @@ ...@@ -3,15 +3,17 @@
# Pyrex - Parse tree nodes # Pyrex - Parse tree nodes
# #
import sys, os, time, copy import cython
from cython import set
cython.declare(sys=object, os=object, time=object, copy=object,
Builtin=object, error=object, warning=object, Naming=object, PyrexTypes=object,
py_object_type=object, ModuleScope=object, LocalScope=object, ClosureScope=object, \
StructOrUnionScope=object, PyClassScope=object, CClassScope=object,
CppClassScope=object, UtilityCode=object, EncodedString=object,
absolute_path_length=cython.Py_ssize_t)
try: import sys, os, time, copy
set
except NameError:
# Python 2.3
from sets import Set as set
import Code
import Builtin import Builtin
from Errors import error, warning, InternalError from Errors import error, warning, InternalError
import Naming import Naming
...@@ -241,7 +243,7 @@ class Node(object): ...@@ -241,7 +243,7 @@ class Node(object):
if encountered is None: if encountered is None:
encountered = set() encountered = set()
if id(self) in encountered: if id(self) in encountered:
return "<%s (%d) -- already output>" % (self.__class__.__name__, id(self)) return "<%s (0x%x) -- already output>" % (self.__class__.__name__, id(self))
encountered.add(id(self)) encountered.add(id(self))
def dump_child(x, level): def dump_child(x, level):
...@@ -253,12 +255,12 @@ class Node(object): ...@@ -253,12 +255,12 @@ class Node(object):
return repr(x) return repr(x)
attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out] attrs = [(key, value) for key, value in self.__dict__.items() if key not in filter_out]
if len(attrs) == 0: if len(attrs) == 0:
return "<%s (%d)>" % (self.__class__.__name__, id(self)) return "<%s (0x%x)>" % (self.__class__.__name__, id(self))
else: else:
indent = " " * level indent = " " * level
res = "<%s (%d)\n" % (self.__class__.__name__, id(self)) res = "<%s (0x%x)\n" % (self.__class__.__name__, id(self))
for key, value in attrs: for key, value in attrs:
res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1)) res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1))
res += "%s>" % indent res += "%s>" % indent
...@@ -858,7 +860,7 @@ class TemplatedTypeNode(CBaseTypeNode): ...@@ -858,7 +860,7 @@ class TemplatedTypeNode(CBaseTypeNode):
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
# Py 2.x enforces byte strings as keyword arguments ... # Py 2.x enforces byte strings as keyword arguments ...
options = dict([ (name.encode('ASCII'), value) options = dict([ (name.encode('ASCII'), value)
for name, value in options.iteritems() ]) for name, value in options.items() ])
self.type = PyrexTypes.BufferType(base_type, **options) self.type = PyrexTypes.BufferType(base_type, **options)
...@@ -949,7 +951,7 @@ class CVarDefNode(StatNode): ...@@ -949,7 +951,7 @@ class CVarDefNode(StatNode):
entry.directive_locals = self.directive_locals entry.directive_locals = self.directive_locals
else: else:
if self.directive_locals: if self.directive_locals:
s.error("Decorators can only be followed by functions") error(self.pos, "Decorators can only be followed by functions")
if self.in_pxd and self.visibility != 'extern': if self.in_pxd and self.visibility != 'extern':
error(self.pos, error(self.pos,
"Only 'extern' C variable declaration allowed in .pxd file") "Only 'extern' C variable declaration allowed in .pxd file")
...@@ -1146,11 +1148,13 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1146,11 +1148,13 @@ class FuncDefNode(StatNode, BlockNode):
# #filename string C name of filename string const # #filename string C name of filename string const
# entry Symtab.Entry # entry Symtab.Entry
# needs_closure boolean Whether or not this function has inner functions/classes/yield # needs_closure boolean Whether or not this function has inner functions/classes/yield
# needs_outer_scope boolean Whether or not this function requires outer scope
# directive_locals { string : NameNode } locals defined by cython.locals(...) # directive_locals { string : NameNode } locals defined by cython.locals(...)
py_func = None py_func = None
assmt = None assmt = None
needs_closure = False needs_closure = False
needs_outer_scope = False
modifiers = [] modifiers = []
def analyse_default_values(self, env): def analyse_default_values(self, env):
...@@ -1198,7 +1202,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1198,7 +1202,7 @@ class FuncDefNode(StatNode, BlockNode):
import Buffer import Buffer
lenv = self.local_scope lenv = self.local_scope
if lenv.is_closure_scope: if lenv.is_closure_scope and not lenv.is_passthrough:
outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname, outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname,
Naming.outer_scope_cname) Naming.outer_scope_cname)
else: else:
...@@ -1259,10 +1263,13 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1259,10 +1263,13 @@ class FuncDefNode(StatNode, BlockNode):
cenv = env cenv = env
while cenv.is_py_class_scope or cenv.is_c_class_scope: while cenv.is_py_class_scope or cenv.is_c_class_scope:
cenv = cenv.outer_scope cenv = cenv.outer_scope
if lenv.is_closure_scope: if self.needs_closure:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname)) code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";") code.putln(";")
elif cenv.is_closure_scope: elif self.needs_outer_scope:
if lenv.is_passthrough:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname)) code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname))
code.putln(";") code.putln(";")
self.generate_argument_declarations(lenv, code) self.generate_argument_declarations(lenv, code)
...@@ -1314,12 +1321,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1314,12 +1321,14 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("}") code.putln("}")
code.put_gotref(Naming.cur_scope_cname) code.put_gotref(Naming.cur_scope_cname)
# Note that it is unsafe to decref the scope at this point. # Note that it is unsafe to decref the scope at this point.
if cenv.is_closure_scope: if self.needs_outer_scope:
code.putln("%s = (%s)%s;" % ( code.putln("%s = (%s)%s;" % (
outer_scope_cname, outer_scope_cname,
cenv.scope_class.type.declaration_code(''), cenv.scope_class.type.declaration_code(''),
Naming.self_cname)) Naming.self_cname))
if self.needs_closure: if lenv.is_passthrough:
code.putln("%s = %s;" % (Naming.cur_scope_cname, outer_scope_cname));
elif self.needs_closure:
# inner closures own a reference to their outer parent # inner closures own a reference to their outer parent
code.put_incref(outer_scope_cname, cenv.scope_class.type) code.put_incref(outer_scope_cname, cenv.scope_class.type)
code.put_giveref(outer_scope_cname) code.put_giveref(outer_scope_cname)
...@@ -2206,6 +2215,8 @@ class DefNode(FuncDefNode): ...@@ -2206,6 +2215,8 @@ class DefNode(FuncDefNode):
def needs_assignment_synthesis(self, env, code=None): def needs_assignment_synthesis(self, env, code=None):
# Should enable for module level as well, that will require more testing... # Should enable for module level as well, that will require more testing...
if self.entry.is_lambda:
return True
if env.is_module_scope: if env.is_module_scope:
if code is None: if code is None:
return env.directives['binding'] return env.directives['binding']
...@@ -3208,7 +3219,7 @@ class CClassDefNode(ClassDefNode): ...@@ -3208,7 +3219,7 @@ class CClassDefNode(ClassDefNode):
api = self.api, api = self.api,
buffer_defaults = buffer_defaults) buffer_defaults = buffer_defaults)
if home_scope is not env and self.visibility == 'extern': if home_scope is not env and self.visibility == 'extern':
env.add_imported_entry(self.class_name, self.entry, pos) env.add_imported_entry(self.class_name, self.entry, self.pos)
self.scope = scope = self.entry.type.scope self.scope = scope = self.entry.type.scope
if scope is not None: if scope is not None:
scope.directives = env.directives scope.directives = env.directives
...@@ -3376,7 +3387,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -3376,7 +3387,7 @@ class SingleAssignmentNode(AssignmentNode):
if func_name in ['declare', 'typedef']: if func_name in ['declare', 'typedef']:
if len(args) > 2 or kwds is not None: if len(args) > 2 or kwds is not None:
error(rhs.pos, "Can only declare one type at a time.") error(self.rhs.pos, "Can only declare one type at a time.")
return return
type = args[0].analyse_as_type(env) type = args[0].analyse_as_type(env)
if type is None: if type is None:
...@@ -3407,7 +3418,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -3407,7 +3418,7 @@ class SingleAssignmentNode(AssignmentNode):
elif func_name in ['struct', 'union']: elif func_name in ['struct', 'union']:
self.declaration_only = True self.declaration_only = True
if len(args) > 0 or kwds is None: if len(args) > 0 or kwds is None:
error(rhs.pos, "Struct or union members must be given by name.") error(self.rhs.pos, "Struct or union members must be given by name.")
return return
members = [] members = []
for member, type_node in kwds.key_value_pairs: for member, type_node in kwds.key_value_pairs:
...@@ -3991,7 +4002,7 @@ class AssertStatNode(StatNode): ...@@ -3991,7 +4002,7 @@ class AssertStatNode(StatNode):
gil_message = "Raising exception" gil_message = "Raising exception"
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.putln("#ifndef PYREX_WITHOUT_ASSERTIONS") code.putln("#ifndef CYTHON_WITHOUT_ASSERTIONS")
self.cond.generate_evaluation_code(code) self.cond.generate_evaluation_code(code)
code.putln( code.putln(
"if (unlikely(!%s)) {" % "if (unlikely(!%s)) {" %
......
import cython
from cython import set
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
UtilNodes=object, Naming=object)
import Nodes import Nodes
import ExprNodes import ExprNodes
import PyrexTypes import PyrexTypes
...@@ -17,14 +24,14 @@ from ParseTreeTransforms import SkipDeclarations ...@@ -17,14 +24,14 @@ from ParseTreeTransforms import SkipDeclarations
import codecs import codecs
try: try:
reduce from __builtin__ import reduce
except NameError: except ImportError:
from functools import reduce from functools import reduce
try: try:
set from __builtin__ import basestring
except NameError: except ImportError:
from sets import Set as set basestring = str # Python 3
class FakePythonEnv(object): class FakePythonEnv(object):
"A fake environment for creating type test nodes etc." "A fake environment for creating type test nodes etc."
...@@ -749,7 +756,7 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -749,7 +756,7 @@ class SwitchTransform(Visitor.VisitorTransform):
def extract_in_string_conditions(self, string_literal): def extract_in_string_conditions(self, string_literal):
if isinstance(string_literal, ExprNodes.UnicodeNode): if isinstance(string_literal, ExprNodes.UnicodeNode):
charvals = map(ord, set(string_literal.value)) charvals = list(map(ord, set(string_literal.value)))
charvals.sort() charvals.sort()
return [ ExprNodes.IntNode(string_literal.pos, value=str(charval), return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
constant_result=charval) constant_result=charval)
...@@ -1332,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1332,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
""" """
if len(pos_args) not in (1,2): if len(pos_args) not in (1,2):
return node return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
ExprNodes.ComprehensionNode)):
return node return node
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
if yield_expression is None: yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
return node if yield_expression is None:
return node
else: # ComprehensionNode
yield_stat_node = gen_expr_node.append
yield_expression = yield_stat_node.expr
try:
if not yield_expression.is_literal or not yield_expression.type.is_int:
return node
except AttributeError:
return node # in case we don't have a type yet
# special case: old Py2 backwards compatible "sum([int_const for ...])"
# can safely be unpacked into a genexpr
if len(pos_args) == 1: if len(pos_args) == 1:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
...@@ -1368,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1368,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return ExprNodes.InlinedGeneratorExpressionNode( return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = exec_code, result_node = result_ref, gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum') expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
has_local_scope = gen_expr_node.has_local_scope)
def _handle_simple_function_min(self, node, pos_args): def _handle_simple_function_min(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '<') return self._optimise_min_max(node, pos_args, '<')
...@@ -1383,7 +1403,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1383,7 +1403,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# leave this to Python # leave this to Python
return node return node
cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:]) cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
last_result = args[0] last_result = args[0]
for arg_node in cascaded_nodes: for arg_node in cascaded_nodes:
...@@ -1827,7 +1847,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -1827,7 +1847,7 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
# Note: this requires the float() function to be typed as # Note: this requires the float() function to be typed as
# returning a C 'double' # returning a C 'double'
if len(pos_args) == 0: if len(pos_args) == 0:
return ExprNode.FloatNode( return ExprNodes.FloatNode(
node, value="0.0", constant_result=0.0 node, value="0.0", constant_result=0.0
).coerce_to(Builtin.float_type, self.current_env()) ).coerce_to(Builtin.float_type, self.current_env())
elif len(pos_args) != 1: elif len(pos_args) != 1:
...@@ -1860,8 +1880,12 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -1860,8 +1880,12 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
self._error_wrong_arg_count('bool', node, pos_args, '0 or 1') self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
return node return node
else: else:
return pos_args[0].coerce_to_boolean( # => !!<bint>(x) to make sure it's exactly 0 or 1
self.current_env()).coerce_to_pyobject(self.current_env()) operand = pos_args[0].coerce_to_boolean(self.current_env())
operand = ExprNodes.NotNode(node.pos, operand = operand)
operand = ExprNodes.NotNode(node.pos, operand = operand)
# coerce back to Python object as that's the result we are expecting
return operand.coerce_to_pyobject(self.current_env())
### builtin functions ### builtin functions
...@@ -2931,7 +2955,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2931,7 +2955,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
# check if all children are constant # check if all children are constant
children = self.visitchildren(node) children = self.visitchildren(node)
for child_result in children.itervalues(): for child_result in children.values():
if type(child_result) is list: if type(child_result) is list:
for child in child_result: for child in child_result:
if getattr(child, 'constant_result', not_a_constant) is not_a_constant: if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
...@@ -2966,12 +2990,23 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2966,12 +2990,23 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
self._calculate_const(node) self._calculate_const(node)
return node return node
def visit_UnaryMinusNode(self, node): def visit_UnopNode(self, node):
self._calculate_const(node) self._calculate_const(node)
if node.constant_result is ExprNodes.not_a_constant: if node.constant_result is ExprNodes.not_a_constant:
return node return node
if not node.operand.is_literal: if not node.operand.is_literal:
return node return node
if isinstance(node.operand, ExprNodes.BoolNode):
return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
type = PyrexTypes.c_int_type,
constant_result = node.constant_result)
if node.operator == '+':
return self._handle_UnaryPlusNode(node)
elif node.operator == '-':
return self._handle_UnaryMinusNode(node)
return node
def _handle_UnaryMinusNode(self, node):
if isinstance(node.operand, ExprNodes.LongNode): if isinstance(node.operand, ExprNodes.LongNode):
return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value, return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
constant_result = node.constant_result) constant_result = node.constant_result)
...@@ -2988,10 +3023,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2988,10 +3023,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
constant_result = node.constant_result) constant_result = node.constant_result)
return node return node
def visit_UnaryPlusNode(self, node): def _handle_UnaryPlusNode(self, node):
self._calculate_const(node)
if node.constant_result is ExprNodes.not_a_constant:
return node
if node.constant_result == node.operand.constant_result: if node.constant_result == node.operand.constant_result:
return node.operand return node.operand
return node return node
...@@ -3017,12 +3049,13 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3017,12 +3049,13 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return node return node
if isinstance(node.constant_result, float): if isinstance(node.constant_result, float):
return node return node
if not node.operand1.is_literal or not node.operand2.is_literal: operand1, operand2 = node.operand1, node.operand2
if not operand1.is_literal or not operand2.is_literal:
return node return node
# now inject a new constant node with the calculated value # now inject a new constant node with the calculated value
try: try:
type1, type2 = node.operand1.type, node.operand2.type type1, type2 = operand1.type, operand2.type
if type1 is None or type2 is None: if type1 is None or type2 is None:
return node return node
except AttributeError: except AttributeError:
...@@ -3032,14 +3065,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3032,14 +3065,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
widest_type = PyrexTypes.widest_numeric_type(type1, type2) widest_type = PyrexTypes.widest_numeric_type(type1, type2)
else: else:
widest_type = PyrexTypes.py_object_type widest_type = PyrexTypes.py_object_type
target_class = self._widest_node_class(node.operand1, node.operand2) target_class = self._widest_node_class(operand1, operand2)
if target_class is None: if target_class is None:
return node return node
elif target_class is ExprNodes.IntNode: elif target_class is ExprNodes.IntNode:
unsigned = getattr(node.operand1, 'unsigned', '') and \ unsigned = getattr(operand1, 'unsigned', '') and \
getattr(node.operand2, 'unsigned', '') getattr(operand2, 'unsigned', '')
longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')), longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
len(getattr(node.operand2, 'longness', '')))] len(getattr(operand2, 'longness', '')))]
new_node = ExprNodes.IntNode(pos=node.pos, new_node = ExprNodes.IntNode(pos=node.pos,
unsigned = unsigned, longness = longness, unsigned = unsigned, longness = longness,
value = str(node.constant_result), value = str(node.constant_result),
......
cimport cython
from Cython.Compiler.Visitor cimport (
CythonTransform, VisitorTransform, TreeVisitor,
ScopeTrackingTransform, EnvTransform)
cdef class NameNodeCollector(TreeVisitor):
cdef list name_nodes
cdef class SkipDeclarations: # (object):
pass
cdef class NormalizeTree(CythonTransform):
cdef bint is_in_statlist
cdef bint is_in_expr
cpdef visit_StatNode(self, node, is_listcontainer=*)
cdef class PostParse(ScopeTrackingTransform):
cdef dict specialattribute_handlers
cdef size_t lambda_counter
cdef _visit_assignment_node(self, node, list expr_list)
#def eliminate_rhs_duplicates(list expr_list_list, list ref_node_sequence)
#def sort_common_subsequences(list items)
@cython.locals(starred_targets=Py_ssize_t, lhs_size=Py_ssize_t, rhs_size=Py_ssize_t)
cdef flatten_parallel_assignments(list input, list output)
cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs_args, list rhs_args)
#class PxdPostParse(CythonTransform, SkipDeclarations):
#class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
#class WithTransform(CythonTransform, SkipDeclarations):
#class DecoratorTransform(CythonTransform, SkipDeclarations):
#class AnalyseDeclarationsTransform(CythonTransform):
cdef class AnalyseExpressionsTransform(CythonTransform):
pass
cdef class ExpandInplaceOperators(EnvTransform):
pass
cdef class AlignFunctionDefinitions(CythonTransform):
cdef dict directives
cdef scope
cdef class MarkClosureVisitor(CythonTransform):
cdef bint needs_closure
cdef class CreateClosureClasses(CythonTransform):
cdef list path
cdef bint in_lambda
cdef module_scope
cdef class GilCheck(VisitorTransform):
cdef list env_stack
cdef bint nogil
cdef class TransformBuiltinMethods(EnvTransform):
cdef visit_cython_attribute(self, node)
import cython
cython.declare(copy=object, ModuleNode=object, TreeFragment=object, TemplateTransform=object,
EncodedString=object, error=object, warning=object, PyrexTypes=object, Naming=object)
from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
...@@ -6,8 +11,9 @@ from Cython.Compiler.ExprNodes import * ...@@ -6,8 +11,9 @@ from Cython.Compiler.ExprNodes import *
from Cython.Compiler.UtilNodes import * from Cython.Compiler.UtilNodes import *
from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
from Cython.Compiler.StringEncoding import EncodedString from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import error, CompileError from Cython.Compiler.Errors import error, warning, CompileError
from Cython.Compiler import PyrexTypes from Cython.Compiler import PyrexTypes, Naming
try: try:
set set
...@@ -25,11 +31,12 @@ class NameNodeCollector(TreeVisitor): ...@@ -25,11 +31,12 @@ class NameNodeCollector(TreeVisitor):
super(NameNodeCollector, self).__init__() super(NameNodeCollector, self).__init__()
self.name_nodes = [] self.name_nodes = []
visit_Node = TreeVisitor.visitchildren
def visit_NameNode(self, node): def visit_NameNode(self, node):
self.name_nodes.append(node) self.name_nodes.append(node)
def visit_Node(self, node):
self._visitchildren(node, None)
class SkipDeclarations(object): class SkipDeclarations(object):
""" """
...@@ -180,9 +187,6 @@ class PostParse(ScopeTrackingTransform): ...@@ -180,9 +187,6 @@ class PostParse(ScopeTrackingTransform):
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
# unpack a lambda expression into the corresponding DefNode # unpack a lambda expression into the corresponding DefNode
if self.scope_type != 'function':
error(node.pos,
"lambda functions are currently only supported in functions")
lambda_id = self.lambda_counter lambda_id = self.lambda_counter
self.lambda_counter += 1 self.lambda_counter += 1
node.lambda_name = EncodedString(u'lambda%d' % lambda_id) node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
...@@ -244,8 +248,10 @@ class PostParse(ScopeTrackingTransform): ...@@ -244,8 +248,10 @@ class PostParse(ScopeTrackingTransform):
# Split parallel assignments (a,b = b,a) into separate partial # Split parallel assignments (a,b = b,a) into separate partial
# assignments that are executed rhs-first using temps. This # assignments that are executed rhs-first using temps. This
# optimisation is best applied before type analysis so that known # restructuring must be applied before type analysis so that known
# types on rhs and lhs can be matched directly. # types on rhs and lhs can be matched directly. It is required in
# the case that the types cannot be coerced to a Python type in
# order to assign from a tuple.
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
self.visitchildren(node) self.visitchildren(node)
...@@ -300,7 +306,7 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): ...@@ -300,7 +306,7 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
and appends them to ref_node_sequence. The input list is modified and appends them to ref_node_sequence. The input list is modified
in-place. in-place.
""" """
seen_nodes = set() seen_nodes = cython.set()
ref_nodes = {} ref_nodes = {}
def find_duplicates(node): def find_duplicates(node):
if node.is_literal or node.is_name: if node.is_literal or node.is_name:
...@@ -328,13 +334,13 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): ...@@ -328,13 +334,13 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
if node in ref_nodes: if node in ref_nodes:
return ref_nodes[node] return ref_nodes[node]
elif node.is_sequence_constructor: elif node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args) node.args = list(map(substitute_nodes, node.args))
return node return node
# replace nodes inside of the common subexpressions # replace nodes inside of the common subexpressions
for node in ref_nodes: for node in ref_nodes:
if node.is_sequence_constructor: if node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args) node.args = list(map(substitute_nodes, node.args))
# replace common subexpressions on all rhs items # replace common subexpressions on all rhs items
for expr_list in expr_list_list: for expr_list in expr_list_list:
...@@ -342,10 +348,15 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): ...@@ -342,10 +348,15 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
def sort_common_subsequences(items): def sort_common_subsequences(items):
"""Sort items/subsequences so that all items and subsequences that """Sort items/subsequences so that all items and subsequences that
an item contains appear before the item itself. This implies a an item contains appear before the item itself. This is needed
partial order, and the sort must be stable to preserve the because each rhs item must only be evaluated once, so its value
original order as much as possible, so we use a simple insertion must be evaluated first and then reused when packing sequences
sort. that contain it.
This implies a partial order, and the sort must be stable to
preserve the original order as much as possible, so we use a
simple insertion sort (which is very fast for short sequences, the
normal case in practice).
""" """
def contains(seq, x): def contains(seq, x):
for item in seq: for item in seq:
...@@ -358,8 +369,8 @@ def sort_common_subsequences(items): ...@@ -358,8 +369,8 @@ def sort_common_subsequences(items):
return b.is_sequence_constructor and contains(b.args, a) return b.is_sequence_constructor and contains(b.args, a)
for pos, item in enumerate(items): for pos, item in enumerate(items):
key = item[1] # the ResultRefNode which has already been injected into the sequences
new_pos = pos new_pos = pos
key = item[0]
for i in xrange(pos-1, -1, -1): for i in xrange(pos-1, -1, -1):
if lower_than(key, items[i][0]): if lower_than(key, items[i][0]):
new_pos = i new_pos = i
...@@ -566,16 +577,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -566,16 +577,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
'operator.comma' : c_binop_constructor(','), 'operator.comma' : c_binop_constructor(','),
} }
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL'] 'cast', 'pointer', 'compiled', 'NULL'])
+ unop_method_nodes.keys()) special_methods.update(unop_method_nodes.keys())
def __init__(self, context, compilation_directive_defaults): def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context) super(InterpretCompilerDirectives, self).__init__(context)
self.compilation_directive_defaults = {} self.compilation_directive_defaults = {}
for key, value in compilation_directive_defaults.iteritems(): for key, value in compilation_directive_defaults.items():
self.compilation_directive_defaults[unicode(key)] = value self.compilation_directive_defaults[unicode(key)] = value
self.cython_module_names = set() self.cython_module_names = cython.set()
self.directive_names = {} self.directive_names = {}
def check_directive_scope(self, pos, directive, scope): def check_directive_scope(self, pos, directive, scope):
...@@ -589,7 +600,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -589,7 +600,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
# Set up processing and handle the cython: comments. # Set up processing and handle the cython: comments.
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
for key, value in node.directive_comments.iteritems(): for key, value in node.directive_comments.items():
if not self.check_directive_scope(node.pos, key, 'module'): if not self.check_directive_scope(node.pos, key, 'module'):
self.wrong_scope_error(node.pos, key, 'module') self.wrong_scope_error(node.pos, key, 'module')
del node.directive_comments[key] del node.directive_comments[key]
...@@ -1017,7 +1028,7 @@ property NAME: ...@@ -1017,7 +1028,7 @@ property NAME:
return node return node
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.seen_vars_stack.append(set()) self.seen_vars_stack.append(cython.set())
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.env_stack[-1])
self.visitchildren(node) self.visitchildren(node)
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
...@@ -1049,7 +1060,7 @@ property NAME: ...@@ -1049,7 +1060,7 @@ property NAME:
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.seen_vars_stack.append(set()) self.seen_vars_stack.append(cython.set())
lenv = node.local_scope lenv = node.local_scope
node.body.analyse_control_flow(lenv) # this will be totally refactored node.body.analyse_control_flow(lenv) # this will be totally refactored
node.declare_arguments(lenv) node.declare_arguments(lenv)
...@@ -1068,15 +1079,18 @@ property NAME: ...@@ -1068,15 +1079,18 @@ property NAME:
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
node.analyse_declarations(self.env_stack[-1]) env = self.env_stack[-1]
node.analyse_declarations(env)
# the node may or may not have a local scope # the node may or may not have a local scope
if node.expr_scope: if node.has_local_scope:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1])) self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
self.env_stack.append(node.expr_scope) self.env_stack.append(node.expr_scope)
node.analyse_scoped_declarations(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.env_stack.pop()
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
else: else:
node.analyse_scoped_declarations(env)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -1172,7 +1186,7 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1172,7 +1186,7 @@ class AnalyseExpressionsTransform(CythonTransform):
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
if node.expr_scope is not None: if node.has_local_scope:
node.expr_scope.infer_types() node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope) node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
...@@ -1289,9 +1303,12 @@ class AlignFunctionDefinitions(CythonTransform): ...@@ -1289,9 +1303,12 @@ class AlignFunctionDefinitions(CythonTransform):
class MarkClosureVisitor(CythonTransform): class MarkClosureVisitor(CythonTransform):
needs_closure = False def visit_ModuleNode(self, node):
self.needs_closure = False
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.needs_closure = False self.needs_closure = False
self.visitchildren(node) self.visitchildren(node)
...@@ -1320,16 +1337,58 @@ class MarkClosureVisitor(CythonTransform): ...@@ -1320,16 +1337,58 @@ class MarkClosureVisitor(CythonTransform):
class CreateClosureClasses(CythonTransform): class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions # Output closure classes in module scope for all functions
# that need it. # that really need it.
def __init__(self, context):
super(CreateClosureClasses, self).__init__(context)
self.path = []
self.in_lambda = False
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.module_scope = node.scope self.module_scope = node.scope
self.visitchildren(node) self.visitchildren(node)
return node return node
def create_class_from_scope(self, node, target_module_scope): def get_scope_use(self, node):
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname) from_closure = []
in_closure = []
for name, entry in node.local_scope.entries.items():
if entry.from_closure:
from_closure.append((name, entry))
elif entry.in_closure and not entry.from_closure:
in_closure.append((name, entry))
return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None):
from_closure, in_closure = self.get_scope_use(node)
in_closure.sort()
# Now from the begining
node.needs_closure = False
node.needs_outer_scope = False
func_scope = node.local_scope func_scope = node.local_scope
cscope = node.entry.scope
while cscope.is_py_class_scope or cscope.is_c_class_scope:
cscope = cscope.outer_scope
if not from_closure and (self.path or inner_node):
if not inner_node:
if not node.assmt:
raise InternalError, "DefNode does not have assignment node"
inner_node = node.assmt.rhs
inner_node.needs_self_code = False
node.needs_outer_scope = False
# Simple cases
if not in_closure and not from_closure:
return
elif not in_closure:
func_scope.is_passthrough = True
func_scope.scope_class = cscope.scope_class
node.needs_outer_scope = True
return
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
entry = target_module_scope.declare_c_class(name = as_name, entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True) pos = node.pos, defining = True, implementing = True)
...@@ -1338,34 +1397,41 @@ class CreateClosureClasses(CythonTransform): ...@@ -1338,34 +1397,41 @@ class CreateClosureClasses(CythonTransform):
class_scope.is_internal = True class_scope.is_internal = True
class_scope.directives = {'final': True} class_scope.directives = {'final': True}
cscope = node.entry.scope if from_closure:
while cscope.is_py_class_scope or cscope.is_c_class_scope: assert cscope.is_closure_scope
cscope = cscope.outer_scope
if cscope.is_closure_scope:
class_scope.declare_var(pos=node.pos, class_scope.declare_var(pos=node.pos,
name=Naming.outer_scope_cname, # this could conflict? name=Naming.outer_scope_cname,
cname=Naming.outer_scope_cname, cname=Naming.outer_scope_cname,
type=cscope.scope_class.type, type=cscope.scope_class.type,
is_cdef=True) is_cdef=True)
entries = func_scope.entries.items() node.needs_outer_scope = True
entries.sort() for name, entry in in_closure:
for name, entry in entries:
# This is wasteful--we should do this later when we know
# which vars are actually being used inside...
#
# Also, this happens before type inference and type
# analysis, so the entries created here may end up having
# incorrect or at least unspecified types.
class_scope.declare_var(pos=entry.pos, class_scope.declare_var(pos=entry.pos,
name=entry.name, name=entry.name,
cname=entry.cname, cname=entry.cname,
type=entry.type, type=entry.type,
is_cdef=True) is_cdef=True)
node.needs_closure = True
# Do it here because other classes are already checked
target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node):
was_in_lambda = self.in_lambda
self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node)
self.visitchildren(node)
self.in_lambda = was_in_lambda
return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
if node.needs_closure: if self.in_lambda:
self.visitchildren(node)
return node
if node.needs_closure or self.path:
self.create_class_from_scope(node, self.module_scope) self.create_class_from_scope(node, self.module_scope)
self.path.append(node)
self.visitchildren(node) self.visitchildren(node)
self.path.pop()
return node return node
......
...@@ -75,6 +75,7 @@ class Entry(object): ...@@ -75,6 +75,7 @@ class Entry(object):
# is_cfunction boolean Is a C function # is_cfunction boolean Is a C function
# is_cmethod boolean Is a C method of an extension type # is_cmethod boolean Is a C method of an extension type
# is_unbound_cmethod boolean Is an unbound C method of an extension type # is_unbound_cmethod boolean Is an unbound C method of an extension type
# is_lambda boolean Is a lambda function
# is_type boolean Is a type definition # is_type boolean Is a type definition
# is_cclass boolean Is an extension class # is_cclass boolean Is an extension class
# is_cpp_class boolean Is a C++ class # is_cpp_class boolean Is a C++ class
...@@ -137,6 +138,7 @@ class Entry(object): ...@@ -137,6 +138,7 @@ class Entry(object):
is_cfunction = 0 is_cfunction = 0
is_cmethod = 0 is_cmethod = 0
is_unbound_cmethod = 0 is_unbound_cmethod = 0
is_lambda = 0
is_type = 0 is_type = 0
is_cclass = 0 is_cclass = 0
is_cpp_class = 0 is_cpp_class = 0
...@@ -211,7 +213,8 @@ class Scope(object): ...@@ -211,7 +213,8 @@ class Scope(object):
# return_type PyrexType or None Return type of function owning scope # return_type PyrexType or None Return type of function owning scope
# is_py_class_scope boolean Is a Python class scope # is_py_class_scope boolean Is a Python class scope
# is_c_class_scope boolean Is an extension type scope # is_c_class_scope boolean Is an extension type scope
# is_closure_scope boolean # is_closure_scope boolean Is a closure scope
# is_passthrough boolean Outer scope is passed directly
# is_cpp_class_scope boolean Is a C++ class scope # is_cpp_class_scope boolean Is a C++ class scope
# is_property_scope boolean Is a extension type property scope # is_property_scope boolean Is a extension type property scope
# scope_prefix string Disambiguator for C names # scope_prefix string Disambiguator for C names
...@@ -228,6 +231,7 @@ class Scope(object): ...@@ -228,6 +231,7 @@ class Scope(object):
is_py_class_scope = 0 is_py_class_scope = 0
is_c_class_scope = 0 is_c_class_scope = 0
is_closure_scope = 0 is_closure_scope = 0
is_passthrough = 0
is_cpp_class_scope = 0 is_cpp_class_scope = 0
is_property_scope = 0 is_property_scope = 0
is_module_scope = 0 is_module_scope = 0
...@@ -528,7 +532,7 @@ class Scope(object): ...@@ -528,7 +532,7 @@ class Scope(object):
entry.name = EncodedString(func_cname) entry.name = EncodedString(func_cname)
entry.func_cname = func_cname entry.func_cname = func_cname
entry.signature = pyfunction_signature entry.signature = pyfunction_signature
self.pyfunc_entries.append(entry) entry.is_lambda = True
return entry return entry
def add_lambda_def(self, def_node): def add_lambda_def(self, def_node):
...@@ -1121,7 +1125,30 @@ class ModuleScope(Scope): ...@@ -1121,7 +1125,30 @@ class ModuleScope(Scope):
# Check defined # Check defined
if not entry.type.scope: if not entry.type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % entry.name) error(entry.pos, "C class '%s' is declared but not defined" % entry.name)
def check_c_class(self, entry):
type = entry.type
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
def check_c_classes(self): def check_c_classes(self):
# Performs post-analysis checking and finishing up of extension types # Performs post-analysis checking and finishing up of extension types
# being implemented in this module. This is called only for the main # being implemented in this module. This is called only for the main
...@@ -1144,28 +1171,8 @@ class ModuleScope(Scope): ...@@ -1144,28 +1171,8 @@ class ModuleScope(Scope):
print("...entry %s %s" % (entry.name, entry)) print("...entry %s %s" % (entry.name, entry))
print("......type = ", entry.type) print("......type = ", entry.type)
print("......visibility = ", entry.visibility) print("......visibility = ", entry.visibility)
type = entry.type self.check_c_class(entry)
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
def check_c_functions(self): def check_c_functions(self):
# Performs post-analysis checking making sure all # Performs post-analysis checking making sure all
# defined c functions are actually implemented. # defined c functions are actually implemented.
...@@ -1253,6 +1260,8 @@ class LocalScope(Scope): ...@@ -1253,6 +1260,8 @@ class LocalScope(Scope):
entry = Scope.lookup(self, name) entry = Scope.lookup(self, name)
if entry is not None: if entry is not None:
if entry.scope is not self and entry.scope.is_closure_scope: if entry.scope is not self and entry.scope.is_closure_scope:
if hasattr(entry.scope, "scope_class"):
raise InternalError, "lookup() after scope class created."
# The actual c fragment for the different scopes differs # The actual c fragment for the different scopes differs
# on the outside and inside, so we make a new entry # on the outside and inside, so we make a new entry
entry.in_closure = True entry.in_closure = True
...@@ -1270,27 +1279,29 @@ class LocalScope(Scope): ...@@ -1270,27 +1279,29 @@ class LocalScope(Scope):
for entry in self.entries.values(): for entry in self.entries.values():
if entry.from_closure: if entry.from_closure:
cname = entry.outer_entry.cname cname = entry.outer_entry.cname
if cname.startswith(Naming.cur_scope_cname): if self.is_passthrough:
cname = cname[len(Naming.cur_scope_cname)+2:] entry.cname = cname
entry.cname = "%s->%s" % (outer_scope_cname, cname) else:
if cname.startswith(Naming.cur_scope_cname):
cname = cname[len(Naming.cur_scope_cname)+2:]
entry.cname = "%s->%s" % (outer_scope_cname, cname)
elif entry.in_closure: elif entry.in_closure:
entry.original_cname = entry.cname entry.original_cname = entry.cname
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname) entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class GeneratorExpressionScope(Scope):
class GeneratorExpressionScope(LocalScope):
"""Scope for generator expressions and comprehensions. As opposed """Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all to generators, these can be easily inlined in some cases, so all
we really need is a scope that holds the loop variable(s). we really need is a scope that holds the loop variable(s).
""" """
def __init__(self, outer_scope): def __init__(self, outer_scope):
name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref) name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref)
LocalScope.__init__(self, name, outer_scope) Scope.__init__(self, name, outer_scope, outer_scope)
self.directives = outer_scope.directives self.directives = outer_scope.directives
self.genexp_prefix = "%s%d%s" % (Naming.pyrex_prefix, len(name), name) self.genexp_prefix = "%s%d%s" % (Naming.pyrex_prefix, len(name), name)
def mangle(self, prefix, name): def mangle(self, prefix, name):
return '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(self, prefix, name)) return '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(self, prefix, name))
def declare_var(self, name, type, pos, def declare_var(self, name, type, pos,
cname = None, visibility = 'private', is_cdef = True): cname = None, visibility = 'private', is_cdef = True):
...@@ -1299,10 +1310,10 @@ class GeneratorExpressionScope(LocalScope): ...@@ -1299,10 +1310,10 @@ class GeneratorExpressionScope(LocalScope):
outer_entry = self.outer_scope.lookup(name) outer_entry = self.outer_scope.lookup(name)
if outer_entry and outer_entry.is_variable: if outer_entry and outer_entry.is_variable:
type = outer_entry.type # may still be 'unspecified_type' ! type = outer_entry.type # may still be 'unspecified_type' !
# the outer scope needs to generate code for the variable, but # the parent scope needs to generate code for the variable, but
# this scope must hold its name exclusively # this scope must hold its name exclusively
cname = '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(Naming.var_prefix, name)) cname = '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(Naming.var_prefix, name))
entry = self.outer_scope.declare_var(None, type, pos, cname, visibility, is_cdef = True) entry = self.parent_scope.declare_var(None, type, pos, cname, visibility, is_cdef = True)
self.entries[name] = entry self.entries[name] = entry
return entry return entry
......
...@@ -225,8 +225,6 @@ class SimpleAssignmentTypeInferer(object): ...@@ -225,8 +225,6 @@ class SimpleAssignmentTypeInferer(object):
for entry in scope.entries.values(): for entry in scope.entries.values():
if entry.type is unspecified_type: if entry.type is unspecified_type:
entry.type = py_object_type entry.type = py_object_type
if scope.is_closure_scope:
fix_closure_entries(scope)
return return
dependancies_by_entry = {} # entry -> dependancies dependancies_by_entry = {} # entry -> dependancies
...@@ -288,19 +286,6 @@ class SimpleAssignmentTypeInferer(object): ...@@ -288,19 +286,6 @@ class SimpleAssignmentTypeInferer(object):
entry.type = py_object_type entry.type = py_object_type
if verbose: if verbose:
message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type)) message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
#if scope.is_closure_scope:
# fix_closure_entries(scope)
def fix_closure_entries(scope):
"""Temporary work-around to fix field types in the closure class
that were unknown at the time of creation and only determined
during type inference.
"""
closure_entries = scope.scope_class.type.scope.entries
for name, entry in scope.entries.iteritems():
if name in closure_entries:
closure_entry = closure_entries[name]
closure_entry.type = entry.type
def find_spanning_type(type1, type2): def find_spanning_type(type1, type2):
if type1 is type2: if type1 is type2:
......
...@@ -141,6 +141,9 @@ class ResultRefNode(AtomicExprNode): ...@@ -141,6 +141,9 @@ class ResultRefNode(AtomicExprNode):
def infer_type(self, env): def infer_type(self, env):
if self.expression is not None: if self.expression is not None:
return self.expression.infer_type(env) return self.expression.infer_type(env)
if self.type is not None:
return self.type
assert False, "cannot infer type of ResultRefNode"
def may_be_none(self): def may_be_none(self):
if not self.type.is_pyobject: if not self.type.is_pyobject:
......
__version__ = "0.13" __version__ = "0.13+"
# Void cython.* directives (for case insensitive operating systems). # Void cython.* directives (for case insensitive operating systems).
from Cython.Shadow import * from Cython.Shadow import *
...@@ -56,14 +56,13 @@ EXT_DEP_INCLUDES = [ ...@@ -56,14 +56,13 @@ EXT_DEP_INCLUDES = [
VER_DEP_MODULES = { VER_DEP_MODULES = {
# tests are excluded if 'CurrentPythonVersion OP VersionTuple', i.e. # tests are excluded if 'CurrentPythonVersion OP VersionTuple', i.e.
# (2,4) : (operator.le, ...) excludes ... when PyVer <= 2.4.x # (2,4) : (operator.lt, ...) excludes ... when PyVer < 2.4.x
(2,4) : (operator.lt, lambda x: x in ['run.extern_builtins_T258',
'run.builtin_sorted'
]),
(2,5) : (operator.lt, lambda x: x in ['run.any', (2,5) : (operator.lt, lambda x: x in ['run.any',
'run.all', 'run.all',
]), ]),
(2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258'
]),
(2,4) : (operator.lt, lambda x: x in ['run.builtin_sorted'
]),
(2,6) : (operator.lt, lambda x: x in ['run.print_function', (2,6) : (operator.lt, lambda x: x in ['run.print_function',
'run.cython3', 'run.cython3',
]), ]),
......
...@@ -84,7 +84,7 @@ else: ...@@ -84,7 +84,7 @@ else:
else: else:
scripts = ["cython.py", "cygdb.py"] scripts = ["cython.py", "cygdb.py"]
def compile_cython_modules(profile=False): def compile_cython_modules(profile=False, compile_more=False, cython_with_refnanny=False):
source_root = os.path.abspath(os.path.dirname(__file__)) source_root = os.path.abspath(os.path.dirname(__file__))
compiled_modules = ["Cython.Plex.Scanners", compiled_modules = ["Cython.Plex.Scanners",
"Cython.Plex.Actions", "Cython.Plex.Actions",
...@@ -92,8 +92,20 @@ def compile_cython_modules(profile=False): ...@@ -92,8 +92,20 @@ def compile_cython_modules(profile=False):
"Cython.Compiler.Parsing", "Cython.Compiler.Parsing",
"Cython.Compiler.Visitor", "Cython.Compiler.Visitor",
"Cython.Runtime.refnanny"] "Cython.Runtime.refnanny"]
extensions = [] if compile_more:
compiled_modules.extend([
"Cython.Compiler.ParseTreeTransforms",
"Cython.Compiler.Nodes",
"Cython.Compiler.ExprNodes",
"Cython.Compiler.ModuleNode",
"Cython.Compiler.Optimize",
])
defines = []
if cython_with_refnanny:
defines.append(('CYTHON_REFNANNY', '1'))
extensions = []
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
from Cython.Distutils import build_ext as build_ext_orig from Cython.Distutils import build_ext as build_ext_orig
for module in compiled_modules: for module in compiled_modules:
...@@ -105,8 +117,13 @@ def compile_cython_modules(profile=False): ...@@ -105,8 +117,13 @@ def compile_cython_modules(profile=False):
dep_files = [] dep_files = []
if os.path.exists(source_file + '.pxd'): if os.path.exists(source_file + '.pxd'):
dep_files.append(source_file + '.pxd') dep_files.append(source_file + '.pxd')
if '.refnanny' in module:
defines_for_module = []
else:
defines_for_module = defines
extensions.append( extensions.append(
Extension(module, sources = [pyx_source_file], Extension(module, sources = [pyx_source_file],
define_macros = defines_for_module,
depends = dep_files) depends = dep_files)
) )
...@@ -181,8 +198,13 @@ def compile_cython_modules(profile=False): ...@@ -181,8 +198,13 @@ def compile_cython_modules(profile=False):
if filename_encoding is None: if filename_encoding is None:
filename_encoding = sys.getdefaultencoding() filename_encoding = sys.getdefaultencoding()
c_source_file = c_source_file.encode(filename_encoding) c_source_file = c_source_file.encode(filename_encoding)
if '.refnanny' in module:
defines_for_module = []
else:
defines_for_module = defines
extensions.append( extensions.append(
Extension(module, sources = [c_source_file]) Extension(module, sources = [c_source_file],
define_macros = defines_for_module)
) )
else: else:
print("Compilation failed") print("Compilation failed")
...@@ -204,10 +226,22 @@ cython_profile = '--cython-profile' in sys.argv ...@@ -204,10 +226,22 @@ cython_profile = '--cython-profile' in sys.argv
if cython_profile: if cython_profile:
sys.argv.remove('--cython-profile') sys.argv.remove('--cython-profile')
try:
sys.argv.remove("--cython-compile-all")
cython_compile_more = True
except ValueError:
cython_compile_more = False
try:
sys.argv.remove("--cython-with-refnanny")
cython_with_refnanny = True
except ValueError:
cython_with_refnanny = False
try: try:
sys.argv.remove("--no-cython-compile") sys.argv.remove("--no-cython-compile")
except ValueError: except ValueError:
compile_cython_modules(cython_profile) compile_cython_modules(cython_profile, cython_compile_more, cython_with_refnanny)
setup_args.update(setuptools_extra_args) setup_args.update(setuptools_extra_args)
......
...@@ -7,7 +7,6 @@ numpy_ValueError_T172 ...@@ -7,7 +7,6 @@ numpy_ValueError_T172
unsignedbehaviour_T184 unsignedbehaviour_T184
missing_baseclass_in_predecl_T262 missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408 cfunc_call_tuple_args_T408
cascaded_list_unpacking_T467
compile.cpp_operators compile.cpp_operators
cpp_templated_ctypedef cpp_templated_ctypedef
cpp_structs cpp_structs
...@@ -17,6 +16,8 @@ function_as_method_T494 ...@@ -17,6 +16,8 @@ function_as_method_T494
closure_inside_cdef_T554 closure_inside_cdef_T554
ipow_crash_T562 ipow_crash_T562
pure_mode_cmethod_inheritance_T583 pure_mode_cmethod_inheritance_T583
genexpr_iterable_lookup_T600
for_from_pyvar_loop_T601
# CPython regression tests that don't current work: # CPython regression tests that don't current work:
pyregr.test_threadsignals pyregr.test_threadsignals
......
...@@ -10,10 +10,10 @@ all_tests_run() is executed which does final validation. ...@@ -10,10 +10,10 @@ all_tests_run() is executed which does final validation.
>>> items.sort() >>> items.sort()
>>> for key, value in items: >>> for key, value in items:
... print('%s ; %s' % (key, value)) ... print('%s ; %s' % (key, value))
MyCdefClass.cpdef_method (line 76) ; >>> add_log("cpdef class method") MyCdefClass.cpdef_method (line 77) ; >>> add_log("cpdef class method")
MyCdefClass.method (line 73) ; >>> add_log("cdef class method") MyCdefClass.method (line 74) ; >>> add_log("cdef class method")
MyClass.method (line 62) ; >>> add_log("class method") MyClass.method (line 63) ; >>> add_log("class method")
mycpdeffunc (line 49) ; >>> add_log("cpdef") mycpdeffunc (line 50) ; >>> add_log("cpdef")
myfunc (line 40) ; >>> add_log("def") myfunc (line 40) ; >>> add_log("def")
""" """
...@@ -39,6 +39,7 @@ def add_log(s): ...@@ -39,6 +39,7 @@ def add_log(s):
def myfunc(): def myfunc():
""">>> add_log("def")""" """>>> add_log("def")"""
x = lambda a:1 # no docstring here ...
def doc_without_test(): def doc_without_test():
"""Some docs""" """Some docs"""
......
...@@ -5,4 +5,13 @@ def test(): ...@@ -5,4 +5,13 @@ def test():
True True
""" """
cdef int x = 5 cdef int x = 5
print bool(x) return bool(x)
def test_bool_and_int():
"""
>>> test_bool_and_int()
1
"""
cdef int x = 5
cdef int b = bool(x)
return b
...@@ -7,26 +7,65 @@ def simple_parallel_assignment_from_call(): ...@@ -7,26 +7,65 @@ def simple_parallel_assignment_from_call():
cdef int ai, bi cdef int ai, bi
cdef long al, bl cdef long al, bl
cdef object ao, bo cdef object ao, bo
cdef int side_effect_count = call_count reset()
ai, bi = al, bl = ao, bo = c = d = [intval(1), intval(2)] ai, bi = al, bl = ao, bo = c = d = [intval(1), intval(2)]
side_effect_count = call_count - side_effect_count return call_count, ao, bo, ai, bi, al, bl, c, d
return side_effect_count, ao, bo, ai, bi, al, bl, c, d
def recursive_parallel_assignment_from_call(): def recursive_parallel_assignment_from_call_left():
""" """
>>> recursive_parallel_assignment_from_call() >>> recursive_parallel_assignment_from_call_left()
(3, 1, 2, 3, 1, 2, 3, (1, 2), 3, [(1, 2), 3]) (3, 1, 2, 3, 1, 2, 3, (1, 2), 3, [(1, 2), 3])
""" """
cdef int ai, bi, ci cdef int ai, bi, ci
cdef object ao, bo, co cdef object ao, bo, co
cdef int side_effect_count = call_count reset()
(ai, bi), ci = (ao, bo), co = t,o = d = [(intval(1), intval(2)), intval(3)] (ai, bi), ci = (ao, bo), co = t,o = d = [(intval(1), intval(2)), intval(3)]
side_effect_count = call_count - side_effect_count return call_count, ao, bo, co, ai, bi, ci, t, o, d
return side_effect_count, ao, bo, co, ai, bi, ci, t, o, d
def recursive_parallel_assignment_from_call_right():
"""
>>> recursive_parallel_assignment_from_call_right()
(3, 1, 2, 3, 1, 2, 3, 1, (2, 3), [1, (2, 3)])
"""
cdef int ai, bi, ci
cdef object ao, bo, co
reset()
ai, (bi, ci) = ao, (bo, co) = o,t = d = [intval(1), (intval(2), intval(3))]
return call_count, ao, bo, co, ai, bi, ci, o, t, d
def recursive_parallel_assignment_from_call_left_reversed():
"""
>>> recursive_parallel_assignment_from_call_left_reversed()
(3, 1, 2, 3, 1, 2, 3, (1, 2), 3, [(1, 2), 3])
"""
cdef int ai, bi, ci
cdef object ao, bo, co
reset()
d = t,o = (ao, bo), co = (ai, bi), ci = [(intval(1), intval(2)), intval(3)]
return call_count, ao, bo, co, ai, bi, ci, t, o, d
def recursive_parallel_assignment_from_call_right_reversed():
"""
>>> recursive_parallel_assignment_from_call_right_reversed()
(3, 1, 2, 3, 1, 2, 3, 1, (2, 3), [1, (2, 3)])
"""
cdef int ai, bi, ci
cdef object ao, bo, co
reset()
d = o,t = ao, (bo, co) = ai, (bi, ci) = [intval(1), (intval(2), intval(3))]
return call_count, ao, bo, co, ai, bi, ci, o, t, d
cdef int call_count = 0 cdef int call_count = 0
cdef int next_expected_arg = 1
cdef reset():
global call_count, next_expected_arg
call_count = 0
next_expected_arg = 1
cdef int intval(int x): cdef int intval(int x) except -1:
global call_count global call_count, next_expected_arg
call_count += 1 call_count += 1
assert next_expected_arg == x, "calls not in source code order: expected %d, found %d" % (next_expected_arg, x)
next_expected_arg += 1
return x return x
__doc__ = u"""
>>> f = add_n(3)
>>> f(2)
5
>>> f = add_n(1000000) cimport cython
>>> f(1000000), f(-1000000)
(2000000, 0)
>>> a(5)()
8
>>> local_x(1)(2)(4)
4 2 1
15
# this currently crashes Cython due to redefinition
#>>> x(1)(2)(4)
#15
>>> x2(1)(2)(4)
4 2 1
15
>>> inner_override(2,4)()
5
>>> reassign(4)(2)
3
>>> reassign_int(4)(2)
3
>>> reassign_int_int(4)(2)
3
>>> def py_twofuncs(x):
... def f(a):
... return g(x) + a
... def g(b):
... return x + b
... return f
>>> py_twofuncs(1)(2) == cy_twofuncs(1)(2)
True
>>> py_twofuncs(3)(5) == cy_twofuncs(3)(5)
True
>>> inner_funcs = more_inner_funcs(1)(2,4,8)
>>> inner_funcs[0](16), inner_funcs[1](32), inner_funcs[2](64)
(19, 37, 73)
>>> switch_funcs([1,2,3], [4,5,6], 0)([10])
[1, 2, 3, 10]
>>> switch_funcs([1,2,3], [4,5,6], 1)([10])
[4, 5, 6, 10]
>>> switch_funcs([1,2,3], [4,5,6], 2) is None
True
>>> call_ignore_func()
"""
def add_n(int n): def add_n(int n):
"""
>>> f = add_n(3)
>>> f(2)
5
>>> f = add_n(1000000)
>>> f(1000000), f(-1000000)
(2000000, 0)
"""
def f(int x): def f(int x):
return x+n return x+n
return f return f
def a(int x): def a(int x):
"""
>>> a(5)()
8
"""
def b(): def b():
def c(): def c():
return 3+x return 3+x
...@@ -74,6 +27,11 @@ def a(int x): ...@@ -74,6 +27,11 @@ def a(int x):
return b return b
def local_x(int arg_x): def local_x(int arg_x):
"""
>>> local_x(1)(2)(4)
4 2 1
15
"""
cdef int local_x = arg_x cdef int local_x = arg_x
def y(arg_y): def y(arg_y):
y = arg_y y = arg_y
...@@ -84,15 +42,23 @@ def local_x(int arg_x): ...@@ -84,15 +42,23 @@ def local_x(int arg_x):
return z return z
return y return y
# currently crashes Cython due to name redefinitions (see local_x()) def x(int x):
## def x(int x): """
## def y(y): >>> x(1)(2)(4)
## def z(long z): 15
## return 8+z+y+x """
## return z def y(y):
## return y def z(long z):
return 8+z+y+x
return z
return y
def x2(int x2): def x2(int x2):
"""
>>> x2(1)(2)(4)
4 2 1
15
"""
def y2(y2): def y2(y2):
def z2(long z2): def z2(long z2):
print z2, y2, x2 print z2, y2, x2
...@@ -102,6 +68,10 @@ def x2(int x2): ...@@ -102,6 +68,10 @@ def x2(int x2):
def inner_override(a,b): def inner_override(a,b):
"""
>>> inner_override(2,4)()
5
"""
def f(): def f():
a = 1 a = 1
return a+b return a+b
...@@ -109,18 +79,30 @@ def inner_override(a,b): ...@@ -109,18 +79,30 @@ def inner_override(a,b):
def reassign(x): def reassign(x):
"""
>>> reassign(4)(2)
3
"""
def f(a): def f(a):
return a+x return a+x
x = 1 x = 1
return f return f
def reassign_int(x): def reassign_int(x):
"""
>>> reassign_int(4)(2)
3
"""
def f(int a): def f(int a):
return a+x return a+x
x = 1 x = 1
return f return f
def reassign_int_int(int x): def reassign_int_int(int x):
"""
>>> reassign_int_int(4)(2)
3
"""
def f(int a): def f(int a):
return a+x return a+x
x = 1 x = 1
...@@ -128,6 +110,19 @@ def reassign_int_int(int x): ...@@ -128,6 +110,19 @@ def reassign_int_int(int x):
def cy_twofuncs(x): def cy_twofuncs(x):
"""
>>> def py_twofuncs(x):
... def f(a):
... return g(x) + a
... def g(b):
... return x + b
... return f
>>> py_twofuncs(1)(2) == cy_twofuncs(1)(2)
True
>>> py_twofuncs(3)(5) == cy_twofuncs(3)(5)
True
"""
def f(a): def f(a):
return g(x) + a return g(x) + a
def g(b): def g(b):
...@@ -135,6 +130,14 @@ def cy_twofuncs(x): ...@@ -135,6 +130,14 @@ def cy_twofuncs(x):
return f return f
def switch_funcs(a, b, int ix): def switch_funcs(a, b, int ix):
"""
>>> switch_funcs([1,2,3], [4,5,6], 0)([10])
[1, 2, 3, 10]
>>> switch_funcs([1,2,3], [4,5,6], 1)([10])
[4, 5, 6, 10]
>>> switch_funcs([1,2,3], [4,5,6], 2) is None
True
"""
def f(x): def f(x):
return a + x return a + x
def g(x): def g(x):
...@@ -152,9 +155,17 @@ def ignore_func(x): ...@@ -152,9 +155,17 @@ def ignore_func(x):
return None return None
def call_ignore_func(): def call_ignore_func():
"""
>>> call_ignore_func()
"""
ignore_func((1,2,3)) ignore_func((1,2,3))
def more_inner_funcs(x): def more_inner_funcs(x):
"""
>>> inner_funcs = more_inner_funcs(1)(2,4,8)
>>> inner_funcs[0](16), inner_funcs[1](32), inner_funcs[2](64)
(19, 37, 73)
"""
# called with x==1 # called with x==1
def f(a): def f(a):
def g(b): def g(b):
...@@ -175,3 +186,45 @@ def more_inner_funcs(x): ...@@ -175,3 +186,45 @@ def more_inner_funcs(x):
# called with (2,4,8) # called with (2,4,8)
return f(a_f), g(b_g), h(b_h) return f(a_f), g(b_g), h(b_h)
return resolve return resolve
@cython.test_assert_path_exists("//DefNode//DefNode//DefNode//DefNode",
"//DefNode[@needs_outer_scope = False]", # deep_inner()
"//DefNode//DefNode//DefNode//DefNode[@needs_closure = False]", # h()
)
@cython.test_fail_if_path_exists("//DefNode//DefNode[@needs_outer_scope = False]")
def deep_inner():
"""
>>> deep_inner()()
2
"""
cdef int x = 1
def f():
def g():
def h():
return x+1
return h
return g()
return f()
@cython.test_assert_path_exists("//DefNode//DefNode//DefNode",
"//DefNode//DefNode//DefNode[@needs_outer_scope = False]", # a()
"//DefNode//DefNode//DefNode[@needs_closure = False]", # a(), g(), h()
)
@cython.test_fail_if_path_exists("//DefNode//DefNode//DefNode[@needs_closure = True]") # a(), g(), h()
def deep_inner_sibling():
"""
>>> deep_inner_sibling()()
2
"""
cdef int x = 1
def f():
def a():
return 1
def g():
return x+a()
def h():
return g()
return h
return f()
cdef extern from "cpp_namespaces_helper.h" namespace "A": cdef extern from "cpp_namespaces_helper.h" namespace "A":
ctypedef int A_t ctypedef int A_t
A_t A_func(A_t first, A_t) A_t A_func(A_t first, A_t)
cdef void f(A_t)
cdef extern from "cpp_namespaces_helper.h" namespace "outer": cdef extern from "cpp_namespaces_helper.h" namespace "outer":
int outer_value int outer_value
...@@ -26,3 +27,9 @@ def test_nested(): ...@@ -26,3 +27,9 @@ def test_nested():
print outer_value print outer_value
print inner_value print inner_value
def test_typedef(A_t a):
"""
>>> test_typedef(3)
3
"""
return a
...@@ -76,6 +76,26 @@ def list_comp(): ...@@ -76,6 +76,26 @@ def list_comp():
assert x == 'abc' # don't leak in Py3 code assert x == 'abc' # don't leak in Py3 code
return result return result
module_level_lc = [ module_level_loopvar*2 for module_level_loopvar in range(4) ]
def list_comp_module_level():
"""
>>> module_level_lc
[0, 2, 4, 6]
>>> module_level_loopvar
Traceback (most recent call last):
NameError: name 'module_level_loopvar' is not defined
"""
module_level_list_genexp = list(module_level_genexp_loopvar*2 for module_level_genexp_loopvar in range(4))
def genexpr_module_level():
"""
>>> module_level_list_genexp
[0, 2, 4, 6]
>>> module_level_genexp_loopvar
Traceback (most recent call last):
NameError: name 'module_level_genexp_loopvar' is not defined
"""
def list_comp_unknown_type(l): def list_comp_unknown_type(l):
""" """
>>> list_comp_unknown_type(range(5)) >>> list_comp_unknown_type(range(5))
......
cdef unsigned long size2():
return 3
def for_from_plain_ulong():
"""
>>> for_from_plain_ulong()
0
1
2
"""
cdef object j = 0
for j from 0 <= j < size2():
print j
def for_in_plain_ulong():
"""
>>> for_in_plain_ulong()
0
1
2
"""
cdef object j = 0
for j in range(size2()):
print j
cdef extern from "for_from_pyvar_loop_T601_extern_def.h":
ctypedef unsigned long Ulong
cdef Ulong size():
return 3
def for_from_ctypedef_ulong():
"""
>>> for_from_ctypedef_ulong()
0
1
2
"""
cdef object j = 0
for j from 0 <= j < size():
print j
def for_in_ctypedef_ulong():
"""
>>> for_in_ctypedef_ulong()
0
1
2
"""
cdef object j = 0
for j in range(size()):
print j
cimport cython
@cython.test_assert_path_exists('//ComprehensionNode')
@cython.test_fail_if_path_exists('//SimpleCallNode')
def list_genexpr_iterable_lookup():
"""
>>> x = (0,1,2,3,4,5)
>>> [ x*2 for x in x if x % 2 == 0 ] # leaks in Py2 but finds the right 'x'
[0, 4, 8]
>>> list_genexpr_iterable_lookup()
[0, 4, 8]
"""
x = (0,1,2,3,4,5)
result = list( x*2 for x in x if x % 2 == 0 )
assert x == (0,1,2,3,4,5)
return result
@cython.test_assert_path_exists('//ComprehensionNode')
@cython.test_fail_if_path_exists('//SingleAssignmentNode//SimpleCallNode')
def genexpr_iterable_in_closure():
"""
>>> genexpr_iterable_in_closure()
[0, 4, 8]
"""
x = 'abc'
def f():
return x
result = list( x*2 for x in x if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
cdef object executable, version_info
cdef long hexversion
from sys import *
def test_cdefed_objects():
"""
>>> ex, vi = test_cdefed_objects()
>>> assert ex is not None
>>> assert vi is not None
"""
return executable, version_info
def test_cdefed_cvalues():
"""
>>> hexver = test_cdefed_cvalues()
>>> assert hexver is not None
>>> assert hexver > 0x02020000
"""
return hexversion
def test_non_cdefed_names():
"""
>>> mod, pth = test_non_cdefed_names()
>>> assert mod is not None
>>> assert pth is not None
"""
return modules, path
...@@ -149,6 +149,46 @@ def return_typed_sum_squares_start(seq, int start): ...@@ -149,6 +149,46 @@ def return_typed_sum_squares_start(seq, int start):
return <int>sum((i*i for i in seq), start) return <int>sum((i*i for i in seq), start)
@cython.test_assert_path_exists('//ForInStatNode',
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode')
def return_sum_of_listcomp_consts_start(seq, int start):
"""
>>> sum([1 for i in range(10) if i > 3], -1)
5
>>> return_sum_of_listcomp_consts_start(range(10), -1)
5
>>> print(sum([1 for i in range(10000) if i > 3], 9))
10005
>>> print(return_sum_of_listcomp_consts_start(range(10000), 9))
10005
"""
return sum([1 for i in seq if i > 3], start)
@cython.test_assert_path_exists('//ForInStatNode',
"//InlinedGeneratorExpressionNode",
# the next test is for a deficiency
# (see InlinedGeneratorExpressionNode.coerce_to()),
# hope this breaks one day
"//CoerceFromPyTypeNode//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode')
def return_typed_sum_of_listcomp_consts_start(seq, int start):
"""
>>> sum([1 for i in range(10) if i > 3], -1)
5
>>> return_typed_sum_of_listcomp_consts_start(range(10), -1)
5
>>> print(sum([1 for i in range(10000) if i > 3], 9))
10005
>>> print(return_typed_sum_of_listcomp_consts_start(range(10000), 9))
10005
"""
return <int>sum([1 for i in seq if i > 3], start)
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
'//ForInStatNode', '//ForInStatNode',
"//InlinedGeneratorExpressionNode") "//InlinedGeneratorExpressionNode")
......
# Module scope lambda functions
__doc__ = """
>>> pow2(16)
256
>>> with_closure(0)
0
>>> typed_lambda(1)(2)
3
>>> typed_lambda(1.5)(1.5)
2
>>> cdef_const_lambda()
123
>>> const_lambda()
321
"""
pow2 = lambda x: x * x
with_closure = lambda x:(lambda: x)()
typed_lambda = lambda int x : (lambda int y: x + y)
cdef int xxx = 123
cdef_const_lambda = lambda: xxx
yyy = 321
const_lambda = lambda: yyy
# cython: language_level=3
def list_comp_in_closure():
"""
>>> list_comp_in_closure()
[0, 4, 8]
"""
x = 'abc'
def f():
return x
result = [x*2 for x in range(5) if x % 2 == 0]
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_list_comp_in_closure():
"""
>>> pytyped_list_comp_in_closure()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
result = [x*2 for x in range(5) if x % 2 == 0]
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_list_comp_in_closure_repeated():
"""
>>> pytyped_list_comp_in_closure_repeated()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
for i in range(3):
result = [x*2 for x in range(5) if x % 2 == 0]
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def genexpr_in_closure():
"""
>>> genexpr_in_closure()
[0, 4, 8]
"""
x = 'abc'
def f():
return x
result = list( x*2 for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_genexpr_in_closure():
"""
>>> pytyped_genexpr_in_closure()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
result = list( x*2 for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def pytyped_genexpr_in_closure_repeated():
"""
>>> pytyped_genexpr_in_closure_repeated()
[0, 4, 8]
"""
cdef object x
x = 'abc'
def f():
return x
for i in range(3):
result = list( x*2 for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == 'abc' # don't leak in Py3 code
return result
def genexpr_scope_in_closure():
"""
>>> genexpr_scope_in_closure()
[0, 4, 8]
"""
i = 2
x = 'abc'
def f():
return i, x
result = list( x*i for x in range(5) if x % 2 == 0 )
assert x == 'abc' # don't leak in Py3 code
assert f() == (2,'abc') # don't leak in Py3 code
return result
# tests copied from test/test_bool.py in Py2.7
cdef assertEqual(a,b):
assert a == b, '%r != %r' % (a,b)
cdef assertIs(a,b):
assert a is b, '%r is not %r' % (a,b)
cdef assertIsNot(a,b):
assert a is not b, '%r is %r' % (a,b)
cdef assertNotIsInstance(a,b):
assert not isinstance(a,b), 'isinstance(%r, %s)' % (a,b)
def test_int():
"""
>>> test_int()
"""
assertEqual(int(False), 0)
assertIsNot(int(False), False)
assertEqual(int(True), 1)
assertIsNot(int(True), True)
def test_float():
"""
>>> test_float()
"""
assertEqual(float(False), 0.0)
assertIsNot(float(False), False)
assertEqual(float(True), 1.0)
assertIsNot(float(True), True)
def test_repr():
"""
>>> test_repr()
"""
assertEqual(repr(False), 'False')
assertEqual(repr(True), 'True')
assertEqual(eval(repr(False)), False)
assertEqual(eval(repr(True)), True)
def test_str():
"""
>>> test_str()
"""
assertEqual(str(False), 'False')
assertEqual(str(True), 'True')
def test_math():
"""
>>> test_math()
"""
assertEqual(+False, 0)
assertIsNot(+False, False)
assertEqual(-False, 0)
assertIsNot(-False, False)
assertEqual(abs(False), 0)
assertIsNot(abs(False), False)
assertEqual(+True, 1)
assertIsNot(+True, True)
assertEqual(-True, -1)
assertEqual(abs(True), 1)
assertIsNot(abs(True), True)
assertEqual(~False, -1)
assertEqual(~True, -2)
assertEqual(False+2, 2)
assertEqual(True+2, 3)
assertEqual(2+False, 2)
assertEqual(2+True, 3)
assertEqual(False+False, 0)
assertIsNot(False+False, False)
assertEqual(False+True, 1)
assertIsNot(False+True, True)
assertEqual(True+False, 1)
assertIsNot(True+False, True)
assertEqual(True+True, 2)
assertEqual(True-True, 0)
assertIsNot(True-True, False)
assertEqual(False-False, 0)
assertIsNot(False-False, False)
assertEqual(True-False, 1)
assertIsNot(True-False, True)
assertEqual(False-True, -1)
assertEqual(True*1, 1)
assertEqual(False*1, 0)
assertIsNot(False*1, False)
assertEqual(True/1, 1)
assertIsNot(True/1, True)
assertEqual(False/1, 0)
assertIsNot(False/1, False)
for b in False, True:
for i in 0, 1, 2:
assertEqual(b**i, int(b)**i)
assertIsNot(b**i, bool(int(b)**i))
for a in False, True:
for b in False, True:
assertIs(a&b, bool(int(a)&int(b)))
assertIs(a|b, bool(int(a)|int(b)))
assertIs(a^b, bool(int(a)^int(b)))
assertEqual(a&int(b), int(a)&int(b))
assertIsNot(a&int(b), bool(int(a)&int(b)))
assertEqual(a|int(b), int(a)|int(b))
assertIsNot(a|int(b), bool(int(a)|int(b)))
assertEqual(a^int(b), int(a)^int(b))
assertIsNot(a^int(b), bool(int(a)^int(b)))
assertEqual(int(a)&b, int(a)&int(b))
assertIsNot(int(a)&b, bool(int(a)&int(b)))
assertEqual(int(a)|b, int(a)|int(b))
assertIsNot(int(a)|b, bool(int(a)|int(b)))
assertEqual(int(a)^b, int(a)^int(b))
assertIsNot(int(a)^b, bool(int(a)^int(b)))
assertIs(1==1, True)
assertIs(1==0, False)
assertIs(0<1, True)
assertIs(1<0, False)
assertIs(0<=0, True)
assertIs(1<=0, False)
assertIs(1>0, True)
assertIs(1>1, False)
assertIs(1>=1, True)
assertIs(0>=1, False)
assertIs(0!=1, True)
assertIs(0!=0, False)
x = [1]
assertIs(x is x, True)
assertIs(x is not x, False)
assertIs(1 in x, True)
assertIs(0 in x, False)
assertIs(1 not in x, False)
assertIs(0 not in x, True)
x = {1: 2}
assertIs(x is x, True)
assertIs(x is not x, False)
assertIs(1 in x, True)
assertIs(0 in x, False)
assertIs(1 not in x, False)
assertIs(0 not in x, True)
assertIs(not True, False)
assertIs(not False, True)
def test_convert():
"""
>>> test_convert()
"""
assertIs(bool(10), True)
assertIs(bool(1), True)
assertIs(bool(-1), True)
assertIs(bool(0), False)
assertIs(bool("hello"), True)
assertIs(bool(""), False)
assertIs(bool(), False)
def test_isinstance():
"""
>>> test_isinstance()
"""
assertIs(isinstance(True, bool), True)
assertIs(isinstance(False, bool), True)
assertIs(isinstance(True, int), True)
assertIs(isinstance(False, int), True)
assertIs(isinstance(1, bool), False)
assertIs(isinstance(0, bool), False)
def test_issubclass():
"""
>>> test_issubclass()
"""
assertIs(issubclass(bool, int), True)
assertIs(issubclass(int, bool), False)
def test_boolean():
"""
>>> test_boolean()
"""
assertEqual(True & 1, 1)
assertNotIsInstance(True & 1, bool)
assertIs(True & True, True)
assertEqual(True | 1, 1)
assertNotIsInstance(True | 1, bool)
assertIs(True | True, True)
assertEqual(True ^ 1, 0)
assertNotIsInstance(True ^ 1, bool)
assertIs(True ^ True, False)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment