Commit 10a5972a authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Started on TempName support, more CodeWriter

parent d30d8fb7
...@@ -35,6 +35,7 @@ class CodeWriter(TreeVisitor): ...@@ -35,6 +35,7 @@ class CodeWriter(TreeVisitor):
result = LinesResult() result = LinesResult()
self.result = result self.result = result
self.numindents = 0 self.numindents = 0
self.tempnames = {}
def write(self, tree): def write(self, tree):
self.visit(tree) self.visit(tree)
...@@ -58,6 +59,11 @@ class CodeWriter(TreeVisitor): ...@@ -58,6 +59,11 @@ class CodeWriter(TreeVisitor):
self.startline(s) self.startline(s)
self.endline() self.endline()
def putname(self, name):
if isinstance(name, TempName):
name = self.tempnames.setdefault(name, u"$" + name.description)
self.put(name)
def comma_seperated_list(self, items, output_rhs=False): def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0: if len(items) > 0:
for item in items[:-1]: for item in items[:-1]:
...@@ -116,7 +122,7 @@ class CodeWriter(TreeVisitor): ...@@ -116,7 +122,7 @@ class CodeWriter(TreeVisitor):
self.endline() self.endline()
def visit_NameNode(self, node): def visit_NameNode(self, node):
self.put(node.name) self.putname(node.name)
def visit_IntNode(self, node): def visit_IntNode(self, node):
self.put(node.value) self.put(node.value)
...@@ -185,7 +191,8 @@ class CodeWriter(TreeVisitor): ...@@ -185,7 +191,8 @@ class CodeWriter(TreeVisitor):
self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm... self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
def visit_SimpleCallNode(self, node): def visit_SimpleCallNode(self, node):
self.put(node.function.name + u"(") self.visit(node.function)
self.put(u"(")
self.comma_seperated_list(node.args) self.comma_seperated_list(node.args)
self.put(")") self.put(")")
...@@ -197,9 +204,62 @@ class CodeWriter(TreeVisitor): ...@@ -197,9 +204,62 @@ class CodeWriter(TreeVisitor):
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.startline() self.startline()
self.visit(node.lhs) self.visit(node.lhs)
self.put(" %s= " % node.operator) self.put(u" %s= " % node.operator)
self.visit(node.rhs) self.visit(node.rhs)
self.endline() self.endline()
def visit_WithStatNode(self, node):
self.startline()
self.put(u"with ")
self.visit(node.manager)
if node.target is not None:
self.put(u" as ")
self.visit(node.target)
self.endline(u":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_AttributeNode(self, node):
self.visit(node.obj)
self.put(u".%s" % node.attribute)
def visit_BoolNode(self, node):
self.put(str(node.value))
def visit_TryFinallyStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
self.line(u"finally:")
self.indent()
self.visit(node.finally_clause)
self.dedent()
def visit_TryExceptStatNode(self, node):
self.line(u"try:")
self.indent()
self.visit(node.body)
self.dedent()
for x in node.except_clauses:
self.visit(x)
if node.else_clause is not None:
self.visit(node.else_clause)
def visit_ExceptClauseNode(self, node):
self.startline(u"except")
if node.pattern is not None:
self.put(u" ")
self.visit(node.pattern)
if node.target is not None:
self.put(u", ")
self.visit(node.target)
self.endline(":")
self.indent()
self.visit(node.body)
self.dedent()
def visit_NoneNode(self, node):
self.put(u"None")
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
pyrex_prefix = "__pyx_" pyrex_prefix = "__pyx_"
temp_prefix = "__pyxtmp_"
builtin_prefix = pyrex_prefix + "builtin_" builtin_prefix = pyrex_prefix + "builtin_"
arg_prefix = pyrex_prefix + "arg_" arg_prefix = pyrex_prefix + "arg_"
funcdoc_prefix = pyrex_prefix + "doc_" funcdoc_prefix = pyrex_prefix + "doc_"
......
...@@ -16,6 +16,29 @@ from TypeSlots import \ ...@@ -16,6 +16,29 @@ from TypeSlots import \
import ControlFlow import ControlFlow
import __builtin__ import __builtin__
class TempName(object):
"""
Use instances of this class in order to provide a name for
anonymous, temporary functions. Each instance is considered
a seperate name, which are guaranteed not to clash with one
another or with names explicitly given as strings.
The argument to the constructor is simply a describing string
for debugging purposes and does not affect name clashes at all.
NOTE: Support for these TempNames are introduced on an as-needed
basis and will not "just work" everywhere. Places where they work:
- (none)
"""
def __init__(self, description):
self.description = description
# Spoon-feed operators for documentation purposes
def __hash__(self):
return id(self)
def __cmp__(self, other):
return cmp(id(self), id(other))
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
class Entry: class Entry:
...@@ -1036,6 +1059,7 @@ class ModuleScope(Scope): ...@@ -1036,6 +1059,7 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
tempctr = 0
class LocalScope(Scope): class LocalScope(Scope):
...@@ -1043,6 +1067,11 @@ class LocalScope(Scope): ...@@ -1043,6 +1067,11 @@ class LocalScope(Scope):
Scope.__init__(self, name, outer_scope, outer_scope) Scope.__init__(self, name, outer_scope, outer_scope)
def mangle(self, prefix, name): def mangle(self, prefix, name):
if isinstance(name, TempName):
global tempctr
tempctr += 1
return u"%s%s%d" % (Naming.temp_prefix, name.description, tempctr)
else:
return prefix + name return prefix + name
def declare_arg(self, name, type, pos): def declare_arg(self, name, type, pos):
......
...@@ -44,7 +44,21 @@ class TestTreeFragments(CythonTest): ...@@ -44,7 +44,21 @@ class TestTreeFragments(CythonTest):
a = T.body.stats[1].rhs.operand2.operand1 a = T.body.stats[1].rhs.operand2.operand1
self.assertEquals(v.pos, a.pos) self.assertEquals(v.pos, a.pos)
def test_temps(self):
F = self.fragment(u"""
TMP
x = TMP
""")
T = F.substitute(temps=[u"TMP"])
s = T.body.stats
print s[0].expr.name
self.assert_(s[0].expr.name.__class__ is TempName)
self.assert_(s[1].rhs.name.__class__ is TempName)
self.assert_(s[0].expr.name == s[1].rhs.name)
self.assert_(s[0].expr.name != u"TMP")
self.assert_(s[0].expr.name != TempName(u"TMP"))
self.assert_(s[0].expr.name.description == u"TMP")
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
......
...@@ -8,6 +8,7 @@ from Scanning import PyrexScanner, StringSourceDescriptor ...@@ -8,6 +8,7 @@ from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
from Visitor import VisitorTransform from Visitor import VisitorTransform
from Nodes import Node from Nodes import Node
from Symtab import TempName
from ExprNodes import NameNode from ExprNodes import NameNode
import Parsing import Parsing
import Main import Main
...@@ -93,25 +94,54 @@ class TemplateTransform(VisitorTransform): ...@@ -93,25 +94,54 @@ class TemplateTransform(VisitorTransform):
same way. It is the responsibility of the caller to make sure same way. It is the responsibility of the caller to make sure
that the replacement nodes is a valid expression. that the replacement nodes is a valid expression.
Also a list "temps" should be passed. Any names listed will
be transformed into anonymous, temporary names.
Currently supported for tempnames is:
NameNode
(various function and class definition nodes etc. should be added to this)
Each replacement node gets the position of the substituted node Each replacement node gets the position of the substituted node
recursively applied to every member node. recursively applied to every member node.
""" """
def __call__(self, node, substitutions, temps, pos):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = TempName(key)
self.temps = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
def visit_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
return node return None
else: else:
c = node.clone_node() c = node.clone_node()
if self.pos is not None:
c.pos = self.pos
self.visitchildren(c) self.visitchildren(c)
return c return c
def try_substitution(self, node, key): def try_substitution(self, node, key):
sub = self.substitutions.get(key) sub = self.substitutions.get(key)
if sub is None: if sub is not None:
return self.visit_Node(node) # make copy as usual pos = self.pos
if pos is None: pos = node.pos
return ApplyPositionAndCopy(pos)(sub)
else: else:
return ApplyPositionAndCopy(node.pos)(sub) return self.visit_Node(node) # make copy as usual
def visit_NameNode(self, node): def visit_NameNode(self, node):
tempname = self.temps.get(node.name)
if tempname is not None:
# Replace name with temporary
node.name = tempname
return self.visit_Node(node)
else:
return self.try_substitution(node, node.name) return self.try_substitution(node, node.name)
def visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
...@@ -122,10 +152,6 @@ class TemplateTransform(VisitorTransform): ...@@ -122,10 +152,6 @@ class TemplateTransform(VisitorTransform):
else: else:
return self.visit_Node(node) return self.visit_Node(node)
def __call__(self, node, substitutions):
self.substitutions = substitutions
return super(TemplateTransform, self).__call__(node)
def copy_code_tree(node): def copy_code_tree(node):
return TreeCopier()(node) return TreeCopier()(node)
...@@ -157,8 +183,10 @@ class TreeFragment(object): ...@@ -157,8 +183,10 @@ class TreeFragment(object):
def copy(self): def copy(self):
return copy_code_tree(self.root) return copy_code_tree(self.root)
def substitute(self, nodes={}): def substitute(self, nodes={}, temps=[], pos = None):
return TemplateTransform()(self.root, substitutions = nodes) return TemplateTransform()(self.root,
substitutions = nodes,
temps = temps, pos = pos)
......
...@@ -4,6 +4,28 @@ import unittest ...@@ -4,6 +4,28 @@ import unittest
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
from Cython.Compiler.Visitor import TreeVisitor
class NodeTypeWriter(TreeVisitor):
def __init__(self):
super(NodeTypeWriter, self).__init__()
self._indents = 0
self.result = []
def visit_Node(self, node):
if len(self.access_path) == 0:
name = u"(root)"
else:
tip = self.access_path[-1]
if tip[2] is not None:
name = u"%s[%d]" % tip[1:3]
else:
name = tip[1]
self.result.append(u" " * self._indents +
u"%s: %s" % (name, node.__class__.__name__))
self._indents += 1
self.visitchildren(node)
self._indents -= 1
class CythonTest(unittest.TestCase): class CythonTest(unittest.TestCase):
def assertCode(self, expected, result_tree): def assertCode(self, expected, result_tree):
...@@ -25,6 +47,14 @@ class CythonTest(unittest.TestCase): ...@@ -25,6 +47,14 @@ class CythonTest(unittest.TestCase):
name = name.replace(".", "_") name = name.replace(".", "_")
return TreeFragment(code, name, pxds) return TreeFragment(code, name, pxds)
def treetypes(self, root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class TransformTest(CythonTest): class TransformTest(CythonTest):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment