Commit 8a365013 authored by Stefan Behnel's avatar Stefan Behnel

clean up some test code and apply some safety fixes

parent df86b715
......@@ -8,13 +8,15 @@ import unittest
import os, sys
import tempfile
class NodeTypeWriter(TreeVisitor):
def __init__(self):
super(NodeTypeWriter, self).__init__()
self._indents = 0
self.result = []
def visit_Node(self, node):
if len(self.access_path) == 0:
if not self.access_path:
name = u"(root)"
else:
tip = self.access_path[-1]
......@@ -29,6 +31,7 @@ class NodeTypeWriter(TreeVisitor):
self.visitchildren(node)
self._indents -= 1
def treetypes(root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
......@@ -38,6 +41,7 @@ def treetypes(root):
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class CythonTest(unittest.TestCase):
def setUp(self):
......@@ -110,6 +114,7 @@ class CythonTest(unittest.TestCase):
except:
self.fail(str(sys.exc_info()[1]))
class TransformTest(CythonTest):
"""
Utility base class for transform unit tests. It is based around constructing
......@@ -134,7 +139,6 @@ class TransformTest(CythonTest):
Plans: One could have a pxd dictionary parameter to run_pipeline.
"""
def run_pipeline(self, pipeline, pyx, pxds={}):
tree = self.fragment(pyx, pxds).root
# Run pipeline
......@@ -166,6 +170,7 @@ class TreeAssertVisitor(VisitorTransform):
visit_Node = VisitorTransform.recurse_to_children
def unpack_source_tree(tree_file, dir=None):
if dir is None:
dir = tempfile.mkdtemp()
......@@ -176,21 +181,24 @@ def unpack_source_tree(tree_file, dir=None):
lines = f.readlines()
finally:
f.close()
f = None
for line in lines:
if line[:5] == '#####':
filename = line.strip().strip('#').strip().replace('/', os.path.sep)
path = os.path.join(dir, filename)
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
if cur_file is not None:
cur_file.close()
cur_file = open(path, 'w')
elif cur_file is not None:
cur_file.write(line)
elif line.strip() and not line.lstrip().startswith('#'):
if line.strip() not in ('"""', "'''"):
header.append(line)
if cur_file is not None:
cur_file.close()
del f
try:
for line in lines:
if line[:5] == '#####':
filename = line.strip().strip('#').strip().replace('/', os.path.sep)
path = os.path.join(dir, filename)
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
if cur_file is not None:
f, cur_file = cur_file, None
f.close()
cur_file = open(path, 'w')
elif cur_file is not None:
cur_file.write(line)
elif line.strip() and not line.lstrip().startswith('#'):
if line.strip() not in ('"""', "'''"):
header.append(line)
finally:
if cur_file is not None:
cur_file.close()
return dir, ''.join(header)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment