Commit 2673725c authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

New features: CodeWriter, TreeFragment, and a transform unit test framework.

See the documentation of each class for details.

It is a rather big commit, however seperating it is non-trivial. The tests
for all of these features all rely on using each other, so there's a
circular dependency in the tests and I wanted to commit the tests and
features at the same time. (However, the non-test-code does not have a circular
dependency.)
parent 6f5104eb
from Cython.Compiler.Transform import ReadonlyVisitor
from Cython.Compiler.Nodes import *
"""
Serializes a Cython code tree to Cython code. This is primarily useful for
debugging and testing purposes.
The output is in a strict format, no whitespace or comments from the input
is preserved (and it could not be as it is not present in the code tree).
"""
class LinesResult(object):
def __init__(self):
self.lines = []
self.s = u""
def put(self, s):
self.s += s
def newline(self):
self.lines.append(self.s)
self.s = u""
def putline(self, s):
self.put(s)
self.newline()
class CodeWriter(ReadonlyVisitor):
indent_string = u" "
def __init__(self, result = None):
super(CodeWriter, self).__init__()
if result is None:
result = LinesResult()
self.result = result
self.numindents = 0
def indent(self):
self.numindents += 1
def dedent(self):
self.numindents -= 1
def startline(self, s = u""):
self.result.put(self.indent_string * self.numindents + s)
def put(self, s):
self.result.put(s)
def endline(self, s = u""):
self.result.putline(s)
def line(self, s):
self.startline(s)
self.endline()
def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0:
for item in items[:-1]:
self.process_node(item)
if output_rhs and item.rhs is not None:
self.put(u" = ")
self.process_node(item.rhs)
self.put(u", ")
self.process_node(items[-1])
def process_Node(self, node):
raise AssertionError("Node not handled by serializer: %r" % node)
def process_ModuleNode(self, node):
self.process_children(node)
def process_StatListNode(self, node):
self.process_children(node)
def process_FuncDefNode(self, node):
self.startline(u"def %s(" % node.name)
self.comma_seperated_list(node.args)
self.endline(u"):")
self.indent()
self.process_node(node.body)
self.dedent()
def process_CArgDeclNode(self, node):
if node.base_type.name is not None:
self.process_node(node.base_type)
self.put(u" ")
self.process_node(node.declarator)
if node.default is not None:
self.put(u" = ")
self.process_node(node.default)
def process_CNameDeclaratorNode(self, node):
self.put(node.name)
def process_CSimpleBaseTypeNode(self, node):
# See Parsing.p_sign_and_longness
if node.is_basic_c_type:
self.put(("unsigned ", "", "signed ")[node.signed])
if node.longness < 0:
self.put("short " * -node.longness)
elif node.longness > 0:
self.put("long " * node.longness)
self.put(node.name)
def process_SingleAssignmentNode(self, node):
self.startline()
self.process_node(node.lhs)
self.put(u" = ")
self.process_node(node.rhs)
self.endline()
def process_NameNode(self, node):
self.put(node.name)
def process_IntNode(self, node):
self.put(node.value)
def process_IfStatNode(self, node):
# The IfClauseNode is handled directly without a seperate match
# for clariy.
self.startline(u"if ")
self.process_node(node.if_clauses[0].condition)
self.endline(":")
self.indent()
self.process_node(node.if_clauses[0].body)
self.dedent()
for clause in node.if_clauses[1:]:
self.startline("elif ")
self.process_node(clause.condition)
self.endline(":")
self.indent()
self.process_node(clause.body)
self.dedent()
if node.else_clause is not None:
self.line("else:")
self.indent()
self.process_node(node.else_clause)
self.dedent()
def process_PassStatNode(self, node):
self.startline(u"pass")
self.endline()
def process_PrintStatNode(self, node):
self.startline(u"print ")
self.comma_seperated_list(node.args)
if node.ends_with_comma:
self.put(u",")
self.endline()
def process_BinopNode(self, node):
self.process_node(node.operand1)
self.put(u" %s " % node.operator)
self.process_node(node.operand2)
def process_CVarDefNode(self, node):
self.startline(u"cdef ")
self.process_node(node.base_type)
self.put(u" ")
self.comma_seperated_list(node.declarators, output_rhs=True)
self.endline()
def process_ForInStatNode(self, node):
self.startline(u"for ")
self.process_node(node.target)
self.put(u" in ")
self.process_node(node.iterator.sequence)
self.endline(u":")
self.indent()
self.process_node(node.body)
self.dedent()
if node.else_clause is not None:
self.line(u"else:")
self.indent()
self.process_node(node.else_clause)
self.dedent()
def process_SequenceNode(self, node):
self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
def process_SimpleCallNode(self, node):
self.put(node.function.name + u"(")
self.comma_seperated_list(node.args)
self.put(")")
def process_ExprStatNode(self, node):
self.startline()
self.process_node(node.expr)
self.endline()
def process_InPlaceAssignmentNode(self, node):
self.startline()
self.process_node(node.lhs)
self.put(" %s= " % node.operator)
self.process_node(node.rhs)
self.endline()
from Cython.TestUtils import CythonTest
from Cython.Compiler.TreeFragment import *
class TestTreeFragments(CythonTest):
def test_basic(self):
F = self.fragment(u"x = 4")
T = F.copy()
self.assertCode(u"x = 4", T)
def test_copy_is_independent(self):
F = self.fragment(u"if True: x = 4")
T1 = F.root
T2 = F.copy()
self.assertEqual("x", T2.body.if_clauses[0].body.lhs.name)
T2.body.if_clauses[0].body.lhs.name = "other"
self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
def test_substitution(self):
F = self.fragment(u"x = 4")
y = NameNode(pos=None, name=u"y")
T = F.substitute({"x" : y})
self.assertCode(u"y = 4", T)
if __name__ == "__main__":
import unittest
unittest.main()
......@@ -109,7 +109,7 @@ class VisitorTransform(Transform):
if node is None:
return None
result = self.get_visitfunc("process_", node.__class__)(node)
return node
return result
def process_Node(self, node):
descend = self.get_visitfunc("pre_", node.__class__)(node)
......
#
# TreeFragments - parsing of strings to trees
#
import re
from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope
from Transform import Transform, VisitorTransform
from Nodes import Node
from ExprNodes import NameNode
import Parsing
import Main
"""
Support for parsing strings into code trees.
"""
class StringParseContext(Main.Context):
def __init__(self, include_directories, name):
Main.Context.__init__(self, include_directories)
self.module_name = name
def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
if module_name != self.module_name:
raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
return ModuleScope(module_name, parent_module = None, context = self)
def parse_from_strings(name, code, pxds={}):
"""
Utility method to parse a (unicode) string of code. This is mostly
used for internal Cython compiler purposes (creating code snippets
that transforms should emit, as well as unit testing).
code - a unicode string containing Cython (module-level) code
name - a descriptive name for the code source (to use in error messages etc.)
"""
# Since source files carry an encoding, it makes sense in this context
# to use a unicode string so that code fragments don't have to bother
# with encoding. This means that test code passed in should not have an
# encoding header.
assert isinstance(code, unicode), "unicode code snippets only please"
encoding = "UTF-8"
module_name = name
initial_pos = (name, 1, 0)
code_source = StringSourceDescriptor(name, code)
context = StringParseContext([], name)
scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
buf = StringIO(code.encode(encoding))
scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
type_names = scope.type_names, context = context)
tree = Parsing.p_module(scanner, 0, module_name)
return tree
class TreeCopier(Transform):
def process_node(self, node):
if node is None:
return node
else:
c = node.clone_node()
self.process_children(c)
return c
class SubstitutionTransform(VisitorTransform):
def process_Node(self, node):
if node is None:
return node
else:
c = node.clone_node()
self.process_children(c)
return c
def process_NameNode(self, node):
if node.name in self.substitute:
# Name matched, substitute node
return self.substitute[node.name]
else:
# Clone
return self.process_Node(node)
def copy_code_tree(node):
return TreeCopier()(node)
INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines):
"Strips empty lines and common indentation from the list of strings given in lines"
lines = [x for x in lines if x.strip() != u""]
minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines)
lines = [x[minindent:] for x in lines]
return lines
class TreeFragment(object):
def __init__(self, code, name, pxds={}):
if isinstance(code, unicode):
def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
fmt_code = fmt(code)
fmt_pxds = {}
for key, value in pxds.iteritems():
fmt_pxds[key] = fmt(value)
self.root = parse_from_strings(name, fmt_code, fmt_pxds)
elif isinstance(code, Node):
if pxds != {}: raise NotImplementedError()
self.root = code
else:
raise ValueError("Unrecognized code format (accepts unicode and Node)")
def copy(self):
return copy_code_tree(self.root)
def substitute(self, nodes={}):
return SubstitutionTransform()(self.root, substitute = nodes)
import Cython.Compiler.Errors as Errors
from Cython.CodeWriter import CodeWriter
import unittest
from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
class CythonTest(unittest.TestCase):
def assertCode(self, expected, result_tree):
writer = CodeWriter()
writer(result_tree)
result_lines = writer.result.lines
expected_lines = strip_common_indent(expected.split("\n"))
for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
def fragment(self, code, pxds={}):
"Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_")
return TreeFragment(code, name, pxds)
class TransformTest(CythonTest):
"""
Utility base class for transform unit tests. It is based around constructing
test trees (either explicitly or by parsing a Cython code string); running
the transform, serialize it using a customized Cython serializer (with
special markup for nodes that cannot be represented in Cython),
and do a string-comparison line-by-line of the result.
To create a test case:
- Call run_pipeline. The pipeline should at least contain the transform you
are testing; pyx should be either a string (passed to the parser to
create a post-parse tree) or a ModuleNode representing input to pipeline.
The result will be a transformed result (usually a ModuleNode).
- Check that the tree is correct. If wanted, assertCode can be used, which
takes a code string as expected, and a ModuleNode in result_tree
(it serializes the ModuleNode to a string and compares line-by-line).
All code strings are first stripped for whitespace lines and then common
indentation.
Plans: One could have a pxd dictionary parameter to run_pipeline.
"""
def run_pipeline(self, pipeline, pyx, pxds={}):
tree = self.fragment(pyx, pxds).root
assert isinstance(tree, ModuleNode)
# Run pipeline
for T in pipeline:
tree = T(tree)
return tree
from Cython.TestUtils import CythonTest
class TestCodeWriter(CythonTest):
# CythonTest uses the CodeWriter heavily, so do some checking by
# roundtripping Cython code through the test framework.
# Note that this test is dependant upon the normal Cython parser
# to generate the input trees to the CodeWriter. This save *a lot*
# of time; better to spend that time writing other tests than perfecting
# this one...
# Whitespace is very significant in this process:
# - always newline on new block (!)
# - indent 4 spaces
# - 1 space around every operator
def t(self, codestr):
self.assertCode(codestr, self.fragment(codestr).root)
def test_print(self):
self.t(u"""
print x, y
print x + y ** 2
print x, y, z,
""")
def test_if(self):
self.t(u"if x:\n pass")
def test_ifelifelse(self):
self.t(u"""
if x:
pass
elif y:
pass
elif z + 34 ** 34 - 2:
pass
else:
pass
""")
def test_def(self):
self.t(u"""
def f(x, y, z):
pass
def f(x = 34, y = 54, z):
pass
""")
def test_longness_and_signedness(self):
self.t(u"def f(unsigned long long long long long int y):\n pass")
def test_signed_short(self):
self.t(u"def f(signed short int y):\n pass")
def test_typed_args(self):
self.t(u"def f(int x, unsigned long int y):\n pass")
def test_cdef_var(self):
self.t(u"""
cdef int hello
cdef int hello = 4, x = 3, y, z
""")
def test_for_loop(self):
self.t(u"""
for x, y, z in f(g(h(34) * 2) + 23):
print x, y, z
else:
print 43
""")
def test_inplace_assignment(self):
self.t(u"x += 43")
if __name__ == "__main__":
import unittest
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment