Commit fcccb15f authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Focus on visitors rather than transforms; Transform.py renamed to Visitor.py

Some changes in class hierarchies etc.; transforms no longer has a common
base class and VisitorTransform is a subclass of TreeVisitor rather than
the reverse. Also removed visitor use of get_child_accessors;
child_attrs is accessed directly (because of claims of overengineering :-) ).

--HG--
rename : Cython/Compiler/Transform.py => Cython/Compiler/Visitor.py
parent dd000240
from Cython.Compiler.Transform import ReadonlyVisitor from Cython.Compiler.Visitor import TreeVisitor
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
""" """
...@@ -25,7 +25,7 @@ class LinesResult(object): ...@@ -25,7 +25,7 @@ class LinesResult(object):
self.put(s) self.put(s)
self.newline() self.newline()
class CodeWriter(ReadonlyVisitor): class CodeWriter(TreeVisitor):
indent_string = u" " indent_string = u" "
...@@ -36,6 +36,9 @@ class CodeWriter(ReadonlyVisitor): ...@@ -36,6 +36,9 @@ class CodeWriter(ReadonlyVisitor):
self.result = result self.result = result
self.numindents = 0 self.numindents = 0
def write(self, tree):
self.visit(tree)
def indent(self): def indent(self):
self.numindents += 1 self.numindents += 1
...@@ -58,43 +61,43 @@ class CodeWriter(ReadonlyVisitor): ...@@ -58,43 +61,43 @@ class CodeWriter(ReadonlyVisitor):
def comma_seperated_list(self, items, output_rhs=False): def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0: if len(items) > 0:
for item in items[:-1]: for item in items[:-1]:
self.process_node(item) self.visit(item)
if output_rhs and item.rhs is not None: if output_rhs and item.rhs is not None:
self.put(u" = ") self.put(u" = ")
self.process_node(item.rhs) self.visit(item.rhs)
self.put(u", ") self.put(u", ")
self.process_node(items[-1]) self.visit(items[-1])
def process_Node(self, node): def visit_Node(self, node):
raise AssertionError("Node not handled by serializer: %r" % node) raise AssertionError("Node not handled by serializer: %r" % node)
def process_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.process_children(node) self.visitchildren(node)
def process_StatListNode(self, node): def visit_StatListNode(self, node):
self.process_children(node) self.visitchildren(node)
def process_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.startline(u"def %s(" % node.name) self.startline(u"def %s(" % node.name)
self.comma_seperated_list(node.args) self.comma_seperated_list(node.args)
self.endline(u"):") self.endline(u"):")
self.indent() self.indent()
self.process_node(node.body) self.visit(node.body)
self.dedent() self.dedent()
def process_CArgDeclNode(self, node): def visit_CArgDeclNode(self, node):
if node.base_type.name is not None: if node.base_type.name is not None:
self.process_node(node.base_type) self.visit(node.base_type)
self.put(u" ") self.put(u" ")
self.process_node(node.declarator) self.visit(node.declarator)
if node.default is not None: if node.default is not None:
self.put(u" = ") self.put(u" = ")
self.process_node(node.default) self.visit(node.default)
def process_CNameDeclaratorNode(self, node): def visit_CNameDeclaratorNode(self, node):
self.put(node.name) self.put(node.name)
def process_CSimpleBaseTypeNode(self, node): def visit_CSimpleBaseTypeNode(self, node):
# See Parsing.p_sign_and_longness # See Parsing.p_sign_and_longness
if node.is_basic_c_type: if node.is_basic_c_type:
self.put(("unsigned ", "", "signed ")[node.signed]) self.put(("unsigned ", "", "signed ")[node.signed])
...@@ -105,97 +108,97 @@ class CodeWriter(ReadonlyVisitor): ...@@ -105,97 +108,97 @@ class CodeWriter(ReadonlyVisitor):
self.put(node.name) self.put(node.name)
def process_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
self.startline() self.startline()
self.process_node(node.lhs) self.visit(node.lhs)
self.put(u" = ") self.put(u" = ")
self.process_node(node.rhs) self.visit(node.rhs)
self.endline() self.endline()
def process_NameNode(self, node): def visit_NameNode(self, node):
self.put(node.name) self.put(node.name)
def process_IntNode(self, node): def visit_IntNode(self, node):
self.put(node.value) self.put(node.value)
def process_IfStatNode(self, node): def visit_IfStatNode(self, node):
# The IfClauseNode is handled directly without a seperate match # The IfClauseNode is handled directly without a seperate match
# for clariy. # for clariy.
self.startline(u"if ") self.startline(u"if ")
self.process_node(node.if_clauses[0].condition) self.visit(node.if_clauses[0].condition)
self.endline(":") self.endline(":")
self.indent() self.indent()
self.process_node(node.if_clauses[0].body) self.visit(node.if_clauses[0].body)
self.dedent() self.dedent()
for clause in node.if_clauses[1:]: for clause in node.if_clauses[1:]:
self.startline("elif ") self.startline("elif ")
self.process_node(clause.condition) self.visit(clause.condition)
self.endline(":") self.endline(":")
self.indent() self.indent()
self.process_node(clause.body) self.visit(clause.body)
self.dedent() self.dedent()
if node.else_clause is not None: if node.else_clause is not None:
self.line("else:") self.line("else:")
self.indent() self.indent()
self.process_node(node.else_clause) self.visit(node.else_clause)
self.dedent() self.dedent()
def process_PassStatNode(self, node): def visit_PassStatNode(self, node):
self.startline(u"pass") self.startline(u"pass")
self.endline() self.endline()
def process_PrintStatNode(self, node): def visit_PrintStatNode(self, node):
self.startline(u"print ") self.startline(u"print ")
self.comma_seperated_list(node.args) self.comma_seperated_list(node.args)
if node.ends_with_comma: if node.ends_with_comma:
self.put(u",") self.put(u",")
self.endline() self.endline()
def process_BinopNode(self, node): def visit_BinopNode(self, node):
self.process_node(node.operand1) self.visit(node.operand1)
self.put(u" %s " % node.operator) self.put(u" %s " % node.operator)
self.process_node(node.operand2) self.visit(node.operand2)
def process_CVarDefNode(self, node): def visit_CVarDefNode(self, node):
self.startline(u"cdef ") self.startline(u"cdef ")
self.process_node(node.base_type) self.visit(node.base_type)
self.put(u" ") self.put(u" ")
self.comma_seperated_list(node.declarators, output_rhs=True) self.comma_seperated_list(node.declarators, output_rhs=True)
self.endline() self.endline()
def process_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
self.startline(u"for ") self.startline(u"for ")
self.process_node(node.target) self.visit(node.target)
self.put(u" in ") self.put(u" in ")
self.process_node(node.iterator.sequence) self.visit(node.iterator.sequence)
self.endline(u":") self.endline(u":")
self.indent() self.indent()
self.process_node(node.body) self.visit(node.body)
self.dedent() self.dedent()
if node.else_clause is not None: if node.else_clause is not None:
self.line(u"else:") self.line(u"else:")
self.indent() self.indent()
self.process_node(node.else_clause) self.visit(node.else_clause)
self.dedent() self.dedent()
def process_SequenceNode(self, node): def visit_SequenceNode(self, node):
self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm... self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
def process_SimpleCallNode(self, node): def visit_SimpleCallNode(self, node):
self.put(node.function.name + u"(") self.put(node.function.name + u"(")
self.comma_seperated_list(node.args) self.comma_seperated_list(node.args)
self.put(")") self.put(")")
def process_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
self.startline() self.startline()
self.process_node(node.expr) self.visit(node.expr)
self.endline() self.endline()
def process_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.startline() self.startline()
self.process_node(node.lhs) self.visit(node.lhs)
self.put(" %s= " % node.operator) self.put(" %s= " % node.operator)
self.process_node(node.rhs) self.visit(node.rhs)
self.endline() self.endline()
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import sys import sys
import Options import Options
import Transform
usage = """\ usage = """\
Cython (http://cython.org) is a compiler for code written in the Cython (http://cython.org) is a compiler for code written in the
...@@ -56,6 +55,7 @@ def bad_usage(): ...@@ -56,6 +55,7 @@ def bad_usage():
def parse_command_line(args): def parse_command_line(args):
def parse_add_transform(transforms, param): def parse_add_transform(transforms, param):
from Main import PHASES
def import_symbol(fqn): def import_symbol(fqn):
modsplitpt = fqn.rfind(".") modsplitpt = fqn.rfind(".")
if modsplitpt == -1: bad_usage() if modsplitpt == -1: bad_usage()
...@@ -65,7 +65,7 @@ def parse_command_line(args): ...@@ -65,7 +65,7 @@ def parse_command_line(args):
return getattr(module, symbolname) return getattr(module, symbolname)
stagename, factoryname = param.split(":") stagename, factoryname = param.split(":")
if not stagename in Transform.PHASES: if not stagename in PHASES:
bad_usage() bad_usage()
factory = import_symbol(factoryname) factory = import_symbol(factoryname)
transform = factory() transform = factory()
......
...@@ -168,6 +168,9 @@ class ExprNode(Node): ...@@ -168,6 +168,9 @@ class ExprNode(Node):
saved_subexpr_nodes = None saved_subexpr_nodes = None
is_temp = 0 is_temp = 0
def get_child_attrs(self): return self.subexprs
child_attrs = property(fget=get_child_attrs)
def get_child_attrs(self): def get_child_attrs(self):
"""Automatically provide the contents of subexprs as children, unless child_attr """Automatically provide the contents of subexprs as children, unless child_attr
has been declared. See Nodes.Node.get_child_accessors.""" has been declared. See Nodes.Node.get_child_accessors."""
......
...@@ -17,7 +17,22 @@ from Symtab import BuiltinScope, ModuleScope ...@@ -17,7 +17,22 @@ from Symtab import BuiltinScope, ModuleScope
import Code import Code
from Cython.Utils import replace_suffix from Cython.Utils import replace_suffix
from Cython import Utils from Cython import Utils
import Transform
# Note: PHASES and TransformSet should be removed soon; but that's for
# another day and another commit.
PHASES = [
'before_analyse_function', # run in FuncDefNode.generate_function_definitions
'after_analyse_function' # run in FuncDefNode.generate_function_definitions
]
class TransformSet(dict):
def __init__(self):
for name in PHASES:
self[name] = []
def run(self, name, node, **options):
assert name in self, "Transform phase %s not defined" % name
for transform in self[name]:
transform(node, phase=name, **options)
verbose = 0 verbose = 0
...@@ -364,7 +379,7 @@ default_options = dict( ...@@ -364,7 +379,7 @@ default_options = dict(
output_file = None, output_file = None,
annotate = False, annotate = False,
generate_pxi = 0, generate_pxi = 0,
transforms = Transform.TransformSet(), transforms = TransformSet(),
working_path = "") working_path = "")
if sys.platform == "mac": if sys.platform == "mac":
......
...@@ -100,7 +100,9 @@ class Node(object): ...@@ -100,7 +100,9 @@ class Node(object):
is_name = 0 is_name = 0
is_literal = 0 is_literal = 0
# All descandants should set child_attrs (see get_child_accessors) # All descandants should set child_attrs to a list of the attributes
# containing nodes considered "children" in the tree. Each such attribute
# can either contain a single node or a list of nodes. See Visitor.py.
child_attrs = None child_attrs = None
def __init__(self, pos, **kw): def __init__(self, pos, **kw):
...@@ -156,12 +158,12 @@ class Node(object): ...@@ -156,12 +158,12 @@ class Node(object):
copied. Lists containing child nodes are thus seen as a way for the node copied. Lists containing child nodes are thus seen as a way for the node
to hold multiple children directly; the list is not treated as a seperate to hold multiple children directly; the list is not treated as a seperate
level in the tree.""" level in the tree."""
c = copy.copy(self) result = copy.copy(self)
for acc in c.get_child_accessors(): for attrname in result.child_attrs:
value = acc.get() value = getattr(result, attrname)
if isinstance(value, list): if isinstance(value, list):
acc.set([x for x in value]) setattr(result, attrname, value)
return c return result
# #
......
#
# Tree transform framework
#
import Nodes
import ExprNodes
import inspect
class Transform(object):
# parent The parent node of the currently processed node.
# access_path [(Node, str, int|None)]
# A stack providing information about where in the tree
# we are located.
# The first tuple item is the a node in the tree (parent nodes).
# The second tuple item is the attribute name followed, while
# the third is the index if the attribute is a list, or
# None otherwise.
#
# Additionally, any keyword arguments to __call__ will be set as fields while in
# a transformation.
# Transforms for the parse tree should usually extend this class for convenience.
# The caller of a transform will only first call initialize and then process_node on
# the root node, the rest are utility functions and conventions.
# Transformations usually happens by recursively filtering through the stream.
# process_node is always expected to return a new node, however it is ok to simply
# return the input node untouched. Returning None will remove the node from the
# parent.
def process_children(self, node, attrnames=None):
"""For all children of node, either process_list (if isinstance(node, list))
or process_node (otherwise) is called."""
if node == None: return
oldparent = self.parent
self.parent = node
for childacc in node.get_child_accessors():
attrname = childacc.name()
if attrnames is not None and attrname not in attrnames:
continue
child = childacc.get()
if isinstance(child, list):
newchild = self.process_list(child, attrname)
if not isinstance(newchild, list): raise Exception("Cannot replace list with non-list!")
else:
self.access_path.append((node, attrname, None))
newchild = self.process_node(child)
if newchild is not None and not isinstance(newchild, Nodes.Node):
raise Exception("Cannot replace Node with non-Node!")
self.access_path.pop()
childacc.set(newchild)
self.parent = oldparent
def process_list(self, l, attrname):
"""Calls process_node on all the items in l. Each item in l is transformed
in-place by the item process_node returns, then l is returned. If process_node
returns None, the item is removed from the list."""
for idx in xrange(len(l)):
self.access_path.append((self.parent, attrname, idx))
l[idx] = self.process_node(l[idx])
self.access_path.pop()
return [x for x in l if x is not None]
def process_node(self, node):
"""Override this method to process nodes. name specifies which kind of relation the
parent has with child. This method should always return the node which the parent
should use for this relation, which can either be the same node, None to remove
the node, or a different node."""
raise NotImplementedError("Not implemented")
def __call__(self, root, **params):
self.parent = None
self.access_path = []
for key, value in params.iteritems():
setattr(self, key, value)
root = self.process_node(root)
for key, value in params.iteritems():
delattr(self, key)
del self.parent
del self.access_path
return root
class VisitorTransform(Transform):
# Note: If needed, this can be replaced with a more efficient metaclass
# approach, resolving the jump table at module load time.
def __init__(self, **kw):
"""readonly - If this is set to True, the results of process_node
will be discarded (so that one can return None without changing
the tree)."""
super(VisitorTransform, self).__init__(**kw)
self.visitmethods = {'process_' : {}, 'pre_' : {}, 'post_' : {}}
def get_visitfunc(self, prefix, cls):
mname = prefix + cls.__name__
m = self.visitmethods[prefix].get(mname)
if m is None:
# Must resolve, try entire hierarchy
for cls in inspect.getmro(cls):
m = getattr(self, prefix + cls.__name__, None)
if m is not None:
break
if m is None: raise RuntimeError("Not a Node descendant: " + cls.__name__)
self.visitmethods[prefix][mname] = m
return m
def process_node(self, node):
# Pass on to calls registered in self.visitmethods
if node is None:
return None
result = self.get_visitfunc("process_", node.__class__)(node)
return result
def process_Node(self, node):
descend = self.get_visitfunc("pre_", node.__class__)(node)
if descend:
self.process_children(node)
self.get_visitfunc("post_", node.__class__)(node)
return node
def pre_Node(self, node):
return True
def post_Node(self, node):
pass
class ReadonlyVisitor(VisitorTransform):
"""
Like VisitorTransform, however process_X methods do not have to return
the result node -- the result of process_X is always discarded and the
structure of the original tree is not changed.
"""
def process_node(self, node):
super(ReadonlyVisitor, self).process_node(node) # discard result
return node
# Utils
def ensure_statlist(node):
if not isinstance(node, Nodes.StatListNode):
node = Nodes.StatListNode(pos=node.pos, stats=[node])
return node
def replace_node(ptr, value):
"""Replaces a node. ptr is of the form used on the access path stack
(parent, attrname, listidx|None)
"""
parent, attrname, listidx = ptr
if listidx is None:
setattr(parent, attrname, value)
else:
getattr(parent, attrname)[listidx] = value
class PrintTree(Transform):
"""Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information
about nodes. """
def __init__(self):
Transform.__init__(self)
self._indent = ""
def indent(self):
self._indent += " "
def unindent(self):
self._indent = self._indent[:-2]
def __call__(self, tree, phase=None, **params):
print("Parse tree dump at phase '%s'" % phase)
super(PrintTree, self).__call__(tree, phase=phase, **params)
# Don't do anything about process_list, the defaults gives
# nice-looking name[idx] nodes which will visually appear
# under the parent-node, not displaying the list itself in
# the hierarchy.
def process_node(self, node):
if len(self.access_path) == 0:
name = "(root)"
else:
parent, attr, idx = self.access_path[-1]
if idx is not None:
name = "%s[%d]" % (attr, idx)
else:
name = attr
print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
self.indent()
self.process_children(node)
self.unindent()
return node
def repr_of(self, node):
if node is None:
return "(none)"
else:
result = node.__class__.__name__
if isinstance(node, ExprNodes.NameNode):
result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
elif isinstance(node, Nodes.DefNode):
result += "(name=\"%s\")" % node.name
elif isinstance(node, ExprNodes.ExprNode):
t = node.type
result += "(type=%s)" % repr(t)
return result
PHASES = [
'before_analyse_function', # run in FuncDefNode.generate_function_definitions
'after_analyse_function' # run in FuncDefNode.generate_function_definitions
]
class TransformSet(dict):
def __init__(self):
for name in PHASES:
self[name] = []
def run(self, name, node, **options):
assert name in self, "Transform phase %s not defined" % name
for transform in self[name]:
transform(node, phase=name, **options)
...@@ -6,7 +6,7 @@ import re ...@@ -6,7 +6,7 @@ import re
from cStringIO import StringIO from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
from Transform import Transform, VisitorTransform from Visitor import VisitorTransform
from Nodes import Node from Nodes import Node
from ExprNodes import NameNode from ExprNodes import NameNode
import Parsing import Parsing
...@@ -57,31 +57,35 @@ def parse_from_strings(name, code, pxds={}): ...@@ -57,31 +57,35 @@ def parse_from_strings(name, code, pxds={}):
tree = Parsing.p_module(scanner, 0, module_name) tree = Parsing.p_module(scanner, 0, module_name)
return tree return tree
class TreeCopier(Transform): class TreeCopier(VisitorTransform):
def process_node(self, node): def visit_Node(self, node):
if node is None: if node is None:
return node return node
else: else:
c = node.clone_node() c = node.clone_node()
self.process_children(c) self.visitchildren(c)
return c return c
class SubstitutionTransform(VisitorTransform): class SubstitutionTransform(VisitorTransform):
def process_Node(self, node): def visit_Node(self, node):
if node is None: if node is None:
return node return node
else: else:
c = node.clone_node() c = node.clone_node()
self.process_children(c) self.visitchildren(c)
return c return c
def process_NameNode(self, node): def visit_NameNode(self, node):
if node.name in self.substitute: if node.name in self.substitute:
# Name matched, substitute node # Name matched, substitute node
return self.substitute[node.name] return self.substitute[node.name]
else: else:
# Clone # Clone
return self.process_Node(node) return self.visit_Node(node)
def __call__(self, node, substitute):
self.substitute = substitute
return super(SubstitutionTransform, self).__call__(node)
def copy_code_tree(node): def copy_code_tree(node):
return TreeCopier()(node) return TreeCopier()(node)
......
#
# Tree visitor and transform framework
#
import Nodes
import ExprNodes
import inspect
class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object."""
# Note: If needed, this can be replaced with a more efficient metaclass
# approach, resolving the jump table at module load time rather than per visitor
# instance.
def __init__(self):
self.dispatch_table = {}
def visit(self, obj):
pattern = "visit_%s"
cls = obj.__class__
mname = pattern % cls.__name__
m = self.dispatch_table.get(mname)
if m is None:
# Must resolve, try entire hierarchy
mro = inspect.getmro(cls)
for cls in mro:
m = getattr(self, pattern % cls.__name__, None)
if m is not None:
break
else:
raise RuntimeError("Visitor does not accept object: %s" % obj)
self.dispatch_table[mname] = m
return m(obj)
class TreeVisitor(BasicVisitor):
"""
Base class for writing visitors for a Cython tree, contains utilities for
recursing such trees using visitors. Each node is
expected to have a child_attrs iterable containing the names of attributes
containing child nodes or lists of child nodes. Lists are not considered
part of the tree structure (i.e. contained nodes are considered direct
children of the parent node).
visit_children visits each of the children of a given node (see the visit_children
documentation). When recursing the tree using visit_children, an attribute
access_path is maintained which gives information about the current location
in the tree as a stack of tuples: (parent_node, attrname, index), representing
the node, attribute and optional list index that was taken in each step in the path to
the current node.
Example:
>>> class SampleNode:
... child_attrs = ["head", "body"]
... def __init__(self, value, head=None, body=None):
... self.value = value
... self.head = head
... self.body = body
... def __repr__(self): return "SampleNode(%s)" % self.value
...
>>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
>>> class MyVisitor(TreeVisitor):
... def visit_SampleNode(self, node):
... print "in", node.value, self.access_path
... self.visitchildren(node)
... print "out", node.value
...
>>> MyVisitor().visit(tree)
in 0 []
in 1 [(SampleNode(0), 'head', None)]
out 1
in 2 [(SampleNode(0), 'body', 0)]
out 2
in 3 [(SampleNode(0), 'body', 1)]
out 3
out 0
"""
def __init__(self):
super(TreeVisitor, self).__init__()
self.access_path = []
def visitchild(self, child, parent, attrname, idx):
self.access_path.append((parent, attrname, idx))
result = self.visit(child)
self.access_path.pop()
return result
def visitchildren(self, parent, attrs=None):
"""
Visits the children of the given parent. If parent is None, returns
immediately (returning None).
The return value is a dictionary giving the results for each
child (mapping the attribute name to either the return value
or a list of return values (in the case of multiple children
in an attribute)).
"""
if parent is None: return None
result = {}
for attr in parent.child_attrs:
if attrs is not None and attr not in attrs: continue
child = getattr(parent, attr)
if child is not None:
if isinstance(child, list):
childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
else:
childretval = self.visitchild(child, parent, attr, None)
result[attr] = childretval
return result
class VisitorTransform(TreeVisitor):
"""
A tree transform is a base class for visitors that wants to do stream
processing of the structure (rather than attributes etc.) of a tree.
It implements __call__ to simply visit the argument node.
It requires the visitor methods to return the nodes which should take
the place of the visited node in the result tree (which can be the same
or one or more replacement). Specifically, if the return value from
a visitor method is:
- [] or None; the visited node will be removed (set to None if an attribute and
removed if in a list)
- A single node; the visited node will be replaced by the returned node.
- A list of nodes; the visited nodes will be replaced by all the nodes in the
list. This will only work if the node was already a member of a list; if it
was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.)
"""
def visitchildren(self, parent, attrs=None):
result = super(VisitorTransform, self).visitchildren(parent, attrs)
for attr, newnode in result.iteritems():
if not isinstance(newnode, list):
setattr(parent, attr, newnode)
else:
# Flatten the list one level and remove any None
newlist = []
for x in newnode:
if x is not None:
if isinstance(x, list):
newlist += x
else:
newlist.append(x)
setattr(parent, attr, newlist)
return result
def __call__(self, root):
return self.visit(root)
# Utils
def ensure_statlist(node):
if not isinstance(node, Nodes.StatListNode):
node = Nodes.StatListNode(pos=node.pos, stats=[node])
return node
def replace_node(ptr, value):
"""Replaces a node. ptr is of the form used on the access path stack
(parent, attrname, listidx|None)
"""
parent, attrname, listidx = ptr
if listidx is None:
setattr(parent, attrname, value)
else:
getattr(parent, attrname)[listidx] = value
class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information
about nodes. """
def __init__(self):
Transform.__init__(self)
self._indent = ""
def indent(self):
self._indent += " "
def unindent(self):
self._indent = self._indent[:-2]
def __call__(self, tree, phase=None):
print("Parse tree dump at phase '%s'" % phase)
# Don't do anything about process_list, the defaults gives
# nice-looking name[idx] nodes which will visually appear
# under the parent-node, not displaying the list itself in
# the hierarchy.
def visit_Node(self, node):
if len(self.access_path) == 0:
name = "(root)"
else:
parent, attr, idx = self.access_path[-1]
if idx is not None:
name = "%s[%d]" % (attr, idx)
else:
name = attr
print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
self.indent()
self.visitchildren(node)
self.unindent()
return node
def repr_of(self, node):
if node is None:
return "(none)"
else:
result = node.__class__.__name__
if isinstance(node, ExprNodes.NameNode):
result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
elif isinstance(node, Nodes.DefNode):
result += "(name=\"%s\")" % node.name
elif isinstance(node, ExprNodes.ExprNode):
t = node.type
result += "(type=%s)" % repr(t)
return result
if __name__ == "__main__":
import doctest
doctest.testmod()
...@@ -8,7 +8,7 @@ from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent ...@@ -8,7 +8,7 @@ from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
class CythonTest(unittest.TestCase): class CythonTest(unittest.TestCase):
def assertCode(self, expected, result_tree): def assertCode(self, expected, result_tree):
writer = CodeWriter() writer = CodeWriter()
writer(result_tree) writer.write(result_tree)
result_lines = writer.result.lines result_lines = writer.result.lines
expected_lines = strip_common_indent(expected.split("\n")) expected_lines = strip_common_indent(expected.split("\n"))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment