Commit 79790bc3 authored by Victor Stinner's avatar Victor Stinner Committed by GitHub

bpo-33694: Fix race condition in asyncio proactor (GH-7498)

The cancellation of an overlapped WSARecv() has a race condition
which causes data loss because of the current implementation of
proactor in asyncio.

No longer cancel overlapped WSARecv() in _ProactorReadPipeTransport
to work around the race condition.

Remove the optimized recv_into() implementation to get simple
implementation of pause_reading() using the single _pending_data
attribute.

Move _feed_data_to_bufferred_proto() to protocols.py.

Remove set_protocol() method which became useless.
parent d3ed67d1
...@@ -159,27 +159,13 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -159,27 +159,13 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
def __init__(self, loop, sock, protocol, waiter=None, def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None): extra=None, server=None):
self._loop_reading_cb = None self._pending_data = None
self._paused = True self._paused = True
super().__init__(loop, sock, protocol, waiter, extra, server) super().__init__(loop, sock, protocol, waiter, extra, server)
self._reschedule_on_resume = False
self._loop.call_soon(self._loop_reading) self._loop.call_soon(self._loop_reading)
self._paused = False self._paused = False
def set_protocol(self, protocol):
if isinstance(protocol, protocols.BufferedProtocol):
self._loop_reading_cb = self._loop_reading__get_buffer
else:
self._loop_reading_cb = self._loop_reading__data_received
super().set_protocol(protocol)
if self.is_reading():
# reset reading callback / buffers / self._read_fut
self.pause_reading()
self.resume_reading()
def is_reading(self): def is_reading(self):
return not self._paused and not self._closing return not self._paused and not self._closing
...@@ -188,17 +174,16 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -188,17 +174,16 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
return return
self._paused = True self._paused = True
if self._read_fut is not None and not self._read_fut.done(): # bpo-33694: Don't cancel self._read_fut because cancelling an
# TODO: This is an ugly hack to cancel the current read future # overlapped WSASend() loss silently data with the current proactor
# *and* avoid potential race conditions, as read cancellation # implementation.
# goes through `future.cancel()` and `loop.call_soon()`. #
# We then use this special attribute in the reader callback to # If CancelIoEx() fails with ERROR_NOT_FOUND, it means that WSASend()
# exit *immediately* without doing any cleanup/rescheduling. # completed (even if HasOverlappedIoCompleted() returns 0), but
self._read_fut.__asyncio_cancelled_on_pause__ = True # Overlapped.cancel() currently silently ignores the ERROR_NOT_FOUND
# error. Once the overlapped is ignored, the IOCP loop will ignores the
self._read_fut.cancel() # completion I/O event and so not read the result of the overlapped
self._read_fut = None # WSARecv().
self._reschedule_on_resume = True
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r pauses reading", self) logger.debug("%r pauses reading", self)
...@@ -206,14 +191,22 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -206,14 +191,22 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
def resume_reading(self): def resume_reading(self):
if self._closing or not self._paused: if self._closing or not self._paused:
return return
self._paused = False self._paused = False
if self._reschedule_on_resume: if self._read_fut is None:
self._loop.call_soon(self._loop_reading, self._read_fut) self._loop.call_soon(self._loop_reading, None)
self._reschedule_on_resume = False
data = self._pending_data
self._pending_data = None
if data is not None:
# Call the protocol methode after calling _loop_reading(),
# since the protocol can decide to pause reading again.
self._loop.call_soon(self._data_received, data)
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r resumes reading", self) logger.debug("%r resumes reading", self)
def _loop_reading__on_eof(self): def _eof_received(self):
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r received EOF", self) logger.debug("%r received EOF", self)
...@@ -227,18 +220,30 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -227,18 +220,30 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
if not keep_open: if not keep_open:
self.close() self.close()
def _loop_reading(self, fut=None): def _data_received(self, data):
self._loop_reading_cb(fut) if self._paused:
# Don't call any protocol method while reading is paused.
def _loop_reading__data_received(self, fut): # The protocol will be called on resume_reading().
if (fut is not None and assert self._pending_data is None
getattr(fut, '__asyncio_cancelled_on_pause__', False)): self._pending_data = data
return return
if self._paused: if not data:
self._reschedule_on_resume = True self._eof_received()
return return
if isinstance(self._protocol, protocols.BufferedProtocol):
try:
protocols._feed_data_to_bufferred_proto(self._protocol, data)
except Exception as exc:
self._fatal_error(exc,
'Fatal error: protocol.buffer_updated() '
'call failed.')
return
else:
self._protocol.data_received(data)
def _loop_reading(self, fut=None):
data = None data = None
try: try:
if fut is not None: if fut is not None:
...@@ -261,8 +266,12 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -261,8 +266,12 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
# we got end-of-file so no need to reschedule a new read # we got end-of-file so no need to reschedule a new read
return return
# reschedule a new read # bpo-33694: buffer_updated() has currently no fast path because of
self._read_fut = self._loop._proactor.recv(self._sock, 32768) # a data loss issue caused by overlapped WSASend() cancellation.
if not self._paused:
# reschedule a new read
self._read_fut = self._loop._proactor.recv(self._sock, 32768)
except ConnectionAbortedError as exc: except ConnectionAbortedError as exc:
if not self._closing: if not self._closing:
self._fatal_error(exc, 'Fatal read error on pipe transport') self._fatal_error(exc, 'Fatal read error on pipe transport')
...@@ -277,92 +286,11 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -277,92 +286,11 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
if not self._closing: if not self._closing:
raise raise
else: else:
self._read_fut.add_done_callback(self._loop_reading__data_received) if not self._paused:
self._read_fut.add_done_callback(self._loop_reading)
finally: finally:
if data: if data is not None:
self._protocol.data_received(data) self._data_received(data)
elif data == b'':
self._loop_reading__on_eof()
def _loop_reading__get_buffer(self, fut):
if (fut is not None and
getattr(fut, '__asyncio_cancelled_on_pause__', False)):
return
if self._paused:
self._reschedule_on_resume = True
return
nbytes = None
if fut is not None:
assert self._read_fut is fut or (self._read_fut is None and
self._closing)
self._read_fut = None
try:
if fut.done():
nbytes = fut.result()
else:
# the future will be replaced by next proactor.recv call
fut.cancel()
except ConnectionAbortedError as exc:
if not self._closing:
self._fatal_error(
exc, 'Fatal read error on pipe transport')
elif self._loop.get_debug():
logger.debug("Read error on pipe transport while closing",
exc_info=True)
except ConnectionResetError as exc:
self._force_close(exc)
except OSError as exc:
self._fatal_error(exc, 'Fatal read error on pipe transport')
except futures.CancelledError:
if not self._closing:
raise
if nbytes is not None:
if nbytes == 0:
# we got end-of-file so no need to reschedule a new read
self._loop_reading__on_eof()
else:
try:
self._protocol.buffer_updated(nbytes)
except Exception as exc:
self._fatal_error(
exc,
'Fatal error: '
'protocol.buffer_updated() call failed.')
return
if self._closing or nbytes == 0:
# since close() has been called we ignore any read data
return
try:
buf = self._protocol.get_buffer(-1)
if not len(buf):
raise RuntimeError('get_buffer() returned an empty buffer')
except Exception as exc:
self._fatal_error(
exc, 'Fatal error: protocol.get_buffer() call failed.')
return
try:
# schedule a new read
self._read_fut = self._loop._proactor.recv_into(self._sock, buf)
self._read_fut.add_done_callback(self._loop_reading__get_buffer)
except ConnectionAbortedError as exc:
if not self._closing:
self._fatal_error(exc, 'Fatal read error on pipe transport')
elif self._loop.get_debug():
logger.debug("Read error on pipe transport while closing",
exc_info=True)
except ConnectionResetError as exc:
self._force_close(exc)
except OSError as exc:
self._fatal_error(exc, 'Fatal read error on pipe transport')
except futures.CancelledError:
if not self._closing:
raise
class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
......
...@@ -189,3 +189,22 @@ class SubprocessProtocol(BaseProtocol): ...@@ -189,3 +189,22 @@ class SubprocessProtocol(BaseProtocol):
def process_exited(self): def process_exited(self):
"""Called when subprocess has exited.""" """Called when subprocess has exited."""
def _feed_data_to_bufferred_proto(proto, data):
data_len = len(data)
while data_len:
buf = proto.get_buffer(data_len)
buf_len = len(buf)
if not buf_len:
raise RuntimeError('get_buffer() returned an empty buffer')
if buf_len >= data_len:
buf[:data_len] = data
proto.buffer_updated(data_len)
return
else:
buf[:buf_len] = data[:buf_len]
proto.buffer_updated(buf_len)
data = data[buf_len:]
data_len = len(data)
...@@ -535,7 +535,7 @@ class SSLProtocol(protocols.Protocol): ...@@ -535,7 +535,7 @@ class SSLProtocol(protocols.Protocol):
if chunk: if chunk:
try: try:
if self._app_protocol_is_buffer: if self._app_protocol_is_buffer:
_feed_data_to_bufferred_proto( protocols._feed_data_to_bufferred_proto(
self._app_protocol, chunk) self._app_protocol, chunk)
else: else:
self._app_protocol.data_received(chunk) self._app_protocol.data_received(chunk)
...@@ -721,22 +721,3 @@ class SSLProtocol(protocols.Protocol): ...@@ -721,22 +721,3 @@ class SSLProtocol(protocols.Protocol):
self._transport.abort() self._transport.abort()
finally: finally:
self._finalize() self._finalize()
def _feed_data_to_bufferred_proto(proto, data):
data_len = len(data)
while data_len:
buf = proto.get_buffer(data_len)
buf_len = len(buf)
if not buf_len:
raise RuntimeError('get_buffer() returned an empty buffer')
if buf_len >= data_len:
buf[:data_len] = data
proto.buffer_updated(data_len)
return
else:
buf[:buf_len] = data[:buf_len]
proto.buffer_updated(buf_len)
data = data[buf_len:]
data_len = len(data)
...@@ -459,6 +459,8 @@ class ProactorSocketTransportTests(test_utils.TestCase): ...@@ -459,6 +459,8 @@ class ProactorSocketTransportTests(test_utils.TestCase):
self.assertFalse(self.protocol.pause_writing.called) self.assertFalse(self.protocol.pause_writing.called)
@unittest.skip('FIXME: bpo-33694: these tests are too close '
'to the implementation and should be refactored or removed')
class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase): class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase):
def setUp(self): def setUp(self):
...@@ -551,6 +553,8 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase): ...@@ -551,6 +553,8 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase):
self.loop._proactor.recv_into.assert_called_with(self.sock, buf) self.loop._proactor.recv_into.assert_called_with(self.sock, buf)
buf_proto.buffer_updated.assert_called_with(4) buf_proto.buffer_updated.assert_called_with(4)
@unittest.skip('FIXME: bpo-33694: this test is too close to the '
'implementation and should be refactored or removed')
def test_proto_buf_switch(self): def test_proto_buf_switch(self):
tr = self.socket_transport() tr = self.socket_transport()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
......
...@@ -11,6 +11,7 @@ except ImportError: ...@@ -11,6 +11,7 @@ except ImportError:
import asyncio import asyncio
from asyncio import log from asyncio import log
from asyncio import protocols
from asyncio import sslproto from asyncio import sslproto
from asyncio import tasks from asyncio import tasks
from test.test_asyncio import utils as test_utils from test.test_asyncio import utils as test_utils
...@@ -189,28 +190,28 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): ...@@ -189,28 +190,28 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
for usemv in [False, True]: for usemv in [False, True]:
proto = Proto(1, usemv) proto = Proto(1, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345') protocols._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345') self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv) proto = Proto(2, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345') protocols._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345') self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv) proto = Proto(2, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'1234') protocols._feed_data_to_bufferred_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234') self.assertEqual(proto.data, b'1234')
proto = Proto(4, usemv) proto = Proto(4, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'1234') protocols._feed_data_to_bufferred_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234') self.assertEqual(proto.data, b'1234')
proto = Proto(100, usemv) proto = Proto(100, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345') protocols._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345') self.assertEqual(proto.data, b'12345')
proto = Proto(0, usemv) proto = Proto(0, usemv)
with self.assertRaisesRegex(RuntimeError, 'empty buffer'): with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
sslproto._feed_data_to_bufferred_proto(proto, b'12345') protocols._feed_data_to_bufferred_proto(proto, b'12345')
def test_start_tls_client_reg_proto_1(self): def test_start_tls_client_reg_proto_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE HELLO_MSG = b'1' * self.PAYLOAD_SIZE
......
asyncio: Fix a race condition causing data loss on
pause_reading()/resume_reading() when using the ProactorEventLoop.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment