Commit d2a99890 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Better exception info reading for with statement

parent 21c0a027
from Cython.Compiler.Visitor import TreeVisitor
from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.Symtab import TempName
"""
Serializes a Cython code tree to Cython code. This is primarily useful for
......@@ -62,8 +61,9 @@ class CodeWriter(TreeVisitor):
self.endline()
def putname(self, name):
if isinstance(name, TempName):
name = self.tempnames.setdefault(name, u"$" + name.description)
tmpdesc = get_temp_name_handle_desc(name)
if tmpdesc is not None:
name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
self.put(name)
def comma_seperated_list(self, items, output_rhs=False):
......
......@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode):
else:
code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name)))
class BackquoteNode(ExprNode):
# `expr`
#
......@@ -1212,6 +1211,9 @@ class ExcValueNode(AtomicExprNode):
def generate_result_code(self, code):
pass
def analyse_types(self, env):
pass
class TempNode(AtomicExprNode):
# Node created during analyse_types phase
......
......@@ -3329,18 +3329,24 @@ class ExceptClauseNode(Node):
# pattern ExprNode
# target ExprNode or None
# body StatNode
# excinfo_target NameNode or None optional target for exception info
# excinfo_target NameNode or None used internally
# match_flag string result of exception match
# exc_value ExcValueNode used internally
# function_name string qualified name of enclosing function
# exc_vars (string * 3) local exception variables
child_attrs = ["pattern", "target", "body", "exc_value"]
child_attrs = ["pattern", "target", "body", "exc_value", "excinfo_target"]
exc_value = None
excinfo_target = None
excinfo_assignment = None
def analyse_declarations(self, env):
if self.target:
self.target.analyse_target_declaration(env)
if self.excinfo_target is not None:
self.excinfo_target.analyse_target_declaration(env)
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
......@@ -3358,6 +3364,17 @@ class ExceptClauseNode(Node):
self.exc_value = ExprNodes.ExcValueNode(self.pos, env, self.exc_vars[1])
self.exc_value.allocate_temps(env)
self.target.analyse_target_expression(env, self.exc_value)
if self.excinfo_target is not None:
import ExprNodes
self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[0]),
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[1]),
ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[2])
])
self.excinfo_tuple.analyse_expressions(env)
self.excinfo_tuple.allocate_temps(env)
self.excinfo_target.analyse_target_expression(env, self.excinfo_tuple)
self.body.analyse_expressions(env)
for var in self.exc_vars:
env.release_temp(var)
......@@ -3387,6 +3404,10 @@ class ExceptClauseNode(Node):
if self.target:
self.exc_value.generate_evaluation_code(code)
self.target.generate_assignment_code(self.exc_value, code)
if self.excinfo_target is not None:
self.excinfo_tuple.generate_evaluation_code(code)
self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code)
old_exc_vars = code.exc_vars
code.exc_vars = self.exc_vars
self.body.generate_execution_code(code)
......@@ -4497,6 +4518,7 @@ bad:
Py_XDECREF(*tb);
return -1;
}
"""]
#------------------------------------------------------------------------------------
from Cython.Compiler.Visitor import VisitorTransform
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
......@@ -71,10 +72,13 @@ class PostParse(VisitorTransform):
def visit_CStructOrUnionDefNode(self, node):
return self.visit_StatNode(node, True)
class WithTransform(VisitorTransform):
# EXCINFO is manually set to a variable that contains
# the exc_info() tuple that can be generated by the enclosing except
# statement.
template_without_target = TreeFragment(u"""
import sys as SYS
MGR = EXPR
EXIT = MGR.__exit__
MGR.__enter__()
......@@ -84,15 +88,15 @@ class WithTransform(VisitorTransform):
BODY
except:
EXC = False
if not EXIT(*SYS.exc_info()):
if not EXIT(*EXCINFO):
raise
finally:
if EXC:
EXIT(None, None, None)
""", u"WithTransformFragment")
""", temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
pipeline=[PostParse()])
template_with_target = TreeFragment(u"""
import sys as SYS
MGR = EXPR
EXIT = MGR.__exit__
VALUE = MGR.__enter__()
......@@ -103,47 +107,38 @@ class WithTransform(VisitorTransform):
BODY
except:
EXC = False
if not EXIT(*SYS.exc_info()):
if not EXIT(*EXCINFO):
raise
finally:
if EXC:
EXIT(None, None, None)
""", u"WithTransformFragment")
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
pipeline=[PostParse()])
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO')
excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
if node.target is not None:
result = self.template_with_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
u'TARGET' : node.target
}, temps=(u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"),
pos = node.pos)
u'TARGET' : node.target,
u'EXCINFO' : excinfo_namenode
}, pos = node.pos)
# Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
else:
result = self.template_without_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
}, temps=(u'MGR', u'EXC', u"EXIT", u"SYS"),
pos = node.pos)
return result.body.stats
class CallExitFuncNode(Node):
def analyse_types(self, env):
pass
def analyse_expressions(self, env):
self.exc_vars = [
env.allocate_temp(PyrexTypes.py_object_type)
for x in xrange(3)
]
u'EXCINFO' : excinfo_namenode
}, pos = node.pos)
# Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
def generate_result(self, code):
code.putln("""{
PyObject* type; PyObject* value; PyObject* tb;
__Pyx_GetException(
}""")
return result.stats
......@@ -16,29 +16,6 @@ from TypeSlots import \
import ControlFlow
import __builtin__
class TempName(object):
"""
Use instances of this class in order to provide a name for
anonymous, temporary functions. Each instance is considered
a seperate name, which are guaranteed not to clash with one
another or with names explicitly given as strings.
The argument to the constructor is simply a describing string
for debugging purposes and does not affect name clashes at all.
NOTE: Support for these TempNames are introduced on an as-needed
basis and will not "just work" everywhere. Places where they work:
- (none)
"""
def __init__(self, description):
self.description = description
# Spoon-feed operators for documentation purposes
def __hash__(self):
return id(self)
def __cmp__(self, other):
return cmp(id(self), id(other))
possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
......@@ -1098,19 +1075,12 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1
entry.as_variable = var_entry
tempctr = 0
class LocalScope(Scope):
def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, outer_scope)
def mangle(self, prefix, name):
if isinstance(name, TempName):
global tempctr
tempctr += 1
return u"%s%s%d" % (Naming.temp_prefix, name.description, tempctr)
else:
return prefix + name
def declare_arg(self, name, type, pos):
......
......@@ -6,8 +6,8 @@ class TestPostParse(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root
self.assertLines(u"""
(root): ModuleNode
body: IfStatNode
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: ExprStatNode
......@@ -17,8 +17,7 @@ class TestPostParse(TransformTest):
def test_wrap_singlestat(self):
t = self.run_pipeline([PostParse()], u"if x: y")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
......@@ -34,8 +33,7 @@ class TestPostParse(TransformTest):
y
""")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
(root): StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
......@@ -51,8 +49,7 @@ class TestPostParse(TransformTest):
a, b = x, y
""")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
(root): StatListNode
stats[0]: ParallelAssignmentNode
stats[0]: SingleAssignmentNode
lhs: NameNode
......@@ -70,8 +67,7 @@ class TestPostParse(TransformTest):
x
""")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
(root): StatListNode
stats[0]: ExprStatNode
expr: NameNode
stats[1]: ExprStatNode
......@@ -87,7 +83,7 @@ class TestPostParse(TransformTest):
def test_pass_eliminated(self):
t = self.run_pipeline([PostParse()], u"pass")
self.assert_(len(t.body.stats) == 0)
self.assert_(len(t.stats) == 0)
class TestWithTransform(TransformTest):
......@@ -99,7 +95,6 @@ class TestWithTransform(TransformTest):
self.assertCode(u"""
$SYS = (import sys)
$MGR = x
$EXIT = $MGR.__exit__
$MGR.__enter__()
......@@ -109,7 +104,7 @@ class TestWithTransform(TransformTest):
y = z ** 3
except:
$EXC = False
if (not $EXIT($SYS.exc_info())):
if (not $EXIT($EXCINFO)):
raise
finally:
if $EXC:
......@@ -124,7 +119,6 @@ class TestWithTransform(TransformTest):
""")
self.assertCode(u"""
$SYS = (import sys)
$MGR = x
$EXIT = $MGR.__exit__
$VALUE = $MGR.__enter__()
......@@ -135,7 +129,7 @@ class TestWithTransform(TransformTest):
y = z ** 3
except:
$EXC = False
if (not $EXIT($SYS.exc_info())):
if (not $EXIT($EXCINFO)):
raise
finally:
if $EXC:
......
......@@ -6,9 +6,8 @@ import re
from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope
from Visitor import VisitorTransform
from Nodes import Node
from Symtab import TempName
from Visitor import VisitorTransform, temp_name_handle
from Nodes import Node, StatListNode
from ExprNodes import NameNode
import Parsing
import Main
......@@ -109,7 +108,7 @@ class TemplateTransform(VisitorTransform):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = TempName(key)
tempdict[key] = temp_name_handle(key)
self.temps = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
......@@ -164,7 +163,7 @@ def strip_common_indent(lines):
return lines
class TreeFragment(object):
def __init__(self, code, name, pxds={}):
def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[]):
if isinstance(code, unicode):
def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
......@@ -173,12 +172,20 @@ class TreeFragment(object):
for key, value in pxds.iteritems():
fmt_pxds[key] = fmt(value)
self.root = parse_from_strings(name, fmt_code, fmt_pxds)
t = parse_from_strings(name, fmt_code, fmt_pxds)
mod = t
t = t.body # Make sure a StatListNode is at the top
if not isinstance(t, StatListNode):
t = StatListNode(pos=mod.pos, stats=[t])
for transform in pipeline:
t = transform(t)
self.root = t
elif isinstance(code, Node):
if pxds != {}: raise NotImplementedError()
self.root = code
else:
raise ValueError("Unrecognized code format (accepts unicode and Node)")
self.temps = temps
def copy(self):
return copy_code_tree(self.root)
......@@ -186,7 +193,7 @@ class TreeFragment(object):
def substitute(self, nodes={}, temps=[], pos = None):
return TemplateTransform()(self.root,
substitutions = nodes,
temps = temps, pos = pos)
temps = self.temps + temps, pos = pos)
......
......@@ -166,6 +166,19 @@ def replace_node(ptr, value):
else:
getattr(parent, attrname)[listidx] = value
tmpnamectr = 0
def temp_name_handle(description):
global tmpnamectr
tmpnamectr += 1
return u"__cyt_%d_%s" % (tmpnamectr, description)
def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"):
return None
else:
idx = handle.find(u"_", 6)
return handle[idx+1:]
class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information
......
......@@ -77,8 +77,8 @@ class TransformTest(CythonTest):
To create a test case:
- Call run_pipeline. The pipeline should at least contain the transform you
are testing; pyx should be either a string (passed to the parser to
create a post-parse tree) or a ModuleNode representing input to pipeline.
The result will be a transformed result (usually a ModuleNode).
create a post-parse tree) or a node representing input to pipeline.
The result will be a transformed result.
- Check that the tree is correct. If wanted, assertCode can be used, which
takes a code string as expected, and a ModuleNode in result_tree
......@@ -93,7 +93,6 @@ class TransformTest(CythonTest):
def run_pipeline(self, pipeline, pyx, pxds={}):
tree = self.fragment(pyx, pxds).root
assert isinstance(tree, ModuleNode)
# Run pipeline
for T in pipeline:
tree = T(tree)
......
from __future__ import with_statement
__doc__ = u"""
>>> no_as()
enter
hello
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> basic()
enter
value
......@@ -8,12 +12,12 @@ exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> with_exception(None)
enter
value
exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
outer except
>>> with_exception(True)
enter
value
exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
>>> multitarget()
enter
1 2 3 4 5
......@@ -24,19 +28,26 @@ enter
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
class MyException(Exception):
pass
class ContextManager:
def __init__(self, value, exit_ret = None):
self.value = value
self.exit_ret = exit_ret
def __exit__(self, a, b, c):
print "exit", type(a), type(b), type(c)
def __exit__(self, a, b, tb):
print "exit", type(a), type(b), type(tb)
return self.exit_ret
def __enter__(self):
print "enter"
return self.value
def no_as():
with ContextManager("value"):
print "hello"
def basic():
with ContextManager("value") as x:
print x
......@@ -45,7 +56,7 @@ def with_exception(exit_ret):
try:
with ContextManager("value", exit_ret=exit_ret) as value:
print value
raise Exception()
raise MyException()
except:
print "outer except"
......@@ -56,3 +67,4 @@ def multitarget():
def tupletarget():
with ContextManager((1, 2, (3, (4, 5)))) as t:
print t
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment