Commit 003b2cfd authored by Stefan Behnel's avatar Stefan Behnel

integrate new Py3.6+ couroutine tests and repair misbehaviour in error cases...

integrate new Py3.6+ couroutine tests and repair misbehaviour in error cases and when async-iterating over awaitable tuple/StopIteration subtypes (however rare that is)
parent 6c5fab7c
...@@ -56,14 +56,14 @@ static int __Pyx_WarnAIterDeprecation(CYTHON_UNUSED PyObject *aiter) { ...@@ -56,14 +56,14 @@ static int __Pyx_WarnAIterDeprecation(CYTHON_UNUSED PyObject *aiter) {
int result; int result;
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
result = PyErr_WarnFormat( result = PyErr_WarnFormat(
PyExc_PendingDeprecationWarning, 1, PyExc_DeprecationWarning, 1,
"'%.100s' implements legacy __aiter__ protocol; " "'%.100s' implements legacy __aiter__ protocol; "
"__aiter__ should return an asynchronous " "__aiter__ should return an asynchronous "
"iterator, not awaitable", "iterator, not awaitable",
Py_TYPE(aiter)->tp_name); Py_TYPE(aiter)->tp_name);
#else #else
result = PyErr_WarnEx( result = PyErr_WarnEx(
PyExc_PendingDeprecationWarning, PyExc_DeprecationWarning,
"object implements legacy __aiter__ protocol; " "object implements legacy __aiter__ protocol; "
"__aiter__ should return an asynchronous " "__aiter__ should return an asynchronous "
"iterator, not awaitable", "iterator, not awaitable",
...@@ -72,6 +72,69 @@ static int __Pyx_WarnAIterDeprecation(CYTHON_UNUSED PyObject *aiter) { ...@@ -72,6 +72,69 @@ static int __Pyx_WarnAIterDeprecation(CYTHON_UNUSED PyObject *aiter) {
return result != 0; return result != 0;
} }
static void __Pyx__Coroutine_Yield_From_Error(PyObject *source) {
#if PY_VERSION_HEX >= 0x03060000 || defined(_PyErr_FormatFromCause)
_PyErr_FormatFromCause(
PyExc_TypeError,
"'async for' received an invalid object "
"from __anext__: %.100s",
Py_TYPE(source)->tp_name);
#elif PY_MAJOR_VERSION >= 3
PyObject *exc, *val, *val2, *tb;
assert(PyErr_Occurred());
PyErr_Fetch(&exc, &val, &tb);
PyErr_NormalizeException(&exc, &val, &tb);
if (tb != NULL) {
PyException_SetTraceback(val, tb);
Py_DECREF(tb);
}
Py_DECREF(exc);
assert(!PyErr_Occurred());
PyErr_Format(
PyExc_TypeError,
"'async for' received an invalid object "
"from __anext__: %.100s",
Py_TYPE(source)->tp_name);
PyErr_Fetch(&exc, &val2, &tb);
PyErr_NormalizeException(&exc, &val2, &tb);
Py_INCREF(val);
PyException_SetCause(val2, val);
PyException_SetContext(val2, val);
PyErr_Restore(exc, val2, tb);
#else
// since Py2 does not have exception chaining, it's better to avoid shadowing exceptions there
source++;
#endif
}
static PyObject* __Pyx__Coroutine_Yield_From_Generic(__pyx_CoroutineObject *gen, PyObject *source, int warn) {
PyObject *retval;
PyObject *source_gen = __Pyx__Coroutine_GetAwaitableIter(source);
if (unlikely(!source_gen)) {
// surprisingly, CPython replaces the exception here...
__Pyx__Coroutine_Yield_From_Error(source);
return NULL;
}
if (warn && unlikely(__Pyx_WarnAIterDeprecation(source))) {
/* Warning was converted to an error. */
Py_DECREF(source_gen);
return NULL;
}
// source_gen is now the iterator, make the first next() call
if (__Pyx_Coroutine_CheckExact(source_gen)) {
retval = __Pyx_Generator_Next(source_gen);
} else {
retval = Py_TYPE(source_gen)->tp_iternext(source_gen);
}
if (retval) {
gen->yieldfrom = source_gen;
return retval;
}
Py_DECREF(source_gen);
return NULL;
}
static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source, int warn) { static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source, int warn) {
PyObject *retval; PyObject *retval;
if (__Pyx_Coroutine_CheckExact(source)) { if (__Pyx_Coroutine_CheckExact(source)) {
...@@ -96,25 +159,7 @@ static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject ...@@ -96,25 +159,7 @@ static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject
} }
#endif #endif
} else { } else {
PyObject *source_gen = __Pyx__Coroutine_GetAwaitableIter(source); return __Pyx__Coroutine_Yield_From_Generic(gen, source, warn);
if (unlikely(!source_gen))
return NULL;
if (warn && unlikely(__Pyx_WarnAIterDeprecation(source))) {
/* Warning was converted to an error. */
Py_DECREF(source_gen);
return NULL;
}
// source_gen is now the iterator, make the first next() call
if (__Pyx_Coroutine_CheckExact(source_gen)) {
retval = __Pyx_Generator_Next(source_gen);
} else {
retval = Py_TYPE(source_gen)->tp_iternext(source_gen);
}
if (retval) {
gen->yieldfrom = source_gen;
return retval;
}
Py_DECREF(source_gen);
} }
return NULL; return NULL;
} }
...@@ -142,8 +187,7 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineO ...@@ -142,8 +187,7 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineO
// //
// See http://bugs.python.org/issue27243 for more // See http://bugs.python.org/issue27243 for more
// details. // details.
PyErr_SetObject(PyExc_StopIteration, source); goto store_result;
return NULL;
} }
#endif #endif
#if PY_VERSION_HEX < 0x030500B2 #if PY_VERSION_HEX < 0x030500B2
...@@ -156,14 +200,24 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineO ...@@ -156,14 +200,24 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineO
PyObject *method = __Pyx_PyObject_GetAttrStr(source, PYIDENT("__anext__")); PyObject *method = __Pyx_PyObject_GetAttrStr(source, PYIDENT("__anext__"));
if (method) { if (method) {
Py_DECREF(method); Py_DECREF(method);
PyErr_SetObject(PyExc_StopIteration, source); goto store_result;
return NULL;
} }
PyErr_Clear(); PyErr_Clear();
} }
} }
#endif #endif
return __Pyx__Coroutine_Yield_From(gen, source, 1); return __Pyx__Coroutine_Yield_From(gen, source, 1);
store_result:
if (unlikely(PyTuple_Check(source) || __Pyx_TypeCheck(source, (PyTypeObject*)PyExc_StopIteration))) {
PyObject *t = PyTuple_Pack(1, source);
if (unlikely(!t)) return NULL;
PyErr_SetObject(PyExc_StopIteration, t);
Py_DECREF(t);
} else {
PyErr_SetObject(PyExc_StopIteration, source);
}
return NULL;
} }
......
# cython: language_level=3, binding=True # cython: language_level=3, binding=True
# mode: run # mode: run
# tag: pep492, asyncfor, await # tag: pep492, pep530, asyncfor, await
import re import re
import gc import gc
...@@ -20,7 +20,7 @@ from Cython.Compiler import Errors ...@@ -20,7 +20,7 @@ from Cython.Compiler import Errors
try: try:
from types import coroutine as types_coroutine from types import coroutine as types_coroutine
except ImportError: except ImportError:
# duck typed types.coroutine() decorator copied from types.py in Py3.5 # duck typed types_coroutine() decorator copied from types.py in Py3.5
class types_coroutine(object): class types_coroutine(object):
def __init__(self, gen): def __init__(self, gen):
self._gen = gen self._gen = gen
...@@ -141,27 +141,375 @@ def ignore_py26(manager): ...@@ -141,27 +141,375 @@ def ignore_py26(manager):
return dummy() if sys.version_info < (2, 7) else manager return dummy() if sys.version_info < (2, 7) else manager
@contextlib.contextmanager
def captured_stderr():
try:
# StringIO.StringIO() also accepts str in Py2, io.StringIO() does not
from StringIO import StringIO
except ImportError:
from io import StringIO
orig_stderr = sys.stderr
try:
sys.stderr = StringIO()
yield sys.stderr
finally:
sys.stderr = orig_stderr
class AsyncBadSyntaxTest(unittest.TestCase): class AsyncBadSyntaxTest(unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def assertRaisesRegex(self, exc_type, regex): def assertRaisesRegex(self, exc_type, regex):
class Holder(object):
exception = None
holder = Holder()
# the error messages usually don't match, so we just ignore them # the error messages usually don't match, so we just ignore them
try: try:
yield yield holder
except exc_type: except exc_type as exc:
holder.exception = exc
self.assertTrue(True) self.assertTrue(True)
else: else:
self.assertTrue(False) self.assertTrue(False)
def test_badsyntax_9(self): def test_badsyntax_1(self):
ns = {} samples = [
for comp in {'(await a for a in b)', """def foo():
'[await a for a in b]', await something()
'{await a for a in b}', """,
'{await a: a for a in b}'}:
"""await something()""",
"""async def foo():
yield from []
""",
"""async def foo():
await await fut
""",
"""async def foo(a=await something()):
pass
""",
"""async def foo(a:await something()):
pass
""",
"""async def foo():
def bar():
[i async for i in els]
""",
"""async def foo():
def bar():
[await i for i in els]
""",
"""async def foo():
def bar():
[i for i in els
async for b in els]
""",
"""async def foo():
def bar():
[i for i in els
for c in b
async for b in els]
""",
"""async def foo():
def bar():
[i for i in els
async for b in els
for c in b]
""",
"""async def foo():
def bar():
[i for i in els
for b in await els]
""",
"""async def foo():
def bar():
[i for i in els
for b in els
if await b]
""",
"""async def foo():
def bar():
[i for i in await els]
""",
"""async def foo():
def bar():
[i for i in els if await i]
""",
"""def bar():
[i async for i in els]
""",
"""def bar():
[await i for i in els]
""",
"""def bar():
[i for i in els
async for b in els]
""",
"""def bar():
[i for i in els
for c in b
async for b in els]
""",
"""def bar():
[i for i in els
async for b in els
for c in b]
""",
"""def bar():
[i for i in els
for b in await els]
""",
"""def bar():
[i for i in els
for b in els
if await b]
""",
"""def bar():
[i for i in await els]
""",
"""def bar():
[i for i in els if await i]
""",
"""async def foo():
await
""",
"""async def foo():
def bar(): pass
await = 1
""",
"""async def foo():
def bar(): pass
await = 1
""",
"""async def foo():
def bar(): pass
if 1:
await = 1
""",
"""def foo():
async def bar(): pass
if 1:
await a
""",
"""def foo():
async def bar(): pass
await a
""",
"""def foo():
def baz(): pass
async def bar(): pass
await a
""",
"""def foo():
def baz(): pass
# 456
async def bar(): pass
# 123
await a
""",
"""async def foo():
def baz(): pass
# 456
async def bar(): pass
# 123
await = 2
""",
"""def foo():
def baz(): pass
async def bar(): pass
await a
""",
"""async def foo():
def baz(): pass
async def bar(): pass
with self.assertRaisesRegex(Errors.CompileError, 'await.*in comprehen'): await = 2
exec('async def f():\n\t{0}'.format(comp), ns, ns) """,
"""async def foo():
def async(): pass
""",
"""async def foo():
def await(): pass
""",
"""async def foo():
def bar():
await
""",
"""async def foo():
return lambda async: await
""",
"""async def foo():
return lambda a: await
""",
"""await a()""",
"""async def foo(a=await b):
pass
""",
"""async def foo(a:await b):
pass
""",
"""def baz():
async def foo(a=await b):
pass
""",
"""async def foo(async):
pass
""",
"""async def foo():
def bar():
def baz():
async = 1
""",
"""async def foo():
def bar():
def baz():
pass
async = 1
""",
"""def foo():
async def bar():
async def baz():
pass
def baz():
42
async = 1
""",
"""async def foo():
def bar():
def baz():
pass\nawait foo()
""",
"""def foo():
def bar():
async def baz():
pass\nawait foo()
""",
"""async def foo(await):
pass
""",
"""def foo():
async def bar(): pass
await a
""",
"""def foo():
async def bar():
pass\nawait a
"""]
for code in samples:
with self.subTest(code=code), self.assertRaises(SyntaxError):
compile(code, "<test>", "exec")
def test_badsyntax_2(self):
samples = [
"""def foo():
await = 1
""",
"""class Bar:
def async(): pass
""",
"""class Bar:
async = 1
""",
"""class async:
pass
""",
"""class await:
pass
""",
"""import math as await""",
"""def async():
pass""",
"""def foo(*, await=1):
pass"""
"""async = 1""",
"""print(await=1)"""
]
for code in samples:
with self.subTest(code=code), self.assertWarnsRegex(
DeprecationWarning,
"'await' will become reserved keywords"):
compile(code, "<test>", "exec")
def test_badsyntax_3(self):
with self.assertRaises(DeprecationWarning):
with warnings.catch_warnings():
warnings.simplefilter("error")
compile("async = 1", "<test>", "exec")
def test_badsyntax_10(self): def test_badsyntax_10(self):
# Tests for issue 24619 # Tests for issue 24619
...@@ -437,10 +785,15 @@ class CoroutineTest(unittest.TestCase): ...@@ -437,10 +785,15 @@ class CoroutineTest(unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def assertRaisesRegex(self, exc_type, regex): def assertRaisesRegex(self, exc_type, regex):
class Holder(object):
exception = None
holder = Holder()
# the error messages usually don't match, so we just ignore them # the error messages usually don't match, so we just ignore them
try: try:
yield yield holder
except exc_type: except exc_type as exc:
holder.exception = exc
self.assertTrue(True) self.assertTrue(True)
else: else:
self.assertTrue(False) self.assertTrue(False)
...@@ -513,7 +866,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -513,7 +866,7 @@ class CoroutineTest(unittest.TestCase):
self.assertEqual(run_async__await__(foo()), ([], 10)) self.assertEqual(run_async__await__(foo()), ([], 10))
def bar(): pass def bar(): pass
self.assertFalse(bool(bar.__code__.co_flags & 0x80)) self.assertFalse(bool(bar.__code__.co_flags & 0x80)) # inspect.CO_COROUTINE
# TODO # TODO
def __test_func_2(self): def __test_func_2(self):
...@@ -653,7 +1006,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -653,7 +1006,7 @@ class CoroutineTest(unittest.TestCase):
self.assertTrue(aw is iter(aw)) self.assertTrue(aw is iter(aw))
next(aw) next(aw)
self.assertEqual(aw.send(10), 100) self.assertEqual(aw.send(10), 100)
with self.assertRaises(TypeError): with self.assertRaises(TypeError): # removed from CPython test suite?
type(aw).send(None, None) type(aw).send(None, None)
self.assertEqual(N, 0) self.assertEqual(N, 0)
...@@ -714,6 +1067,162 @@ class CoroutineTest(unittest.TestCase): ...@@ -714,6 +1067,162 @@ class CoroutineTest(unittest.TestCase):
"coroutine ignored GeneratorExit"): "coroutine ignored GeneratorExit"):
c.close() c.close()
def test_func_15(self):
# See http://bugs.python.org/issue25887 for details
async def spammer():
return 'spam'
async def reader(coro):
return await coro
spammer_coro = spammer()
with self.assertRaisesRegex(StopIteration, 'spam'):
reader(spammer_coro).send(None)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
reader(spammer_coro).send(None)
def test_func_16(self):
# See http://bugs.python.org/issue25887 for details
@types_coroutine
def nop():
yield
async def send():
await nop()
return 'spam'
async def read(coro):
await nop()
return await coro
spammer = send()
reader = read(spammer)
reader.send(None)
reader.send(None)
with self.assertRaisesRegex(Exception, 'ham'):
reader.throw(Exception('ham'))
reader = read(spammer)
reader.send(None)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
reader.send(None)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
reader.throw(Exception('wat'))
def test_func_17(self):
# See http://bugs.python.org/issue25887 for details
async def coroutine():
return 'spam'
coro = coroutine()
with self.assertRaisesRegex(StopIteration, 'spam'):
coro.send(None)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
coro.send(None)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
coro.throw(Exception('wat'))
# Closing a coroutine shouldn't raise any exception even if it's
# already closed/exhausted (similar to generators)
coro.close()
coro.close()
def test_func_18(self):
# See http://bugs.python.org/issue25887 for details
async def coroutine():
return 'spam'
coro = coroutine()
await_iter = coro.__await__()
it = iter(await_iter)
with self.assertRaisesRegex(StopIteration, 'spam'):
it.send(None)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
it.send(None)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
# Although the iterator protocol requires iterators to
# raise another StopIteration here, we don't want to do
# that. In this particular case, the iterator will raise
# a RuntimeError, so that 'yield from' and 'await'
# expressions will trigger the error, instead of silently
# ignoring the call.
next(it)
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
it.throw(Exception('wat'))
with self.assertRaisesRegex(RuntimeError,
'cannot reuse already awaited coroutine'):
it.throw(Exception('wat'))
# Closing a coroutine shouldn't raise any exception even if it's
# already closed/exhausted (similar to generators)
it.close()
it.close()
def test_func_19(self):
CHK = 0
@types_coroutine
def foo():
nonlocal CHK
yield
try:
yield
except GeneratorExit:
CHK += 1
async def coroutine():
await foo()
coro = coroutine()
coro.send(None)
coro.send(None)
self.assertEqual(CHK, 0)
coro.close()
self.assertEqual(CHK, 1)
for _ in range(3):
# Closing a coroutine shouldn't raise any exception even if it's
# already closed/exhausted (similar to generators)
coro.close()
self.assertEqual(CHK, 1)
def test_coro_wrapper_send_tuple(self):
async def foo():
return (10,)
result = run_async__await__(foo())
self.assertEqual(result, ([], (10,)))
def test_coro_wrapper_send_stop_iterator(self):
async def foo():
return StopIteration(10)
result = run_async__await__(foo())
self.assertIsInstance(result[1], StopIteration)
self.assertEqual(result[1].value, 10)
def test_cr_await(self): def test_cr_await(self):
@types_coroutine @types_coroutine
def a(): def a():
...@@ -829,7 +1338,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -829,7 +1338,7 @@ class CoroutineTest(unittest.TestCase):
class Awaitable(object): class Awaitable(object):
pass pass
async def foo(): return (await Awaitable()) async def foo(): return await Awaitable()
with self.assertRaisesRegex( with self.assertRaisesRegex(
TypeError, "object Awaitable can't be used in 'await' expression"): TypeError, "object Awaitable can't be used in 'await' expression"):
...@@ -899,7 +1408,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -899,7 +1408,7 @@ class CoroutineTest(unittest.TestCase):
return await Awaitable() return await Awaitable()
with self.assertRaisesRegex( with self.assertRaisesRegex(
TypeError, "__await__\(\) returned a coroutine"): TypeError, r"__await__\(\) returned a coroutine"):
run_async(foo()) run_async(foo())
...@@ -949,7 +1458,41 @@ class CoroutineTest(unittest.TestCase): ...@@ -949,7 +1458,41 @@ class CoroutineTest(unittest.TestCase):
with self.assertRaises(Marker): with self.assertRaises(Marker):
c.throw(ZeroDivisionError) c.throw(ZeroDivisionError)
def test_await_iterator(self): def test_await_15(self):
@types_coroutine
def nop():
yield
async def coroutine():
await nop()
async def waiter(coro):
await coro
coro = coroutine()
coro.send(None)
with self.assertRaisesRegex(RuntimeError,
"coroutine is being awaited already"):
waiter(coro).send(None)
def test_await_16(self):
# See https://bugs.python.org/issue29600 for details.
async def f():
return ValueError()
async def g():
try:
raise KeyError
except:
return await f()
_, result = run_async(g())
self.assertIsNone(result.__context__)
# removed from CPython ?
def __test_await_iterator(self):
async def foo(): async def foo():
return 123 return 123
...@@ -1260,7 +1803,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1260,7 +1803,7 @@ class CoroutineTest(unittest.TestCase):
buffer = [] buffer = []
async def test1(): async def test1():
with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")): with ignore_py26(self.assertWarnsRegex(DeprecationWarning, "legacy")):
async for i1, i2 in AsyncIter(): async for i1, i2 in AsyncIter():
buffer.append(i1 + i2) buffer.append(i1 + i2)
...@@ -1274,7 +1817,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1274,7 +1817,7 @@ class CoroutineTest(unittest.TestCase):
buffer = [] buffer = []
async def test2(): async def test2():
nonlocal buffer nonlocal buffer
with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")): with ignore_py26(self.assertWarnsRegex(DeprecationWarning, "legacy")):
async for i in AsyncIter(): async for i in AsyncIter():
buffer.append(i[0]) buffer.append(i[0])
if i[0] == 20: if i[0] == 20:
...@@ -1293,7 +1836,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1293,7 +1836,7 @@ class CoroutineTest(unittest.TestCase):
buffer = [] buffer = []
async def test3(): async def test3():
nonlocal buffer nonlocal buffer
with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")): with ignore_py26(self.assertWarnsRegex(DeprecationWarning, "legacy")):
async for i in AsyncIter(): async for i in AsyncIter():
if i[0] > 20: if i[0] > 20:
continue continue
...@@ -1376,7 +1919,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1376,7 +1919,7 @@ class CoroutineTest(unittest.TestCase):
return 123 return 123
async def foo(): async def foo():
with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"): with self.assertWarnsRegex(DeprecationWarning, "legacy"):
async for i in I(): async for i in I():
print('never going to happen') print('never going to happen')
...@@ -1481,7 +2024,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1481,7 +2024,7 @@ class CoroutineTest(unittest.TestCase):
1/0 1/0
async def foo(): async def foo():
nonlocal CNT nonlocal CNT
with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"): with self.assertWarnsRegex(DeprecationWarning, "legacy"):
async for i in AI(): async for i in AI():
CNT += 1 CNT += 1
CNT += 10 CNT += 10
...@@ -1491,7 +2034,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1491,7 +2034,7 @@ class CoroutineTest(unittest.TestCase):
def test_for_8(self): def test_for_8(self):
CNT = 0 CNT = 0
class AI: class AI(object):
def __aiter__(self): def __aiter__(self):
1/0 1/0
async def foo(): async def foo():
...@@ -1500,7 +2043,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1500,7 +2043,7 @@ class CoroutineTest(unittest.TestCase):
CNT += 1 CNT += 1
CNT += 10 CNT += 10
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
run_async(foo()) #run_async(foo())
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("error") warnings.simplefilter("error")
# Test that if __aiter__ raises an exception it propagates # Test that if __aiter__ raises an exception it propagates
...@@ -1510,37 +2053,308 @@ class CoroutineTest(unittest.TestCase): ...@@ -1510,37 +2053,308 @@ class CoroutineTest(unittest.TestCase):
@min_py27 @min_py27
def test_for_9(self): def test_for_9(self):
# Test that PendingDeprecationWarning can safely be converted into # Test that DeprecationWarning can safely be converted into
# an exception (__aiter__ should not have a chance to raise # an exception (__aiter__ should not have a chance to raise
# a ZeroDivisionError.) # a ZeroDivisionError.)
class AI: class AI(object):
async def __aiter__(self): async def __aiter__(self):
1/0 1/0
async def foo(): async def foo():
async for i in AI(): async for i in AI():
pass pass
with self.assertRaises(PendingDeprecationWarning): with self.assertRaises(DeprecationWarning):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("error") warnings.simplefilter("error")
run_async(foo()) run_async(foo())
@min_py27 @min_py27
def test_for_10(self): def test_for_10(self):
# Test that PendingDeprecationWarning can safely be converted into # Test that DeprecationWarning can safely be converted into
# an exception. # an exception.
class AI: class AI(object):
async def __aiter__(self): async def __aiter__(self):
pass pass
async def foo(): async def foo():
async for i in AI(): async for i in AI():
pass pass
with self.assertRaises(PendingDeprecationWarning): with self.assertRaises(DeprecationWarning):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("error") warnings.simplefilter("error")
run_async(foo()) run_async(foo())
def test_for_11(self):
class F(object):
def __aiter__(self):
return self
def __anext__(self):
return self
def __await__(self):
1 / 0
async def main():
async for _ in F():
pass
if sys.version_info[0] < 3:
with self.assertRaises(ZeroDivisionError) as c:
main().send(None)
else:
with self.assertRaisesRegex(TypeError,
'an invalid object from __anext__') as c:
main().send(None)
err = c.exception
self.assertIsInstance(err.__cause__, ZeroDivisionError)
def test_for_12(self):
class F(object):
def __aiter__(self):
return self
def __await__(self):
1 / 0
async def main():
async for _ in F():
pass
if sys.version_info[0] < 3:
with self.assertRaises(ZeroDivisionError) as c:
main().send(None)
else:
with self.assertRaisesRegex(TypeError,
'an invalid object from __aiter__') as c:
main().send(None)
err = c.exception
self.assertIsInstance(err.__cause__, ZeroDivisionError)
def test_for_tuple(self):
class Done(Exception): pass
class AIter(tuple):
i = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.i >= len(self):
raise StopAsyncIteration
self.i += 1
return self[self.i - 1]
result = []
async def foo():
async for i in AIter([42]):
result.append(i)
raise Done
with self.assertRaises(Done):
foo().send(None)
self.assertEqual(result, [42])
def test_for_stop_iteration(self):
class Done(Exception): pass
class AIter(StopIteration):
i = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.i:
raise StopAsyncIteration
self.i += 1
return self.value
result = []
async def foo():
async for i in AIter(42):
result.append(i)
raise Done
with self.assertRaises(Done):
foo().send(None)
self.assertEqual(result, [42])
def test_comp_1(self):
async def f(i):
return i
async def run_list():
return [await c for c in [f(1), f(41)]]
async def run_set():
return {await c for c in [f(1), f(41)]}
async def run_dict1():
return {await c: 'a' for c in [f(1), f(41)]}
async def run_dict2():
return {i: await c for i, c in enumerate([f(1), f(41)])}
self.assertEqual(run_async(run_list()), ([], [1, 41]))
self.assertEqual(run_async(run_set()), ([], {1, 41}))
self.assertEqual(run_async(run_dict1()), ([], {1: 'a', 41: 'a'}))
self.assertEqual(run_async(run_dict2()), ([], {0: 1, 1: 41}))
def test_comp_2(self):
async def f(i):
return i
async def run_list():
return [s for c in [f(''), f('abc'), f(''), f(['de', 'fg'])]
for s in await c]
self.assertEqual(
run_async(run_list()),
([], ['a', 'b', 'c', 'de', 'fg']))
async def run_set():
return {d
for c in [f([f([10, 30]),
f([20])])]
for s in await c
for d in await s}
self.assertEqual(
run_async(run_set()),
([], {10, 20, 30}))
async def run_set2():
return {await s
for c in [f([f(10), f(20)])]
for s in await c}
self.assertEqual(
run_async(run_set2()),
([], {10, 20}))
def test_comp_3(self):
async def f(it):
for i in it:
yield i
async def run_list():
return [i + 1 async for i in f([10, 20])]
self.assertEqual(
run_async(run_list()),
([], [11, 21]))
async def run_set():
return {i + 1 async for i in f([10, 20])}
self.assertEqual(
run_async(run_set()),
([], {11, 21}))
async def run_dict():
return {i + 1: i + 2 async for i in f([10, 20])}
self.assertEqual(
run_async(run_dict()),
([], {11: 12, 21: 22}))
async def run_gen():
gen = (i + 1 async for i in f([10, 20]))
return [g + 100 async for g in gen]
self.assertEqual(
run_async(run_gen()),
([], [111, 121]))
def test_comp_4(self):
async def f(it):
for i in it:
yield i
async def run_list():
return [i + 1 async for i in f([10, 20]) if i > 10]
self.assertEqual(
run_async(run_list()),
([], [21]))
async def run_set():
return {i + 1 async for i in f([10, 20]) if i > 10}
self.assertEqual(
run_async(run_set()),
([], {21}))
async def run_dict():
return {i + 1: i + 2 async for i in f([10, 20]) if i > 10}
self.assertEqual(
run_async(run_dict()),
([], {21: 22}))
async def run_gen():
gen = (i + 1 async for i in f([10, 20]) if i > 10)
return [g + 100 async for g in gen]
self.assertEqual(
run_async(run_gen()),
([], [121]))
def test_comp_5(self):
async def f(it):
for i in it:
yield i
async def run_list():
return [i + 1 for pair in ([10, 20], [30, 40]) if pair[0] > 10
async for i in f(pair) if i > 30]
self.assertEqual(
run_async(run_list()),
([], [41]))
def test_comp_6(self):
async def f(it):
for i in it:
yield i
async def run_list():
return [i + 1 async for seq in f([(10, 20), (30,)])
for i in seq]
self.assertEqual(
run_async(run_list()),
([], [11, 21, 31]))
def test_comp_7(self):
async def f():
yield 1
yield 2
raise Exception('aaa')
async def run_list():
return [i async for i in f()]
with self.assertRaisesRegex(Exception, 'aaa'):
run_async(run_list())
def test_comp_8(self):
async def f():
return [i for i in [1, 2, 3]]
self.assertEqual(
run_async(f()),
([], [1, 2, 3]))
def test_comp_9(self):
async def gen():
yield 1
yield 2
async def f():
l = [i async for i in gen()]
return [i for i in l]
self.assertEqual(
run_async(f()),
([], [1, 2]))
def test_comp_10(self):
async def f():
xx = {i for i in [1, 2, 3]}
return {x: x for x in xx}
self.assertEqual(
run_async(f()),
([], {1: 1, 2: 2, 3: 3}))
def test_copy(self): def test_copy(self):
async def func(): pass async def func(): pass
coro = func() coro = func()
...@@ -1569,6 +2383,15 @@ class CoroutineTest(unittest.TestCase): ...@@ -1569,6 +2383,15 @@ class CoroutineTest(unittest.TestCase):
finally: finally:
aw.close() aw.close()
def test_fatal_coro_warning(self):
# Issue 27811
async def func(): pass
with warnings.catch_warnings(), captured_stderr() as stderr:
warnings.filterwarnings("error")
func()
gc.collect()
self.assertIn("was never awaited", stderr.getvalue())
class CoroAsyncIOCompatTest(unittest.TestCase): class CoroAsyncIOCompatTest(unittest.TestCase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment