Commit 60056ff4 authored by Stefan Behnel's avatar Stefan Behnel

convert IterationTransform to inherit from EnvTransform for better scope tracking

parent fc691c4c
...@@ -51,28 +51,13 @@ def is_common_value(a, b): ...@@ -51,28 +51,13 @@ def is_common_value(a, b):
return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
return False return False
class IterationTransform(Visitor.VisitorTransform): class IterationTransform(Visitor.EnvTransform):
"""Transform some common for-in loop patterns into efficient C loops: """Transform some common for-in loop patterns into efficient C loops:
- for-in-dict loop becomes a while loop calling PyDict_Next() - for-in-dict loop becomes a while loop calling PyDict_Next()
- for-in-enumerate is replaced by an external counter variable - for-in-enumerate is replaced by an external counter variable
- for-in-range loop becomes a plain C for loop - for-in-range loop becomes a plain C for loop
""" """
visit_Node = Visitor.VisitorTransform.recurse_to_children
def visit_ModuleNode(self, node):
self.current_scope = node.scope
self.module_scope = node.scope
self.visitchildren(node)
return node
def visit_DefNode(self, node):
oldscope = self.current_scope
self.current_scope = node.entry.scope
self.visitchildren(node)
self.current_scope = oldscope
return node
def visit_PrimaryCmpNode(self, node): def visit_PrimaryCmpNode(self, node):
if node.is_ptr_contains(): if node.is_ptr_contains():
...@@ -110,7 +95,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -110,7 +95,7 @@ class IterationTransform(Visitor.VisitorTransform):
iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
body=if_node, body=if_node,
else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
for_loop.analyse_expressions(self.current_scope) for_loop.analyse_expressions(self.current_env())
for_loop = self(for_loop) for_loop = self(for_loop)
new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
...@@ -160,7 +145,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -160,7 +145,7 @@ class IterationTransform(Visitor.VisitorTransform):
base_obj = iterator.self or function.obj base_obj = iterator.self or function.obj
method = function.attribute method = function.attribute
# in Py3, items() is equivalent to Py2's iteritems() # in Py3, items() is equivalent to Py2's iteritems()
is_safe_iter = self.module_scope.context.language_level >= 3 is_safe_iter = self.global_scope().context.language_level >= 3
if not is_safe_iter and method in ('keys', 'values', 'items'): if not is_safe_iter and method in ('keys', 'values', 'items'):
# try to reduce this to the corresponding .iter*() methods # try to reduce this to the corresponding .iter*() methods
...@@ -319,7 +304,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -319,7 +304,7 @@ class IterationTransform(Visitor.VisitorTransform):
) )
if target_value.type != node.target.type: if target_value.type != node.target.type:
target_value = target_value.coerce_to(node.target.type, target_value = target_value.coerce_to(node.target.type,
self.current_scope) self.current_env())
target_assign = Nodes.SingleAssignmentNode( target_assign = Nodes.SingleAssignmentNode(
pos = node.target.pos, pos = node.target.pos,
lhs = node.target, lhs = node.target,
...@@ -419,12 +404,12 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -419,12 +404,12 @@ class IterationTransform(Visitor.VisitorTransform):
if start.constant_result is None: if start.constant_result is None:
start = None start = None
else: else:
start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope) start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
if stop: if stop:
if stop.constant_result is None: if stop.constant_result is None:
stop = None stop = None
else: else:
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope) stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
if stop is None: if stop is None:
if neg_step: if neg_step:
stop = ExprNodes.IntNode( stop = ExprNodes.IntNode(
...@@ -443,7 +428,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -443,7 +428,7 @@ class IterationTransform(Visitor.VisitorTransform):
ptr_type = slice_base.type ptr_type = slice_base.type
if ptr_type.is_array: if ptr_type.is_array:
ptr_type = ptr_type.element_ptr_type() ptr_type = ptr_type.element_ptr_type()
carray_ptr = slice_base.coerce_to_simple(self.current_scope) carray_ptr = slice_base.coerce_to_simple(self.current_env())
if start and start.constant_result != 0: if start and start.constant_result != 0:
start_ptr_node = ExprNodes.AddNode( start_ptr_node = ExprNodes.AddNode(
...@@ -462,7 +447,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -462,7 +447,7 @@ class IterationTransform(Visitor.VisitorTransform):
operator='+', operator='+',
operand2=stop, operand2=stop,
type=ptr_type type=ptr_type
).coerce_to_simple(self.current_scope) ).coerce_to_simple(self.current_env())
else: else:
stop_ptr_node = ExprNodes.CloneNode(carray_ptr) stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
...@@ -497,7 +482,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -497,7 +482,7 @@ class IterationTransform(Visitor.VisitorTransform):
if target_value.type != node.target.type: if target_value.type != node.target.type:
target_value = target_value.coerce_to(node.target.type, target_value = target_value.coerce_to(node.target.type,
self.current_scope) self.current_env())
target_assign = Nodes.SingleAssignmentNode( target_assign = Nodes.SingleAssignmentNode(
pos = node.target.pos, pos = node.target.pos,
...@@ -553,7 +538,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -553,7 +538,7 @@ class IterationTransform(Visitor.VisitorTransform):
return node return node
if len(args) == 2: if len(args) == 2:
start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_scope) start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
else: else:
start = ExprNodes.IntNode(enumerate_function.pos, start = ExprNodes.IntNode(enumerate_function.pos,
value='0', value='0',
...@@ -593,7 +578,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -593,7 +578,7 @@ class IterationTransform(Visitor.VisitorTransform):
stats = loop_body) stats = loop_body)
node.target = iterable_target node.target = iterable_target
node.item = node.item.coerce_to(iterable_target.type, self.current_scope) node.item = node.item.coerce_to(iterable_target.type, self.current_env())
node.iterator.sequence = args[0] node.iterator.sequence = args[0]
# recurse into loop to check for further optimisations # recurse into loop to check for further optimisations
...@@ -638,10 +623,10 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -638,10 +623,10 @@ class IterationTransform(Visitor.VisitorTransform):
if len(args) == 1: if len(args) == 1:
bound1 = ExprNodes.IntNode(range_function.pos, value='0', bound1 = ExprNodes.IntNode(range_function.pos, value='0',
constant_result=0) constant_result=0)
bound2 = args[0].coerce_to_integer(self.current_scope) bound2 = args[0].coerce_to_integer(self.current_env())
else: else:
bound1 = args[0].coerce_to_integer(self.current_scope) bound1 = args[0].coerce_to_integer(self.current_env())
bound2 = args[1].coerce_to_integer(self.current_scope) bound2 = args[1].coerce_to_integer(self.current_env())
relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed) relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
...@@ -655,7 +640,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -655,7 +640,7 @@ class IterationTransform(Visitor.VisitorTransform):
step.value = str(step_value) step.value = str(step_value)
step.constant_result = step_value step.constant_result = step_value
step = step.coerce_to_integer(self.current_scope) step = step.coerce_to_integer(self.current_env())
if not bound2.is_literal: if not bound2.is_literal:
# stop bound must be immutable => keep it in a temp var # stop bound must be immutable => keep it in a temp var
...@@ -725,7 +710,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -725,7 +710,7 @@ class IterationTransform(Visitor.VisitorTransform):
dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
key_target, value_target, tuple_target, key_target, value_target, tuple_target,
is_dict_temp) is_dict_temp)
iter_next_node.analyse_expressions(self.current_scope) iter_next_node.analyse_expressions(self.current_env())
body.stats[0:0] = [iter_next_node] body.stats[0:0] = [iter_next_node]
if method: if method:
......
...@@ -198,7 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -198,7 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()):
ExpandInplaceOperators(context), ExpandInplaceOperators(context),
OptimizeBuiltinCalls(context), ## Necessary? OptimizeBuiltinCalls(context), ## Necessary?
ConsolidateOverflowCheck(context), ConsolidateOverflowCheck(context),
IterationTransform(), IterationTransform(context),
SwitchTransform(), SwitchTransform(),
DropRefcountingTransform(), DropRefcountingTransform(),
FinalOptimizePhase(context), FinalOptimizePhase(context),
......
...@@ -330,6 +330,9 @@ class EnvTransform(CythonTransform): ...@@ -330,6 +330,9 @@ class EnvTransform(CythonTransform):
def current_scope_node(self): def current_scope_node(self):
return self.env_stack[-1][0] return self.env_stack[-1][0]
def global_scope(self):
return self.current_env().global_scope()
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.env_stack.append((node, node.local_scope)) self.env_stack.append((node, node.local_scope))
self.visitchildren(node) self.visitchildren(node)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment