Commit 6c1aa761 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Support for with statement

parent a053e3e1
......@@ -7,5 +7,6 @@ def _get_feature(name):
return object()
unicode_literals = _get_feature("unicode_literals")
with_statement = _get_feature("with_statement")
del _get_feature
......@@ -232,6 +232,10 @@ class Context:
errors_occurred = False
try:
tree = self.parse(source, scope.type_names, pxd = 0, full_module_name = full_module_name)
# This is of course going to change and be refactored real soon
from ParseTreeTransforms import WithTransform, PostParse
tree = PostParse()(tree)
tree = WithTransform()(tree)
tree.process_implementation(scope, options, result)
except CompileError:
errors_occurred = True
......
......@@ -2414,7 +2414,7 @@ class InPlaceAssignmentNode(AssignmentNode):
# Fortunately, the type of the lhs node is fairly constrained
# (it must be a NameNode, AttributeNode, or IndexNode).
child_attrs = ["lhs", "rhs", "dup"]
child_attrs = ["lhs", "rhs"]
def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env)
......@@ -2998,7 +2998,7 @@ class ForInStatNode(LoopNode, StatNode):
# else_clause StatNode
# item NextNode used internally
child_attrs = ["target", "iterator", "body", "else_clause", "item"]
child_attrs = ["target", "iterator", "body", "else_clause"]
def analyse_declarations(self, env):
self.target.analyse_target_declaration(env)
......@@ -3115,7 +3115,7 @@ class ForFromStatNode(LoopNode, StatNode):
# is_py_target bool
# loopvar_name string
# py_loopvar_node PyTempNode or None
child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause", "py_loopvar_node"]
child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause"]
def analyse_declarations(self, env):
self.target.analyse_target_declaration(env)
......@@ -3231,6 +3231,18 @@ class ForFromStatNode(LoopNode, StatNode):
self.else_clause.annotate(code)
class WithStatNode(StatNode):
"""
Represents a Python with statement.
This is only used at parse tree level; and is not present in
analysis or generation phases.
"""
# manager The with statement manager object
# target Node (lhs expression)
# body StatNode
child_attrs = ["manager", "target", "body"]
class TryExceptStatNode(StatNode):
# try .. except statement
#
......@@ -3326,6 +3338,8 @@ class ExceptClauseNode(Node):
child_attrs = ["pattern", "target", "body", "exc_value"]
exc_value = None
def analyse_declarations(self, env):
if self.target:
self.target.analyse_target_declaration(env)
......
from Cython.Compiler.Visitor import VisitorTransform
from Cython.Compiler.Nodes import *
from Cython.Compiler.TreeFragment import TreeFragment
class PostParse(VisitorTransform):
"""
This transform fixes up a few things after parsing
in order to make the parse tree more suitable for
transforms.
a) After parsing, blocks with only one statement will
be represented by that statement, not by a StatListNode.
When doing transforms this is annoying and inconsistent,
as one cannot in general remove a statement in a consistent
way and so on. This transform wraps any single statements
in a StatListNode containing a single statement.
b) The PassStatNode is a noop and serves no purpose beyond
plugging such one-statement blocks; i.e., once parsed a
` "pass" can just as well be represented using an empty
StatListNode. This means less special cases to worry about
in subsequent transforms (one always checks to see if a
StatListNode has no children to see if the block is empty).
"""
def __init__(self):
super(PostParse, self).__init__()
self.is_in_statlist = False
self.is_in_expr = False
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_ExprNode(self, node):
stacktmp = self.is_in_expr
self.is_in_expr = True
self.visitchildren(node)
self.is_in_expr = stacktmp
return node
def visit_StatNode(self, node, is_listcontainer=False):
stacktmp = self.is_in_statlist
self.is_in_statlist = is_listcontainer
self.visitchildren(node)
self.is_in_statlist = stacktmp
if not self.is_in_statlist and not self.is_in_expr:
return StatListNode(pos=node.pos, stats=[node])
else:
return node
def visit_PassStatNode(self, node):
if not self.is_in_statlist:
return StatListNode(pos=node.pos, stats=[])
else:
return []
def visit_StatListNode(self, node):
self.is_in_statlist = True
self.visitchildren(node)
self.is_in_statlist = False
return node
def visit_ParallelAssignmentNode(self, node):
return self.visit_StatNode(node, True)
def visit_CEnumDefNode(self, node):
return self.visit_StatNode(node, True)
def visit_CStructOrUnionDefNode(self, node):
return self.visit_StatNode(node, True)
class WithTransform(VisitorTransform):
template_without_target = TreeFragment(u"""
import sys as SYS
MGR = EXPR
EXIT = MGR.__exit__
MGR.__enter__()
EXC = True
try:
try:
BODY
except:
EXC = False
if not EXIT(*SYS.exc_info()):
raise
finally:
if EXC:
EXIT(None, None, None)
""", u"WithTransformFragment")
template_with_target = TreeFragment(u"""
import sys as SYS
MGR = EXPR
EXIT = MGR.__exit__
VALUE = MGR.__enter__()
EXC = True
try:
try:
TARGET = VALUE
BODY
except:
EXC = False
if not EXIT(*SYS.exc_info()):
raise
finally:
if EXC:
EXIT(None, None, None)
""", u"WithTransformFragment")
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_WithStatNode(self, node):
if node.target is not None:
result = self.template_with_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
u'TARGET' : node.target
}, temps=(u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"),
pos = node.pos)
else:
result = self.template_without_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
}, temps=(u'MGR', u'EXC', u"EXIT", u"SYS"),
pos = node.pos)
return result.body.stats
class CallExitFuncNode(Node):
def analyse_types(self, env):
pass
def analyse_expressions(self, env):
self.exc_vars = [
env.allocate_temp(PyrexTypes.py_object_type)
for x in xrange(3)
]
def generate_result(self, code):
code.putln("""{
PyObject* type; PyObject* value; PyObject* tb;
__Pyx_GetException(
}""")
......@@ -1134,13 +1134,13 @@ def p_for_from_step(s):
inequality_relations = ('<', '<=', '>', '>=')
def p_for_target(s):
def p_target(s, terminator):
pos = s.position()
expr = p_bit_expr(s)
if s.sy == ',':
s.next()
exprs = [expr]
while s.sy != 'in':
while s.sy != terminator:
exprs.append(p_bit_expr(s))
if s.sy != ',':
break
......@@ -1149,6 +1149,9 @@ def p_for_target(s):
else:
return expr
def p_for_target(s):
return p_target(s, 'in')
def p_for_iterator(s):
pos = s.position()
expr = p_testlist(s)
......@@ -1227,8 +1230,17 @@ def p_with_statement(s):
body = p_suite(s)
return Nodes.GILStatNode(pos, state = state, body = body)
else:
s.error("Only 'with gil' and 'with nogil' implemented",
pos = pos)
manager = p_expr(s)
target = None
if s.sy == 'IDENT' and s.systring == 'as':
s.next()
allow_multi = (s.sy == '(')
target = p_target(s, ':')
if not allow_multi and isinstance(target, ExprNodes.TupleNode):
s.error("Multiple with statement target values not allowed without paranthesis")
body = p_suite(s)
return Nodes.WithStatNode(pos, manager = manager,
target = target, body = body)
def p_simple_statement(s, first_statement = 0):
#print "p_simple_statement:", s.sy, s.systring ###
......
from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import *
class TestPostParse(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root
self.assertLines(u"""
(root): ModuleNode
body: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_wrap_singlestat(self):
t = self.run_pipeline([PostParse()], u"if x: y")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_wrap_multistat(self):
t = self.run_pipeline([PostParse()], u"""
if z:
x
y
""")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
stats[0]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
stats[1]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_statinexpr(self):
t = self.run_pipeline([PostParse()], u"""
a, b = x, y
""")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
stats[0]: ParallelAssignmentNode
stats[0]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
stats[1]: SingleAssignmentNode
lhs: NameNode
rhs: NameNode
""", self.treetypes(t))
def test_wrap_offagain(self):
t = self.run_pipeline([PostParse()], u"""
x
y
if z:
x
""")
self.assertLines(u"""
(root): ModuleNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
stats[1]: ExprStatNode
expr: NameNode
stats[2]: IfStatNode
if_clauses[0]: IfClauseNode
condition: NameNode
body: StatListNode
stats[0]: ExprStatNode
expr: NameNode
""", self.treetypes(t))
def test_pass_eliminated(self):
t = self.run_pipeline([PostParse()], u"pass")
self.assert_(len(t.body.stats) == 0)
class TestWithTransform(TransformTest):
def test_simplified(self):
t = self.run_pipeline([WithTransform()], u"""
with x:
y = z ** 3
""")
self.assertCode(u"""
$SYS = (import sys)
$MGR = x
$EXIT = $MGR.__exit__
$MGR.__enter__()
$EXC = True
try:
try:
y = z ** 3
except:
$EXC = False
if (not $EXIT($SYS.exc_info())):
raise
finally:
if $EXC:
$EXIT(None, None, None)
""", t)
def test_basic(self):
t = self.run_pipeline([WithTransform()], u"""
with x as y:
y = z ** 3
""")
self.assertCode(u"""
$SYS = (import sys)
$MGR = x
$EXIT = $MGR.__exit__
$VALUE = $MGR.__enter__()
$EXC = True
try:
try:
y = $VALUE
y = z ** 3
except:
$EXC = False
if (not $EXIT($SYS.exc_info())):
raise
finally:
if $EXC:
$EXIT(None, None, None)
""", t)
if __name__ == "__main__":
import unittest
unittest.main()
......@@ -28,6 +28,16 @@ class NodeTypeWriter(TreeVisitor):
self._indents -= 1
class CythonTest(unittest.TestCase):
def assertLines(self, expected, result):
"Checks that the given strings or lists of strings are equal line by line"
if not isinstance(expected, list): expected = expected.split(u"\n")
if not isinstance(result, list): result = result.split(u"\n")
for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
self.assertEqual(len(expected), len(result),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
def assertCode(self, expected, result_tree):
writer = CodeWriter()
writer.write(result_tree)
......
......@@ -72,6 +72,9 @@ class TestCodeWriter(CythonTest):
def test_inplace_assignment(self):
self.t(u"x += 43")
def test_attribute(self):
self.t(u"a.x")
if __name__ == "__main__":
import unittest
......
from __future__ import with_statement
__doc__ = u"""
>>> basic()
enter
value
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> with_exception(None)
enter
value
exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
outer except
>>> with_exception(True)
enter
value
exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
>>> multitarget()
enter
1 2 3 4 5
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
>>> tupletarget()
enter
(1, 2, (3, (4, 5)))
exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
"""
class ContextManager:
def __init__(self, value, exit_ret = None):
self.value = value
self.exit_ret = exit_ret
def __exit__(self, a, b, c):
print "exit", type(a), type(b), type(c)
return self.exit_ret
def __enter__(self):
print "enter"
return self.value
def basic():
with ContextManager("value") as x:
print x
def with_exception(exit_ret):
try:
with ContextManager("value", exit_ret=exit_ret) as value:
print value
raise Exception()
except:
print "outer except"
def multitarget():
with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
print a, b, c, d, e
def tupletarget():
with ContextManager((1, 2, (3, (4, 5)))) as t:
print t
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment