Commit f9c385e0 authored by Stefan Behnel's avatar Stefan Behnel

refactor analyse_types() and friends to work more like a transform by...

refactor analyse_types() and friends to work more like a transform by returning the node or a replacement
parent 927e1a4d
...@@ -346,18 +346,18 @@ class ExprNode(Node): ...@@ -346,18 +346,18 @@ class ExprNode(Node):
# Convenience routine performing both the Type # Convenience routine performing both the Type
# Analysis and Temp Allocation phases for a whole # Analysis and Temp Allocation phases for a whole
# expression. # expression.
self.analyse_types(env) return self.analyse_types(env)
def analyse_target_expression(self, env, rhs): def analyse_target_expression(self, env, rhs):
# Convenience routine performing both the Type # Convenience routine performing both the Type
# Analysis and Temp Allocation phases for the LHS of # Analysis and Temp Allocation phases for the LHS of
# an assignment. # an assignment.
self.analyse_target_types(env) return self.analyse_target_types(env)
def analyse_boolean_expression(self, env): def analyse_boolean_expression(self, env):
# Analyse expression and coerce to a boolean. # Analyse expression and coerce to a boolean.
self.analyse_types(env) node = self.analyse_types(env)
bool = self.coerce_to_boolean(env) bool = node.coerce_to_boolean(env)
return bool return bool
def analyse_temp_boolean_expression(self, env): def analyse_temp_boolean_expression(self, env):
...@@ -368,8 +368,8 @@ class ExprNode(Node): ...@@ -368,8 +368,8 @@ class ExprNode(Node):
# afterwards. By forcing the result into a temporary, # afterwards. By forcing the result into a temporary,
# we ensure that all disposal has been done by the # we ensure that all disposal has been done by the
# time we get the result. # time we get the result.
self.analyse_types(env) node = self.analyse_types(env)
return self.coerce_to_boolean(env).coerce_to_simple(env) return node.coerce_to_boolean(env).coerce_to_simple(env)
# --------------- Type Inference ----------------- # --------------- Type Inference -----------------
...@@ -418,7 +418,7 @@ class ExprNode(Node): ...@@ -418,7 +418,7 @@ class ExprNode(Node):
self.not_implemented("analyse_types") self.not_implemented("analyse_types")
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.analyse_types(env) return self.analyse_types(env)
def nogil_check(self, env): def nogil_check(self, env):
# By default, any expression based on Python objects is # By default, any expression based on Python objects is
...@@ -809,7 +809,7 @@ class PyConstNode(AtomicExprNode): ...@@ -809,7 +809,7 @@ class PyConstNode(AtomicExprNode):
return False return False
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def calculate_result_code(self): def calculate_result_code(self):
return self.value return self.value
...@@ -864,7 +864,7 @@ class ConstNode(AtomicExprNode): ...@@ -864,7 +864,7 @@ class ConstNode(AtomicExprNode):
return False return False
def analyse_types(self, env): def analyse_types(self, env):
pass # Types are held in class variables return self # Types are held in class variables
def check_const(self): def check_const(self):
return True return True
...@@ -1071,7 +1071,7 @@ class BytesNode(ConstNode): ...@@ -1071,7 +1071,7 @@ class BytesNode(ConstNode):
pos = (self.pos[0], self.pos[1], self.pos[2]-7) pos = (self.pos[0], self.pos[1], self.pos[2]-7)
declaration = TreeFragment(u"sizeof(%s)" % self.value, name=pos[0].filename, initial_pos=pos) declaration = TreeFragment(u"sizeof(%s)" % self.value, name=pos[0].filename, initial_pos=pos)
sizeof_node = declaration.root.stats[0].expr sizeof_node = declaration.root.stats[0].expr
sizeof_node.analyse_types(env) sizeof_node = sizeof_node.analyse_types(env)
if isinstance(sizeof_node, SizeofTypeNode): if isinstance(sizeof_node, SizeofTypeNode):
return sizeof_node.arg_type return sizeof_node.arg_type
...@@ -1255,6 +1255,7 @@ class LongNode(AtomicExprNode): ...@@ -1255,6 +1255,7 @@ class LongNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.is_temp = 1 self.is_temp = 1
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -1285,6 +1286,7 @@ class ImagNode(AtomicExprNode): ...@@ -1285,6 +1286,7 @@ class ImagNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.type.create_declaration_utility_code(env) self.type.create_declaration_utility_code(env)
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -1347,6 +1349,7 @@ class NewExprNode(AtomicExprNode): ...@@ -1347,6 +1349,7 @@ class NewExprNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
if self.type is None: if self.type is None:
self.infer_type(env) self.infer_type(env)
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -1500,7 +1503,7 @@ class NameNode(AtomicExprNode): ...@@ -1500,7 +1503,7 @@ class NameNode(AtomicExprNode):
self.entry = env.declare_builtin(self.name, self.pos) self.entry = env.declare_builtin(self.name, self.pos)
if not self.entry: if not self.entry:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
return return self
entry = self.entry entry = self.entry
if entry: if entry:
entry.used = 1 entry.used = 1
...@@ -1508,6 +1511,7 @@ class NameNode(AtomicExprNode): ...@@ -1508,6 +1511,7 @@ class NameNode(AtomicExprNode):
import Buffer import Buffer
Buffer.used_buffer_aux_vars(entry) Buffer.used_buffer_aux_vars(entry)
self.analyse_rvalue_entry(env) self.analyse_rvalue_entry(env)
return self
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.analyse_entry(env) self.analyse_entry(env)
...@@ -1530,6 +1534,7 @@ class NameNode(AtomicExprNode): ...@@ -1530,6 +1534,7 @@ class NameNode(AtomicExprNode):
if self.entry.type.is_buffer: if self.entry.type.is_buffer:
import Buffer import Buffer
Buffer.used_buffer_aux_vars(self.entry) Buffer.used_buffer_aux_vars(self.entry)
return self
def analyse_rvalue_entry(self, env): def analyse_rvalue_entry(self, env):
#print "NameNode.analyse_rvalue_entry:", self.name ### #print "NameNode.analyse_rvalue_entry:", self.name ###
...@@ -1917,9 +1922,10 @@ class BackquoteNode(ExprNode): ...@@ -1917,9 +1922,10 @@ class BackquoteNode(ExprNode):
subexprs = ['arg'] subexprs = ['arg']
def analyse_types(self, env): def analyse_types(self, env):
self.arg.analyse_types(env) self.arg = self.arg.analyse_types(env)
self.arg = self.arg.coerce_to_pyobject(env) self.arg = self.arg.coerce_to_pyobject(env)
self.is_temp = 1 self.is_temp = 1
return self
gil_message = "Backquote expression" gil_message = "Backquote expression"
...@@ -1963,13 +1969,14 @@ class ImportNode(ExprNode): ...@@ -1963,13 +1969,14 @@ class ImportNode(ExprNode):
self.level = -1 self.level = -1
else: else:
self.level = 0 self.level = 0
self.module_name.analyse_types(env) module_name = self.module_name.analyse_types(env)
self.module_name = self.module_name.coerce_to_pyobject(env) self.module_name = module_name.coerce_to_pyobject(env)
if self.name_list: if self.name_list:
self.name_list.analyse_types(env) name_list = self.name_list.analyse_types(env)
self.name_list.coerce_to_pyobject(env) self.name_list = name_list.coerce_to_pyobject(env)
self.is_temp = 1 self.is_temp = 1
env.use_utility_code(UtilityCode.load_cached("Import", "ImportExport.c")) env.use_utility_code(UtilityCode.load_cached("Import", "ImportExport.c"))
return self
gil_message = "Python import" gil_message = "Python import"
...@@ -2004,7 +2011,7 @@ class IteratorNode(ExprNode): ...@@ -2004,7 +2011,7 @@ class IteratorNode(ExprNode):
subexprs = ['sequence'] subexprs = ['sequence']
def analyse_types(self, env): def analyse_types(self, env):
self.sequence.analyse_types(env) self.sequence = self.sequence.analyse_types(env)
if (self.sequence.type.is_array or self.sequence.type.is_ptr) and \ if (self.sequence.type.is_array or self.sequence.type.is_ptr) and \
not self.sequence.type.is_string: not self.sequence.type.is_string:
# C array iteration will be transformed later on # C array iteration will be transformed later on
...@@ -2017,6 +2024,7 @@ class IteratorNode(ExprNode): ...@@ -2017,6 +2024,7 @@ class IteratorNode(ExprNode):
self.sequence.type is tuple_type: self.sequence.type is tuple_type:
self.sequence = self.sequence.as_none_safe_node("'NoneType' object is not iterable") self.sequence = self.sequence.as_none_safe_node("'NoneType' object is not iterable")
self.is_temp = 1 self.is_temp = 1
return self
gil_message = "Iterating over Python object" gil_message = "Iterating over Python object"
...@@ -2276,6 +2284,7 @@ class NextNode(AtomicExprNode): ...@@ -2276,6 +2284,7 @@ class NextNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.type = self.infer_type(env, self.iterator.type) self.type = self.infer_type(env, self.iterator.type)
self.is_temp = 1 self.is_temp = 1
return self
def generate_result_code(self, code): def generate_result_code(self, code):
self.iterator.generate_iter_next_result_code(self.result(), code) self.iterator.generate_iter_next_result_code(self.result(), code)
...@@ -2291,9 +2300,10 @@ class WithExitCallNode(ExprNode): ...@@ -2291,9 +2300,10 @@ class WithExitCallNode(ExprNode):
subexprs = ['args'] subexprs = ['args']
def analyse_types(self, env): def analyse_types(self, env):
self.args.analyse_types(env) self.args = self.args.analyse_types(env)
self.type = PyrexTypes.c_bint_type self.type = PyrexTypes.c_bint_type
self.is_temp = True self.is_temp = True
return self
def generate_result_code(self, code): def generate_result_code(self, code):
if isinstance(self.args, TupleNode): if isinstance(self.args, TupleNode):
...@@ -2335,7 +2345,7 @@ class ExcValueNode(AtomicExprNode): ...@@ -2335,7 +2345,7 @@ class ExcValueNode(AtomicExprNode):
pass pass
def analyse_types(self, env): def analyse_types(self, env):
pass return self
class TempNode(ExprNode): class TempNode(ExprNode):
...@@ -2357,7 +2367,7 @@ class TempNode(ExprNode): ...@@ -2357,7 +2367,7 @@ class TempNode(ExprNode):
self.is_temp = 1 self.is_temp = 1
def analyse_types(self, env): def analyse_types(self, env):
return self.type return self
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
pass pass
...@@ -2401,7 +2411,7 @@ class RawCNameExprNode(ExprNode): ...@@ -2401,7 +2411,7 @@ class RawCNameExprNode(ExprNode):
self.cname = cname self.cname = cname
def analyse_types(self, env): def analyse_types(self, env):
return self.type return self
def set_cname(self, cname): def set_cname(self, cname):
self.cname = cname self.cname = cname
...@@ -2433,7 +2443,7 @@ class ParallelThreadsAvailableNode(AtomicExprNode): ...@@ -2433,7 +2443,7 @@ class ParallelThreadsAvailableNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.is_temp = True self.is_temp = True
# env.add_include_file("omp.h") # env.add_include_file("omp.h")
return self.type return self
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln("#ifdef _OPENMP") code.putln("#ifdef _OPENMP")
...@@ -2458,7 +2468,7 @@ class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode): ...@@ -2458,7 +2468,7 @@ class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
def analyse_types(self, env): def analyse_types(self, env):
self.is_temp = True self.is_temp = True
# env.add_include_file("omp.h") # env.add_include_file("omp.h")
return self.type return self
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln("#ifdef _OPENMP") code.putln("#ifdef _OPENMP")
...@@ -2623,6 +2633,7 @@ class IndexNode(ExprNode): ...@@ -2623,6 +2633,7 @@ class IndexNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.analyse_base_and_index_types(env, getting = 1) self.analyse_base_and_index_types(env, getting = 1)
return self
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
...@@ -2630,6 +2641,7 @@ class IndexNode(ExprNode): ...@@ -2630,6 +2641,7 @@ class IndexNode(ExprNode):
error(self.pos, "Assignment to const dereference") error(self.pos, "Assignment to const dereference")
if not self.is_lvalue(): if not self.is_lvalue():
error(self.pos, "Assignment to non-lvalue of type '%s'" % self.type) error(self.pos, "Assignment to non-lvalue of type '%s'" % self.type)
return self
def analyse_base_and_index_types(self, env, getting = 0, setting = 0, analyse_base = True): def analyse_base_and_index_types(self, env, getting = 0, setting = 0, analyse_base = True):
# Note: This might be cleaned up by having IndexNode # Note: This might be cleaned up by having IndexNode
...@@ -2649,7 +2661,7 @@ class IndexNode(ExprNode): ...@@ -2649,7 +2661,7 @@ class IndexNode(ExprNode):
self.memslice_index = False self.memslice_index = False
if analyse_base: if analyse_base:
self.base.analyse_types(env) self.base = self.base.analyse_types(env)
if self.base.type.is_error: if self.base.type.is_error:
# Do not visit child tree if base is undeclared to avoid confusing # Do not visit child tree if base is undeclared to avoid confusing
...@@ -2710,7 +2722,7 @@ class IndexNode(ExprNode): ...@@ -2710,7 +2722,7 @@ class IndexNode(ExprNode):
axis_idx = 0 axis_idx = 0
for i, index in enumerate(indices[:]): for i, index in enumerate(indices[:]):
index.analyse_types(env) index = index.analyse_types(env)
if not index.is_none: if not index.is_none:
access, packing = self.base.type.axes[axis_idx] access, packing = self.base.type.axes[axis_idx]
axis_idx += 1 axis_idx += 1
...@@ -2764,7 +2776,7 @@ class IndexNode(ExprNode): ...@@ -2764,7 +2776,7 @@ class IndexNode(ExprNode):
buffer_access = True buffer_access = True
skip_child_analysis = True skip_child_analysis = True
for x in indices: for x in indices:
x.analyse_types(env) x = x.analyse_types(env)
if not x.type.is_int: if not x.type.is_int:
buffer_access = False buffer_access = False
...@@ -2833,9 +2845,10 @@ class IndexNode(ExprNode): ...@@ -2833,9 +2845,10 @@ class IndexNode(ExprNode):
fused_index_operation = base_type.is_cfunction and base_type.is_fused fused_index_operation = base_type.is_cfunction and base_type.is_fused
if not fused_index_operation: if not fused_index_operation:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis) self.index = self.index.analyse_types(
env, skip_children=skip_child_analysis)
elif not skip_child_analysis: elif not skip_child_analysis:
self.index.analyse_types(env) self.index = self.index.analyse_types(env)
self.original_index_type = self.index.type self.original_index_type = self.index.type
if base_type.is_unicode_char: if base_type.is_unicode_char:
...@@ -2955,7 +2968,7 @@ class IndexNode(ExprNode): ...@@ -2955,7 +2968,7 @@ class IndexNode(ExprNode):
specific_types = [False] specific_types = [False]
if not Utils.all(specific_types): if not Utils.all(specific_types):
self.index.analyse_types(env) self.index = self.index.analyse_types(env)
if not self.base.entry.as_variable: if not self.base.entry.as_variable:
error(self.pos, "Can only index fused functions with types") error(self.pos, "Can only index fused functions with types")
...@@ -3402,13 +3415,14 @@ class SliceIndexNode(ExprNode): ...@@ -3402,13 +3415,14 @@ class SliceIndexNode(ExprNode):
pass pass
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.analyse_types(env, getting=False) node = self.analyse_types(env, getting=False)
# when assigning, we must accept any Python type # when assigning, we must accept any Python type
if self.type.is_pyobject: if node.type.is_pyobject:
self.type = py_object_type node.type = py_object_type
return node
def analyse_types(self, env, getting=True): def analyse_types(self, env, getting=True):
self.base.analyse_types(env) self.base = self.base.analyse_types(env)
if self.base.type.is_memoryviewslice: if self.base.type.is_memoryviewslice:
# Gross hack here! But we do not know the type until this point, # Gross hack here! But we do not know the type until this point,
...@@ -3427,12 +3441,12 @@ class SliceIndexNode(ExprNode): ...@@ -3427,12 +3441,12 @@ class SliceIndexNode(ExprNode):
getting=getting, getting=getting,
setting=not getting, setting=not getting,
analyse_base=False) analyse_base=False)
return return self
if self.start: if self.start:
self.start.analyse_types(env) self.start = self.start.analyse_types(env)
if self.stop: if self.stop:
self.stop.analyse_types(env) self.stop = self.stop.analyse_types(env)
base_type = self.base.type base_type = self.base.type
if base_type.is_string or base_type.is_cpp_string: if base_type.is_string or base_type.is_cpp_string:
self.type = bytes_type self.type = bytes_type
...@@ -3455,6 +3469,7 @@ class SliceIndexNode(ExprNode): ...@@ -3455,6 +3469,7 @@ class SliceIndexNode(ExprNode):
if self.stop: if self.stop:
self.stop = self.stop.coerce_to(c_int, env) self.stop = self.stop.coerce_to(c_int, env)
self.is_temp = 1 self.is_temp = 1
return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Slicing Python object" gil_message = "Slicing Python object"
...@@ -3637,15 +3652,16 @@ class SliceNode(ExprNode): ...@@ -3637,15 +3652,16 @@ class SliceNode(ExprNode):
self.compile_time_value_error(e) self.compile_time_value_error(e)
def analyse_types(self, env): def analyse_types(self, env):
self.start.analyse_types(env) start = self.start.analyse_types(env)
self.stop.analyse_types(env) stop = self.stop.analyse_types(env)
self.step.analyse_types(env) step = self.step.analyse_types(env)
self.start = self.start.coerce_to_pyobject(env) self.start = start.coerce_to_pyobject(env)
self.stop = self.stop.coerce_to_pyobject(env) self.stop = stop.coerce_to_pyobject(env)
self.step = self.step.coerce_to_pyobject(env) self.step = step.coerce_to_pyobject(env)
if self.start.is_literal and self.stop.is_literal and self.step.is_literal: if self.start.is_literal and self.stop.is_literal and self.step.is_literal:
self.is_literal = True self.is_literal = True
self.is_temp = False self.is_temp = False
return self
gil_message = "Constructing Python slice object" gil_message = "Constructing Python slice object"
...@@ -3737,12 +3753,11 @@ class CallNode(ExprNode): ...@@ -3737,12 +3753,11 @@ class CallNode(ExprNode):
items += kwds.key_value_pairs items += kwds.key_value_pairs
self.key_value_pairs = items self.key_value_pairs = items
self.__class__ = DictNode self.__class__ = DictNode
self.analyse_types(env) self.analyse_types(env) # FIXME
self.coerce_to(type, env) self.coerce_to(type, env)
return True return True
elif type and type.is_cpp_class: elif type and type.is_cpp_class:
for arg in self.args: self.args = [ arg.analyse_types(env) for arg in self.args ]
arg.analyse_types(env)
constructor = type.scope.lookup("<init>") constructor = type.scope.lookup("<init>")
self.function = RawCNameExprNode(self.function.pos, constructor.type) self.function = RawCNameExprNode(self.function.pos, constructor.type)
self.function.entry = constructor self.function.entry = constructor
...@@ -3811,13 +3826,13 @@ class SimpleCallNode(CallNode): ...@@ -3811,13 +3826,13 @@ class SimpleCallNode(CallNode):
def analyse_types(self, env): def analyse_types(self, env):
if self.analyse_as_type_constructor(env): if self.analyse_as_type_constructor(env):
return return self
if self.analysed: if self.analysed:
return return self
self.analysed = True self.analysed = True
self.function.is_called = 1
self.function = self.function.analyse_types(env)
function = self.function function = self.function
function.is_called = 1
self.function.analyse_types(env)
if function.is_attribute and function.entry and function.entry.is_cmethod: if function.is_attribute and function.entry and function.entry.is_cmethod:
# Take ownership of the object from which the attribute # Take ownership of the object from which the attribute
...@@ -3828,7 +3843,7 @@ class SimpleCallNode(CallNode): ...@@ -3828,7 +3843,7 @@ class SimpleCallNode(CallNode):
func_type = self.function_type() func_type = self.function_type()
if func_type.is_pyobject: if func_type.is_pyobject:
self.arg_tuple = TupleNode(self.pos, args = self.args) self.arg_tuple = TupleNode(self.pos, args = self.args)
self.arg_tuple.analyse_types(env) self.arg_tuple = self.arg_tuple.analyse_types(env)
self.args = None self.args = None
if func_type is Builtin.type_type and function.is_name and \ if func_type is Builtin.type_type and function.is_name and \
function.entry and \ function.entry and \
...@@ -3854,8 +3869,7 @@ class SimpleCallNode(CallNode): ...@@ -3854,8 +3869,7 @@ class SimpleCallNode(CallNode):
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
else: else:
for arg in self.args: self.args = [ arg.analyse_types(env) for arg in self.args ]
arg.analyse_types(env)
if self.self and func_type.args: if self.self and func_type.args:
# Coerce 'self' to the type expected by the method. # Coerce 'self' to the type expected by the method.
...@@ -3874,6 +3888,7 @@ class SimpleCallNode(CallNode): ...@@ -3874,6 +3888,7 @@ class SimpleCallNode(CallNode):
# Insert coerced 'self' argument into argument list. # Insert coerced 'self' argument into argument list.
self.args.insert(0, self.coerced_self) self.args.insert(0, self.coerced_self)
self.analyse_c_function_call(env) self.analyse_c_function_call(env)
return self
def function_type(self): def function_type(self):
# Return the type of the function being called, coercing a function # Return the type of the function being called, coercing a function
...@@ -4176,11 +4191,9 @@ class InlinedDefNodeCallNode(CallNode): ...@@ -4176,11 +4191,9 @@ class InlinedDefNodeCallNode(CallNode):
return True return True
def analyse_types(self, env): def analyse_types(self, env):
self.function_name.analyse_types(env) self.function_name = self.function_name.analyse_types(env)
for arg in self.args:
arg.analyse_types(env)
self.args = [ arg.analyse_types(env) for arg in self.args ]
func_type = self.function.def_node func_type = self.function.def_node
actual_nargs = len(self.args) actual_nargs = len(self.args)
...@@ -4232,6 +4245,7 @@ class InlinedDefNodeCallNode(CallNode): ...@@ -4232,6 +4245,7 @@ class InlinedDefNodeCallNode(CallNode):
if i > 0: if i > 0:
warning(arg.pos, "Argument evaluation order in C function call is undefined and may not be as expected", 0) warning(arg.pos, "Argument evaluation order in C function call is undefined and may not be as expected", 0)
break break
return self
def generate_result_code(self, code): def generate_result_code(self, code):
arg_code = [self.function_name.py_result()] arg_code = [self.function_name.py_result()]
...@@ -4259,7 +4273,7 @@ class PythonCapiFunctionNode(ExprNode): ...@@ -4259,7 +4273,7 @@ class PythonCapiFunctionNode(ExprNode):
type=func_type, utility_code=utility_code) type=func_type, utility_code=utility_code)
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def generate_result_code(self, code): def generate_result_code(self, code):
if self.utility_code: if self.utility_code:
...@@ -4323,12 +4337,12 @@ class GeneralCallNode(CallNode): ...@@ -4323,12 +4337,12 @@ class GeneralCallNode(CallNode):
def analyse_types(self, env): def analyse_types(self, env):
if self.analyse_as_type_constructor(env): if self.analyse_as_type_constructor(env):
return return self
self.function.analyse_types(env) self.function = self.function.analyse_types(env)
if not self.function.type.is_pyobject: if not self.function.type.is_pyobject:
if self.function.type.is_error: if self.function.type.is_error:
self.type = error_type self.type = error_type
return return self
if hasattr(self.function, 'entry'): if hasattr(self.function, 'entry'):
self.map_keywords_to_posargs() self.map_keywords_to_posargs()
if not self.is_simple_call: if not self.is_simple_call:
...@@ -4340,8 +4354,8 @@ class GeneralCallNode(CallNode): ...@@ -4340,8 +4354,8 @@ class GeneralCallNode(CallNode):
else: else:
self.function = self.function.coerce_to_pyobject(env) self.function = self.function.coerce_to_pyobject(env)
if self.keyword_args: if self.keyword_args:
self.keyword_args.analyse_types(env) self.keyword_args = self.keyword_args.analyse_types(env)
self.positional_args.analyse_types(env) self.positional_args = self.positional_args.analyse_types(env)
self.positional_args = \ self.positional_args = \
self.positional_args.coerce_to_pyobject(env) self.positional_args.coerce_to_pyobject(env)
function = self.function function = self.function
...@@ -4354,6 +4368,7 @@ class GeneralCallNode(CallNode): ...@@ -4354,6 +4368,7 @@ class GeneralCallNode(CallNode):
else: else:
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
return self
def map_keywords_to_posargs(self): def map_keywords_to_posargs(self):
if not isinstance(self.positional_args, TupleNode): if not isinstance(self.positional_args, TupleNode):
...@@ -4457,10 +4472,11 @@ class AsTupleNode(ExprNode): ...@@ -4457,10 +4472,11 @@ class AsTupleNode(ExprNode):
self.compile_time_value_error(e) self.compile_time_value_error(e)
def analyse_types(self, env): def analyse_types(self, env):
self.arg.analyse_types(env) self.arg = self.arg.analyse_types(env)
self.arg = self.arg.coerce_to_pyobject(env) self.arg = self.arg.coerce_to_pyobject(env)
self.type = tuple_type self.type = tuple_type
self.is_temp = 1 self.is_temp = 1
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -4565,11 +4581,12 @@ class AttributeNode(ExprNode): ...@@ -4565,11 +4581,12 @@ class AttributeNode(ExprNode):
pass pass
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.analyse_types(env, target = 1) node = self.analyse_types(env, target = 1)
if self.type.is_const: if node.type.is_const:
error(self.pos, "Assignment to const attribute '%s'" % self.attribute) error(self.pos, "Assignment to const attribute '%s'" % self.attribute)
if not self.is_lvalue(): if not node.is_lvalue():
error(self.pos, "Assignment to non-lvalue of type '%s'" % self.type) error(self.pos, "Assignment to non-lvalue of type '%s'" % self.type)
return node
def analyse_types(self, env, target = 0): def analyse_types(self, env, target = 0):
self.initialized_check = env.directives['initializedcheck'] self.initialized_check = env.directives['initializedcheck']
...@@ -4585,6 +4602,7 @@ class AttributeNode(ExprNode): ...@@ -4585,6 +4602,7 @@ class AttributeNode(ExprNode):
# may be mutated in a namenode now :) # may be mutated in a namenode now :)
if self.is_attribute: if self.is_attribute:
self.wrap_obj_in_nonecheck(env) self.wrap_obj_in_nonecheck(env)
return self
def analyse_as_cimported_attribute(self, env, target): def analyse_as_cimported_attribute(self, env, target):
# Try to interpret this as a reference to an imported # Try to interpret this as a reference to an imported
...@@ -4597,7 +4615,7 @@ class AttributeNode(ExprNode): ...@@ -4597,7 +4615,7 @@ class AttributeNode(ExprNode):
if entry and ( if entry and (
entry.is_cglobal or entry.is_cfunction entry.is_cglobal or entry.is_cfunction
or entry.is_type or entry.is_const): or entry.is_type or entry.is_const):
self.mutate_into_name_node(env, entry, target) self.mutate_into_name_node(env, entry, target) # FIXME
entry.used = 1 entry.used = 1
return 1 return 1
return 0 return 0
...@@ -4619,7 +4637,7 @@ class AttributeNode(ExprNode): ...@@ -4619,7 +4637,7 @@ class AttributeNode(ExprNode):
ubcm_entry.is_cfunction = 1 ubcm_entry.is_cfunction = 1
ubcm_entry.func_cname = entry.func_cname ubcm_entry.func_cname = entry.func_cname
ubcm_entry.is_unbound_cmethod = 1 ubcm_entry.is_unbound_cmethod = 1
self.mutate_into_name_node(env, ubcm_entry, None) self.mutate_into_name_node(env, ubcm_entry, None) # FIXME
return 1 return 1
return 0 return 0
...@@ -4662,12 +4680,12 @@ class AttributeNode(ExprNode): ...@@ -4662,12 +4680,12 @@ class AttributeNode(ExprNode):
del self.obj del self.obj
del self.attribute del self.attribute
if target: if target:
NameNode.analyse_target_types(self, env) NameNode.analyse_target_types(self, env) # FIXME
else: else:
NameNode.analyse_rvalue_entry(self, env) NameNode.analyse_rvalue_entry(self, env)
def analyse_as_ordinary_attribute(self, env, target): def analyse_as_ordinary_attribute(self, env, target):
self.obj.analyse_types(env) self.obj = self.obj.analyse_types(env)
self.analyse_attribute(env) self.analyse_attribute(env)
if self.entry and self.entry.is_cmethod and not self.is_called: if self.entry and self.entry.is_cmethod and not self.is_called:
# error(self.pos, "C method can only be called") # error(self.pos, "C method can only be called")
...@@ -4988,15 +5006,17 @@ class StarredTargetNode(ExprNode): ...@@ -4988,15 +5006,17 @@ class StarredTargetNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
error(self.pos, "can use starred expression only as assignment target") error(self.pos, "can use starred expression only as assignment target")
self.target.analyse_types(env) self.target = self.target.analyse_types(env)
self.type = self.target.type self.type = self.target.type
return self
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
def analyse_target_types(self, env): def analyse_target_types(self, env):
self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
self.type = self.target.type self.type = self.target.type
return self
def calculate_result_code(self): def calculate_result_code(self):
return "" return ""
...@@ -5045,14 +5065,15 @@ class SequenceNode(ExprNode): ...@@ -5045,14 +5065,15 @@ class SequenceNode(ExprNode):
def analyse_types(self, env, skip_children=False): def analyse_types(self, env, skip_children=False):
for i in range(len(self.args)): for i in range(len(self.args)):
arg = self.args[i] arg = self.args[i]
if not skip_children: arg.analyse_types(env) if not skip_children: arg = arg.analyse_types(env)
self.args[i] = arg.coerce_to_pyobject(env) self.args[i] = arg.coerce_to_pyobject(env)
if self.mult_factor: if self.mult_factor:
self.mult_factor.analyse_types(env) self.mult_factor = self.mult_factor.analyse_types(env)
if not self.mult_factor.type.is_int: if not self.mult_factor.type.is_int:
self.mult_factor = self.mult_factor.coerce_to_pyobject(env) self.mult_factor = self.mult_factor.coerce_to_pyobject(env)
self.is_temp = 1 self.is_temp = 1
# not setting self.type here, subtypes do this # not setting self.type here, subtypes do this
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -5063,8 +5084,8 @@ class SequenceNode(ExprNode): ...@@ -5063,8 +5084,8 @@ class SequenceNode(ExprNode):
self.unpacked_items = [] self.unpacked_items = []
self.coerced_unpacked_items = [] self.coerced_unpacked_items = []
self.any_coerced_items = False self.any_coerced_items = False
for arg in self.args: for i, arg in enumerate(self.args):
arg.analyse_target_types(env) arg = self.args[i] = arg.analyse_target_types(env)
if arg.is_starred: if arg.is_starred:
if not arg.type.assignable_from(Builtin.list_type): if not arg.type.assignable_from(Builtin.list_type):
error(arg.pos, error(arg.pos,
...@@ -5078,6 +5099,7 @@ class SequenceNode(ExprNode): ...@@ -5078,6 +5099,7 @@ class SequenceNode(ExprNode):
self.unpacked_items.append(unpacked_item) self.unpacked_items.append(unpacked_item)
self.coerced_unpacked_items.append(coerced_unpacked_item) self.coerced_unpacked_items.append(coerced_unpacked_item)
self.type = py_object_type self.type = py_object_type
return self
def generate_result_code(self, code): def generate_result_code(self, code):
self.generate_operation_code(code) self.generate_operation_code(code)
...@@ -5468,23 +5490,25 @@ class TupleNode(SequenceNode): ...@@ -5468,23 +5490,25 @@ class TupleNode(SequenceNode):
def analyse_types(self, env, skip_children=False): def analyse_types(self, env, skip_children=False):
if len(self.args) == 0: if len(self.args) == 0:
self.is_temp = False node = self
self.is_literal = True node.is_temp = False
node.is_literal = True
else: else:
SequenceNode.analyse_types(self, env, skip_children) node = SequenceNode.analyse_types(self, env, skip_children)
for child in self.args: for child in node.args:
if not child.is_literal: if not child.is_literal:
break break
else: else:
if not self.mult_factor or self.mult_factor.is_literal and \ if not node.mult_factor or node.mult_factor.is_literal and \
isinstance(self.mult_factor.constant_result, (int, long)): isinstance(node.mult_factor.constant_result, (int, long)):
self.is_temp = False node.is_temp = False
self.is_literal = True node.is_literal = True
else: else:
if not self.mult_factor.type.is_pyobject: if not node.mult_factor.type.is_pyobject:
self.mult_factor = self.mult_factor.coerce_to_pyobject(env) node.mult_factor = node.mult_factor.coerce_to_pyobject(env)
self.is_temp = True node.is_temp = True
self.is_partly_literal = True node.is_partly_literal = True
return node
def is_simple(self): def is_simple(self):
# either temp or constant => always simple # either temp or constant => always simple
...@@ -5558,15 +5582,16 @@ class ListNode(SequenceNode): ...@@ -5558,15 +5582,16 @@ class ListNode(SequenceNode):
return list_type return list_type
def analyse_expressions(self, env): def analyse_expressions(self, env):
SequenceNode.analyse_expressions(self, env) node = SequenceNode.analyse_expressions(self, env)
self.coerce_to_pyobject(env) return node.coerce_to_pyobject(env)
def analyse_types(self, env): def analyse_types(self, env):
hold_errors() hold_errors()
self.original_args = list(self.args) self.original_args = list(self.args)
SequenceNode.analyse_types(self, env) node = SequenceNode.analyse_types(self, env)
self.obj_conversion_errors = held_errors() node.obj_conversion_errors = held_errors()
release_errors(ignore=True) release_errors(ignore=True)
return node
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type.is_pyobject: if dst_type.is_pyobject:
...@@ -5680,11 +5705,11 @@ class ScopedExprNode(ExprNode): ...@@ -5680,11 +5705,11 @@ class ScopedExprNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
# no recursion here, the children will be analysed separately below # no recursion here, the children will be analysed separately below
pass return self
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
# this is called with the expr_scope as env # this is called with the expr_scope as env
pass return self
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
# set up local variables and free their references on exit # set up local variables and free their references on exit
...@@ -5749,14 +5774,16 @@ class ComprehensionNode(ScopedExprNode): ...@@ -5749,14 +5774,16 @@ class ComprehensionNode(ScopedExprNode):
self.loop.analyse_declarations(env) self.loop.analyse_declarations(env)
def analyse_types(self, env): def analyse_types(self, env):
self.target.analyse_expressions(env) self.target = self.target.analyse_expressions(env)
self.type = self.target.type self.type = self.target.type
if not self.has_local_scope: if not self.has_local_scope:
self.loop.analyse_expressions(env) self.loop = self.loop.analyse_expressions(env)
return self
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
if self.has_local_scope: if self.has_local_scope:
self.loop.analyse_expressions(env) self.loop = self.loop.analyse_expressions(env)
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -5783,9 +5810,10 @@ class ComprehensionAppendNode(Node): ...@@ -5783,9 +5810,10 @@ class ComprehensionAppendNode(Node):
type = PyrexTypes.c_int_type type = PyrexTypes.c_int_type
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.expr.analyse_expressions(env) self.expr = self.expr.analyse_expressions(env)
if not self.expr.type.is_pyobject: if not self.expr.type.is_pyobject:
self.expr = self.expr.coerce_to_pyobject(env) self.expr = self.expr.coerce_to_pyobject(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if self.target.type is list_type: if self.target.type is list_type:
...@@ -5816,12 +5844,13 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): ...@@ -5816,12 +5844,13 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
child_attrs = ['key_expr', 'value_expr'] child_attrs = ['key_expr', 'value_expr']
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.key_expr.analyse_expressions(env) self.key_expr = self.key_expr.analyse_expressions(env)
if not self.key_expr.type.is_pyobject: if not self.key_expr.type.is_pyobject:
self.key_expr = self.key_expr.coerce_to_pyobject(env) self.key_expr = self.key_expr.coerce_to_pyobject(env)
self.value_expr.analyse_expressions(env) self.value_expr = self.value_expr.analyse_expressions(env)
if not self.value_expr.type.is_pyobject: if not self.value_expr.type.is_pyobject:
self.value_expr = self.value_expr.coerce_to_pyobject(env) self.value_expr = self.value_expr.coerce_to_pyobject(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.key_expr.generate_evaluation_code(code) self.key_expr.generate_evaluation_code(code)
...@@ -5874,14 +5903,16 @@ class InlinedGeneratorExpressionNode(ScopedExprNode): ...@@ -5874,14 +5903,16 @@ class InlinedGeneratorExpressionNode(ScopedExprNode):
def analyse_types(self, env): def analyse_types(self, env):
if not self.has_local_scope: if not self.has_local_scope:
self.loop_analysed = True self.loop_analysed = True
self.loop.analyse_expressions(env) self.loop = self.loop.analyse_expressions(env)
self.type = self.result_node.type self.type = self.result_node.type
self.is_temp = True self.is_temp = True
return self
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
self.loop_analysed = True self.loop_analysed = True
if self.has_local_scope: if self.has_local_scope:
self.loop.analyse_expressions(env) self.loop = self.loop.analyse_expressions(env)
return self
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if self.orig_func == 'sum' and dst_type.is_numeric and not self.loop_analysed: if self.orig_func == 'sum' and dst_type.is_numeric and not self.loop_analysed:
...@@ -5911,10 +5942,11 @@ class SetNode(ExprNode): ...@@ -5911,10 +5942,11 @@ class SetNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
for i in range(len(self.args)): for i in range(len(self.args)):
arg = self.args[i] arg = self.args[i]
arg.analyse_types(env) arg = arg.analyse_types(env)
self.args[i] = arg.coerce_to_pyobject(env) self.args[i] = arg.coerce_to_pyobject(env)
self.type = set_type self.type = set_type
self.is_temp = 1 self.is_temp = 1
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -5989,10 +6021,11 @@ class DictNode(ExprNode): ...@@ -5989,10 +6021,11 @@ class DictNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
hold_errors() hold_errors()
for item in self.key_value_pairs: self.key_value_pairs = [ item.analyse_types(env)
item.analyse_types(env) for item in self.key_value_pairs ]
self.obj_conversion_errors = held_errors() self.obj_conversion_errors = held_errors()
release_errors(ignore=True) release_errors(ignore=True)
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -6086,10 +6119,11 @@ class DictItemNode(ExprNode): ...@@ -6086,10 +6119,11 @@ class DictItemNode(ExprNode):
self.key.constant_result, self.value.constant_result) self.key.constant_result, self.value.constant_result)
def analyse_types(self, env): def analyse_types(self, env):
self.key.analyse_types(env) self.key = self.key.analyse_types(env)
self.value.analyse_types(env) self.value = self.value.analyse_types(env)
self.key = self.key.coerce_to_pyobject(env) self.key = self.key.coerce_to_pyobject(env)
self.value = self.value.coerce_to_pyobject(env) self.value = self.value.coerce_to_pyobject(env)
return self
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
self.key.generate_evaluation_code(code) self.key.generate_evaluation_code(code)
...@@ -6141,15 +6175,16 @@ class ClassNode(ExprNode, ModuleNameMixin): ...@@ -6141,15 +6175,16 @@ class ClassNode(ExprNode, ModuleNameMixin):
subexprs = ['bases', 'doc'] subexprs = ['bases', 'doc']
def analyse_types(self, env): def analyse_types(self, env):
self.bases.analyse_types(env) self.bases = self.bases.analyse_types(env)
if self.doc: if self.doc:
self.doc.analyse_types(env) self.doc = self.doc.analyse_types(env)
self.doc = self.doc.coerce_to_pyobject(env) self.doc = self.doc.coerce_to_pyobject(env)
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
env.use_utility_code(UtilityCode.load_cached("CreateClass", "ObjectHandling.c")) env.use_utility_code(UtilityCode.load_cached("CreateClass", "ObjectHandling.c"))
#TODO(craig,haoyu) This should be moved to a better place #TODO(craig,haoyu) This should be moved to a better place
self.set_qualified_name(env, self.name) self.set_qualified_name(env, self.name)
return self
def may_be_none(self): def may_be_none(self):
return True return True
...@@ -6192,6 +6227,7 @@ class Py3ClassNode(ExprNode): ...@@ -6192,6 +6227,7 @@ class Py3ClassNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
return self
def may_be_none(self): def may_be_none(self):
return True return True
...@@ -6252,12 +6288,14 @@ class KeywordArgsNode(ExprNode): ...@@ -6252,12 +6288,14 @@ class KeywordArgsNode(ExprNode):
return dict_type return dict_type
def analyse_types(self, env): def analyse_types(self, env):
self.starstar_arg.analyse_types(env) arg = self.starstar_arg.analyse_types(env)
self.starstar_arg = self.starstar_arg.coerce_to_pyobject(env).as_none_safe_node( arg = arg.coerce_to_pyobject(env)
self.starstar_arg = arg.as_none_safe_node(
# FIXME: CPython's error message starts with the runtime function name # FIXME: CPython's error message starts with the runtime function name
'argument after ** must be a mapping, not NoneType') 'argument after ** must be a mapping, not NoneType')
for item in self.keyword_args: self.keyword_args = [ item.analyse_types(env)
item.analyse_types(env) for item in self.keyword_args ]
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -6336,6 +6374,7 @@ class PyClassMetaclassNode(ExprNode): ...@@ -6336,6 +6374,7 @@ class PyClassMetaclassNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.type = py_object_type self.type = py_object_type
self.is_temp = True self.is_temp = True
return self
def may_be_none(self): def may_be_none(self):
return True return True
...@@ -6362,14 +6401,15 @@ class PyClassNamespaceNode(ExprNode, ModuleNameMixin): ...@@ -6362,14 +6401,15 @@ class PyClassNamespaceNode(ExprNode, ModuleNameMixin):
subexprs = ['doc'] subexprs = ['doc']
def analyse_types(self, env): def analyse_types(self, env):
self.bases.analyse_types(env) self.bases = self.bases.analyse_types(env)
if self.doc: if self.doc:
self.doc.analyse_types(env) self.doc = self.doc.analyse_types(env)
self.doc = self.doc.coerce_to_pyobject(env) self.doc = self.doc.coerce_to_pyobject(env)
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
#TODO(craig,haoyu) This should be moved to a better place #TODO(craig,haoyu) This should be moved to a better place
self.set_qualified_name(env, self.name) self.set_qualified_name(env, self.name)
return self
def may_be_none(self): def may_be_none(self):
return True return True
...@@ -6407,6 +6447,7 @@ class ClassCellInjectorNode(ExprNode): ...@@ -6407,6 +6447,7 @@ class ClassCellInjectorNode(ExprNode):
if self.is_active: if self.is_active:
env.use_utility_code( env.use_utility_code(
UtilityCode.load_cached("CyFunctionClassCell", "CythonFunction.c")) UtilityCode.load_cached("CyFunctionClassCell", "CythonFunction.c"))
return self
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
if self.is_active: if self.is_active:
...@@ -6431,7 +6472,7 @@ class ClassCellNode(ExprNode): ...@@ -6431,7 +6472,7 @@ class ClassCellNode(ExprNode):
type = py_object_type type = py_object_type
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def generate_result_code(self, code): def generate_result_code(self, code):
if not self.is_generator: if not self.is_generator:
...@@ -6460,9 +6501,10 @@ class BoundMethodNode(ExprNode): ...@@ -6460,9 +6501,10 @@ class BoundMethodNode(ExprNode):
subexprs = ['function'] subexprs = ['function']
def analyse_types(self, env): def analyse_types(self, env):
self.function.analyse_types(env) self.function = self.function.analyse_types(env)
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
return self
gil_message = "Constructing an bound method" gil_message = "Constructing an bound method"
...@@ -6489,7 +6531,8 @@ class UnboundMethodNode(ExprNode): ...@@ -6489,7 +6531,8 @@ class UnboundMethodNode(ExprNode):
subexprs = ['function'] subexprs = ['function']
def analyse_types(self, env): def analyse_types(self, env):
self.function.analyse_types(env) self.function = self.function.analyse_types(env)
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -6557,6 +6600,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6557,6 +6600,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
#TODO(craig,haoyu) This should be moved to a better place #TODO(craig,haoyu) This should be moved to a better place
self.set_qualified_name(env, self.def_node.name) self.set_qualified_name(env, self.def_node.name)
return self
def analyse_default_args(self, env): def analyse_default_args(self, env):
""" """
...@@ -6602,9 +6646,9 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6602,9 +6646,9 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
if default_args: if default_args:
if self.defaults_struct is None: if self.defaults_struct is None:
self.defaults_tuple = TupleNode(self.pos, args=[ defaults_tuple = TupleNode(self.pos, args=[
arg.default for arg in default_args]) arg.default for arg in default_args])
self.defaults_tuple.analyse_types(env) self.defaults_tuple = defaults_tuple.analyse_types(env)
else: else:
defaults_getter = Nodes.DefNode( defaults_getter = Nodes.DefNode(
self.pos, args=[], star_arg=None, starstar_arg=None, self.pos, args=[], star_arg=None, starstar_arg=None,
...@@ -6615,8 +6659,8 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6615,8 +6659,8 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
self.defaults_struct)), self.defaults_struct)),
decorators=None, name=StringEncoding.EncodedString("__defaults__")) decorators=None, name=StringEncoding.EncodedString("__defaults__"))
defaults_getter.analyse_declarations(env) defaults_getter.analyse_declarations(env)
defaults_getter.analyse_expressions(env) defaults_getter = defaults_getter.analyse_expressions(env)
defaults_getter.body.analyse_expressions( defaults_getter.body = defaults_getter.body.analyse_expressions(
defaults_getter.local_scope) defaults_getter.local_scope)
defaults_getter.py_wrapper_required = False defaults_getter.py_wrapper_required = False
defaults_getter.pymethdef_required = False defaults_getter.pymethdef_required = False
...@@ -6813,7 +6857,7 @@ class DefaultLiteralArgNode(ExprNode): ...@@ -6813,7 +6857,7 @@ class DefaultLiteralArgNode(ExprNode):
self.evaluated = False self.evaluated = False
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
...@@ -6840,6 +6884,7 @@ class DefaultNonLiteralArgNode(ExprNode): ...@@ -6840,6 +6884,7 @@ class DefaultNonLiteralArgNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.type = self.arg.type self.type = self.arg.type
self.is_temp = False self.is_temp = False
return self
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
...@@ -6887,8 +6932,8 @@ class LambdaNode(InnerFunctionNode): ...@@ -6887,8 +6932,8 @@ class LambdaNode(InnerFunctionNode):
env.add_lambda_def(self.def_node) env.add_lambda_def(self.def_node)
def analyse_types(self, env): def analyse_types(self, env):
self.def_node.analyse_expressions(env) self.def_node = self.def_node.analyse_expressions(env)
super(LambdaNode, self).analyse_types(env) return super(LambdaNode, self).analyse_types(env)
def generate_result_code(self, code): def generate_result_code(self, code):
self.def_node.generate_execution_code(code) self.def_node.generate_execution_code(code)
...@@ -6943,9 +6988,10 @@ class YieldExprNode(ExprNode): ...@@ -6943,9 +6988,10 @@ class YieldExprNode(ExprNode):
error(self.pos, "'yield' not supported here") error(self.pos, "'yield' not supported here")
self.is_temp = 1 self.is_temp = 1
if self.arg is not None: if self.arg is not None:
self.arg.analyse_types(env) self.arg = self.arg.analyse_types(env)
if not self.arg.type.is_pyobject: if not self.arg.type.is_pyobject:
self.coerce_yield_argument(env) self.coerce_yield_argument(env)
return self
def coerce_yield_argument(self, env): def coerce_yield_argument(self, env):
self.arg = self.arg.coerce_to_pyobject(env) self.arg = self.arg.coerce_to_pyobject(env)
...@@ -7051,6 +7097,7 @@ class GlobalsExprNode(AtomicExprNode): ...@@ -7051,6 +7097,7 @@ class GlobalsExprNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
env.use_utility_code(Builtin.globals_utility_code) env.use_utility_code(Builtin.globals_utility_code)
return self
gil_message = "Constructing globals dict" gil_message = "Constructing globals dict"
...@@ -7063,13 +7110,14 @@ class GlobalsExprNode(AtomicExprNode): ...@@ -7063,13 +7110,14 @@ class GlobalsExprNode(AtomicExprNode):
class LocalsDictItemNode(DictItemNode): class LocalsDictItemNode(DictItemNode):
def analyse_types(self, env): def analyse_types(self, env):
self.key.analyse_types(env) self.key = self.key.analyse_types(env)
self.value.analyse_types(env) self.value = self.value.analyse_types(env)
self.key = self.key.coerce_to_pyobject(env) self.key = self.key.coerce_to_pyobject(env)
if self.value.type.can_coerce_to_pyobject(env): if self.value.type.can_coerce_to_pyobject(env):
self.value = self.value.coerce_to_pyobject(env) self.value = self.value.coerce_to_pyobject(env)
else: else:
self.value = None self.value = None
return self
class FuncLocalsExprNode(DictNode): class FuncLocalsExprNode(DictNode):
...@@ -7084,9 +7132,10 @@ class FuncLocalsExprNode(DictNode): ...@@ -7084,9 +7132,10 @@ class FuncLocalsExprNode(DictNode):
exclude_null_values=True) exclude_null_values=True)
def analyse_types(self, env): def analyse_types(self, env):
super(FuncLocalsExprNode, self).analyse_types(env) node = super(FuncLocalsExprNode, self).analyse_types(env)
self.key_value_pairs = [i for i in self.key_value_pairs node.key_value_pairs = [ i for i in node.key_value_pairs
if i.value is not None] if i.value is not None ]
return node
class PyClassLocalsExprNode(AtomicExprNode): class PyClassLocalsExprNode(AtomicExprNode):
...@@ -7097,6 +7146,7 @@ class PyClassLocalsExprNode(AtomicExprNode): ...@@ -7097,6 +7146,7 @@ class PyClassLocalsExprNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.type = self.pyclass_dict.type self.type = self.pyclass_dict.type
self.is_temp = False self.is_temp = False
return self
def result(self): def result(self):
return self.pyclass_dict.result() return self.pyclass_dict.result()
...@@ -7172,7 +7222,7 @@ class UnopNode(ExprNode): ...@@ -7172,7 +7222,7 @@ class UnopNode(ExprNode):
return operand_type return operand_type
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
if self.is_py_operation(): if self.is_py_operation():
self.coerce_operand_to_pyobject(env) self.coerce_operand_to_pyobject(env)
self.type = py_object_type self.type = py_object_type
...@@ -7181,6 +7231,7 @@ class UnopNode(ExprNode): ...@@ -7181,6 +7231,7 @@ class UnopNode(ExprNode):
self.analyse_cpp_operation(env) self.analyse_cpp_operation(env)
else: else:
self.analyse_c_operation(env) self.analyse_c_operation(env)
return self
def check_const(self): def check_const(self):
return self.operand.check_const() return self.operand.check_const()
...@@ -7251,7 +7302,7 @@ class NotNode(UnopNode): ...@@ -7251,7 +7302,7 @@ class NotNode(UnopNode):
return PyrexTypes.c_bint_type return PyrexTypes.c_bint_type
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
operand_type = self.operand.type operand_type = self.operand.type
if operand_type.is_cpp_class: if operand_type.is_cpp_class:
cpp_type = operand_type.find_cpp_operation_type(self.operator) cpp_type = operand_type.find_cpp_operation_type(self.operator)
...@@ -7262,6 +7313,7 @@ class NotNode(UnopNode): ...@@ -7262,6 +7313,7 @@ class NotNode(UnopNode):
self.type = cpp_type self.type = cpp_type
else: else:
self.operand = self.operand.coerce_to_boolean(env) self.operand = self.operand.coerce_to_boolean(env)
return self
def calculate_result_code(self): def calculate_result_code(self):
return "(!%s)" % self.operand.result() return "(!%s)" % self.operand.result()
...@@ -7396,23 +7448,24 @@ class AmpersandNode(CUnopNode): ...@@ -7396,23 +7448,24 @@ class AmpersandNode(CUnopNode):
return PyrexTypes.c_ptr_type(operand_type) return PyrexTypes.c_ptr_type(operand_type)
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
argtype = self.operand.type argtype = self.operand.type
if argtype.is_cpp_class: if argtype.is_cpp_class:
cpp_type = argtype.find_cpp_operation_type(self.operator) cpp_type = argtype.find_cpp_operation_type(self.operator)
if cpp_type is not None: if cpp_type is not None:
self.type = cpp_type self.type = cpp_type
return return self
if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()): if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()):
if argtype.is_memoryviewslice: if argtype.is_memoryviewslice:
self.error("Cannot take address of memoryview slice") self.error("Cannot take address of memoryview slice")
else: else:
self.error("Taking address of non-lvalue") self.error("Taking address of non-lvalue")
return return self
if argtype.is_pyobject: if argtype.is_pyobject:
self.error("Cannot take address of Python variable") self.error("Cannot take address of Python variable")
return return self
self.type = PyrexTypes.c_ptr_type(argtype) self.type = PyrexTypes.c_ptr_type(argtype)
return self
def check_const(self): def check_const(self):
return self.operand.check_const_addr() return self.operand.check_const_addr()
...@@ -7481,7 +7534,7 @@ class TypecastNode(ExprNode): ...@@ -7481,7 +7534,7 @@ class TypecastNode(ExprNode):
error(self.pos, error(self.pos,
"Cannot cast to a function type") "Cannot cast to a function type")
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
to_py = self.type.is_pyobject to_py = self.type.is_pyobject
from_py = self.operand.type.is_pyobject from_py = self.operand.type.is_pyobject
if from_py and not to_py and self.operand.is_ephemeral(): if from_py and not to_py and self.operand.is_ephemeral():
...@@ -7520,6 +7573,7 @@ class TypecastNode(ExprNode): ...@@ -7520,6 +7573,7 @@ class TypecastNode(ExprNode):
elif self.operand.type.is_fused: elif self.operand.type.is_fused:
self.operand = self.operand.coerce_to(self.type, env) self.operand = self.operand.coerce_to(self.type, env)
#self.type = self.operand.type #self.type = self.operand.type
return self
def is_simple(self): def is_simple(self):
# either temp or a C cast => no side effects other than the operand's # either temp or a C cast => no side effects other than the operand's
...@@ -7617,7 +7671,7 @@ class CythonArrayNode(ExprNode): ...@@ -7617,7 +7671,7 @@ class CythonArrayNode(ExprNode):
def analyse_types(self, env): def analyse_types(self, env):
import MemoryView import MemoryView
self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
if self.array_dtype: if self.array_dtype:
array_dtype = self.array_dtype array_dtype = self.array_dtype
else: else:
...@@ -7635,7 +7689,7 @@ class CythonArrayNode(ExprNode): ...@@ -7635,7 +7689,7 @@ class CythonArrayNode(ExprNode):
if not self.operand.type.is_ptr and not self.operand.type.is_array: if not self.operand.type.is_ptr and not self.operand.type.is_array:
error(self.operand.pos, ERR_NOT_POINTER) error(self.operand.pos, ERR_NOT_POINTER)
return return self
# Dimension sizes of C array # Dimension sizes of C array
array_dimension_sizes = [] array_dimension_sizes = []
...@@ -7647,16 +7701,16 @@ class CythonArrayNode(ExprNode): ...@@ -7647,16 +7701,16 @@ class CythonArrayNode(ExprNode):
base_type = base_type.base_type base_type = base_type.base_type
else: else:
error(self.pos, "unexpected base type %s found" % base_type) error(self.pos, "unexpected base type %s found" % base_type)
return return self
if not (base_type.same_as(array_dtype) or base_type.is_void): if not (base_type.same_as(array_dtype) or base_type.is_void):
error(self.operand.pos, ERR_BASE_TYPE) error(self.operand.pos, ERR_BASE_TYPE)
return return self
elif self.operand.type.is_array and len(array_dimension_sizes) != ndim: elif self.operand.type.is_array and len(array_dimension_sizes) != ndim:
error(self.operand.pos, error(self.operand.pos,
"Expected %d dimensions, array has %d dimensions" % "Expected %d dimensions, array has %d dimensions" %
(ndim, len(array_dimension_sizes))) (ndim, len(array_dimension_sizes)))
return return self
# Verify the start, stop and step values # Verify the start, stop and step values
# In case of a C array, use the size of C array in each dimension to # In case of a C array, use the size of C array in each dimension to
...@@ -7664,7 +7718,7 @@ class CythonArrayNode(ExprNode): ...@@ -7664,7 +7718,7 @@ class CythonArrayNode(ExprNode):
for axis_no, axis in enumerate(axes): for axis_no, axis in enumerate(axes):
if not axis.start.is_none: if not axis.start.is_none:
error(axis.start.pos, ERR_START) error(axis.start.pos, ERR_START)
return return self
if axis.stop.is_none: if axis.stop.is_none:
if array_dimension_sizes: if array_dimension_sizes:
...@@ -7674,9 +7728,9 @@ class CythonArrayNode(ExprNode): ...@@ -7674,9 +7728,9 @@ class CythonArrayNode(ExprNode):
type=PyrexTypes.c_int_type) type=PyrexTypes.c_int_type)
else: else:
error(axis.pos, ERR_NOT_STOP) error(axis.pos, ERR_NOT_STOP)
return return self
axis.stop.analyse_types(env) axis.stop = axis.stop.analyse_types(env)
shape = axis.stop.coerce_to(self.shape_type, env) shape = axis.stop.coerce_to(self.shape_type, env)
if not shape.is_literal: if not shape.is_literal:
shape.coerce_to_temp(env) shape.coerce_to_temp(env)
...@@ -7686,15 +7740,15 @@ class CythonArrayNode(ExprNode): ...@@ -7686,15 +7740,15 @@ class CythonArrayNode(ExprNode):
first_or_last = axis_no in (0, ndim - 1) first_or_last = axis_no in (0, ndim - 1)
if not axis.step.is_none and first_or_last: if not axis.step.is_none and first_or_last:
# '1' in the first or last dimension denotes F or C contiguity # '1' in the first or last dimension denotes F or C contiguity
axis.step.analyse_types(env) axis.step = axis.step.analyse_types(env)
if (not axis.step.type.is_int and axis.step.is_literal and not if (not axis.step.type.is_int and axis.step.is_literal and not
axis.step.type.is_error): axis.step.type.is_error):
error(axis.step.pos, "Expected an integer literal") error(axis.step.pos, "Expected an integer literal")
return return self
if axis.step.compile_time_value(env) != 1: if axis.step.compile_time_value(env) != 1:
error(axis.step.pos, ERR_STEPS) error(axis.step.pos, ERR_STEPS)
return return self
if axis_no == 0: if axis_no == 0:
self.mode = "fortran" self.mode = "fortran"
...@@ -7702,7 +7756,7 @@ class CythonArrayNode(ExprNode): ...@@ -7702,7 +7756,7 @@ class CythonArrayNode(ExprNode):
elif not axis.step.is_none and not first_or_last: elif not axis.step.is_none and not first_or_last:
# step provided in some other dimension # step provided in some other dimension
error(axis.step.pos, ERR_STEPS) error(axis.step.pos, ERR_STEPS)
return return self
if not self.operand.is_name: if not self.operand.is_name:
self.operand = self.operand.coerce_to_temp(env) self.operand = self.operand.coerce_to_temp(env)
...@@ -7717,6 +7771,7 @@ class CythonArrayNode(ExprNode): ...@@ -7717,6 +7771,7 @@ class CythonArrayNode(ExprNode):
self.type = self.get_cython_array_type(env) self.type = self.get_cython_array_type(env)
MemoryView.use_cython_array_utility_code(env) MemoryView.use_cython_array_utility_code(env)
env.use_utility_code(MemoryView.typeinfo_to_format_code) env.use_utility_code(MemoryView.typeinfo_to_format_code)
return self
def allocate_temp_result(self, code): def allocate_temp_result(self, code):
if self.temp_code: if self.temp_code:
...@@ -7800,7 +7855,7 @@ class CythonArrayNode(ExprNode): ...@@ -7800,7 +7855,7 @@ class CythonArrayNode(ExprNode):
base_type_node=base_type) base_type_node=base_type)
result = CythonArrayNode(pos, base_type_node=memslicenode, result = CythonArrayNode(pos, base_type_node=memslicenode,
operand=src_node, array_dtype=base_type) operand=src_node, array_dtype=base_type)
result.analyse_types(env) result = result.analyse_types(env)
return result return result
class SizeofNode(ExprNode): class SizeofNode(ExprNode):
...@@ -7837,13 +7892,14 @@ class SizeofTypeNode(SizeofNode): ...@@ -7837,13 +7892,14 @@ class SizeofTypeNode(SizeofNode):
operand = AttributeNode(pos=self.pos, obj=operand, attribute=self.base_type.name) operand = AttributeNode(pos=self.pos, obj=operand, attribute=self.base_type.name)
self.operand = operand self.operand = operand
self.__class__ = SizeofVarNode self.__class__ = SizeofVarNode
self.analyse_types(env) node = self.analyse_types(env)
return return node
if self.arg_type is None: if self.arg_type is None:
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
_, arg_type = self.declarator.analyse(base_type, env) _, arg_type = self.declarator.analyse(base_type, env)
self.arg_type = arg_type self.arg_type = arg_type
self.check_type() self.check_type()
return self
def check_type(self): def check_type(self):
arg_type = self.arg_type arg_type = self.arg_type
...@@ -7882,7 +7938,8 @@ class SizeofVarNode(SizeofNode): ...@@ -7882,7 +7938,8 @@ class SizeofVarNode(SizeofNode):
self.__class__ = SizeofTypeNode self.__class__ = SizeofTypeNode
self.check_type() self.check_type()
else: else:
self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
return self
def calculate_result_code(self): def calculate_result_code(self):
return "(sizeof(%s))" % self.operand.result() return "(sizeof(%s))" % self.operand.result()
...@@ -7902,11 +7959,12 @@ class TypeofNode(ExprNode): ...@@ -7902,11 +7959,12 @@ class TypeofNode(ExprNode):
subexprs = ['literal'] # 'operand' will be ignored after type analysis! subexprs = ['literal'] # 'operand' will be ignored after type analysis!
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand = self.operand.analyse_types(env)
value = StringEncoding.EncodedString(str(self.operand.type)) #self.operand.type.typeof_name()) value = StringEncoding.EncodedString(str(self.operand.type)) #self.operand.type.typeof_name())
self.literal = StringNode(self.pos, value=value) literal = StringNode(self.pos, value=value)
self.literal.analyse_types(env) literal = literal.analyse_types(env)
self.literal = self.literal.coerce_to_pyobject(env) self.literal = literal.coerce_to_pyobject(env)
return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -7995,9 +8053,10 @@ class BinopNode(ExprNode): ...@@ -7995,9 +8053,10 @@ class BinopNode(ExprNode):
self.operand2.infer_type(env)) self.operand2.infer_type(env))
def analyse_types(self, env): def analyse_types(self, env):
self.operand1.analyse_types(env) self.operand1 = self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2 = self.operand2.analyse_types(env)
self.analyse_operation(env) self.analyse_operation(env)
return self
def analyse_operation(self, env): def analyse_operation(self, env):
if self.is_py_operation(): if self.is_py_operation():
...@@ -8109,9 +8168,10 @@ class BinopNode(ExprNode): ...@@ -8109,9 +8168,10 @@ class BinopNode(ExprNode):
class CBinopNode(BinopNode): class CBinopNode(BinopNode):
def analyse_types(self, env): def analyse_types(self, env):
BinopNode.analyse_types(self, env) node = BinopNode.analyse_types(self, env)
if self.is_py_operation(): if node.is_py_operation():
self.type = PyrexTypes.error_type node.type = PyrexTypes.error_type
return node
def py_operation_function(self): def py_operation_function(self):
return "" return ""
...@@ -8604,8 +8664,8 @@ class BoolBinopNode(ExprNode): ...@@ -8604,8 +8664,8 @@ class BoolBinopNode(ExprNode):
is_temp = self.is_temp) is_temp = self.is_temp)
def analyse_types(self, env): def analyse_types(self, env):
self.operand1.analyse_types(env) self.operand1 = self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2 = self.operand2.analyse_types(env)
self.type = PyrexTypes.independent_spanning_type(self.operand1.type, self.operand2.type) self.type = PyrexTypes.independent_spanning_type(self.operand1.type, self.operand2.type)
self.operand1 = self.operand1.coerce_to(self.type, env) self.operand1 = self.operand1.coerce_to(self.type, env)
self.operand2 = self.operand2.coerce_to(self.type, env) self.operand2 = self.operand2.coerce_to(self.type, env)
...@@ -8615,6 +8675,7 @@ class BoolBinopNode(ExprNode): ...@@ -8615,6 +8675,7 @@ class BoolBinopNode(ExprNode):
self.operand1 = self.operand1.coerce_to_simple(env) self.operand1 = self.operand1.coerce_to_simple(env)
self.operand2 = self.operand2.coerce_to_simple(env) self.operand2 = self.operand2.coerce_to_simple(env)
self.is_temp = 1 self.is_temp = 1
return self
gil_message = "Truth-testing Python object" gil_message = "Truth-testing Python object"
...@@ -8690,10 +8751,9 @@ class CondExprNode(ExprNode): ...@@ -8690,10 +8751,9 @@ class CondExprNode(ExprNode):
self.constant_result = self.false_val.constant_result self.constant_result = self.false_val.constant_result
def analyse_types(self, env): def analyse_types(self, env):
self.test.analyse_types(env) self.test = self.test.analyse_types(env).coerce_to_boolean(env)
self.test = self.test.coerce_to_boolean(env) self.true_val = self.true_val.analyse_types(env)
self.true_val.analyse_types(env) self.false_val = self.false_val.analyse_types(env)
self.false_val.analyse_types(env)
self.type = PyrexTypes.independent_spanning_type(self.true_val.type, self.false_val.type) self.type = PyrexTypes.independent_spanning_type(self.true_val.type, self.false_val.type)
if self.true_val.type.is_pyobject or self.false_val.type.is_pyobject: if self.true_val.type.is_pyobject or self.false_val.type.is_pyobject:
self.true_val = self.true_val.coerce_to(self.type, env) self.true_val = self.true_val.coerce_to(self.type, env)
...@@ -8701,6 +8761,7 @@ class CondExprNode(ExprNode): ...@@ -8701,6 +8761,7 @@ class CondExprNode(ExprNode):
self.is_temp = 1 self.is_temp = 1
if self.type == PyrexTypes.error_type: if self.type == PyrexTypes.error_type:
self.type_error() self.type_error()
return self
def type_error(self): def type_error(self):
if not (self.true_val.type.is_error or self.false_val.type.is_error): if not (self.true_val.type.is_error or self.false_val.type.is_error):
...@@ -9083,19 +9144,19 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -9083,19 +9144,19 @@ class PrimaryCmpNode(ExprNode, CmpNode):
return self.cascaded_compile_time_value(operand1, denv) return self.cascaded_compile_time_value(operand1, denv)
def analyse_types(self, env): def analyse_types(self, env):
self.operand1.analyse_types(env) self.operand1 = self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2 = self.operand2.analyse_types(env)
if self.is_cpp_comparison(): if self.is_cpp_comparison():
self.analyse_cpp_comparison(env) self.analyse_cpp_comparison(env)
if self.cascade: if self.cascade:
error(self.pos, "Cascading comparison not yet supported for cpp types.") error(self.pos, "Cascading comparison not yet supported for cpp types.")
return return self
if self.analyse_memoryviewslice_comparison(env): if self.analyse_memoryviewslice_comparison(env):
return return self
if self.cascade: if self.cascade:
self.cascade.analyse_types(env) self.cascade = self.cascade.analyse_types(env)
if self.operator in ('in', 'not_in'): if self.operator in ('in', 'not_in'):
if self.is_c_string_contains(): if self.is_c_string_contains():
...@@ -9103,7 +9164,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -9103,7 +9164,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
common_type = None common_type = None
if self.cascade: if self.cascade:
error(self.pos, "Cascading comparison not yet supported for 'int_val in string'.") error(self.pos, "Cascading comparison not yet supported for 'int_val in string'.")
return return self
if self.operand2.type is unicode_type: if self.operand2.type is unicode_type:
env.use_utility_code(UtilityCode.load_cached("PyUCS4InUnicode", "StringTools.c")) env.use_utility_code(UtilityCode.load_cached("PyUCS4InUnicode", "StringTools.c"))
else: else:
...@@ -9119,7 +9180,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -9119,7 +9180,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
error(self.pos, "Cascading comparison not supported for 'val in sliced pointer'.") error(self.pos, "Cascading comparison not supported for 'val in sliced pointer'.")
self.type = PyrexTypes.c_bint_type self.type = PyrexTypes.c_bint_type
# Will be transformed by IterationTransform # Will be transformed by IterationTransform
return return self
elif self.find_special_bool_compare_function(env, self.operand1): elif self.find_special_bool_compare_function(env, self.operand1):
if not self.operand1.type.is_pyobject: if not self.operand1.type.is_pyobject:
self.operand1 = self.operand1.coerce_to_pyobject(env) self.operand1 = self.operand1.coerce_to_pyobject(env)
...@@ -9159,6 +9220,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -9159,6 +9220,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
if self.is_pycmp or self.cascade or self.special_bool_cmp_function: if self.is_pycmp or self.cascade or self.special_bool_cmp_function:
# 1) owned reference, 2) reused value, 3) potential function error return value # 1) owned reference, 2) reused value, 3) potential function error return value
self.is_temp = 1 self.is_temp = 1
return self
def analyse_cpp_comparison(self, env): def analyse_cpp_comparison(self, env):
type1 = self.operand1.type type1 = self.operand1.type
...@@ -9306,9 +9368,10 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -9306,9 +9368,10 @@ class CascadedCmpNode(Node, CmpNode):
self.constant_result is not not_a_constant self.constant_result is not not_a_constant
def analyse_types(self, env): def analyse_types(self, env):
self.operand2.analyse_types(env) self.operand2 = self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
self.cascade.analyse_types(env) self.cascade = self.cascade.analyse_types(env)
return self
def has_python_operands(self): def has_python_operands(self):
return self.operand2.type.is_pyobject return self.operand2.type.is_pyobject
...@@ -9487,7 +9550,7 @@ class PyTypeTestNode(CoercionNode): ...@@ -9487,7 +9550,7 @@ class PyTypeTestNode(CoercionNode):
gil_message = "Python type test" gil_message = "Python type test"
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def may_be_none(self): def may_be_none(self):
if self.notnone: if self.notnone:
...@@ -9551,7 +9614,7 @@ class NoneCheckNode(CoercionNode): ...@@ -9551,7 +9614,7 @@ class NoneCheckNode(CoercionNode):
nogil_check = None # this node only guards an operation that would fail already nogil_check = None # this node only guards an operation that would fail already
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -9669,7 +9732,7 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -9669,7 +9732,7 @@ class CoerceToPyTypeNode(CoercionNode):
def analyse_types(self, env): def analyse_types(self, env):
# The arg is always already analysed # The arg is always already analysed
pass return self
def generate_result_code(self, code): def generate_result_code(self, code):
if self.arg.type.is_memoryviewslice: if self.arg.type.is_memoryviewslice:
...@@ -9748,7 +9811,7 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -9748,7 +9811,7 @@ class CoerceFromPyTypeNode(CoercionNode):
def analyse_types(self, env): def analyse_types(self, env):
# The arg is always already analysed # The arg is always already analysed
pass return self
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.type.from_py_function function = self.type.from_py_function
...@@ -9858,7 +9921,7 @@ class CoerceToTempNode(CoercionNode): ...@@ -9858,7 +9921,7 @@ class CoerceToTempNode(CoercionNode):
def analyse_types(self, env): def analyse_types(self, env):
# The arg is always already analysed # The arg is always already analysed
pass return self
def coerce_to_boolean(self, env): def coerce_to_boolean(self, env):
self.arg = self.arg.coerce_to_boolean(env) self.arg = self.arg.coerce_to_boolean(env)
...@@ -9896,8 +9959,9 @@ class ProxyNode(CoercionNode): ...@@ -9896,8 +9959,9 @@ class ProxyNode(CoercionNode):
self._proxy_type() self._proxy_type()
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.arg.analyse_expressions(env) self.arg = self.arg.analyse_expressions(env)
self._proxy_type() self._proxy_type()
return self
def _proxy_type(self): def _proxy_type(self):
if hasattr(self.arg, 'type'): if hasattr(self.arg, 'type'):
...@@ -9967,6 +10031,7 @@ class CloneNode(CoercionNode): ...@@ -9967,6 +10031,7 @@ class CloneNode(CoercionNode):
self.is_temp = 1 self.is_temp = 1
if hasattr(self.arg, 'entry'): if hasattr(self.arg, 'entry'):
self.entry = self.arg.entry self.entry = self.arg.entry
return self
def is_simple(self): def is_simple(self):
return True # result is always in a temp (or a name) return True # result is always in a temp (or a name)
...@@ -10004,7 +10069,7 @@ class ModuleRefNode(ExprNode): ...@@ -10004,7 +10069,7 @@ class ModuleRefNode(ExprNode):
subexprs = [] subexprs = []
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -10028,7 +10093,7 @@ class DocstringRefNode(ExprNode): ...@@ -10028,7 +10093,7 @@ class DocstringRefNode(ExprNode):
self.body = body self.body = body
def analyse_types(self, env): def analyse_types(self, env):
pass return self
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln('%s = __Pyx_GetAttr(%s, %s); %s' % ( code.putln('%s = __Pyx_GetAttr(%s, %s); %s' % (
......
...@@ -673,22 +673,22 @@ class FusedCFuncDefNode(StatListNode): ...@@ -673,22 +673,22 @@ class FusedCFuncDefNode(StatListNode):
specialization_type.create_declaration_utility_code(env) specialization_type.create_declaration_utility_code(env)
if self.py_func: if self.py_func:
self.__signatures__.analyse_expressions(env) self.__signatures__ = self.__signatures__.analyse_expressions(env)
self.py_func.analyse_expressions(env) self.py_func = self.py_func.analyse_expressions(env)
self.resulting_fused_function.analyse_expressions(env) self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env)
self.fused_func_assignment.analyse_expressions(env) self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env)
self.defaults = defaults = [] self.defaults = defaults = []
for arg in self.node.args: for arg in self.node.args:
if arg.default: if arg.default:
arg.default.analyse_expressions(env) arg.default = arg.default.analyse_expressions(env)
defaults.append(ProxyNode(arg.default)) defaults.append(ProxyNode(arg.default))
else: else:
defaults.append(None) defaults.append(None)
for stat in self.stats: for i, stat in enumerate(self.stats):
stat.analyse_expressions(env) stat = self.stats[i] = stat.analyse_expressions(env)
if isinstance(stat, FuncDefNode): if isinstance(stat, FuncDefNode):
for arg, default in zip(stat.args, defaults): for arg, default in zip(stat.args, defaults):
if default is not None: if default is not None:
...@@ -697,7 +697,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -697,7 +697,7 @@ class FusedCFuncDefNode(StatListNode):
if self.py_func: if self.py_func:
args = [CloneNode(default) for default in defaults if default] args = [CloneNode(default) for default in defaults if default]
self.defaults_tuple = TupleNode(self.pos, args=args) self.defaults_tuple = TupleNode(self.pos, args=args)
self.defaults_tuple.analyse_types(env, skip_children=True) self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True)
self.defaults_tuple = ProxyNode(self.defaults_tuple) self.defaults_tuple = ProxyNode(self.defaults_tuple)
self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object) self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object)
...@@ -705,10 +705,11 @@ class FusedCFuncDefNode(StatListNode): ...@@ -705,10 +705,11 @@ class FusedCFuncDefNode(StatListNode):
fused_func.defaults_tuple = CloneNode(self.defaults_tuple) fused_func.defaults_tuple = CloneNode(self.defaults_tuple)
fused_func.code_object = CloneNode(self.code_object) fused_func.code_object = CloneNode(self.code_object)
for pycfunc in self.specialized_pycfuncs: for i, pycfunc in enumerate(self.specialized_pycfuncs):
pycfunc.code_object = CloneNode(self.code_object) pycfunc.code_object = CloneNode(self.code_object)
pycfunc.analyse_types(env) pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env)
pycfunc.defaults_tuple = CloneNode(self.defaults_tuple) pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
return self
def synthesize_defnodes(self): def synthesize_defnodes(self):
""" """
......
...@@ -75,12 +75,9 @@ def embed_position(pos, docstring): ...@@ -75,12 +75,9 @@ def embed_position(pos, docstring):
return doc return doc
from Code import CCodeWriter def write_func_call(func, codewriter_class):
from types import FunctionType
def write_func_call(func):
def f(*args, **kwds): def f(*args, **kwds):
if len(args) > 1 and isinstance(args[1], CCodeWriter): if len(args) > 1 and isinstance(args[1], codewriter_class):
# here we annotate the code with this function call # here we annotate the code with this function call
# but only if new code is generated # but only if new code is generated
node, code = args[:2] node, code = args[:2]
...@@ -109,18 +106,45 @@ class VerboseCodeWriter(type): ...@@ -109,18 +106,45 @@ class VerboseCodeWriter(type):
# Set this as a metaclass to trace function calls in code. # Set this as a metaclass to trace function calls in code.
# This slows down code generation and makes much larger files. # This slows down code generation and makes much larger files.
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
from types import FunctionType
from Code import CCodeWriter
attrs = dict(attrs) attrs = dict(attrs)
for mname, m in attrs.items(): for mname, m in attrs.items():
if isinstance(m, FunctionType): if isinstance(m, FunctionType):
attrs[mname] = write_func_call(m) attrs[mname] = write_func_call(m, CCodeWriter)
return super(VerboseCodeWriter, cls).__new__(cls, name, bases, attrs) return super(VerboseCodeWriter, cls).__new__(cls, name, bases, attrs)
class CheckAnalysers(type):
"""Metaclass to check that type analysis functions return a node.
"""
methods = set(['analyse_types',
'analyse_expressions',
'analyse_target_types'])
def __new__(cls, name, bases, attrs):
from types import FunctionType
def check(name, func):
def call(*args, **kwargs):
retval = func(*args, **kwargs)
if retval is None:
print name, args, kwargs
return retval
return call
attrs = dict(attrs)
for mname, m in attrs.items():
if isinstance(m, FunctionType) and mname in cls.methods:
attrs[mname] = check(mname, m)
return super(CheckAnalysers, cls).__new__(cls, name, bases, attrs)
class Node(object): class Node(object):
# pos (string, int, int) Source file position # pos (string, int, int) Source file position
# is_name boolean Is a NameNode # is_name boolean Is a NameNode
# is_literal boolean Is a ConstNode # is_literal boolean Is a ConstNode
#__metaclass__ = CheckAnalysers
if DebugFlags.debug_trace_code_generation: if DebugFlags.debug_trace_code_generation:
__metaclass__ = VerboseCodeWriter __metaclass__ = VerboseCodeWriter
...@@ -305,8 +329,9 @@ class CompilerDirectivesNode(Node): ...@@ -305,8 +329,9 @@ class CompilerDirectivesNode(Node):
def analyse_expressions(self, env): def analyse_expressions(self, env):
old = env.directives old = env.directives
env.directives = self.directives env.directives = self.directives
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
env.directives = old env.directives = old
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
env_old = env.directives env_old = env.directives
...@@ -358,8 +383,9 @@ class StatListNode(Node): ...@@ -358,8 +383,9 @@ class StatListNode(Node):
def analyse_expressions(self, env): def analyse_expressions(self, env):
#print "StatListNode.analyse_expressions" ### #print "StatListNode.analyse_expressions" ###
for stat in self.stats: self.stats = [ stat.analyse_expressions(env)
stat.analyse_expressions(env) for stat in self.stats ]
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
#print "StatListNode.generate_function_definitions" ### #print "StatListNode.generate_function_definitions" ###
...@@ -413,7 +439,7 @@ class CDefExternNode(StatNode): ...@@ -413,7 +439,7 @@ class CDefExternNode(StatNode):
env.in_cinclude = old_cinclude_flag env.in_cinclude = old_cinclude_flag
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -604,7 +630,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -604,7 +630,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
if self.exception_value: if self.exception_value:
self.exception_value.analyse_const_expression(env) self.exception_value.analyse_const_expression(env)
if self.exception_check == '+': if self.exception_check == '+':
self.exception_value.analyse_types(env) self.exception_value = self.exception_value.analyse_types(env)
exc_val_type = self.exception_value.type exc_val_type = self.exception_value.type
if not exc_val_type.is_error and \ if not exc_val_type.is_error and \
not exc_val_type.is_pyobject and \ not exc_val_type.is_pyobject and \
...@@ -1206,7 +1232,7 @@ class CStructOrUnionDefNode(StatNode): ...@@ -1206,7 +1232,7 @@ class CStructOrUnionDefNode(StatNode):
error(attr.pos, "Struct cannot contain itself as a member.") error(attr.pos, "Struct cannot contain itself as a member.")
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -1270,7 +1296,8 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1270,7 +1296,8 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
self.scope = scope self.scope = scope
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(self.entry.type.scope) self.body = self.body.analyse_expressions(self.entry.type.scope)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(self.entry.type.scope, code) self.body.generate_function_definitions(self.entry.type.scope, code)
...@@ -1307,7 +1334,7 @@ class CEnumDefNode(StatNode): ...@@ -1307,7 +1334,7 @@ class CEnumDefNode(StatNode):
item.analyse_declarations(env, self.entry) item.analyse_declarations(env, self.entry)
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if self.visibility == 'public' or self.api: if self.visibility == 'public' or self.api:
...@@ -1371,7 +1398,8 @@ class CTypeDefNode(StatNode): ...@@ -1371,7 +1398,8 @@ class CTypeDefNode(StatNode):
entry.defined_in_pxd = 1 entry.defined_in_pxd = 1
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -1413,7 +1441,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1413,7 +1441,7 @@ class FuncDefNode(StatNode, BlockNode):
if arg.default: if arg.default:
default_seen = 1 default_seen = 1
if arg.is_generic: if arg.is_generic:
arg.default.analyse_types(env) arg.default = arg.default.analyse_types(env)
arg.default = arg.default.coerce_to(arg.type, env) arg.default = arg.default.coerce_to(arg.type, env)
else: else:
error(arg.pos, error(arg.pos,
...@@ -2173,10 +2201,11 @@ class CFuncDefNode(FuncDefNode): ...@@ -2173,10 +2201,11 @@ class CFuncDefNode(FuncDefNode):
self.local_scope.directives = env.directives self.local_scope.directives = env.directives
if self.py_func is not None: if self.py_func is not None:
# this will also analyse the default values # this will also analyse the default values
self.py_func.analyse_expressions(env) self.py_func = self.py_func.analyse_expressions(env)
else: else:
self.analyse_default_values(env) self.analyse_default_values(env)
self.acquire_gil = self.need_gil_acquisition(self.local_scope) self.acquire_gil = self.need_gil_acquisition(self.local_scope)
return self
def needs_assignment_synthesis(self, env, code=None): def needs_assignment_synthesis(self, env, code=None):
return False return False
...@@ -2719,9 +2748,10 @@ class DefNode(FuncDefNode): ...@@ -2719,9 +2748,10 @@ class DefNode(FuncDefNode):
if not self.needs_assignment_synthesis(env) and self.decorators: if not self.needs_assignment_synthesis(env) and self.decorators:
for decorator in self.decorators[::-1]: for decorator in self.decorators[::-1]:
decorator.decorator.analyse_expressions(env) decorator.decorator = decorator.decorator.analyse_expressions(env)
self.py_wrapper.prepare_argument_coercion(env) self.py_wrapper.prepare_argument_coercion(env)
return self
def needs_assignment_synthesis(self, env, code=None): def needs_assignment_synthesis(self, env, code=None):
if self.is_wrapper or self.specialized_cpdefs or self.entry.is_fused_specialized: if self.is_wrapper or self.specialized_cpdefs or self.entry.is_fused_specialized:
...@@ -3802,9 +3832,11 @@ class OverrideCheckNode(StatNode): ...@@ -3802,9 +3832,11 @@ class OverrideCheckNode(StatNode):
self.func_node = ExprNodes.RawCNameExprNode(self.pos, py_object_type) self.func_node = ExprNodes.RawCNameExprNode(self.pos, py_object_type)
call_node = ExprNodes.SimpleCallNode( call_node = ExprNodes.SimpleCallNode(
self.pos, function=self.func_node, self.pos, function=self.func_node,
args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[first_arg:]]) args=[ ExprNodes.NameNode(self.pos, name=arg.name)
for arg in self.args[first_arg:] ])
self.body = ReturnStatNode(self.pos, value=call_node) self.body = ReturnStatNode(self.pos, value=call_node)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
interned_attr_cname = code.intern_identifier(self.py_func.entry.name) interned_attr_cname = code.intern_identifier(self.py_func.entry.name)
...@@ -3985,16 +4017,17 @@ class PyClassDefNode(ClassDefNode): ...@@ -3985,16 +4017,17 @@ class PyClassDefNode(ClassDefNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.py3_style_class: if self.py3_style_class:
self.bases.analyse_expressions(env) self.bases = self.bases.analyse_expressions(env)
self.metaclass.analyse_expressions(env) self.metaclass = self.metaclass.analyse_expressions(env)
self.mkw.analyse_expressions(env) self.mkw = self.mkw.analyse_expressions(env)
self.dict.analyse_expressions(env) self.dict = self.dict.analyse_expressions(env)
self.class_result.analyse_expressions(env) self.class_result = self.class_result.analyse_expressions(env)
genv = env.global_scope() genv = env.global_scope()
cenv = self.scope cenv = self.scope
self.body.analyse_expressions(cenv) self.body = self.body.analyse_expressions(cenv)
self.target.analyse_target_expression(env, self.classobj) self.target.analyse_target_expression(env, self.classobj)
self.class_cell.analyse_expressions(cenv) self.class_cell = self.class_cell.analyse_expressions(cenv)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.generate_lambda_definitions(self.scope, code) self.generate_lambda_definitions(self.scope, code)
...@@ -4205,7 +4238,8 @@ class CClassDefNode(ClassDefNode): ...@@ -4205,7 +4238,8 @@ class CClassDefNode(ClassDefNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.body: if self.body:
scope = self.entry.type.scope scope = self.entry.type.scope
self.body.analyse_expressions(scope) self.body = self.body.analyse_expressions(scope)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
if self.body: if self.body:
...@@ -4239,7 +4273,8 @@ class PropertyNode(StatNode): ...@@ -4239,7 +4273,8 @@ class PropertyNode(StatNode):
self.body.analyse_declarations(entry.scope) self.body.analyse_declarations(entry.scope)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code) self.body.generate_function_definitions(env, code)
...@@ -4263,7 +4298,7 @@ class GlobalNode(StatNode): ...@@ -4263,7 +4298,7 @@ class GlobalNode(StatNode):
env.declare_global(name, self.pos) env.declare_global(name, self.pos)
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -4281,7 +4316,7 @@ class NonlocalNode(StatNode): ...@@ -4281,7 +4316,7 @@ class NonlocalNode(StatNode):
env.declare_nonlocal(name, self.pos) env.declare_nonlocal(name, self.pos)
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -4312,7 +4347,8 @@ class ExprStatNode(StatNode): ...@@ -4312,7 +4347,8 @@ class ExprStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.expr.result_is_used = False # hint that .result() may safely be left empty self.expr.result_is_used = False # hint that .result() may safely be left empty
self.expr.analyse_expressions(env) self.expr = self.expr.analyse_expressions(env)
return self
def nogil_check(self, env): def nogil_check(self, env):
if self.expr.type.is_pyobject and self.expr.is_temp: if self.expr.type.is_pyobject and self.expr.is_temp:
...@@ -4344,7 +4380,7 @@ class AssignmentNode(StatNode): ...@@ -4344,7 +4380,7 @@ class AssignmentNode(StatNode):
# to any of the left hand sides. # to any of the left hand sides.
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.analyse_types(env) return self.analyse_types(env)
# def analyse_expressions(self, env): # def analyse_expressions(self, env):
# self.analyse_expressions_1(env) # self.analyse_expressions_1(env)
...@@ -4449,8 +4485,8 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4449,8 +4485,8 @@ class SingleAssignmentNode(AssignmentNode):
def analyse_types(self, env, use_temp = 0): def analyse_types(self, env, use_temp = 0):
import ExprNodes import ExprNodes
self.rhs.analyse_types(env) self.rhs = self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast: if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast:
...@@ -4474,6 +4510,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4474,6 +4510,7 @@ class SingleAssignmentNode(AssignmentNode):
self.rhs = self.rhs.coerce_to_temp(env) self.rhs = self.rhs.coerce_to_temp(env)
elif self.rhs.type.is_pyobject: elif self.rhs.type.is_pyobject:
self.rhs = self.rhs.coerce_to_simple(env) self.rhs = self.rhs.coerce_to_simple(env)
return self
def generate_rhs_evaluation_code(self, code): def generate_rhs_evaluation_code(self, code):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
...@@ -4511,7 +4548,7 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -4511,7 +4548,7 @@ class CascadedAssignmentNode(AssignmentNode):
def analyse_types(self, env, use_temp = 0): def analyse_types(self, env, use_temp = 0):
from ExprNodes import CloneNode, ProxyNode from ExprNodes import CloneNode, ProxyNode
self.rhs.analyse_types(env) self.rhs = self.rhs.analyse_types(env)
if use_temp or self.rhs.is_attribute: if use_temp or self.rhs.is_attribute:
# (cdef) attribute access is not safe as it traverses pointers # (cdef) attribute access is not safe as it traverses pointers
self.rhs = self.rhs.coerce_to_temp(env) self.rhs = self.rhs.coerce_to_temp(env)
...@@ -4526,6 +4563,7 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -4526,6 +4563,7 @@ class CascadedAssignmentNode(AssignmentNode):
rhs = CloneNode(self.rhs) rhs = CloneNode(self.rhs)
rhs = rhs.coerce_to(lhs.type, env) rhs = rhs.coerce_to(lhs.type, env)
self.coerced_rhs_list.append(rhs) self.coerced_rhs_list.append(rhs)
return self
def generate_rhs_evaluation_code(self, code): def generate_rhs_evaluation_code(self, code):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
...@@ -4571,8 +4609,9 @@ class ParallelAssignmentNode(AssignmentNode): ...@@ -4571,8 +4609,9 @@ class ParallelAssignmentNode(AssignmentNode):
stat.analyse_declarations(env) stat.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
for stat in self.stats: self.stats = [ stat.analyse_types(env, use_temp = 1)
stat.analyse_types(env, use_temp = 1) for stat in self.stats ]
return self
# def analyse_expressions(self, env): # def analyse_expressions(self, env):
# for stat in self.stats: # for stat in self.stats:
...@@ -4621,13 +4660,14 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -4621,13 +4660,14 @@ class InPlaceAssignmentNode(AssignmentNode):
def analyse_types(self, env): def analyse_types(self, env):
import ExprNodes import ExprNodes
self.rhs.analyse_types(env) self.rhs = self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
# When assigning to a fully indexed buffer or memoryview, coerce the rhs # When assigning to a fully indexed buffer or memoryview, coerce the rhs
if (isinstance(self.lhs, ExprNodes.IndexNode) and if (isinstance(self.lhs, ExprNodes.IndexNode) and
(self.lhs.memslice_index or self.lhs.is_buffer_access)): (self.lhs.memslice_index or self.lhs.is_buffer_access)):
self.rhs = self.rhs.coerce_to(self.lhs.type, env) self.rhs = self.rhs.coerce_to(self.lhs.type, env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
import ExprNodes import ExprNodes
...@@ -4674,13 +4714,14 @@ class PrintStatNode(StatNode): ...@@ -4674,13 +4714,14 @@ class PrintStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.stream: if self.stream:
self.stream.analyse_expressions(env) stream = self.stream.analyse_expressions(env)
self.stream = self.stream.coerce_to_pyobject(env) self.stream = stream.coerce_to_pyobject(env)
self.arg_tuple.analyse_expressions(env) arg_tuple = self.arg_tuple.analyse_expressions(env)
self.arg_tuple = self.arg_tuple.coerce_to_pyobject(env) self.arg_tuple = arg_tuple.coerce_to_pyobject(env)
env.use_utility_code(printing_utility_code) env.use_utility_code(printing_utility_code)
if len(self.arg_tuple.args) == 1 and self.append_newline: if len(self.arg_tuple.args) == 1 and self.append_newline:
env.use_utility_code(printing_one_utility_code) env.use_utility_code(printing_one_utility_code)
return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Python print statement" gil_message = "Python print statement"
...@@ -4737,10 +4778,11 @@ class ExecStatNode(StatNode): ...@@ -4737,10 +4778,11 @@ class ExecStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
for i, arg in enumerate(self.args): for i, arg in enumerate(self.args):
arg.analyse_expressions(env) arg = arg.analyse_expressions(env)
arg = arg.coerce_to_pyobject(env) arg = arg.coerce_to_pyobject(env)
self.args[i] = arg self.args[i] = arg
env.use_utility_code(Builtin.pyexec_utility_code) env.use_utility_code(Builtin.pyexec_utility_code)
return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Python exec statement" gil_message = "Python exec statement"
...@@ -4781,8 +4823,8 @@ class DelStatNode(StatNode): ...@@ -4781,8 +4823,8 @@ class DelStatNode(StatNode):
arg.analyse_target_declaration(env) arg.analyse_target_declaration(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
for arg in self.args: for i, arg in enumerate(self.args):
arg.analyse_target_expression(env, None) arg = self.args[i] = arg.analyse_target_expression(env, None)
if arg.type.is_pyobject or (arg.is_name and if arg.type.is_pyobject or (arg.is_name and
arg.type.is_memoryviewslice): arg.type.is_memoryviewslice):
pass pass
...@@ -4793,6 +4835,7 @@ class DelStatNode(StatNode): ...@@ -4793,6 +4835,7 @@ class DelStatNode(StatNode):
else: else:
error(arg.pos, "Deletion of non-Python, non-C++ object") error(arg.pos, "Deletion of non-Python, non-C++ object")
#arg.release_target_temp(env) #arg.release_target_temp(env)
return self
def nogil_check(self, env): def nogil_check(self, env):
for arg in self.args: for arg in self.args:
...@@ -4822,7 +4865,7 @@ class PassStatNode(StatNode): ...@@ -4822,7 +4865,7 @@ class PassStatNode(StatNode):
child_attrs = [] child_attrs = []
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -4843,7 +4886,7 @@ class BreakStatNode(StatNode): ...@@ -4843,7 +4886,7 @@ class BreakStatNode(StatNode):
is_terminator = True is_terminator = True
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if not code.break_label: if not code.break_label:
...@@ -4858,7 +4901,7 @@ class ContinueStatNode(StatNode): ...@@ -4858,7 +4901,7 @@ class ContinueStatNode(StatNode):
is_terminator = True is_terminator = True
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if code.funcstate.in_try_finally: if code.funcstate.in_try_finally:
...@@ -4888,9 +4931,9 @@ class ReturnStatNode(StatNode): ...@@ -4888,9 +4931,9 @@ class ReturnStatNode(StatNode):
self.return_type = return_type self.return_type = return_type
if not return_type: if not return_type:
error(self.pos, "Return not inside a function body") error(self.pos, "Return not inside a function body")
return return self
if self.value: if self.value:
self.value.analyse_types(env) self.value = self.value.analyse_types(env)
if return_type.is_void or return_type.is_returncode: if return_type.is_void or return_type.is_returncode:
error(self.value.pos, error(self.value.pos,
"Return with value in void function") "Return with value in void function")
...@@ -4901,6 +4944,7 @@ class ReturnStatNode(StatNode): ...@@ -4901,6 +4944,7 @@ class ReturnStatNode(StatNode):
and not return_type.is_pyobject and not return_type.is_pyobject
and not return_type.is_returncode): and not return_type.is_returncode):
error(self.pos, "Return value required") error(self.pos, "Return value required")
return self
def nogil_check(self, env): def nogil_check(self, env):
if self.return_type.is_pyobject: if self.return_type.is_pyobject:
...@@ -4981,17 +5025,17 @@ class RaiseStatNode(StatNode): ...@@ -4981,17 +5025,17 @@ class RaiseStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.exc_type: if self.exc_type:
self.exc_type.analyse_types(env) exc_type = self.exc_type.analyse_types(env)
self.exc_type = self.exc_type.coerce_to_pyobject(env) self.exc_type = exc_type.coerce_to_pyobject(env)
if self.exc_value: if self.exc_value:
self.exc_value.analyse_types(env) exc_value = self.exc_value.analyse_types(env)
self.exc_value = self.exc_value.coerce_to_pyobject(env) self.exc_value = exc_value.coerce_to_pyobject(env)
if self.exc_tb: if self.exc_tb:
self.exc_tb.analyse_types(env) exc_tb = self.exc_tb.analyse_types(env)
self.exc_tb = self.exc_tb.coerce_to_pyobject(env) self.exc_tb = exc_tb.coerce_to_pyobject(env)
if self.cause: if self.cause:
self.cause.analyse_types(env) cause = self.cause.analyse_types(env)
self.cause = self.cause.coerce_to_pyobject(env) self.cause = cause.coerce_to_pyobject(env)
# special cases for builtin exceptions # special cases for builtin exceptions
self.builtin_exc_name = None self.builtin_exc_name = None
if self.exc_type and not self.exc_value and not self.exc_tb: if self.exc_type and not self.exc_value and not self.exc_tb:
...@@ -5005,6 +5049,7 @@ class RaiseStatNode(StatNode): ...@@ -5005,6 +5049,7 @@ class RaiseStatNode(StatNode):
self.builtin_exc_name = exc.name self.builtin_exc_name = exc.name
if self.builtin_exc_name == 'MemoryError': if self.builtin_exc_name == 'MemoryError':
self.exc_type = None # has a separate implementation self.exc_type = None # has a separate implementation
return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Raising exception" gil_message = "Raising exception"
...@@ -5075,7 +5120,7 @@ class ReraiseStatNode(StatNode): ...@@ -5075,7 +5120,7 @@ class ReraiseStatNode(StatNode):
is_terminator = True is_terminator = True
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Raising exception" gil_message = "Raising exception"
...@@ -5107,8 +5152,9 @@ class AssertStatNode(StatNode): ...@@ -5107,8 +5152,9 @@ class AssertStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.cond = self.cond.analyse_boolean_expression(env) self.cond = self.cond.analyse_boolean_expression(env)
if self.value: if self.value:
self.value.analyse_types(env) value = self.value.analyse_types(env)
self.value = self.value.coerce_to_pyobject(env) self.value = value.coerce_to_pyobject(env)
return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Raising exception" gil_message = "Raising exception"
...@@ -5163,10 +5209,11 @@ class IfStatNode(StatNode): ...@@ -5163,10 +5209,11 @@ class IfStatNode(StatNode):
self.else_clause.analyse_declarations(env) self.else_clause.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
for if_clause in self.if_clauses: self.if_clauses = [ if_clause.analyse_expressions(env)
if_clause.analyse_expressions(env) for if_clause in self.if_clauses ]
if self.else_clause: if self.else_clause:
self.else_clause.analyse_expressions(env) self.else_clause = self.else_clause.analyse_expressions(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.mark_pos(self.pos) code.mark_pos(self.pos)
...@@ -5206,7 +5253,8 @@ class IfClauseNode(Node): ...@@ -5206,7 +5253,8 @@ class IfClauseNode(Node):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.condition = \ self.condition = \
self.condition.analyse_temp_boolean_expression(env) self.condition.analyse_temp_boolean_expression(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def get_constant_condition_result(self): def get_constant_condition_result(self):
if self.condition.has_constant_result(): if self.condition.has_constant_result():
...@@ -5315,9 +5363,10 @@ class WhileStatNode(LoopNode, StatNode): ...@@ -5315,9 +5363,10 @@ class WhileStatNode(LoopNode, StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.condition: if self.condition:
self.condition = self.condition.analyse_temp_boolean_expression(env) self.condition = self.condition.analyse_temp_boolean_expression(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
if self.else_clause: if self.else_clause:
self.else_clause.analyse_expressions(env) self.else_clause = self.else_clause.analyse_expressions(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
old_loop_labels = code.new_loop_labels() old_loop_labels = code.new_loop_labels()
...@@ -5384,22 +5433,24 @@ class DictIterationNextNode(Node): ...@@ -5384,22 +5433,24 @@ class DictIterationNextNode(Node):
def analyse_expressions(self, env): def analyse_expressions(self, env):
import ExprNodes import ExprNodes
self.dict_obj.analyse_types(env) self.dict_obj = self.dict_obj.analyse_types(env)
self.expected_size.analyse_types(env) self.expected_size = self.expected_size.analyse_types(env)
if self.pos_index_var: self.pos_index_var.analyse_types(env) if self.pos_index_var:
self.pos_index_var = self.pos_index_var.analyse_types(env)
if self.key_target: if self.key_target:
self.key_target.analyse_target_types(env) self.key_target = self.key_target.analyse_target_types(env)
self.key_ref = ExprNodes.TempNode(self.key_target.pos, PyrexTypes.py_object_type) self.key_ref = ExprNodes.TempNode(self.key_target.pos, PyrexTypes.py_object_type)
self.coerced_key_var = self.key_ref.coerce_to(self.key_target.type, env) self.coerced_key_var = self.key_ref.coerce_to(self.key_target.type, env)
if self.value_target: if self.value_target:
self.value_target.analyse_target_types(env) self.value_target = self.value_target.analyse_target_types(env)
self.value_ref = ExprNodes.TempNode(self.value_target.pos, type=PyrexTypes.py_object_type) self.value_ref = ExprNodes.TempNode(self.value_target.pos, type=PyrexTypes.py_object_type)
self.coerced_value_var = self.value_ref.coerce_to(self.value_target.type, env) self.coerced_value_var = self.value_ref.coerce_to(self.value_target.type, env)
if self.tuple_target: if self.tuple_target:
self.tuple_target.analyse_target_types(env) self.tuple_target = self.tuple_target.analyse_target_types(env)
self.tuple_ref = ExprNodes.TempNode(self.tuple_target.pos, PyrexTypes.py_object_type) self.tuple_ref = ExprNodes.TempNode(self.tuple_target.pos, PyrexTypes.py_object_type)
self.coerced_tuple_var = self.tuple_ref.coerce_to(self.tuple_target.type, env) self.coerced_tuple_var = self.tuple_ref.coerce_to(self.tuple_target.type, env)
self.is_dict_flag.analyse_types(env) self.is_dict_flag = self.is_dict_flag.analyse_types(env)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.dict_obj.generate_function_definitions(env, code) self.dict_obj.generate_function_definitions(env, code)
...@@ -5472,22 +5523,23 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -5472,22 +5523,23 @@ class ForInStatNode(LoopNode, StatNode):
self.item = ExprNodes.NextNode(self.iterator) self.item = ExprNodes.NextNode(self.iterator)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
self.iterator.analyse_expressions(env) self.iterator = self.iterator.analyse_expressions(env)
if self.item is None: if self.item is None:
# Hack. Sometimes analyse_declarations not called. # Hack. Sometimes analyse_declarations not called.
import ExprNodes import ExprNodes
self.item = ExprNodes.NextNode(self.iterator) self.item = ExprNodes.NextNode(self.iterator)
self.item.analyse_expressions(env) self.item = self.item.analyse_expressions(env)
if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \ if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \
self.target.type.assignable_from(self.iterator.type): self.target.type.assignable_from(self.iterator.type):
# C array slice optimization. # C array slice optimization.
pass pass
else: else:
self.item = self.item.coerce_to(self.target.type, env) self.item = self.item.coerce_to(self.target.type, env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
if self.else_clause: if self.else_clause:
self.else_clause.analyse_expressions(env) self.else_clause = self.else_clause.analyse_expressions(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
old_loop_labels = code.new_loop_labels() old_loop_labels = code.new_loop_labels()
...@@ -5581,13 +5633,13 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -5581,13 +5633,13 @@ class ForFromStatNode(LoopNode, StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
import ExprNodes import ExprNodes
self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
self.bound1.analyse_types(env) self.bound1 = self.bound1.analyse_types(env)
self.bound2.analyse_types(env) self.bound2 = self.bound2.analyse_types(env)
if self.step is not None: if self.step is not None:
if isinstance(self.step, ExprNodes.UnaryMinusNode): if isinstance(self.step, ExprNodes.UnaryMinusNode):
warning(self.step.pos, "Probable infinite loop in for-from-by statement. Consider switching the directions of the relations.", 2) warning(self.step.pos, "Probable infinite loop in for-from-by statement. Consider switching the directions of the relations.", 2)
self.step.analyse_types(env) self.step = self.step.analyse_types(env)
if self.target.type.is_numeric: if self.target.type.is_numeric:
loop_type = self.target.type loop_type = self.target.type
...@@ -5624,9 +5676,10 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -5624,9 +5676,10 @@ class ForFromStatNode(LoopNode, StatNode):
self.loopvar_node = c_loopvar_node self.loopvar_node = c_loopvar_node
self.py_loopvar_node = \ self.py_loopvar_node = \
ExprNodes.CloneNode(c_loopvar_node).coerce_to_pyobject(env) ExprNodes.CloneNode(c_loopvar_node).coerce_to_pyobject(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
if self.else_clause: if self.else_clause:
self.else_clause.analyse_expressions(env) self.else_clause = self.else_clause.analyse_expressions(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
old_loop_labels = code.new_loop_labels() old_loop_labels = code.new_loop_labels()
...@@ -5781,9 +5834,10 @@ class WithStatNode(StatNode): ...@@ -5781,9 +5834,10 @@ class WithStatNode(StatNode):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.manager.analyse_types(env) self.manager = self.manager.analyse_types(env)
self.enter_call.analyse_types(env) self.enter_call = self.enter_call.analyse_types(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.manager.generate_function_definitions(env, code) self.manager.generate_function_definitions(env, code)
...@@ -5851,10 +5905,11 @@ class WithTargetAssignmentStatNode(AssignmentNode): ...@@ -5851,10 +5905,11 @@ class WithTargetAssignmentStatNode(AssignmentNode):
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.rhs.analyse_types(env) self.rhs = self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
self.rhs = self.rhs.coerce_to(self.lhs.type, env) self.rhs = self.rhs.coerce_to(self.lhs.type, env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if self.orig_rhs.type.is_pyobject: if self.orig_rhs.type.is_pyobject:
...@@ -5901,17 +5956,18 @@ class TryExceptStatNode(StatNode): ...@@ -5901,17 +5956,18 @@ class TryExceptStatNode(StatNode):
env.use_utility_code(reset_exception_utility_code) env.use_utility_code(reset_exception_utility_code)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
default_clause_seen = 0 default_clause_seen = 0
for except_clause in self.except_clauses: for i, except_clause in enumerate(self.except_clauses):
except_clause.analyse_expressions(env) except_clause = self.except_clauses[i] = except_clause.analyse_expressions(env)
if default_clause_seen: if default_clause_seen:
error(except_clause.pos, "default 'except:' must be last") error(except_clause.pos, "default 'except:' must be last")
if not except_clause.pattern: if not except_clause.pattern:
default_clause_seen = 1 default_clause_seen = 1
self.has_default_clause = default_clause_seen self.has_default_clause = default_clause_seen
if self.else_clause: if self.else_clause:
self.else_clause.analyse_expressions(env) self.else_clause = self.else_clause.analyse_expressions(env)
return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Try-except statement" gil_message = "Try-except statement"
...@@ -6057,20 +6113,21 @@ class ExceptClauseNode(Node): ...@@ -6057,20 +6113,21 @@ class ExceptClauseNode(Node):
if self.pattern: if self.pattern:
# normalise/unpack self.pattern into a list # normalise/unpack self.pattern into a list
for i, pattern in enumerate(self.pattern): for i, pattern in enumerate(self.pattern):
pattern.analyse_expressions(env) pattern = pattern.analyse_expressions(env)
self.pattern[i] = pattern.coerce_to_pyobject(env) self.pattern[i] = pattern.coerce_to_pyobject(env)
if self.target: if self.target:
import ExprNodes import ExprNodes
self.exc_value = ExprNodes.ExcValueNode(self.pos, env) self.exc_value = ExprNodes.ExcValueNode(self.pos, env)
self.target.analyse_target_expression(env, self.exc_value) self.target = self.target.analyse_target_expression(env, self.exc_value)
if self.excinfo_target is not None: if self.excinfo_target is not None:
import ExprNodes import ExprNodes
self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[ excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[
ExprNodes.ExcValueNode(pos=self.pos, env=env) for _ in range(3)]) ExprNodes.ExcValueNode(pos=self.pos, env=env) for _ in range(3)])
self.excinfo_tuple.analyse_expressions(env) self.excinfo_tuple = excinfo_tuple.analyse_expressions(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_handling_code(self, code, end_label): def generate_handling_code(self, code, end_label):
code.mark_pos(self.pos) code.mark_pos(self.pos)
...@@ -6213,8 +6270,9 @@ class TryFinallyStatNode(StatNode): ...@@ -6213,8 +6270,9 @@ class TryFinallyStatNode(StatNode):
self.finally_clause.analyse_declarations(env) self.finally_clause.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
self.finally_clause.analyse_expressions(env) self.finally_clause = self.finally_clause.analyse_expressions(env)
return self
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Try-finally statement" gil_message = "Try-finally statement"
...@@ -6415,8 +6473,9 @@ class GILStatNode(NogilTryFinallyStatNode): ...@@ -6415,8 +6473,9 @@ class GILStatNode(NogilTryFinallyStatNode):
UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c")) UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c"))
was_nogil = env.nogil was_nogil = env.nogil
env.nogil = self.state == 'nogil' env.nogil = self.state == 'nogil'
TryFinallyStatNode.analyse_expressions(self, env) node = TryFinallyStatNode.analyse_expressions(self, env)
env.nogil = was_nogil env.nogil = was_nogil
return node
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.mark_pos(self.pos) code.mark_pos(self.pos)
...@@ -6441,7 +6500,7 @@ class GILExitNode(StatNode): ...@@ -6441,7 +6500,7 @@ class GILExitNode(StatNode):
child_attrs = [] child_attrs = []
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if self.state == 'gil': if self.state == 'gil':
...@@ -6500,7 +6559,7 @@ class CImportStatNode(StatNode): ...@@ -6500,7 +6559,7 @@ class CImportStatNode(StatNode):
*utility_code_for_cimports[self.module_name])) *utility_code_for_cimports[self.module_name]))
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -6574,7 +6633,7 @@ class FromCImportStatNode(StatNode): ...@@ -6574,7 +6633,7 @@ class FromCImportStatNode(StatNode):
return 1 return 1
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -6605,7 +6664,7 @@ class FromImportStatNode(StatNode): ...@@ -6605,7 +6664,7 @@ class FromImportStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
import ExprNodes import ExprNodes
self.module.analyse_expressions(env) self.module = self.module.analyse_expressions(env)
self.item = ExprNodes.RawCNameExprNode(self.pos, py_object_type) self.item = ExprNodes.RawCNameExprNode(self.pos, py_object_type)
self.interned_items = [] self.interned_items = []
for name, target in self.items: for name, target in self.items:
...@@ -6630,7 +6689,7 @@ class FromImportStatNode(StatNode): ...@@ -6630,7 +6689,7 @@ class FromImportStatNode(StatNode):
continue continue
except AttributeError: except AttributeError:
pass pass
target.analyse_target_expression(env, None) target = target.analyse_target_expression(env, None) # FIXME?
if target.type is py_object_type: if target.type is py_object_type:
coerced_item = None coerced_item = None
else: else:
...@@ -6638,6 +6697,7 @@ class FromImportStatNode(StatNode): ...@@ -6638,6 +6697,7 @@ class FromImportStatNode(StatNode):
self.interned_items.append((name, target, coerced_item)) self.interned_items.append((name, target, coerced_item))
if self.interned_items: if self.interned_items:
env.use_utility_code(raise_import_error_utility_code) env.use_utility_code(raise_import_error_utility_code)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.module.generate_evaluation_code(code) self.module.generate_evaluation_code(code)
...@@ -6797,12 +6857,12 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6797,12 +6857,12 @@ class ParallelStatNode(StatNode, ParallelNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.num_threads: if self.num_threads:
self.num_threads.analyse_expressions(env) self.num_threads = self.num_threads.analyse_expressions(env)
if self.chunksize: if self.chunksize:
self.chunksize.analyse_expressions(env) self.chunksize = self.chunksize.analyse_expressions(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
self.analyse_sharing_attributes(env) self.analyse_sharing_attributes(env)
if self.num_threads is not None: if self.num_threads is not None:
...@@ -6821,7 +6881,8 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6821,7 +6881,8 @@ class ParallelStatNode(StatNode, ParallelNode):
if not self.num_threads.is_simple(): if not self.num_threads.is_simple():
self.num_threads = self.num_threads.coerce_to( self.num_threads = self.num_threads.coerce_to(
PyrexTypes.c_int_type, env).coerce_to_temp(env) PyrexTypes.c_int_type, env).coerce_to_temp(env)
return self
def analyse_sharing_attributes(self, env): def analyse_sharing_attributes(self, env):
""" """
...@@ -7504,9 +7565,9 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -7504,9 +7565,9 @@ class ParallelRangeNode(ParallelStatNode):
if self.target is None: if self.target is None:
error(self.pos, "prange() can only be used as part of a for loop") error(self.pos, "prange() can only be used as part of a for loop")
return return self
self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
if not self.target.type.is_numeric: if not self.target.type.is_numeric:
# Not a valid type, assume one for now anyway # Not a valid type, assume one for now anyway
...@@ -7545,7 +7606,7 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -7545,7 +7606,7 @@ class ParallelRangeNode(ParallelStatNode):
self.index_type, node.type) self.index_type, node.type)
if self.else_clause is not None: if self.else_clause is not None:
self.else_clause.analyse_expressions(env) self.else_clause = self.else_clause.analyse_expressions(env)
# Although not actually an assignment in this scope, it should be # Although not actually an assignment in this scope, it should be
# treated as such to ensure it is unpacked if a closure temp, and to # treated as such to ensure it is unpacked if a closure temp, and to
...@@ -7555,35 +7616,36 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -7555,35 +7616,36 @@ class ParallelRangeNode(ParallelStatNode):
if hasattr(self.target, 'entry'): if hasattr(self.target, 'entry'):
self.assignments[self.target.entry] = self.target.pos, None self.assignments[self.target.entry] = self.target.pos, None
super(ParallelRangeNode, self).analyse_expressions(env) node = super(ParallelRangeNode, self).analyse_expressions(env)
if self.chunksize: if node.chunksize:
if not self.schedule: if not node.schedule:
error(self.chunksize.pos, error(node.chunksize.pos,
"Must provide schedule with chunksize") "Must provide schedule with chunksize")
elif self.schedule == 'runtime': elif node.schedule == 'runtime':
error(self.chunksize.pos, error(node.chunksize.pos,
"Chunksize not valid for the schedule runtime") "Chunksize not valid for the schedule runtime")
elif (self.chunksize.type.is_int and elif (node.chunksize.type.is_int and
self.chunksize.is_literal and node.chunksize.is_literal and
self.chunksize.compile_time_value(env) <= 0): node.chunksize.compile_time_value(env) <= 0):
error(self.chunksize.pos, "Chunksize must not be negative") error(node.chunksize.pos, "Chunksize must not be negative")
self.chunksize = self.chunksize.coerce_to( node.chunksize = node.chunksize.coerce_to(
PyrexTypes.c_int_type, env).coerce_to_temp(env) PyrexTypes.c_int_type, env).coerce_to_temp(env)
if self.nogil: if node.nogil:
env.nogil = was_nogil env.nogil = was_nogil
self.is_nested_prange = self.parent and self.parent.is_prange node.is_nested_prange = node.parent and node.parent.is_prange
if self.is_nested_prange: if node.is_nested_prange:
parent = self parent = node
while parent.parent and parent.parent.is_prange: while parent.parent and parent.parent.is_prange:
parent = parent.parent parent = parent.parent
parent.assignments.update(self.assignments) parent.assignments.update(node.assignments)
parent.privates.update(self.privates) parent.privates.update(node.privates)
parent.assigned_nodes.extend(self.assigned_nodes) parent.assigned_nodes.extend(node.assigned_nodes)
return node
def nogil_check(self, env): def nogil_check(self, env):
names = 'start', 'stop', 'step', 'target' names = 'start', 'stop', 'step', 'target'
...@@ -7866,7 +7928,8 @@ class CnameDecoratorNode(StatNode): ...@@ -7866,7 +7928,8 @@ class CnameDecoratorNode(StatNode):
return '%s_%s' % (self.cname, cname) return '%s_%s' % (self.cname, cname)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.node.analyse_expressions(env) self.node = self.node.analyse_expressions(env)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
"Ensure a prototype for every @cname method in the right place" "Ensure a prototype for every @cname method in the right place"
......
...@@ -100,7 +100,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -100,7 +100,7 @@ class IterationTransform(Visitor.EnvTransform):
iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
body=if_node, body=if_node,
else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
for_loop.analyse_expressions(self.current_env()) for_loop = for_loop.analyse_expressions(self.current_env())
for_loop = self.visit(for_loop) for_loop = self.visit(for_loop)
new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
...@@ -704,7 +704,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -704,7 +704,7 @@ class IterationTransform(Visitor.EnvTransform):
dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
key_target, value_target, tuple_target, key_target, value_target, tuple_target,
is_dict_temp) is_dict_temp)
iter_next_node.analyse_expressions(self.current_env()) iter_next_node = iter_next_node.analyse_expressions(self.current_env())
body.stats[0:0] = [iter_next_node] body.stats[0:0] = [iter_next_node]
if method: if method:
...@@ -1187,7 +1187,7 @@ class SimplifyCalls(Visitor.EnvTransform): ...@@ -1187,7 +1187,7 @@ class SimplifyCalls(Visitor.EnvTransform):
node.pos, node.pos,
function=node.function, function=node.function,
args=args) args=args)
call_node.analyse_types(self.current_env()) call_node = call_node.analyse_types(self.current_env())
if node.type != call_node.type: if node.type != call_node.type:
call_node = call_node.coerce_to( call_node = call_node.coerce_to(
node.type, self.current_env()) node.type, self.current_env())
......
...@@ -1819,20 +1819,20 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1819,20 +1819,20 @@ class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.scope.infer_types() node.scope.infer_types()
node.body.analyse_expressions(node.scope) node.body = node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
node.local_scope.infer_types() node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope) node.body = node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
if node.has_local_scope: if node.has_local_scope:
node.expr_scope.infer_types() node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope) node = node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
......
...@@ -33,9 +33,11 @@ class TempRefNode(AtomicExprNode): ...@@ -33,9 +33,11 @@ class TempRefNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
assert self.type == self.handle.type assert self.type == self.handle.type
return self
def analyse_target_types(self, env): def analyse_target_types(self, env):
assert self.type == self.handle.type assert self.type == self.handle.type
return self
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
pass pass
...@@ -104,7 +106,8 @@ class TempsBlockNode(Node): ...@@ -104,7 +106,8 @@ class TempsBlockNode(Node):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code) self.body.generate_function_definitions(env, code)
...@@ -149,6 +152,7 @@ class ResultRefNode(AtomicExprNode): ...@@ -149,6 +152,7 @@ class ResultRefNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
if self.expression is not None: if self.expression is not None:
self.type = self.expression.type self.type = self.expression.type
return self
def infer_type(self, env): def infer_type(self, env):
if self.type is not None: if self.type is not None:
...@@ -263,9 +267,10 @@ class EvalWithTempExprNode(ExprNodes.ExprNode, LetNodeMixin): ...@@ -263,9 +267,10 @@ class EvalWithTempExprNode(ExprNodes.ExprNode, LetNodeMixin):
return self.subexpression.result() return self.subexpression.result()
def analyse_types(self, env): def analyse_types(self, env):
self.temp_expression.analyse_types(env) self.temp_expression = self.temp_expression.analyse_types(env)
self.subexpression.analyse_types(env) self.subexpression = self.subexpression.analyse_types(env)
self.type = self.subexpression.type self.type = self.subexpression.type
return self
def free_subexpr_temps(self, code): def free_subexpr_temps(self, code):
self.subexpression.free_temps(code) self.subexpression.free_temps(code)
...@@ -302,8 +307,9 @@ class LetNode(Nodes.StatNode, LetNodeMixin): ...@@ -302,8 +307,9 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.temp_expression.analyse_expressions(env) self.temp_expression = self.temp_expression.analyse_expressions(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.setup_temp_expr(code) self.setup_temp_expr(code)
...@@ -335,7 +341,8 @@ class TempResultFromStatNode(ExprNodes.ExprNode): ...@@ -335,7 +341,8 @@ class TempResultFromStatNode(ExprNodes.ExprNode):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_types(self, env): def analyse_types(self, env):
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_result_code(self, code): def generate_result_code(self, code):
self.result_ref.result_code = self.result() self.result_ref.result_code = self.result()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment