Commit d78673fc authored by Stefan Behnel's avatar Stefan Behnel

implement "async with" (PEP 492)

parent fd0e47f5
......@@ -2602,12 +2602,16 @@ class WithExitCallNode(ExprNode):
# with_stat WithStatNode the surrounding 'with' statement
# args TupleNode or ResultStatNode the exception info tuple
# await AwaitExprNode the await
subexprs = ['args']
subexprs = ['args', 'await']
test_if_run = True
await = None
def analyse_types(self, env):
self.args = self.args.analyse_types(env)
if self.await:
self.await = self.await.analyse_types(env)
self.type = PyrexTypes.c_bint_type
self.is_temp = True
return self
......@@ -2633,6 +2637,13 @@ class WithExitCallNode(ExprNode):
code.putln(code.error_goto_if_null(result_var, self.pos))
code.put_gotref(result_var)
if self.await:
self.await.generate_evaluation_code(code, source_cname=result_var)
code.putln("%s = %s;" % (result_var, self.await.py_result()))
self.await.generate_post_assignment_code(code)
self.await.free_temps(code)
if self.result_is_used:
self.allocate_temp_result(code)
code.putln("%s = __Pyx_PyObject_IsTrue(%s);" % (self.result(), result_var))
......@@ -8675,15 +8686,19 @@ class YieldFromExprNode(YieldExprNode):
code.globalstate.use_utility_code(UtilityCode.load_cached("GeneratorYieldFrom", "Coroutine.c"))
return "__Pyx_Generator_Yield_From"
def generate_evaluation_code(self, code):
self.arg.generate_evaluation_code(code)
def generate_evaluation_code(self, code, source_cname=None):
if source_cname is None:
self.arg.generate_evaluation_code(code)
code.putln("%s = %s(%s, %s);" % (
Naming.retval_cname,
self.yield_from_func(code),
Naming.generator_cname,
self.arg.py_result()))
self.arg.generate_disposal_code(code)
self.arg.free_temps(code)
self.arg.py_result() if source_cname is None else source_cname))
if source_cname is None:
self.arg.generate_disposal_code(code)
self.arg.free_temps(code)
else:
code.put_decref_clear(source_cname, py_object_type)
code.put_xgotref(Naming.retval_cname)
code.putln("if (likely(%s)) {" % Naming.retval_cname)
......@@ -8715,8 +8730,9 @@ class AwaitExprNode(YieldFromExprNode):
expr_keyword = 'await'
def coerce_yield_argument(self, env):
# FIXME: use same check as in YieldFromExprNode.coerce_yield_argument() ?
self.arg = self.arg.coerce_to_pyobject(env)
if self.arg is not None:
# FIXME: use same check as in YieldFromExprNode.coerce_yield_argument() ?
self.arg = self.arg.coerce_to_pyobject(env)
def yield_from_func(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("CoroutineYieldFrom", "Coroutine.c"))
......
......@@ -6454,7 +6454,7 @@ class WithStatNode(StatNode):
code.putln("%s = __Pyx_PyObject_LookupSpecial(%s, %s); %s" % (
self.exit_var,
self.manager.py_result(),
code.intern_identifier(EncodedString('__exit__')),
code.intern_identifier(EncodedString('__aexit__' if self.is_async else '__exit__')),
code.error_goto_if_null(self.exit_var, self.pos),
))
code.put_gotref(self.exit_var)
......
......@@ -1219,15 +1219,19 @@ class WithTransform(CythonTransform, SkipDeclarations):
def visit_WithStatNode(self, node):
self.visitchildren(node, 'body')
pos = node.pos
is_async = node.is_async
body, target, manager = node.body, node.target, node.manager
node.enter_call = ExprNodes.SimpleCallNode(
pos, function=ExprNodes.AttributeNode(
pos, obj=ExprNodes.CloneNode(manager),
attribute=EncodedString('__enter__'),
attribute=EncodedString('__aenter__' if is_async else '__enter__'),
is_special_lookup=True),
args=[],
is_temp=True)
if is_async:
node.enter_call = ExprNodes.AwaitExprNode(pos, arg=node.enter_call)
if target is not None:
body = Nodes.StatListNode(
pos, stats=[
......@@ -1245,7 +1249,8 @@ class WithTransform(CythonTransform, SkipDeclarations):
pos, operand=ExprNodes.WithExitCallNode(
pos, with_stat=node,
test_if_run=False,
args=excinfo_target)),
args=excinfo_target,
await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
body=Nodes.ReraiseStatNode(pos),
),
],
......@@ -1266,8 +1271,8 @@ class WithTransform(CythonTransform, SkipDeclarations):
pos, with_stat=node,
test_if_run=True,
args=ExprNodes.TupleNode(
pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)]
))),
pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)]),
await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
handle_error_case=False,
)
return node
......
......@@ -120,7 +120,7 @@ cdef p_try_statement(PyrexScanner s)
cdef p_except_clause(PyrexScanner s)
cdef p_include_statement(PyrexScanner s, ctx)
cdef p_with_statement(PyrexScanner s)
cdef p_with_items(PyrexScanner s)
cdef p_with_items(PyrexScanner s, bint is_async=*)
cdef p_with_template(PyrexScanner s)
cdef p_simple_statement(PyrexScanner s, bint first_statement = *)
cdef p_simple_statement_list(PyrexScanner s, ctx, bint first_statement = *)
......
......@@ -393,8 +393,8 @@ def p_async_statement(s, ctx, decorators):
#s.error("'async for' is not currently supported", fatal=False)
return p_statement(s, ctx) # TODO: implement
elif s.sy == 'with':
#s.error("'async with' is not currently supported", fatal=False)
return p_statement(s, ctx) # TODO: implement
s.next()
return p_with_items(s, is_async=True)
else:
s.error("expected one of 'def', 'for', 'with' after 'async'")
......@@ -1781,17 +1781,21 @@ def p_include_statement(s, ctx):
else:
return Nodes.PassStatNode(pos)
def p_with_statement(s):
s.next() # 'with'
s.next() # 'with'
if s.systring == 'template' and not s.in_python_file:
node = p_with_template(s)
else:
node = p_with_items(s)
return node
def p_with_items(s):
def p_with_items(s, is_async=False):
pos = s.position()
if not s.in_python_file and s.sy == 'IDENT' and s.systring in ('nogil', 'gil'):
if is_async:
s.error("with gil/nogil cannot be async")
state = s.systring
s.next()
if s.sy == ',':
......@@ -1799,7 +1803,7 @@ def p_with_items(s):
body = p_with_items(s)
else:
body = p_suite(s)
return Nodes.GILStatNode(pos, state = state, body = body)
return Nodes.GILStatNode(pos, state=state, body=body)
else:
manager = p_test(s)
target = None
......@@ -1808,11 +1812,11 @@ def p_with_items(s):
target = p_starred_expr(s)
if s.sy == ',':
s.next()
body = p_with_items(s)
body = p_with_items(s, is_async=is_async)
else:
body = p_suite(s)
return Nodes.WithStatNode(pos, manager = manager,
target = target, body = body)
return Nodes.WithStatNode(pos, manager=manager, target=target, body=body, is_async=is_async)
def p_with_template(s):
pos = s.position()
......
......@@ -172,12 +172,12 @@ def create_pipeline(context, mode, exclude_classes=()):
InterpretCompilerDirectives(context, context.compiler_directives),
ParallelRangeTransform(context),
AdjustDefByDirectives(context),
WithTransform(context),
MarkClosureVisitor(context),
_align_function_definitions,
RemoveUnreachableCode(context),
ConstantFolding(),
FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
ForwardDeclareTypes(context),
AnalyseDeclarationsTransform(context),
......
# cython: language_level=3, binding=True
import re
import gc
import sys
import types
......@@ -89,6 +90,10 @@ class CoroutineTest(unittest.TestCase):
else:
self.assertTrue(False)
def assertRegex(self, value, regex):
self.assertTrue(re.search(regex, str(value)),
"'%s' did not match '%s'" % (value, regex))
def test_gen_1(self):
def gen(): yield
self.assertFalse(hasattr(gen, '__await__'))
......@@ -508,8 +513,9 @@ class CoroutineTest(unittest.TestCase):
except TypeError as exc:
self.assertRegex(
exc.args[0], "object int can't be used in 'await' expression")
self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
if sys.version_info[0] >= 3:
self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
else:
self.fail('invalid asynchronous context manager did not fail')
......@@ -577,10 +583,10 @@ class CoroutineTest(unittest.TestCase):
try:
run_async(foo())
except ZeroDivisionError as exc:
self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
self.assertTrue(isinstance(exc.__context__.__context__,
RuntimeError))
if sys.version_info[0] >= 3:
self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
self.assertTrue(isinstance(exc.__context__.__context__, RuntimeError))
else:
self.fail('exception from __aexit__ did not propagate')
......@@ -602,7 +608,8 @@ class CoroutineTest(unittest.TestCase):
try:
run_async(foo())
except NotImplementedError as exc:
self.assertTrue(exc.__context__ is None)
if sys.version_info[0] >= 3:
self.assertTrue(exc.__context__ is None)
else:
self.fail('exception from __aenter__ did not propagate')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment