Commit f9c385e0 authored by Stefan Behnel's avatar Stefan Behnel

refactor analyse_types() and friends to work more like a transform by...

refactor analyse_types() and friends to work more like a transform by returning the node or a replacement
parent 927e1a4d
This diff is collapsed.
......@@ -673,22 +673,22 @@ class FusedCFuncDefNode(StatListNode):
specialization_type.create_declaration_utility_code(env)
if self.py_func:
self.__signatures__.analyse_expressions(env)
self.py_func.analyse_expressions(env)
self.resulting_fused_function.analyse_expressions(env)
self.fused_func_assignment.analyse_expressions(env)
self.__signatures__ = self.__signatures__.analyse_expressions(env)
self.py_func = self.py_func.analyse_expressions(env)
self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env)
self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env)
self.defaults = defaults = []
for arg in self.node.args:
if arg.default:
arg.default.analyse_expressions(env)
arg.default = arg.default.analyse_expressions(env)
defaults.append(ProxyNode(arg.default))
else:
defaults.append(None)
for stat in self.stats:
stat.analyse_expressions(env)
for i, stat in enumerate(self.stats):
stat = self.stats[i] = stat.analyse_expressions(env)
if isinstance(stat, FuncDefNode):
for arg, default in zip(stat.args, defaults):
if default is not None:
......@@ -697,7 +697,7 @@ class FusedCFuncDefNode(StatListNode):
if self.py_func:
args = [CloneNode(default) for default in defaults if default]
self.defaults_tuple = TupleNode(self.pos, args=args)
self.defaults_tuple.analyse_types(env, skip_children=True)
self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True)
self.defaults_tuple = ProxyNode(self.defaults_tuple)
self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object)
......@@ -705,10 +705,11 @@ class FusedCFuncDefNode(StatListNode):
fused_func.defaults_tuple = CloneNode(self.defaults_tuple)
fused_func.code_object = CloneNode(self.code_object)
for pycfunc in self.specialized_pycfuncs:
for i, pycfunc in enumerate(self.specialized_pycfuncs):
pycfunc.code_object = CloneNode(self.code_object)
pycfunc.analyse_types(env)
pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env)
pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
return self
def synthesize_defnodes(self):
"""
......
This diff is collapsed.
......@@ -100,7 +100,7 @@ class IterationTransform(Visitor.EnvTransform):
iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
body=if_node,
else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
for_loop.analyse_expressions(self.current_env())
for_loop = for_loop.analyse_expressions(self.current_env())
for_loop = self.visit(for_loop)
new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
......@@ -704,7 +704,7 @@ class IterationTransform(Visitor.EnvTransform):
dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
key_target, value_target, tuple_target,
is_dict_temp)
iter_next_node.analyse_expressions(self.current_env())
iter_next_node = iter_next_node.analyse_expressions(self.current_env())
body.stats[0:0] = [iter_next_node]
if method:
......@@ -1187,7 +1187,7 @@ class SimplifyCalls(Visitor.EnvTransform):
node.pos,
function=node.function,
args=args)
call_node.analyse_types(self.current_env())
call_node = call_node.analyse_types(self.current_env())
if node.type != call_node.type:
call_node = call_node.coerce_to(
node.type, self.current_env())
......
......@@ -1819,20 +1819,20 @@ class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node):
node.scope.infer_types()
node.body.analyse_expressions(node.scope)
node.body = node.body.analyse_expressions(node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope)
node.body = node.body.analyse_expressions(node.local_scope)
self.visitchildren(node)
return node
def visit_ScopedExprNode(self, node):
if node.has_local_scope:
node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope)
node = node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node)
return node
......
......@@ -33,9 +33,11 @@ class TempRefNode(AtomicExprNode):
def analyse_types(self, env):
assert self.type == self.handle.type
return self
def analyse_target_types(self, env):
assert self.type == self.handle.type
return self
def analyse_target_declaration(self, env):
pass
......@@ -104,7 +106,8 @@ class TempsBlockNode(Node):
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
self.body = self.body.analyse_expressions(env)
return self
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
......@@ -149,6 +152,7 @@ class ResultRefNode(AtomicExprNode):
def analyse_types(self, env):
if self.expression is not None:
self.type = self.expression.type
return self
def infer_type(self, env):
if self.type is not None:
......@@ -263,9 +267,10 @@ class EvalWithTempExprNode(ExprNodes.ExprNode, LetNodeMixin):
return self.subexpression.result()
def analyse_types(self, env):
self.temp_expression.analyse_types(env)
self.subexpression.analyse_types(env)
self.temp_expression = self.temp_expression.analyse_types(env)
self.subexpression = self.subexpression.analyse_types(env)
self.type = self.subexpression.type
return self
def free_subexpr_temps(self, code):
self.subexpression.free_temps(code)
......@@ -302,8 +307,9 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.temp_expression.analyse_expressions(env)
self.body.analyse_expressions(env)
self.temp_expression = self.temp_expression.analyse_expressions(env)
self.body = self.body.analyse_expressions(env)
return self
def generate_execution_code(self, code):
self.setup_temp_expr(code)
......@@ -335,7 +341,8 @@ class TempResultFromStatNode(ExprNodes.ExprNode):
self.body.analyse_declarations(env)
def analyse_types(self, env):
self.body.analyse_expressions(env)
self.body = self.body.analyse_expressions(env)
return self
def generate_result_code(self, code):
self.result_ref.result_code = self.result()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment