Commit 3c4c2af5 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Merge; disabled activation of unfinished closure code

parents 70abf0f1 1788299e
...@@ -32,6 +32,7 @@ class CompileError(PyrexError): ...@@ -32,6 +32,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
self.position = position self.position = position
self.message_only = message
# Deprecated and withdrawn in 2.6: # Deprecated and withdrawn in 2.6:
# self.message = message # self.message = message
if position: if position:
...@@ -91,6 +92,7 @@ def error(position, message): ...@@ -91,6 +92,7 @@ def error(position, message):
#print "Errors.error:", repr(position), repr(message) ### #print "Errors.error:", repr(position), repr(message) ###
global num_errors global num_errors
err = CompileError(position, message) err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug
line = "%s\n" % err line = "%s\n" % err
if listing_file: if listing_file:
listing_file.write(line) listing_file.write(line)
......
...@@ -334,20 +334,20 @@ def create_generate_code(context, options, result): ...@@ -334,20 +334,20 @@ def create_generate_code(context, options, result):
return generate_code return generate_code
def create_default_pipeline(context, options, result): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
return [ return [
create_parse(context), create_parse(context),
PostParse(), NormalizeTree(context),
WithTransform(), PostParse(context),
MarkClosureVisitor(), WithTransform(context),
AnalyseDeclarationsTransform(), AnalyseDeclarationsTransform(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(), AnalyseExpressionsTransform(context),
CreateClosureClasses(), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
......
...@@ -172,7 +172,27 @@ class Node(object): ...@@ -172,7 +172,27 @@ class Node(object):
self._end_pos = pos self._end_pos = pos
return pos return pos
def dump(self, level=0, filter_out=("pos",)):
def dump_child(x, level):
if isinstance(x, Node):
return x.dump(level)
elif isinstance(x, list):
return "[%s]" % ", ".join(dump_child(item, level) for item in x)
else:
return repr(x)
attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out]
if len(attrs) == 0:
return "<%s>" % self.__class__.__name__
else:
indent = " " * level
res = "<%s\n" % (self.__class__.__name__)
for key, value in attrs:
res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1))
res += "%s>" % indent
return res
class BlockNode: class BlockNode:
# Mixin class for nodes representing a declaration block. # Mixin class for nodes representing a declaration block.
...@@ -545,7 +565,6 @@ class CBaseTypeNode(Node): ...@@ -545,7 +565,6 @@ class CBaseTypeNode(Node):
pass pass
class CSimpleBaseTypeNode(CBaseTypeNode): class CSimpleBaseTypeNode(CBaseTypeNode):
# name string # name string
# module_path [string] Qualifying name components # module_path [string] Qualifying name components
...@@ -587,6 +606,16 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -587,6 +606,16 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
else: else:
return PyrexTypes.error_type return PyrexTypes.error_type
class CBufferAccessTypeNode(Node):
# base_type_node CBaseTypeNode
# positional_args [ExprNode] List of positional arguments
# keyword_args DictNode Keyword arguments
child_attrs = ["base_type_node", "positional_args", "keyword_args"]
def analyse(self, env):
return self.base_type_node.analyse(env)
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
# base_type CBaseTypeNode # base_type CBaseTypeNode
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
from sets import Set as set
class PostParse(VisitorTransform): class NormalizeTree(CythonTransform):
""" """
This transform fixes up a few things after parsing This transform fixes up a few things after parsing
in order to make the parse tree more suitable for in order to make the parse tree more suitable for
...@@ -25,15 +28,11 @@ class PostParse(VisitorTransform): ...@@ -25,15 +28,11 @@ class PostParse(VisitorTransform):
StatListNode has no children to see if the block is empty). StatListNode has no children to see if the block is empty).
""" """
def __init__(self): def __init__(self, context):
super(PostParse, self).__init__() super(NormalizeTree, self).__init__(context)
self.is_in_statlist = False self.is_in_statlist = False
self.is_in_expr = False self.is_in_expr = False
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_ExprNode(self, node): def visit_ExprNode(self, node):
stacktmp = self.is_in_expr stacktmp = self.is_in_expr
self.is_in_expr = True self.is_in_expr = True
...@@ -73,7 +72,80 @@ class PostParse(VisitorTransform): ...@@ -73,7 +72,80 @@ class PostParse(VisitorTransform):
return self.visit_StatNode(node, True) return self.visit_StatNode(node, True)
class WithTransform(VisitorTransform): class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
class PostParse(CythonTransform):
"""
Basic interpretation of the parse tree, as well as validity
checking that can be done on a very basic level on the parse
tree (while still not being a problem with the basic syntax,
as such).
Specifically:
- CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the option combination is valid.
Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform
if a more pure Abstract Syntax Tree is wanted.
"""
buffer_options = ("dtype", "ndim") # ordered!
def visit_CBufferAccessTypeNode(self, node):
options = {}
# Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name)
options[name] = arg
# Fetch named arguments
for item in node.keyword_args.key_value_pairs:
name = str(item.key.value)
if not name in self.buffer_options:
raise PostParseError(item.key.pos,
ERR_BUF_UNKNOWN % name)
if name in options.keys():
raise PostParseError(item.key.pos,
ERR_BUF_DUP % key)
options[name] = item.value
provided = options.keys()
# get dtype
dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype = dtype
# get ndim
if "ndim" in provided:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
# so nothing more to do here
raise PostParseError(ndimnode.pos, ERR_BUF_INT % 'ndim')
ndim_value = int(ndimnode.value)
if ndim_value < 0:
raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
node.ndim = int(ndimnode.value)
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
return node
class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
# the exc_info() tuple that can be generated by the enclosing except # the exc_info() tuple that can be generated by the enclosing except
...@@ -94,7 +166,7 @@ class WithTransform(VisitorTransform): ...@@ -94,7 +166,7 @@ class WithTransform(VisitorTransform):
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"], """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
pipeline=[PostParse()]) pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u""" template_with_target = TreeFragment(u"""
MGR = EXPR MGR = EXPR
...@@ -113,11 +185,7 @@ class WithTransform(VisitorTransform): ...@@ -113,11 +185,7 @@ class WithTransform(VisitorTransform):
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"], """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
pipeline=[PostParse()]) pipeline=[NormalizeTree(None)])
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO') excinfo_name = temp_name_handle('EXCINFO')
...@@ -143,7 +211,7 @@ class WithTransform(VisitorTransform): ...@@ -143,7 +211,7 @@ class WithTransform(VisitorTransform):
return result.stats return result.stats
class AnalyseDeclarationsTransform(VisitorTransform): class AnalyseDeclarationsTransform(CythonTransform):
def __call__(self, root): def __call__(self, root):
self.env_stack = [root.scope] self.env_stack = [root.scope]
...@@ -164,12 +232,7 @@ class AnalyseDeclarationsTransform(VisitorTransform): ...@@ -164,12 +232,7 @@ class AnalyseDeclarationsTransform(VisitorTransform):
self.env_stack.pop() self.env_stack.pop()
return node return node
def visit_Node(self, node): class AnalyseExpressionsTransform(CythonTransform):
self.visitchildren(node)
return node
class AnalyseExpressionsTransform(VisitorTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
...@@ -180,11 +243,7 @@ class AnalyseExpressionsTransform(VisitorTransform): ...@@ -180,11 +243,7 @@ class AnalyseExpressionsTransform(VisitorTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_Node(self, node): class MarkClosureVisitor(CythonTransform):
self.visitchildren(node)
return node
class MarkClosureVisitor(VisitorTransform):
needs_closure = False needs_closure = False
...@@ -203,12 +262,7 @@ class MarkClosureVisitor(VisitorTransform): ...@@ -203,12 +262,7 @@ class MarkClosureVisitor(VisitorTransform):
def visit_YieldNode(self, node): def visit_YieldNode(self, node):
self.needs_closure = True self.needs_closure = True
def visit_Node(self, node): class CreateClosureClasses(CythonTransform):
self.visitchildren(node)
return node
class CreateClosureClasses(VisitorTransform):
# Output closure classes in module scope for all functions # Output closure classes in module scope for all functions
# that need it. # that need it.
...@@ -235,7 +289,4 @@ class CreateClosureClasses(VisitorTransform): ...@@ -235,7 +289,4 @@ class CreateClosureClasses(VisitorTransform):
self.create_class_from_scope(node, self.module_scope) self.create_class_from_scope(node, self.module_scope)
return node return node
def visit_Node(self, node):
self.visitchildren(node)
return node
...@@ -312,6 +312,7 @@ def p_call(s, function): ...@@ -312,6 +312,7 @@ def p_call(s, function):
if s.sy != ',': if s.sy != ',':
break break
s.next() s.next()
if s.sy == '*': if s.sy == '*':
s.next() s.next()
star_arg = p_simple_expr(s) star_arg = p_simple_expr(s)
...@@ -1386,7 +1387,7 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1386,7 +1387,7 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api: if ctx.api:
error(s.pos, "'api' not allowed with this statement") error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def': elif s.sy == 'def':
if ctx.level not in ('module', 'class', 'c_class', 'function', 'property'): if ctx.level not in ('module', 'class', 'c_class', 'property'):
s.error('def statement not allowed here') s.error('def statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_def_statement(s) return p_def_statement(s)
...@@ -1456,6 +1457,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0): ...@@ -1456,6 +1457,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0):
else: else:
return body return body
def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keywords=()):
"""
Parses positional and keyword arguments. end_sy_set
should contain any s.sy that terminate the argument list.
Argument expansion (* and **) are not allowed.
type_positions and type_keywords specifies which argument
positions and/or names which should be interpreted as
types. Other arguments will be treated as expressions.
Returns: (positional_args, keyword_args)
"""
positional_args = []
keyword_args = []
pos_idx = 0
while s.sy not in end_sy_set:
if s.sy == '*' or s.sy == '**':
s.error('Argument expansion not allowed here.')
was_keyword = False
parsed_type = False
if s.sy == 'IDENT':
# Since we can have either types or expressions as positional args,
# we use a strategy of looking an extra step forward for a '=' and
# if it is a positional arg we backtrack.
ident = s.systring
s.next()
if s.sy == '=':
s.next()
# Is keyword arg
if ident in type_keywords:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
value = Utils.EncodedString(ident))
keyword_args.append((keyword_node, arg))
was_keyword = True
else:
s.put_back('IDENT', ident)
if not was_keyword:
if pos_idx in type_positions:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
positional_args.append(arg)
pos_idx += 1
if len(keyword_args) > 0:
s.error("Non-keyword arg following keyword arg",
pos = arg.pos)
if s.sy != ',':
if s.sy not in end_sy_set:
if parsed_type:
s.error("Expected: type")
else:
s.error("Expected: expression")
break
s.next()
return positional_args, keyword_args
def p_c_base_type(s, self_flag = 0, nonempty = 0): def p_c_base_type(s, self_flag = 0, nonempty = 0):
# If self_flag is true, this is the base type for the # If self_flag is true, this is the base type for the
# self argument of a C method of an extension type. # self argument of a C method of an extension type.
...@@ -1528,11 +1594,43 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1528,11 +1594,43 @@ def p_c_simple_base_type(s, self_flag, nonempty):
else: else:
#print "p_c_simple_base_type: not looking at type at", s.position() #print "p_c_simple_base_type: not looking at type at", s.position()
name = None name = None
return Nodes.CSimpleBaseTypeNode(pos,
type_node = Nodes.CSimpleBaseTypeNode(pos,
name = name, module_path = module_path, name = name, module_path = module_path,
is_basic_c_type = is_basic, signed = signed, is_basic_c_type = is_basic, signed = signed,
longness = longness, is_self_arg = self_flag) longness = longness, is_self_arg = self_flag)
# Treat trailing [] on type as buffer access
if s.sy == '[':
if is_basic:
p.error("Basic C types do not support buffer access")
return p_buffer_access(s, type_node)
else:
return type_node
def p_buffer_access(s, type_node):
# s.sy == '['
pos = s.position()
s.next()
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = type_node)
return result
def looking_at_type(s): def looking_at_type(s):
return looking_at_base_type(s) or s.looking_at_type_name() return looking_at_base_type(s) or s.looking_at_type_name()
......
...@@ -2,7 +2,7 @@ from Cython.TestUtils import TransformTest ...@@ -2,7 +2,7 @@ from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
class TestPostParse(TransformTest): class TestNormalizeTree(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self): def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root t = self.fragment(u"if x: y").root
self.assertLines(u""" self.assertLines(u"""
...@@ -15,7 +15,7 @@ class TestPostParse(TransformTest): ...@@ -15,7 +15,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_singlestat(self): def test_wrap_singlestat(self):
t = self.run_pipeline([PostParse()], u"if x: y") t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
self.assertLines(u""" self.assertLines(u"""
(root): StatListNode (root): StatListNode
stats[0]: IfStatNode stats[0]: IfStatNode
...@@ -27,7 +27,7 @@ class TestPostParse(TransformTest): ...@@ -27,7 +27,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_multistat(self): def test_wrap_multistat(self):
t = self.run_pipeline([PostParse()], u""" t = self.run_pipeline([NormalizeTree(None)], u"""
if z: if z:
x x
y y
...@@ -45,7 +45,7 @@ class TestPostParse(TransformTest): ...@@ -45,7 +45,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_statinexpr(self): def test_statinexpr(self):
t = self.run_pipeline([PostParse()], u""" t = self.run_pipeline([NormalizeTree(None)], u"""
a, b = x, y a, b = x, y
""") """)
self.assertLines(u""" self.assertLines(u"""
...@@ -60,7 +60,7 @@ class TestPostParse(TransformTest): ...@@ -60,7 +60,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_offagain(self): def test_wrap_offagain(self):
t = self.run_pipeline([PostParse()], u""" t = self.run_pipeline([NormalizeTree(None)], u"""
x x
y y
if z: if z:
...@@ -82,13 +82,13 @@ class TestPostParse(TransformTest): ...@@ -82,13 +82,13 @@ class TestPostParse(TransformTest):
def test_pass_eliminated(self): def test_pass_eliminated(self):
t = self.run_pipeline([PostParse()], u"pass") t = self.run_pipeline([NormalizeTree(None)], u"pass")
self.assert_(len(t.stats) == 0) self.assert_(len(t.stats) == 0)
class TestWithTransform(TransformTest): class TestWithTransform(TransformTest):
def test_simplified(self): def test_simplified(self):
t = self.run_pipeline([WithTransform()], u""" t = self.run_pipeline([WithTransform(None)], u"""
with x: with x:
y = z ** 3 y = z ** 3
""") """)
...@@ -113,7 +113,7 @@ class TestWithTransform(TransformTest): ...@@ -113,7 +113,7 @@ class TestWithTransform(TransformTest):
""", t) """, t)
def test_basic(self): def test_basic(self):
t = self.run_pipeline([WithTransform()], u""" t = self.run_pipeline([WithTransform(None)], u"""
with x as y: with x as y:
y = z ** 3 y = z ** 3
""") """)
......
...@@ -4,6 +4,7 @@ from Cython.Compiler.Nodes import * ...@@ -4,6 +4,7 @@ from Cython.Compiler.Nodes import *
import Cython.Compiler.Naming as Naming import Cython.Compiler.Naming as Naming
class TestTreeFragments(CythonTest): class TestTreeFragments(CythonTest):
def test_basic(self): def test_basic(self):
F = self.fragment(u"x = 4") F = self.fragment(u"x = 4")
T = F.copy() T = F.copy()
...@@ -46,13 +47,14 @@ class TestTreeFragments(CythonTest): ...@@ -46,13 +47,14 @@ class TestTreeFragments(CythonTest):
self.assertEquals(v.pos, a.pos) self.assertEquals(v.pos, a.pos)
def test_temps(self): def test_temps(self):
import Cython.Compiler.Visitor as v
v.tmpnamectr = 0
F = self.fragment(u""" F = self.fragment(u"""
TMP TMP
x = TMP x = TMP
""") """)
T = F.substitute(temps=[u"TMP"]) T = F.substitute(temps=[u"TMP"])
s = T.stats s = T.stats
print s[0].expr.name
self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name) self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name)
self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP") self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP")
self.assert_(s[0].expr.name != u"TMP") self.assert_(s[0].expr.name != u"TMP")
......
...@@ -131,7 +131,6 @@ class VisitorTransform(TreeVisitor): ...@@ -131,7 +131,6 @@ class VisitorTransform(TreeVisitor):
was not, an exception will be raised. (Typically you want to ensure that you was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = super(VisitorTransform, self).visitchildren(parent, attrs) result = super(VisitorTransform, self).visitchildren(parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
...@@ -152,6 +151,19 @@ class VisitorTransform(TreeVisitor): ...@@ -152,6 +151,19 @@ class VisitorTransform(TreeVisitor):
def __call__(self, root): def __call__(self, root):
return self.visit(root) return self.visit(root)
class CythonTransform(VisitorTransform):
"""
Certain common conventions and utilitues for Cython transforms.
"""
def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
def visit_Node(self, node):
self.visitchildren(node)
return node
# Utils # Utils
def ensure_statlist(node): def ensure_statlist(node):
if not isinstance(node, Nodes.StatListNode): if not isinstance(node, Nodes.StatListNode):
......
...@@ -59,16 +59,36 @@ class CythonTest(unittest.TestCase): ...@@ -59,16 +59,36 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(result_lines), len(expected_lines), self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
def fragment(self, code, pxds={}): def fragment(self, code, pxds={}, pipeline=[]):
"Simply create a tree fragment using the name of the test-case in parse errors." "Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id() name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):] if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_") name = name.replace(".", "_")
return TreeFragment(code, name, pxds) return TreeFragment(code, name, pxds, pipeline=pipeline)
def treetypes(self, root): def treetypes(self, root):
return treetypes(root) return treetypes(root)
def should_fail(self, func, exc_type=Exception):
"""Calls "func" and fails if it doesn't raise the right exception
(any exception by default). Also returns the exception in question.
"""
try:
func()
self.fail("Expected an exception of type %r" % exc_type)
except exc_type, e:
self.assert_(isinstance(e, exc_type))
return e
def should_not_fail(self, func):
"""Calls func and succeeds if and only if no exception is raised
(i.e. converts exception raising into a failed testcase). Returns
the return value of func."""
try:
return func()
except:
self.fail()
class TransformTest(CythonTest): class TransformTest(CythonTest):
""" """
Utility base class for transform unit tests. It is based around constructing Utility base class for transform unit tests. It is based around constructing
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment