Commit 7696a805 authored by Stefan Behnel's avatar Stefan Behnel

adapt async/await implementation to CPython ticket 24619

http://bugs.python.org/issue24619
parent d8be3ad8
...@@ -3057,6 +3057,9 @@ def p_decorators(s): ...@@ -3057,6 +3057,9 @@ def p_decorators(s):
def p_def_statement(s, decorators=None, is_async_def=False): def p_def_statement(s, decorators=None, is_async_def=False):
# s.sy == 'def' # s.sy == 'def'
pos = s.position() pos = s.position()
# PEP 492 switches the async/await keywords on in "async def" functions
if is_async_def:
s.enter_async()
s.next() s.next()
name = p_ident(s) name = p_ident(s)
s.expect('(') s.expect('(')
...@@ -3069,23 +3072,9 @@ def p_def_statement(s, decorators=None, is_async_def=False): ...@@ -3069,23 +3072,9 @@ def p_def_statement(s, decorators=None, is_async_def=False):
s.next() s.next()
return_type_annotation = p_test(s) return_type_annotation = p_test(s)
# PEP 492 switches the async/await keywords off in simple "def" functions
# and on in "async def" functions
await_was_enabled = s.enable_keyword('await') if is_async_def else s.disable_keyword('await')
async_was_enabled = s.enable_keyword('async') if is_async_def else s.disable_keyword('async')
doc, body = p_suite_with_docstring(s, Ctx(level='function')) doc, body = p_suite_with_docstring(s, Ctx(level='function'))
if is_async_def: if is_async_def:
if not async_was_enabled: s.exit_async()
s.disable_keyword('async')
if not await_was_enabled:
s.disable_keyword('await')
else:
if async_was_enabled:
s.enable_keyword('async')
if await_was_enabled:
s.enable_keyword('await')
return Nodes.DefNode( return Nodes.DefNode(
pos, name=name, args=args, star_arg=star_arg, starstar_arg=starstar_arg, pos, name=name, args=args, star_arg=star_arg, starstar_arg=starstar_arg,
......
...@@ -33,6 +33,7 @@ cdef class PyrexScanner(Scanner): ...@@ -33,6 +33,7 @@ cdef class PyrexScanner(Scanner):
cdef public list indentation_stack cdef public list indentation_stack
cdef public indentation_char cdef public indentation_char
cdef public int bracket_nesting_level cdef public int bracket_nesting_level
cdef bint async_enabled
cdef public sy cdef public sy
cdef public systring cdef public systring
...@@ -57,5 +58,5 @@ cdef class PyrexScanner(Scanner): ...@@ -57,5 +58,5 @@ cdef class PyrexScanner(Scanner):
cdef expect_indent(self) cdef expect_indent(self)
cdef expect_dedent(self) cdef expect_dedent(self)
cdef expect_newline(self, message=*, bint ignore_semicolon=*) cdef expect_newline(self, message=*, bint ignore_semicolon=*)
cdef bint enable_keyword(self, name) except -1 cdef int enter_async(self) except -1
cdef bint disable_keyword(self, name) except -1 cdef int exit_async(self) except -1
...@@ -317,6 +317,7 @@ class PyrexScanner(Scanner): ...@@ -317,6 +317,7 @@ class PyrexScanner(Scanner):
self.indentation_stack = [0] self.indentation_stack = [0]
self.indentation_char = None self.indentation_char = None
self.bracket_nesting_level = 0 self.bracket_nesting_level = 0
self.async_enabled = False
self.begin('INDENT') self.begin('INDENT')
self.sy = '' self.sy = ''
self.next() self.next()
...@@ -493,14 +494,15 @@ class PyrexScanner(Scanner): ...@@ -493,14 +494,15 @@ class PyrexScanner(Scanner):
if useless_trailing_semicolon is not None: if useless_trailing_semicolon is not None:
warning(useless_trailing_semicolon, "useless trailing semicolon") warning(useless_trailing_semicolon, "useless trailing semicolon")
def enable_keyword(self, name): def enter_async(self):
if name in self.keywords: self.async_enabled += 1
return True # was enabled before if self.async_enabled == 1:
self.keywords.add(name) self.keywords.add('async')
return False # was not enabled before self.keywords.add('await')
def disable_keyword(self, name): def exit_async(self):
if name not in self.keywords: assert self.async_enabled > 0
return False # was not enabled before self.async_enabled -= 1
self.keywords.remove(name) if not self.async_enabled:
return True # was enabled before self.keywords.discard('await')
self.keywords.discard('async')
...@@ -12,6 +12,8 @@ import unittest ...@@ -12,6 +12,8 @@ import unittest
import warnings import warnings
import contextlib import contextlib
from Cython.Compiler import Errors
try: try:
from types import coroutine as types_coroutine from types import coroutine as types_coroutine
...@@ -116,6 +118,258 @@ def silence_coro_gc(): ...@@ -116,6 +118,258 @@ def silence_coro_gc():
gc.collect() gc.collect()
class AsyncBadSyntaxTest(unittest.TestCase):
@contextlib.contextmanager
def assertRaisesRegex(self, exc_type, regex):
# the error messages usually don't match, so we just ignore them
try:
yield
except exc_type:
self.assertTrue(True)
else:
self.assertTrue(False)
def test_badsyntax_9(self):
ns = {}
for comp in {'(await a for a in b)',
'[await a for a in b]',
'{await a for a in b}',
'{await a: a for a in b}'}:
with self.assertRaisesRegex(Errors.CompileError, 'await.*in comprehen'):
exec('async def f():\n\t{}'.format(comp), ns, ns)
def test_badsyntax_10(self):
# Tests for issue 24619
samples = [
"""async def foo():
def bar(): pass
await = 1
""",
"""async def foo():
def bar(): pass
await = 1
""",
"""async def foo():
def bar(): pass
if 1:
await = 1
""",
"""def foo():
async def bar(): pass
if 1:
await a
""",
"""def foo():
async def bar(): pass
await a
""",
"""def foo():
def baz(): pass
async def bar(): pass
await a
""",
"""def foo():
def baz(): pass
# 456
async def bar(): pass
# 123
await a
""",
"""async def foo():
def baz(): pass
# 456
async def bar(): pass
# 123
await = 2
""",
"""def foo():
def baz(): pass
async def bar(): pass
await a
""",
"""async def foo():
def baz(): pass
async def bar(): pass
await = 2
""",
"""async def foo():
def async(): pass
""",
"""async def foo():
def await(): pass
""",
"""async def foo():
def bar():
await
""",
"""async def foo():
return lambda async: await
""",
"""async def foo():
return lambda a: await
""",
"""await a()""",
"""async def foo(a=await b):
pass
""",
"""async def foo(a:await b):
pass
""",
"""def baz():
async def foo(a=await b):
pass
""",
"""async def foo(async):
pass
""",
"""async def foo():
def bar():
def baz():
async = 1
""",
"""async def foo():
def bar():
def baz():
pass
async = 1
""",
"""def foo():
async def bar():
async def baz():
pass
def baz():
42
async = 1
""",
"""async def foo():
def bar():
def baz():
pass\nawait foo()
""",
"""def foo():
def bar():
async def baz():
pass\nawait foo()
""",
"""async def foo(await):
pass
""",
"""def foo():
async def bar(): pass
await a
""",
"""def foo():
async def bar():
pass\nawait a
"""]
for code in samples:
with self.subTest(code=code), self.assertRaises(Errors.CompileError):
exec(code, {}, {})
if not hasattr(unittest.TestCase, 'subTest'):
@contextlib.contextmanager
def subTest(self, code, **kwargs):
try:
yield
except Exception:
print(code)
raise
def test_goodsyntax_1(self):
# Tests for issue 24619
def foo(await):
async def foo(): pass
async def foo():
pass
return await + 1
self.assertEqual(foo(10), 11)
def foo(await):
async def foo(): pass
async def foo(): pass
return await + 2
self.assertEqual(foo(20), 22)
def foo(await):
async def foo(): pass
async def foo(): pass
return await + 2
self.assertEqual(foo(20), 22)
def foo(await):
"""spam"""
async def foo(): \
pass
# 123
async def foo(): pass
# 456
return await + 2
self.assertEqual(foo(20), 22)
def foo(await):
def foo(): pass
def foo(): pass
async def bar(): return await_
await_ = await
try:
bar().send(None)
except StopIteration as ex:
return ex.args[0]
self.assertEqual(foo(42), 42)
async def f(z):
async def g(): pass
await z
#self.assertTrue(inspect.iscoroutinefunction(f))
class TokenizerRegrTest(unittest.TestCase): class TokenizerRegrTest(unittest.TestCase):
def test_oneline_defs(self): def test_oneline_defs(self):
...@@ -138,17 +392,6 @@ class TokenizerRegrTest(unittest.TestCase): ...@@ -138,17 +392,6 @@ class TokenizerRegrTest(unittest.TestCase):
self.assertEqual(type(ns['foo']()).__name__, 'coroutine') self.assertEqual(type(ns['foo']()).__name__, 'coroutine')
#self.assertTrue(inspect.iscoroutinefunction(ns['foo'])) #self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
def test_syntax_async_await_as_names(self):
async def enable():
await 123
def disable():
await = 123
async = 'abc'
async def reenable():
await 432
class CoroutineTest(unittest.TestCase): class CoroutineTest(unittest.TestCase):
...@@ -511,8 +754,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -511,8 +754,7 @@ class CoroutineTest(unittest.TestCase):
class Awaitable: class Awaitable:
pass pass
async def foo(): async def foo(): return (await Awaitable())
return (await Awaitable())
with self.assertRaisesRegex( with self.assertRaisesRegex(
TypeError, "object Awaitable can't be used in 'await' expression"): TypeError, "object Awaitable can't be used in 'await' expression"):
...@@ -599,6 +841,39 @@ class CoroutineTest(unittest.TestCase): ...@@ -599,6 +841,39 @@ class CoroutineTest(unittest.TestCase):
run_async(foo()) run_async(foo())
def test_await_14(self):
class Wrapper:
# Forces the interpreter to use CoroutineType.__await__
def __init__(self, coro):
self.coro = coro
def __await__(self):
return self.coro.__await__()
class FutureLike:
def __await__(self):
return (yield)
class Marker(Exception):
pass
async def coro1():
try:
return await FutureLike()
except ZeroDivisionError:
raise Marker
async def coro2():
return await Wrapper(coro1())
c = coro2()
c.send(None)
with self.assertRaisesRegex(StopIteration, 'spam'):
c.send('spam')
c = coro2()
c.send(None)
with self.assertRaises(Marker):
c.throw(ZeroDivisionError)
def test_await_iterator(self): def test_await_iterator(self):
async def foo(): async def foo():
return 123 return 123
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment