Commit d46fa3e8 authored by Stefan Behnel's avatar Stefan Behnel

new transform that hides the loop variable in a comprehension

parent 5270e035
...@@ -80,7 +80,7 @@ class Context: ...@@ -80,7 +80,7 @@ class Context:
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import AlignFunctionDefinitions from ParseTreeTransforms import ComprehensionTransform, AlignFunctionDefinitions
from AutoDocTransforms import EmbedSignature from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
...@@ -125,6 +125,7 @@ class Context: ...@@ -125,6 +125,7 @@ class Context:
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
FlattenBuiltinTypeCreation(), FlattenBuiltinTypeCreation(),
ConstantFolding(), ConstantFolding(),
ComprehensionTransform(),
IterationTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
FinalOptimizePhase(self), FinalOptimizePhase(self),
......
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
...@@ -12,6 +12,23 @@ except NameError: ...@@ -12,6 +12,23 @@ except NameError:
from sets import Set as set from sets import Set as set
import copy import copy
class NameNodeCollector(TreeVisitor):
"""Collect all NameNodes of a (sub-)tree in the ``name_nodes``
attribute.
"""
def __init__(self):
super(NameNodeCollector, self).__init__()
self.name_nodes = []
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_NameNode(self, node):
self.name_nodes.append(node)
class SkipDeclarations: class SkipDeclarations:
""" """
Variable and function declarations can often have a deep tree structure, Variable and function declarations can often have a deep tree structure,
...@@ -565,6 +582,60 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -565,6 +582,60 @@ class WithTransform(CythonTransform, SkipDeclarations):
return node return node
class ComprehensionTransform(VisitorTransform):
"""Prevent the target of list/set/dict comprehensions from leaking by
moving it into a temp variable. This mimics the behaviour of all
comprehensions in Py3 and of generator expressions in Py2.x.
This must run before the IterationTransform, which might replace
for-loops with while-loops. We only handle for-loops here.
"""
def visit_ModuleNode(self, node):
self.comprehension_targets = {}
self.visitchildren(node)
return node
def visit_Node(self, node):
# descend into statements (loops) and nodes (comprehensions)
self.visitchildren(node)
return node
def visit_ComprehensionNode(self, node):
if type(node.loop) not in (Nodes.ForInStatNode,
Nodes.ForFromStatNode):
# this should not happen!
self.visitchildren(node)
return node
outer_comprehension_targets = self.comprehension_targets
self.comprehension_targets = outer_comprehension_targets.copy()
# find all NameNodes in the loop target
target_name_collector = NameNodeCollector()
target_name_collector.visit(node.loop.target)
targets = target_name_collector.name_nodes
# create a temp variable for each target name
temps = []
for target in targets:
handle = TempHandle(target.type)
temps.append(handle)
self.comprehension_targets[target.entry.cname] = handle.ref(node.pos)
# replace name references in the loop code by their temp node
self.visitchildren(node, ['loop'])
self.comprehension_targets = outer_comprehension_targets
node.loop = TempsBlockNode(node.pos, body=node.loop, temps=temps)
return node
def visit_NameNode(self, node):
replacement = self.comprehension_targets.get(node.entry.cname)
if replacement is not None:
return replacement
return node
class DecoratorTransform(CythonTransform, SkipDeclarations): class DecoratorTransform(CythonTransform, SkipDeclarations):
def visit_DefNode(self, func_node): def visit_DefNode(self, func_node):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment