Commit fc4a044a authored by Yury Selivanov's avatar Yury Selivanov Committed by GitHub

bpo-30773: Fix ag_running; prohibit running athrow/asend/aclose in parallel (#7468)

parent 6758e6e1
...@@ -79,6 +79,8 @@ typedef struct { ...@@ -79,6 +79,8 @@ typedef struct {
/* Flag is set to 1 when aclose() is called for the first time, or /* Flag is set to 1 when aclose() is called for the first time, or
when a StopAsyncIteration exception is raised. */ when a StopAsyncIteration exception is raised. */
int ag_closed; int ag_closed;
int ag_running_async;
} PyAsyncGenObject; } PyAsyncGenObject;
PyAPI_DATA(PyTypeObject) PyAsyncGen_Type; PyAPI_DATA(PyTypeObject) PyAsyncGen_Type;
......
...@@ -133,24 +133,6 @@ class AsyncGenTest(unittest.TestCase): ...@@ -133,24 +133,6 @@ class AsyncGenTest(unittest.TestCase):
break break
return res return res
def async_iterate(g):
res = []
while True:
try:
g.__anext__().__next__()
except StopAsyncIteration:
res.append('STOP')
break
except StopIteration as ex:
if ex.args:
res.append(ex.args[0])
else:
res.append('EMPTY StopIteration')
break
except Exception as ex:
res.append(str(type(ex)))
return res
sync_gen_result = sync_iterate(sync_gen) sync_gen_result = sync_iterate(sync_gen)
async_gen_result = async_iterate(async_gen) async_gen_result = async_iterate(async_gen)
self.assertEqual(sync_gen_result, async_gen_result) self.assertEqual(sync_gen_result, async_gen_result)
...@@ -176,19 +158,22 @@ class AsyncGenTest(unittest.TestCase): ...@@ -176,19 +158,22 @@ class AsyncGenTest(unittest.TestCase):
g = gen() g = gen()
ai = g.__aiter__() ai = g.__aiter__()
self.assertEqual(ai.__anext__().__next__(), ('result',))
an = ai.__anext__()
self.assertEqual(an.__next__(), ('result',))
try: try:
ai.__anext__().__next__() an.__next__()
except StopIteration as ex: except StopIteration as ex:
self.assertEqual(ex.args[0], 123) self.assertEqual(ex.args[0], 123)
else: else:
self.fail('StopIteration was not raised') self.fail('StopIteration was not raised')
self.assertEqual(ai.__anext__().__next__(), ('result',)) an = ai.__anext__()
self.assertEqual(an.__next__(), ('result',))
try: try:
ai.__anext__().__next__() an.__next__()
except StopAsyncIteration as ex: except StopAsyncIteration as ex:
self.assertFalse(ex.args) self.assertFalse(ex.args)
else: else:
...@@ -212,10 +197,11 @@ class AsyncGenTest(unittest.TestCase): ...@@ -212,10 +197,11 @@ class AsyncGenTest(unittest.TestCase):
g = gen() g = gen()
ai = g.__aiter__() ai = g.__aiter__()
self.assertEqual(ai.__anext__().__next__(), ('result',)) an = ai.__anext__()
self.assertEqual(an.__next__(), ('result',))
try: try:
ai.__anext__().__next__() an.__next__()
except StopIteration as ex: except StopIteration as ex:
self.assertEqual(ex.args[0], 123) self.assertEqual(ex.args[0], 123)
else: else:
...@@ -646,17 +632,13 @@ class AsyncGenAsyncioTest(unittest.TestCase): ...@@ -646,17 +632,13 @@ class AsyncGenAsyncioTest(unittest.TestCase):
gen = foo() gen = foo()
it = gen.__aiter__() it = gen.__aiter__()
self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), 1)
t = self.loop.create_task(it.__anext__())
await asyncio.sleep(0.01)
await gen.aclose() await gen.aclose()
return t
t = self.loop.run_until_complete(run()) self.loop.run_until_complete(run())
self.assertEqual(DONE, 1) self.assertEqual(DONE, 1)
# Silence ResourceWarnings # Silence ResourceWarnings
fut.cancel() fut.cancel()
t.cancel()
self.loop.run_until_complete(asyncio.sleep(0.01)) self.loop.run_until_complete(asyncio.sleep(0.01))
def test_async_gen_asyncio_gc_aclose_09(self): def test_async_gen_asyncio_gc_aclose_09(self):
...@@ -1053,46 +1035,18 @@ class AsyncGenAsyncioTest(unittest.TestCase): ...@@ -1053,46 +1035,18 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.loop.run_until_complete(asyncio.sleep(0.1)) self.loop.run_until_complete(asyncio.sleep(0.1))
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.assertEqual(finalized, 2)
# Silence warnings # Silence warnings
t1.cancel() t1.cancel()
t2.cancel() t2.cancel()
self.loop.run_until_complete(asyncio.sleep(0.1))
def test_async_gen_asyncio_shutdown_02(self):
logged = 0
def logger(loop, context):
nonlocal logged
self.assertIn('asyncgen', context)
expected = 'an error occurred during closing of asynchronous'
if expected in context['message']:
logged += 1
async def waiter(timeout): with self.assertRaises(asyncio.CancelledError):
try: self.loop.run_until_complete(t1)
await asyncio.sleep(timeout) with self.assertRaises(asyncio.CancelledError):
yield 1 self.loop.run_until_complete(t2)
finally:
1 / 0
async def wait():
async for _ in waiter(1):
pass
t = self.loop.create_task(wait())
self.loop.run_until_complete(asyncio.sleep(0.1))
self.loop.set_exception_handler(logger)
self.loop.run_until_complete(self.loop.shutdown_asyncgens()) self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.assertEqual(logged, 1) self.assertEqual(finalized, 2)
# Silence warnings
t.cancel()
self.loop.run_until_complete(asyncio.sleep(0.1))
def test_async_gen_expression_01(self): def test_async_gen_expression_01(self):
async def arange(n): async def arange(n):
......
Prohibit parallel running of aclose() / asend() / athrow(). Fix ag_running
to reflect the actual running status of the AG.
...@@ -1326,7 +1326,8 @@ static PyGetSetDef async_gen_getsetlist[] = { ...@@ -1326,7 +1326,8 @@ static PyGetSetDef async_gen_getsetlist[] = {
static PyMemberDef async_gen_memberlist[] = { static PyMemberDef async_gen_memberlist[] = {
{"ag_frame", T_OBJECT, offsetof(PyAsyncGenObject, ag_frame), READONLY}, {"ag_frame", T_OBJECT, offsetof(PyAsyncGenObject, ag_frame), READONLY},
{"ag_running", T_BOOL, offsetof(PyAsyncGenObject, ag_running), READONLY}, {"ag_running", T_BOOL, offsetof(PyAsyncGenObject, ag_running_async),
READONLY},
{"ag_code", T_OBJECT, offsetof(PyAsyncGenObject, ag_code), READONLY}, {"ag_code", T_OBJECT, offsetof(PyAsyncGenObject, ag_code), READONLY},
{NULL} /* Sentinel */ {NULL} /* Sentinel */
}; };
...@@ -1420,6 +1421,7 @@ PyAsyncGen_New(PyFrameObject *f, PyObject *name, PyObject *qualname) ...@@ -1420,6 +1421,7 @@ PyAsyncGen_New(PyFrameObject *f, PyObject *name, PyObject *qualname)
o->ag_finalizer = NULL; o->ag_finalizer = NULL;
o->ag_closed = 0; o->ag_closed = 0;
o->ag_hooks_inited = 0; o->ag_hooks_inited = 0;
o->ag_running_async = 0;
return (PyObject*)o; return (PyObject*)o;
} }
...@@ -1467,6 +1469,7 @@ async_gen_unwrap_value(PyAsyncGenObject *gen, PyObject *result) ...@@ -1467,6 +1469,7 @@ async_gen_unwrap_value(PyAsyncGenObject *gen, PyObject *result)
gen->ag_closed = 1; gen->ag_closed = 1;
} }
gen->ag_running_async = 0;
return NULL; return NULL;
} }
...@@ -1474,6 +1477,7 @@ async_gen_unwrap_value(PyAsyncGenObject *gen, PyObject *result) ...@@ -1474,6 +1477,7 @@ async_gen_unwrap_value(PyAsyncGenObject *gen, PyObject *result)
/* async yield */ /* async yield */
_PyGen_SetStopIterationValue(((_PyAsyncGenWrappedValue*)result)->agw_val); _PyGen_SetStopIterationValue(((_PyAsyncGenWrappedValue*)result)->agw_val);
Py_DECREF(result); Py_DECREF(result);
gen->ag_running_async = 0;
return NULL; return NULL;
} }
...@@ -1518,12 +1522,20 @@ async_gen_asend_send(PyAsyncGenASend *o, PyObject *arg) ...@@ -1518,12 +1522,20 @@ async_gen_asend_send(PyAsyncGenASend *o, PyObject *arg)
} }
if (o->ags_state == AWAITABLE_STATE_INIT) { if (o->ags_state == AWAITABLE_STATE_INIT) {
if (o->ags_gen->ag_running_async) {
PyErr_SetString(
PyExc_RuntimeError,
"anext(): asynchronous generator is already running");
return NULL;
}
if (arg == NULL || arg == Py_None) { if (arg == NULL || arg == Py_None) {
arg = o->ags_sendval; arg = o->ags_sendval;
} }
o->ags_state = AWAITABLE_STATE_ITER; o->ags_state = AWAITABLE_STATE_ITER;
} }
o->ags_gen->ag_running_async = 1;
result = gen_send_ex((PyGenObject*)o->ags_gen, arg, 0, 0); result = gen_send_ex((PyGenObject*)o->ags_gen, arg, 0, 0);
result = async_gen_unwrap_value(o->ags_gen, result); result = async_gen_unwrap_value(o->ags_gen, result);
...@@ -1787,8 +1799,23 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg) ...@@ -1787,8 +1799,23 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg)
} }
if (o->agt_state == AWAITABLE_STATE_INIT) { if (o->agt_state == AWAITABLE_STATE_INIT) {
if (o->agt_gen->ag_running_async) {
if (o->agt_args == NULL) {
PyErr_SetString(
PyExc_RuntimeError,
"aclose(): asynchronous generator is already running");
}
else {
PyErr_SetString(
PyExc_RuntimeError,
"athrow(): asynchronous generator is already running");
}
return NULL;
}
if (o->agt_gen->ag_closed) { if (o->agt_gen->ag_closed) {
PyErr_SetNone(PyExc_StopIteration); o->agt_state = AWAITABLE_STATE_CLOSED;
PyErr_SetNone(PyExc_StopAsyncIteration);
return NULL; return NULL;
} }
...@@ -1798,6 +1825,7 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg) ...@@ -1798,6 +1825,7 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg)
} }
o->agt_state = AWAITABLE_STATE_ITER; o->agt_state = AWAITABLE_STATE_ITER;
o->agt_gen->ag_running_async = 1;
if (o->agt_args == NULL) { if (o->agt_args == NULL) {
/* aclose() mode */ /* aclose() mode */
...@@ -1843,6 +1871,7 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg) ...@@ -1843,6 +1871,7 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg)
/* aclose() mode */ /* aclose() mode */
if (retval) { if (retval) {
if (_PyAsyncGenWrappedValue_CheckExact(retval)) { if (_PyAsyncGenWrappedValue_CheckExact(retval)) {
o->agt_gen->ag_running_async = 0;
Py_DECREF(retval); Py_DECREF(retval);
goto yield_close; goto yield_close;
} }
...@@ -1856,11 +1885,13 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg) ...@@ -1856,11 +1885,13 @@ async_gen_athrow_send(PyAsyncGenAThrow *o, PyObject *arg)
} }
yield_close: yield_close:
o->agt_gen->ag_running_async = 0;
PyErr_SetString( PyErr_SetString(
PyExc_RuntimeError, ASYNC_GEN_IGNORED_EXIT_MSG); PyExc_RuntimeError, ASYNC_GEN_IGNORED_EXIT_MSG);
return NULL; return NULL;
check_error: check_error:
o->agt_gen->ag_running_async = 0;
if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration) || if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration) ||
PyErr_ExceptionMatches(PyExc_GeneratorExit)) PyErr_ExceptionMatches(PyExc_GeneratorExit))
{ {
...@@ -1895,6 +1926,7 @@ async_gen_athrow_throw(PyAsyncGenAThrow *o, PyObject *args) ...@@ -1895,6 +1926,7 @@ async_gen_athrow_throw(PyAsyncGenAThrow *o, PyObject *args)
} else { } else {
/* aclose() mode */ /* aclose() mode */
if (retval && _PyAsyncGenWrappedValue_CheckExact(retval)) { if (retval && _PyAsyncGenWrappedValue_CheckExact(retval)) {
o->agt_gen->ag_running_async = 0;
Py_DECREF(retval); Py_DECREF(retval);
PyErr_SetString(PyExc_RuntimeError, ASYNC_GEN_IGNORED_EXIT_MSG); PyErr_SetString(PyExc_RuntimeError, ASYNC_GEN_IGNORED_EXIT_MSG);
return NULL; return NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment