Commit b1febf5b authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Make TreeFragment.py more readable; copy substitution nodes and copy over pos...

Make TreeFragment.py more readable; copy substitution nodes and copy over pos attributes on substitutions
parent 489d5c4b
......@@ -8,7 +8,7 @@ class TestTreeFragments(CythonTest):
T = F.copy()
self.assertCode(u"x = 4", T)
def test_copy_is_independent(self):
def test_copy_is_taken(self):
F = self.fragment(u"if True: x = 4")
T1 = F.root
T2 = F.copy()
......@@ -16,6 +16,12 @@ class TestTreeFragments(CythonTest):
T2.body.if_clauses[0].body.lhs.name = "other"
self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
def test_substitutions_are_copied(self):
T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")})
self.assertEqual("x", T.body.expr.operand1.name)
self.assertEqual("x", T.body.expr.operand2.name)
self.assert_(T.body.expr.operand1 is not T.body.expr.operand2)
def test_substitution(self):
F = self.fragment(u"x = 4")
y = NameNode(pos=None, name=u"y")
......@@ -26,7 +32,19 @@ class TestTreeFragments(CythonTest):
F = self.fragment(u"PASS")
pass_stat = PassStatNode(pos=None)
T = F.substitute({"PASS" : pass_stat})
self.assert_(T.body is pass_stat, T.body)
self.assert_(isinstance(T.body, PassStatNode), T.body)
def test_pos_is_transferred(self):
F = self.fragment(u"""
x = y
x = u * v ** w
""")
T = F.substitute({"v" : NameNode(pos=None, name="a")})
v = F.root.body.stats[1].rhs.operand2.operand1
a = T.body.stats[1].rhs.operand2.operand1
self.assertEquals(v.pos, a.pos)
if __name__ == "__main__":
import unittest
......
......@@ -66,7 +66,36 @@ class TreeCopier(VisitorTransform):
self.visitchildren(c)
return c
class SubstitutionTransform(VisitorTransform):
class ApplyPositionAndCopy(TreeCopier):
def __init__(self, pos):
super(ApplyPositionAndCopy, self).__init__()
self.pos = pos
def visit_Node(self, node):
copy = super(ApplyPositionAndCopy, self).visit_Node(node)
copy.pos = self.pos
return copy
class TemplateTransform(VisitorTransform):
"""
Makes a copy of a template tree while doing substitutions.
A dictionary "substitutions" should be passed in when calling
the transform; mapping names to replacement nodes. Then replacement
happens like this:
- If an ExprStatNode contains a single NameNode, whose name is
a key in the substitutions dictionary, the ExprStatNode is
replaced with a copy of the tree given in the dictionary.
It is the responsibility of the caller that the replacement
node is a valid statement.
- If a single NameNode is otherwise encountered, it is replaced
if its name is listed in the substitutions dictionary in the
same way. It is the responsibility of the caller to make sure
that the replacement nodes is a valid expression.
Each replacement node gets the position of the substituted node
recursively applied to every member node.
"""
def visit_Node(self, node):
if node is None:
return node
......@@ -75,25 +104,27 @@ class SubstitutionTransform(VisitorTransform):
self.visitchildren(c)
return c
def visit_NameNode(self, node):
if node.name in self.substitute:
# Name matched, substitute node
return self.substitute[node.name]
def try_substitution(self, node, key):
sub = self.substitutions.get(key)
if sub is None:
return self.visit_Node(node) # make copy as usual
else:
# Clone
return self.visit_Node(node)
return ApplyPositionAndCopy(node.pos)(sub)
def visit_NameNode(self, node):
return self.try_substitution(node, node.name)
def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable
# NameNode, we replace the entire statement, not only the NameNode
if isinstance(node.expr, NameNode) and node.expr.name in self.substitute:
return self.substitute[node.expr.name]
if isinstance(node.expr, NameNode):
return self.try_substitution(node, node.expr.name)
else:
return self.visit_Node(node)
def __call__(self, node, substitute):
self.substitute = substitute
return super(SubstitutionTransform, self).__call__(node)
def __call__(self, node, substitutions):
self.substitutions = substitutions
return super(TemplateTransform, self).__call__(node)
def copy_code_tree(node):
return TreeCopier()(node)
......@@ -127,7 +158,7 @@ class TreeFragment(object):
return copy_code_tree(self.root)
def substitute(self, nodes={}):
return SubstitutionTransform()(self.root, substitute = nodes)
return TemplateTransform()(self.root, substitutions = nodes)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment