Commit 60056ff4 authored by Stefan Behnel's avatar Stefan Behnel

convert IterationTransform to inherit from EnvTransform for better scope tracking

parent fc691c4c
......@@ -51,28 +51,13 @@ def is_common_value(a, b):
return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
return False
class IterationTransform(Visitor.VisitorTransform):
class IterationTransform(Visitor.EnvTransform):
"""Transform some common for-in loop patterns into efficient C loops:
- for-in-dict loop becomes a while loop calling PyDict_Next()
- for-in-enumerate is replaced by an external counter variable
- for-in-range loop becomes a plain C for loop
"""
visit_Node = Visitor.VisitorTransform.recurse_to_children
def visit_ModuleNode(self, node):
self.current_scope = node.scope
self.module_scope = node.scope
self.visitchildren(node)
return node
def visit_DefNode(self, node):
oldscope = self.current_scope
self.current_scope = node.entry.scope
self.visitchildren(node)
self.current_scope = oldscope
return node
def visit_PrimaryCmpNode(self, node):
if node.is_ptr_contains():
......@@ -110,7 +95,7 @@ class IterationTransform(Visitor.VisitorTransform):
iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
body=if_node,
else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
for_loop.analyse_expressions(self.current_scope)
for_loop.analyse_expressions(self.current_env())
for_loop = self(for_loop)
new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
......@@ -160,7 +145,7 @@ class IterationTransform(Visitor.VisitorTransform):
base_obj = iterator.self or function.obj
method = function.attribute
# in Py3, items() is equivalent to Py2's iteritems()
is_safe_iter = self.module_scope.context.language_level >= 3
is_safe_iter = self.global_scope().context.language_level >= 3
if not is_safe_iter and method in ('keys', 'values', 'items'):
# try to reduce this to the corresponding .iter*() methods
......@@ -319,7 +304,7 @@ class IterationTransform(Visitor.VisitorTransform):
)
if target_value.type != node.target.type:
target_value = target_value.coerce_to(node.target.type,
self.current_scope)
self.current_env())
target_assign = Nodes.SingleAssignmentNode(
pos = node.target.pos,
lhs = node.target,
......@@ -419,12 +404,12 @@ class IterationTransform(Visitor.VisitorTransform):
if start.constant_result is None:
start = None
else:
start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
if stop:
if stop.constant_result is None:
stop = None
else:
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
if stop is None:
if neg_step:
stop = ExprNodes.IntNode(
......@@ -443,7 +428,7 @@ class IterationTransform(Visitor.VisitorTransform):
ptr_type = slice_base.type
if ptr_type.is_array:
ptr_type = ptr_type.element_ptr_type()
carray_ptr = slice_base.coerce_to_simple(self.current_scope)
carray_ptr = slice_base.coerce_to_simple(self.current_env())
if start and start.constant_result != 0:
start_ptr_node = ExprNodes.AddNode(
......@@ -462,7 +447,7 @@ class IterationTransform(Visitor.VisitorTransform):
operator='+',
operand2=stop,
type=ptr_type
).coerce_to_simple(self.current_scope)
).coerce_to_simple(self.current_env())
else:
stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
......@@ -497,7 +482,7 @@ class IterationTransform(Visitor.VisitorTransform):
if target_value.type != node.target.type:
target_value = target_value.coerce_to(node.target.type,
self.current_scope)
self.current_env())
target_assign = Nodes.SingleAssignmentNode(
pos = node.target.pos,
......@@ -553,7 +538,7 @@ class IterationTransform(Visitor.VisitorTransform):
return node
if len(args) == 2:
start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_scope)
start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
else:
start = ExprNodes.IntNode(enumerate_function.pos,
value='0',
......@@ -593,7 +578,7 @@ class IterationTransform(Visitor.VisitorTransform):
stats = loop_body)
node.target = iterable_target
node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
node.item = node.item.coerce_to(iterable_target.type, self.current_env())
node.iterator.sequence = args[0]
# recurse into loop to check for further optimisations
......@@ -638,10 +623,10 @@ class IterationTransform(Visitor.VisitorTransform):
if len(args) == 1:
bound1 = ExprNodes.IntNode(range_function.pos, value='0',
constant_result=0)
bound2 = args[0].coerce_to_integer(self.current_scope)
bound2 = args[0].coerce_to_integer(self.current_env())
else:
bound1 = args[0].coerce_to_integer(self.current_scope)
bound2 = args[1].coerce_to_integer(self.current_scope)
bound1 = args[0].coerce_to_integer(self.current_env())
bound2 = args[1].coerce_to_integer(self.current_env())
relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
......@@ -655,7 +640,7 @@ class IterationTransform(Visitor.VisitorTransform):
step.value = str(step_value)
step.constant_result = step_value
step = step.coerce_to_integer(self.current_scope)
step = step.coerce_to_integer(self.current_env())
if not bound2.is_literal:
# stop bound must be immutable => keep it in a temp var
......@@ -725,7 +710,7 @@ class IterationTransform(Visitor.VisitorTransform):
dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
key_target, value_target, tuple_target,
is_dict_temp)
iter_next_node.analyse_expressions(self.current_scope)
iter_next_node.analyse_expressions(self.current_env())
body.stats[0:0] = [iter_next_node]
if method:
......
......@@ -198,7 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()):
ExpandInplaceOperators(context),
OptimizeBuiltinCalls(context), ## Necessary?
ConsolidateOverflowCheck(context),
IterationTransform(),
IterationTransform(context),
SwitchTransform(),
DropRefcountingTransform(),
FinalOptimizePhase(context),
......
......@@ -330,6 +330,9 @@ class EnvTransform(CythonTransform):
def current_scope_node(self):
return self.env_stack[-1][0]
def global_scope(self):
return self.current_env().global_scope()
def visit_FuncDefNode(self, node):
self.env_stack.append((node, node.local_scope))
self.visitchildren(node)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment