Commit b1febf5b authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Make TreeFragment.py more readable; copy substitution nodes and copy over pos...

Make TreeFragment.py more readable; copy substitution nodes and copy over pos attributes on substitutions
parent 489d5c4b
...@@ -8,7 +8,7 @@ class TestTreeFragments(CythonTest): ...@@ -8,7 +8,7 @@ class TestTreeFragments(CythonTest):
T = F.copy() T = F.copy()
self.assertCode(u"x = 4", T) self.assertCode(u"x = 4", T)
def test_copy_is_independent(self): def test_copy_is_taken(self):
F = self.fragment(u"if True: x = 4") F = self.fragment(u"if True: x = 4")
T1 = F.root T1 = F.root
T2 = F.copy() T2 = F.copy()
...@@ -16,6 +16,12 @@ class TestTreeFragments(CythonTest): ...@@ -16,6 +16,12 @@ class TestTreeFragments(CythonTest):
T2.body.if_clauses[0].body.lhs.name = "other" T2.body.if_clauses[0].body.lhs.name = "other"
self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name) self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
def test_substitutions_are_copied(self):
T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")})
self.assertEqual("x", T.body.expr.operand1.name)
self.assertEqual("x", T.body.expr.operand2.name)
self.assert_(T.body.expr.operand1 is not T.body.expr.operand2)
def test_substitution(self): def test_substitution(self):
F = self.fragment(u"x = 4") F = self.fragment(u"x = 4")
y = NameNode(pos=None, name=u"y") y = NameNode(pos=None, name=u"y")
...@@ -26,7 +32,19 @@ class TestTreeFragments(CythonTest): ...@@ -26,7 +32,19 @@ class TestTreeFragments(CythonTest):
F = self.fragment(u"PASS") F = self.fragment(u"PASS")
pass_stat = PassStatNode(pos=None) pass_stat = PassStatNode(pos=None)
T = F.substitute({"PASS" : pass_stat}) T = F.substitute({"PASS" : pass_stat})
self.assert_(T.body is pass_stat, T.body) self.assert_(isinstance(T.body, PassStatNode), T.body)
def test_pos_is_transferred(self):
F = self.fragment(u"""
x = y
x = u * v ** w
""")
T = F.substitute({"v" : NameNode(pos=None, name="a")})
v = F.root.body.stats[1].rhs.operand2.operand1
a = T.body.stats[1].rhs.operand2.operand1
self.assertEquals(v.pos, a.pos)
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
...@@ -66,7 +66,36 @@ class TreeCopier(VisitorTransform): ...@@ -66,7 +66,36 @@ class TreeCopier(VisitorTransform):
self.visitchildren(c) self.visitchildren(c)
return c return c
class SubstitutionTransform(VisitorTransform): class ApplyPositionAndCopy(TreeCopier):
def __init__(self, pos):
super(ApplyPositionAndCopy, self).__init__()
self.pos = pos
def visit_Node(self, node):
copy = super(ApplyPositionAndCopy, self).visit_Node(node)
copy.pos = self.pos
return copy
class TemplateTransform(VisitorTransform):
"""
Makes a copy of a template tree while doing substitutions.
A dictionary "substitutions" should be passed in when calling
the transform; mapping names to replacement nodes. Then replacement
happens like this:
- If an ExprStatNode contains a single NameNode, whose name is
a key in the substitutions dictionary, the ExprStatNode is
replaced with a copy of the tree given in the dictionary.
It is the responsibility of the caller that the replacement
node is a valid statement.
- If a single NameNode is otherwise encountered, it is replaced
if its name is listed in the substitutions dictionary in the
same way. It is the responsibility of the caller to make sure
that the replacement nodes is a valid expression.
Each replacement node gets the position of the substituted node
recursively applied to every member node.
"""
def visit_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
return node return node
...@@ -75,25 +104,27 @@ class SubstitutionTransform(VisitorTransform): ...@@ -75,25 +104,27 @@ class SubstitutionTransform(VisitorTransform):
self.visitchildren(c) self.visitchildren(c)
return c return c
def visit_NameNode(self, node): def try_substitution(self, node, key):
if node.name in self.substitute: sub = self.substitutions.get(key)
# Name matched, substitute node if sub is None:
return self.substitute[node.name] return self.visit_Node(node) # make copy as usual
else: else:
# Clone return ApplyPositionAndCopy(node.pos)(sub)
return self.visit_Node(node)
def visit_NameNode(self, node):
return self.try_substitution(node, node.name)
def visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable # If an expression-as-statement consists of only a replaceable
# NameNode, we replace the entire statement, not only the NameNode # NameNode, we replace the entire statement, not only the NameNode
if isinstance(node.expr, NameNode) and node.expr.name in self.substitute: if isinstance(node.expr, NameNode):
return self.substitute[node.expr.name] return self.try_substitution(node, node.expr.name)
else: else:
return self.visit_Node(node) return self.visit_Node(node)
def __call__(self, node, substitute): def __call__(self, node, substitutions):
self.substitute = substitute self.substitutions = substitutions
return super(SubstitutionTransform, self).__call__(node) return super(TemplateTransform, self).__call__(node)
def copy_code_tree(node): def copy_code_tree(node):
return TreeCopier()(node) return TreeCopier()(node)
...@@ -127,7 +158,7 @@ class TreeFragment(object): ...@@ -127,7 +158,7 @@ class TreeFragment(object):
return copy_code_tree(self.root) return copy_code_tree(self.root)
def substitute(self, nodes={}): def substitute(self, nodes={}):
return SubstitutionTransform()(self.root, substitute = nodes) return TemplateTransform()(self.root, substitutions = nodes)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment