Commit d78673fc authored by Stefan Behnel's avatar Stefan Behnel

implement "async with" (PEP 492)

parent fd0e47f5
...@@ -2602,12 +2602,16 @@ class WithExitCallNode(ExprNode): ...@@ -2602,12 +2602,16 @@ class WithExitCallNode(ExprNode):
# with_stat WithStatNode the surrounding 'with' statement # with_stat WithStatNode the surrounding 'with' statement
# args TupleNode or ResultStatNode the exception info tuple # args TupleNode or ResultStatNode the exception info tuple
# await AwaitExprNode the await
subexprs = ['args'] subexprs = ['args', 'await']
test_if_run = True test_if_run = True
await = None
def analyse_types(self, env): def analyse_types(self, env):
self.args = self.args.analyse_types(env) self.args = self.args.analyse_types(env)
if self.await:
self.await = self.await.analyse_types(env)
self.type = PyrexTypes.c_bint_type self.type = PyrexTypes.c_bint_type
self.is_temp = True self.is_temp = True
return self return self
...@@ -2633,6 +2637,13 @@ class WithExitCallNode(ExprNode): ...@@ -2633,6 +2637,13 @@ class WithExitCallNode(ExprNode):
code.putln(code.error_goto_if_null(result_var, self.pos)) code.putln(code.error_goto_if_null(result_var, self.pos))
code.put_gotref(result_var) code.put_gotref(result_var)
if self.await:
self.await.generate_evaluation_code(code, source_cname=result_var)
code.putln("%s = %s;" % (result_var, self.await.py_result()))
self.await.generate_post_assignment_code(code)
self.await.free_temps(code)
if self.result_is_used: if self.result_is_used:
self.allocate_temp_result(code) self.allocate_temp_result(code)
code.putln("%s = __Pyx_PyObject_IsTrue(%s);" % (self.result(), result_var)) code.putln("%s = __Pyx_PyObject_IsTrue(%s);" % (self.result(), result_var))
...@@ -8675,15 +8686,19 @@ class YieldFromExprNode(YieldExprNode): ...@@ -8675,15 +8686,19 @@ class YieldFromExprNode(YieldExprNode):
code.globalstate.use_utility_code(UtilityCode.load_cached("GeneratorYieldFrom", "Coroutine.c")) code.globalstate.use_utility_code(UtilityCode.load_cached("GeneratorYieldFrom", "Coroutine.c"))
return "__Pyx_Generator_Yield_From" return "__Pyx_Generator_Yield_From"
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code, source_cname=None):
if source_cname is None:
self.arg.generate_evaluation_code(code) self.arg.generate_evaluation_code(code)
code.putln("%s = %s(%s, %s);" % ( code.putln("%s = %s(%s, %s);" % (
Naming.retval_cname, Naming.retval_cname,
self.yield_from_func(code), self.yield_from_func(code),
Naming.generator_cname, Naming.generator_cname,
self.arg.py_result())) self.arg.py_result() if source_cname is None else source_cname))
if source_cname is None:
self.arg.generate_disposal_code(code) self.arg.generate_disposal_code(code)
self.arg.free_temps(code) self.arg.free_temps(code)
else:
code.put_decref_clear(source_cname, py_object_type)
code.put_xgotref(Naming.retval_cname) code.put_xgotref(Naming.retval_cname)
code.putln("if (likely(%s)) {" % Naming.retval_cname) code.putln("if (likely(%s)) {" % Naming.retval_cname)
...@@ -8715,6 +8730,7 @@ class AwaitExprNode(YieldFromExprNode): ...@@ -8715,6 +8730,7 @@ class AwaitExprNode(YieldFromExprNode):
expr_keyword = 'await' expr_keyword = 'await'
def coerce_yield_argument(self, env): def coerce_yield_argument(self, env):
if self.arg is not None:
# FIXME: use same check as in YieldFromExprNode.coerce_yield_argument() ? # FIXME: use same check as in YieldFromExprNode.coerce_yield_argument() ?
self.arg = self.arg.coerce_to_pyobject(env) self.arg = self.arg.coerce_to_pyobject(env)
......
...@@ -6454,7 +6454,7 @@ class WithStatNode(StatNode): ...@@ -6454,7 +6454,7 @@ class WithStatNode(StatNode):
code.putln("%s = __Pyx_PyObject_LookupSpecial(%s, %s); %s" % ( code.putln("%s = __Pyx_PyObject_LookupSpecial(%s, %s); %s" % (
self.exit_var, self.exit_var,
self.manager.py_result(), self.manager.py_result(),
code.intern_identifier(EncodedString('__exit__')), code.intern_identifier(EncodedString('__aexit__' if self.is_async else '__exit__')),
code.error_goto_if_null(self.exit_var, self.pos), code.error_goto_if_null(self.exit_var, self.pos),
)) ))
code.put_gotref(self.exit_var) code.put_gotref(self.exit_var)
......
...@@ -1219,15 +1219,19 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1219,15 +1219,19 @@ class WithTransform(CythonTransform, SkipDeclarations):
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
self.visitchildren(node, 'body') self.visitchildren(node, 'body')
pos = node.pos pos = node.pos
is_async = node.is_async
body, target, manager = node.body, node.target, node.manager body, target, manager = node.body, node.target, node.manager
node.enter_call = ExprNodes.SimpleCallNode( node.enter_call = ExprNodes.SimpleCallNode(
pos, function=ExprNodes.AttributeNode( pos, function=ExprNodes.AttributeNode(
pos, obj=ExprNodes.CloneNode(manager), pos, obj=ExprNodes.CloneNode(manager),
attribute=EncodedString('__enter__'), attribute=EncodedString('__aenter__' if is_async else '__enter__'),
is_special_lookup=True), is_special_lookup=True),
args=[], args=[],
is_temp=True) is_temp=True)
if is_async:
node.enter_call = ExprNodes.AwaitExprNode(pos, arg=node.enter_call)
if target is not None: if target is not None:
body = Nodes.StatListNode( body = Nodes.StatListNode(
pos, stats=[ pos, stats=[
...@@ -1245,7 +1249,8 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1245,7 +1249,8 @@ class WithTransform(CythonTransform, SkipDeclarations):
pos, operand=ExprNodes.WithExitCallNode( pos, operand=ExprNodes.WithExitCallNode(
pos, with_stat=node, pos, with_stat=node,
test_if_run=False, test_if_run=False,
args=excinfo_target)), args=excinfo_target,
await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
body=Nodes.ReraiseStatNode(pos), body=Nodes.ReraiseStatNode(pos),
), ),
], ],
...@@ -1266,8 +1271,8 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1266,8 +1271,8 @@ class WithTransform(CythonTransform, SkipDeclarations):
pos, with_stat=node, pos, with_stat=node,
test_if_run=True, test_if_run=True,
args=ExprNodes.TupleNode( args=ExprNodes.TupleNode(
pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)] pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)]),
))), await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
handle_error_case=False, handle_error_case=False,
) )
return node return node
......
...@@ -120,7 +120,7 @@ cdef p_try_statement(PyrexScanner s) ...@@ -120,7 +120,7 @@ cdef p_try_statement(PyrexScanner s)
cdef p_except_clause(PyrexScanner s) cdef p_except_clause(PyrexScanner s)
cdef p_include_statement(PyrexScanner s, ctx) cdef p_include_statement(PyrexScanner s, ctx)
cdef p_with_statement(PyrexScanner s) cdef p_with_statement(PyrexScanner s)
cdef p_with_items(PyrexScanner s) cdef p_with_items(PyrexScanner s, bint is_async=*)
cdef p_with_template(PyrexScanner s) cdef p_with_template(PyrexScanner s)
cdef p_simple_statement(PyrexScanner s, bint first_statement = *) cdef p_simple_statement(PyrexScanner s, bint first_statement = *)
cdef p_simple_statement_list(PyrexScanner s, ctx, bint first_statement = *) cdef p_simple_statement_list(PyrexScanner s, ctx, bint first_statement = *)
......
...@@ -393,8 +393,8 @@ def p_async_statement(s, ctx, decorators): ...@@ -393,8 +393,8 @@ def p_async_statement(s, ctx, decorators):
#s.error("'async for' is not currently supported", fatal=False) #s.error("'async for' is not currently supported", fatal=False)
return p_statement(s, ctx) # TODO: implement return p_statement(s, ctx) # TODO: implement
elif s.sy == 'with': elif s.sy == 'with':
#s.error("'async with' is not currently supported", fatal=False) s.next()
return p_statement(s, ctx) # TODO: implement return p_with_items(s, is_async=True)
else: else:
s.error("expected one of 'def', 'for', 'with' after 'async'") s.error("expected one of 'def', 'for', 'with' after 'async'")
...@@ -1781,6 +1781,7 @@ def p_include_statement(s, ctx): ...@@ -1781,6 +1781,7 @@ def p_include_statement(s, ctx):
else: else:
return Nodes.PassStatNode(pos) return Nodes.PassStatNode(pos)
def p_with_statement(s): def p_with_statement(s):
s.next() # 'with' s.next() # 'with'
if s.systring == 'template' and not s.in_python_file: if s.systring == 'template' and not s.in_python_file:
...@@ -1789,9 +1790,12 @@ def p_with_statement(s): ...@@ -1789,9 +1790,12 @@ def p_with_statement(s):
node = p_with_items(s) node = p_with_items(s)
return node return node
def p_with_items(s):
def p_with_items(s, is_async=False):
pos = s.position() pos = s.position()
if not s.in_python_file and s.sy == 'IDENT' and s.systring in ('nogil', 'gil'): if not s.in_python_file and s.sy == 'IDENT' and s.systring in ('nogil', 'gil'):
if is_async:
s.error("with gil/nogil cannot be async")
state = s.systring state = s.systring
s.next() s.next()
if s.sy == ',': if s.sy == ',':
...@@ -1799,7 +1803,7 @@ def p_with_items(s): ...@@ -1799,7 +1803,7 @@ def p_with_items(s):
body = p_with_items(s) body = p_with_items(s)
else: else:
body = p_suite(s) body = p_suite(s)
return Nodes.GILStatNode(pos, state = state, body = body) return Nodes.GILStatNode(pos, state=state, body=body)
else: else:
manager = p_test(s) manager = p_test(s)
target = None target = None
...@@ -1808,11 +1812,11 @@ def p_with_items(s): ...@@ -1808,11 +1812,11 @@ def p_with_items(s):
target = p_starred_expr(s) target = p_starred_expr(s)
if s.sy == ',': if s.sy == ',':
s.next() s.next()
body = p_with_items(s) body = p_with_items(s, is_async=is_async)
else: else:
body = p_suite(s) body = p_suite(s)
return Nodes.WithStatNode(pos, manager = manager, return Nodes.WithStatNode(pos, manager=manager, target=target, body=body, is_async=is_async)
target = target, body = body)
def p_with_template(s): def p_with_template(s):
pos = s.position() pos = s.position()
......
...@@ -172,12 +172,12 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -172,12 +172,12 @@ def create_pipeline(context, mode, exclude_classes=()):
InterpretCompilerDirectives(context, context.compiler_directives), InterpretCompilerDirectives(context, context.compiler_directives),
ParallelRangeTransform(context), ParallelRangeTransform(context),
AdjustDefByDirectives(context), AdjustDefByDirectives(context),
WithTransform(context),
MarkClosureVisitor(context), MarkClosureVisitor(context),
_align_function_definitions, _align_function_definitions,
RemoveUnreachableCode(context), RemoveUnreachableCode(context),
ConstantFolding(), ConstantFolding(),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context), DecoratorTransform(context),
ForwardDeclareTypes(context), ForwardDeclareTypes(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
......
# cython: language_level=3, binding=True # cython: language_level=3, binding=True
import re
import gc import gc
import sys import sys
import types import types
...@@ -89,6 +90,10 @@ class CoroutineTest(unittest.TestCase): ...@@ -89,6 +90,10 @@ class CoroutineTest(unittest.TestCase):
else: else:
self.assertTrue(False) self.assertTrue(False)
def assertRegex(self, value, regex):
self.assertTrue(re.search(regex, str(value)),
"'%s' did not match '%s'" % (value, regex))
def test_gen_1(self): def test_gen_1(self):
def gen(): yield def gen(): yield
self.assertFalse(hasattr(gen, '__await__')) self.assertFalse(hasattr(gen, '__await__'))
...@@ -508,6 +513,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -508,6 +513,7 @@ class CoroutineTest(unittest.TestCase):
except TypeError as exc: except TypeError as exc:
self.assertRegex( self.assertRegex(
exc.args[0], "object int can't be used in 'await' expression") exc.args[0], "object int can't be used in 'await' expression")
if sys.version_info[0] >= 3:
self.assertTrue(exc.__context__ is not None) self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
else: else:
...@@ -577,10 +583,10 @@ class CoroutineTest(unittest.TestCase): ...@@ -577,10 +583,10 @@ class CoroutineTest(unittest.TestCase):
try: try:
run_async(foo()) run_async(foo())
except ZeroDivisionError as exc: except ZeroDivisionError as exc:
if sys.version_info[0] >= 3:
self.assertTrue(exc.__context__ is not None) self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
self.assertTrue(isinstance(exc.__context__.__context__, self.assertTrue(isinstance(exc.__context__.__context__, RuntimeError))
RuntimeError))
else: else:
self.fail('exception from __aexit__ did not propagate') self.fail('exception from __aexit__ did not propagate')
...@@ -602,6 +608,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -602,6 +608,7 @@ class CoroutineTest(unittest.TestCase):
try: try:
run_async(foo()) run_async(foo())
except NotImplementedError as exc: except NotImplementedError as exc:
if sys.version_info[0] >= 3:
self.assertTrue(exc.__context__ is None) self.assertTrue(exc.__context__ is None)
else: else:
self.fail('exception from __aenter__ did not propagate') self.fail('exception from __aenter__ did not propagate')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment