Commit a9acbe82 authored by Victor Stinner's avatar Victor Stinner

Closes #21886, #21447: Fix a race condition in asyncio when setting the result

of a Future with call_soon(). Add an helper, a private method, to set the
result only if the future was not cancelled.
parent 5021cb55
...@@ -64,6 +64,12 @@ class CoroWrapper: ...@@ -64,6 +64,12 @@ class CoroWrapper:
self.gen = gen self.gen = gen
self.func = func self.func = func
self._source_traceback = traceback.extract_stack(sys._getframe(1)) self._source_traceback = traceback.extract_stack(sys._getframe(1))
# __name__, __qualname__, __doc__ attributes are set by the coroutine()
# decorator
def __repr__(self):
return ('<%s %s>'
% (self.__class__.__name__, _format_coroutine(self)))
def __iter__(self): def __iter__(self):
return self return self
......
...@@ -316,6 +316,12 @@ class Future: ...@@ -316,6 +316,12 @@ class Future:
# So-called internal methods (note: no set_running_or_notify_cancel()). # So-called internal methods (note: no set_running_or_notify_cancel()).
def _set_result_unless_cancelled(self, result):
"""Helper setting the result only if the future was not cancelled."""
if self.cancelled():
return
self.set_result(result)
def set_result(self, result): def set_result(self, result):
"""Mark the future done and set its result. """Mark the future done and set its result.
......
...@@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, ...@@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
self._server.attach(self) self._server.attach(self)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
self._loop.call_soon(waiter.set_result, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def _set_extra(self, sock): def _set_extra(self, sock):
self._extra['pipe'] = sock self._extra['pipe'] = sock
......
...@@ -173,7 +173,7 @@ class Queue: ...@@ -173,7 +173,7 @@ class Queue:
# run, we need to defer the put for a tick to ensure that # run, we need to defer the put for a tick to ensure that
# getters and putters alternate perfectly. See # getters and putters alternate perfectly. See
# ChannelTest.test_wait. # ChannelTest.test_wait.
self._loop.call_soon(putter.set_result, None) self._loop.call_soon(putter._set_result_unless_cancelled, None)
return self._get() return self._get()
......
...@@ -481,7 +481,7 @@ class _SelectorSocketTransport(_SelectorTransport): ...@@ -481,7 +481,7 @@ class _SelectorSocketTransport(_SelectorTransport):
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
self._loop.call_soon(waiter.set_result, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def pause_reading(self): def pause_reading(self):
if self._closing: if self._closing:
...@@ -690,7 +690,8 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -690,7 +690,8 @@ class _SelectorSslTransport(_SelectorTransport):
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if self._waiter is not None: if self._waiter is not None:
self._loop.call_soon(self._waiter.set_result, None) self._loop.call_soon(self._waiter._set_result_unless_cancelled,
None)
def pause_reading(self): def pause_reading(self):
# XXX This is a bit icky, given the comment at the top of # XXX This is a bit icky, given the comment at the top of
......
...@@ -487,7 +487,8 @@ def as_completed(fs, *, loop=None, timeout=None): ...@@ -487,7 +487,8 @@ def as_completed(fs, *, loop=None, timeout=None):
def sleep(delay, result=None, *, loop=None): def sleep(delay, result=None, *, loop=None):
"""Coroutine that completes after a given time (in seconds).""" """Coroutine that completes after a given time (in seconds)."""
future = futures.Future(loop=loop) future = futures.Future(loop=loop)
h = future._loop.call_later(delay, future.set_result, result) h = future._loop.call_later(delay,
future._set_result_unless_cancelled, result)
try: try:
return (yield from future) return (yield from future)
finally: finally:
......
...@@ -269,7 +269,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): ...@@ -269,7 +269,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
self._loop.add_reader(self._fileno, self._read_ready) self._loop.add_reader(self._fileno, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
self._loop.call_soon(waiter.set_result, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def _read_ready(self): def _read_ready(self):
try: try:
...@@ -353,7 +353,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -353,7 +353,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
self._loop.call_soon(waiter.set_result, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def get_write_buffer_size(self): def get_write_buffer_size(self):
return sum(len(data) for data in self._buffer) return sum(len(data) for data in self._buffer)
......
...@@ -343,6 +343,12 @@ class FutureTests(test_utils.TestCase): ...@@ -343,6 +343,12 @@ class FutureTests(test_utils.TestCase):
message = m_log.error.call_args[0][0] message = m_log.error.call_args[0][0]
self.assertRegex(message, re.compile(regex, re.DOTALL)) self.assertRegex(message, re.compile(regex, re.DOTALL))
def test_set_result_unless_cancelled(self):
fut = asyncio.Future(loop=self.loop)
fut.cancel()
fut._set_result_unless_cancelled(2)
self.assertTrue(fut.cancelled())
class FutureDoneCallbackTests(test_utils.TestCase): class FutureDoneCallbackTests(test_utils.TestCase):
......
...@@ -211,6 +211,10 @@ class TaskTests(test_utils.TestCase): ...@@ -211,6 +211,10 @@ class TaskTests(test_utils.TestCase):
coro = ('%s() at %s:%s' coro = ('%s() at %s:%s'
% (coro_qualname, code.co_filename, code.co_firstlineno)) % (coro_qualname, code.co_filename, code.co_firstlineno))
# test repr(CoroWrapper)
if coroutines._DEBUG:
self.assertEqual(repr(gen), '<CoroWrapper %s>' % coro)
# test pending Task # test pending Task
t = asyncio.Task(gen, loop=self.loop) t = asyncio.Task(gen, loop=self.loop)
t.add_done_callback(Dummy()) t.add_done_callback(Dummy())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment