Commit 10a5972a authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Started on TempName support, more CodeWriter

parent d30d8fb7
......@@ -35,6 +35,7 @@ class CodeWriter(TreeVisitor):
result = LinesResult()
self.result = result
self.numindents = 0
self.tempnames = {}
def write(self, tree):
self.visit(tree)
......@@ -57,6 +58,11 @@ class CodeWriter(TreeVisitor):
def line(self, s):
self.startline(s)
self.endline()
def putname(self, name):
if isinstance(name, TempName):
name = self.tempnames.setdefault(name, u"$" + name.description)
self.put(name)
def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0:
......@@ -116,7 +122,7 @@ class CodeWriter(TreeVisitor):
self.endline()
def visit_NameNode(self, node):
self.put(node.name)
self.putname(node.name)
def visit_IntNode(self, node):
self.put(node.value)
......@@ -185,7 +191,8 @@ class CodeWriter(TreeVisitor):
self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
def visit_SimpleCallNode(self, node):
self.put(node.function.name + u"(")
self.visit(node.function)
self.put(u"(")
self.comma_seperated_list(node.args)
self.put(")")
......@@ -197,9 +204,62 @@ class CodeWriter(TreeVisitor):
def visit_InPlaceAssignmentNode(self, node):
self.startline()
self.visit(node.lhs)
self.put(" %s= " % node.operator)
self.put(u" %s= " % node.operator)
self.visit(node.rhs)
self.endline()
def visit_WithStatNode(self, node):
self.startline()
self.put(u"with ")
self.visit(node.manager)
if node.target is not None:
self.put(u" as ")
self.visit(node.target)
self.endline(u":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_AttributeNode(self, node):
self.visit(node.obj)
self.put(u".%s" % node.attribute)
def visit_BoolNode(self, node):
self.put(str(node.value))
def visit_TryFinallyStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
self.line(u"finally:")
self.indent()
self.visit(node.finally_clause)
self.dedent()
def visit_TryExceptStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
for x in node.except_clauses:
self.visit(x)
if node.else_clause is not None:
self.visit(node.else_clause)
def visit_ExceptClauseNode(self, node):
self.startline(u"except")
if node.pattern is not None:
self.put(u" ")
self.visit(node.pattern)
if node.target is not None:
self.put(u", ")
self.visit(node.target)
self.endline(":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_NoneNode(self, node):
self.put(u"None")
......@@ -8,6 +8,8 @@
pyrex_prefix = "__pyx_"
temp_prefix = "__pyxtmp_"
builtin_prefix = pyrex_prefix + "builtin_"
arg_prefix = pyrex_prefix + "arg_"
funcdoc_prefix = pyrex_prefix + "doc_"
......
......@@ -16,6 +16,29 @@ from TypeSlots import \
import ControlFlow
import __builtin__
class TempName(object):
"""
Use instances of this class in order to provide a name for
anonymous, temporary functions. Each instance is considered
a seperate name, which are guaranteed not to clash with one
another or with names explicitly given as strings.
The argument to the constructor is simply a describing string
for debugging purposes and does not affect name clashes at all.
NOTE: Support for these TempNames are introduced on an as-needed
basis and will not "just work" everywhere. Places where they work:
- (none)
"""
def __init__(self, description):
self.description = description
# Spoon-feed operators for documentation purposes
def __hash__(self):
return id(self)
def __cmp__(self, other):
return cmp(id(self), id(other))
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
class Entry:
......@@ -1036,14 +1059,20 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1
entry.as_variable = var_entry
tempctr = 0
class LocalScope(Scope):
class LocalScope(Scope):
def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, outer_scope)
def mangle(self, prefix, name):
return prefix + name
if isinstance(name, TempName):
global tempctr
tempctr += 1
return u"%s%s%d" % (Naming.temp_prefix, name.description, tempctr)
else:
return prefix + name
def declare_arg(self, name, type, pos):
# Add an entry for an argument of a function.
......
......@@ -44,7 +44,21 @@ class TestTreeFragments(CythonTest):
a = T.body.stats[1].rhs.operand2.operand1
self.assertEquals(v.pos, a.pos)
def test_temps(self):
F = self.fragment(u"""
TMP
x = TMP
""")
T = F.substitute(temps=[u"TMP"])
s = T.body.stats
print s[0].expr.name
self.assert_(s[0].expr.name.__class__ is TempName)
self.assert_(s[1].rhs.name.__class__ is TempName)
self.assert_(s[0].expr.name == s[1].rhs.name)
self.assert_(s[0].expr.name != u"TMP")
self.assert_(s[0].expr.name != TempName(u"TMP"))
self.assert_(s[0].expr.name.description == u"TMP")
if __name__ == "__main__":
import unittest
......
......@@ -8,6 +8,7 @@ from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope
from Visitor import VisitorTransform
from Nodes import Node
from Symtab import TempName
from ExprNodes import NameNode
import Parsing
import Main
......@@ -92,27 +93,56 @@ class TemplateTransform(VisitorTransform):
if its name is listed in the substitutions dictionary in the
same way. It is the responsibility of the caller to make sure
that the replacement nodes is a valid expression.
Also a list "temps" should be passed. Any names listed will
be transformed into anonymous, temporary names.
Currently supported for tempnames is:
NameNode
(various function and class definition nodes etc. should be added to this)
Each replacement node gets the position of the substituted node
recursively applied to every member node.
"""
def __call__(self, node, substitutions, temps, pos):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = TempName(key)
self.temps = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
def visit_Node(self, node):
if node is None:
return node
return None
else:
c = node.clone_node()
if self.pos is not None:
c.pos = self.pos
self.visitchildren(c)
return c
def try_substitution(self, node, key):
sub = self.substitutions.get(key)
if sub is None:
return self.visit_Node(node) # make copy as usual
if sub is not None:
pos = self.pos
if pos is None: pos = node.pos
return ApplyPositionAndCopy(pos)(sub)
else:
return ApplyPositionAndCopy(node.pos)(sub)
return self.visit_Node(node) # make copy as usual
def visit_NameNode(self, node):
return self.try_substitution(node, node.name)
tempname = self.temps.get(node.name)
if tempname is not None:
# Replace name with temporary
node.name = tempname
return self.visit_Node(node)
else:
return self.try_substitution(node, node.name)
def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable
......@@ -122,10 +152,6 @@ class TemplateTransform(VisitorTransform):
else:
return self.visit_Node(node)
def __call__(self, node, substitutions):
self.substitutions = substitutions
return super(TemplateTransform, self).__call__(node)
def copy_code_tree(node):
return TreeCopier()(node)
......@@ -157,8 +183,10 @@ class TreeFragment(object):
def copy(self):
return copy_code_tree(self.root)
def substitute(self, nodes={}):
return TemplateTransform()(self.root, substitutions = nodes)
def substitute(self, nodes={}, temps=[], pos = None):
return TemplateTransform()(self.root,
substitutions = nodes,
temps = temps, pos = pos)
......
......@@ -4,6 +4,28 @@ import unittest
from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
from Cython.Compiler.Visitor import TreeVisitor
class NodeTypeWriter(TreeVisitor):
def __init__(self):
super(NodeTypeWriter, self).__init__()
self._indents = 0
self.result = []
def visit_Node(self, node):
if len(self.access_path) == 0:
name = u"(root)"
else:
tip = self.access_path[-1]
if tip[2] is not None:
name = u"%s[%d]" % tip[1:3]
else:
name = tip[1]
self.result.append(u" " * self._indents +
u"%s: %s" % (name, node.__class__.__name__))
self._indents += 1
self.visitchildren(node)
self._indents -= 1
class CythonTest(unittest.TestCase):
def assertCode(self, expected, result_tree):
......@@ -24,7 +46,15 @@ class CythonTest(unittest.TestCase):
if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_")
return TreeFragment(code, name, pxds)
def treetypes(self, root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class TransformTest(CythonTest):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment