Commit dbf10227 authored by Yury Selivanov's avatar Yury Selivanov Committed by GitHub

bpo-33654: Support BufferedProtocol in set_protocol() and start_tls() (GH-7130)

In this commit:

* Support BufferedProtocol in set_protocol() and start_tls()
* Fix proactor to cancel readers reliably
* Update tests to be compatible with OpenSSL 1.1.1
* Clarify BufferedProtocol docs
* Bump TLS tests timeouts to 60 seconds; eliminate possible race from start_serving
* Rewrite test_start_tls_server_1
parent e549c4be
...@@ -463,16 +463,23 @@ The idea of BufferedProtocol is that it allows to manually allocate ...@@ -463,16 +463,23 @@ The idea of BufferedProtocol is that it allows to manually allocate
and control the receive buffer. Event loops can then use the buffer and control the receive buffer. Event loops can then use the buffer
provided by the protocol to avoid unnecessary data copies. This provided by the protocol to avoid unnecessary data copies. This
can result in noticeable performance improvement for protocols that can result in noticeable performance improvement for protocols that
receive big amounts of data. Sophisticated protocols can allocate receive big amounts of data. Sophisticated protocols implementations
the buffer only once at creation time. can allocate the buffer only once at creation time.
The following callbacks are called on :class:`BufferedProtocol` The following callbacks are called on :class:`BufferedProtocol`
instances: instances:
.. method:: BufferedProtocol.get_buffer() .. method:: BufferedProtocol.get_buffer(sizehint)
Called to allocate a new receive buffer. Must return an object Called to allocate a new receive buffer.
that implements the :ref:`buffer protocol <bufferobjects>`.
*sizehint* is a recommended minimal size for the returned
buffer. It is acceptable to return smaller or bigger buffers
than what *sizehint* suggests. When set to -1, the buffer size
can be arbitrary. It is an error to return a zero-sized buffer.
Must return an object that implements the
:ref:`buffer protocol <bufferobjects>`.
.. method:: BufferedProtocol.buffer_updated(nbytes) .. method:: BufferedProtocol.buffer_updated(nbytes)
......
...@@ -157,7 +157,6 @@ def _run_until_complete_cb(fut): ...@@ -157,7 +157,6 @@ def _run_until_complete_cb(fut):
futures._get_loop(fut).stop() futures._get_loop(fut).stop()
class _SendfileFallbackProtocol(protocols.Protocol): class _SendfileFallbackProtocol(protocols.Protocol):
def __init__(self, transp): def __init__(self, transp):
if not isinstance(transp, transports._FlowControlMixin): if not isinstance(transp, transports._FlowControlMixin):
...@@ -304,6 +303,9 @@ class Server(events.AbstractServer): ...@@ -304,6 +303,9 @@ class Server(events.AbstractServer):
async def start_serving(self): async def start_serving(self):
self._start_serving() self._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
# go through.
await tasks.sleep(0, loop=self._loop)
async def serve_forever(self): async def serve_forever(self):
if self._serving_forever_fut is not None: if self._serving_forever_fut is not None:
...@@ -1363,6 +1365,9 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -1363,6 +1365,9 @@ class BaseEventLoop(events.AbstractEventLoop):
ssl, backlog, ssl_handshake_timeout) ssl, backlog, ssl_handshake_timeout)
if start_serving: if start_serving:
server._start_serving() server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
# go through.
await tasks.sleep(0, loop=self)
if self._debug: if self._debug:
logger.info("%r is serving", server) logger.info("%r is serving", server)
......
...@@ -30,7 +30,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, ...@@ -30,7 +30,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
super().__init__(extra, loop) super().__init__(extra, loop)
self._set_extra(sock) self._set_extra(sock)
self._sock = sock self._sock = sock
self._protocol = protocol self.set_protocol(protocol)
self._server = server self._server = server
self._buffer = None # None or bytearray. self._buffer = None # None or bytearray.
self._read_fut = None self._read_fut = None
...@@ -159,16 +159,26 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -159,16 +159,26 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
def __init__(self, loop, sock, protocol, waiter=None, def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None): extra=None, server=None):
self._loop_reading_cb = None
self._paused = True
super().__init__(loop, sock, protocol, waiter, extra, server) super().__init__(loop, sock, protocol, waiter, extra, server)
self._paused = False
self._reschedule_on_resume = False self._reschedule_on_resume = False
self._loop.call_soon(self._loop_reading)
self._paused = False
if protocols._is_buffered_protocol(protocol): def set_protocol(self, protocol):
self._loop_reading = self._loop_reading__get_buffer if isinstance(protocol, protocols.BufferedProtocol):
self._loop_reading_cb = self._loop_reading__get_buffer
else: else:
self._loop_reading = self._loop_reading__data_received self._loop_reading_cb = self._loop_reading__data_received
self._loop.call_soon(self._loop_reading) super().set_protocol(protocol)
if self.is_reading():
# reset reading callback / buffers / self._read_fut
self.pause_reading()
self.resume_reading()
def is_reading(self): def is_reading(self):
return not self._paused and not self._closing return not self._paused and not self._closing
...@@ -179,6 +189,13 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -179,6 +189,13 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
self._paused = True self._paused = True
if self._read_fut is not None and not self._read_fut.done(): if self._read_fut is not None and not self._read_fut.done():
# TODO: This is an ugly hack to cancel the current read future
# *and* avoid potential race conditions, as read cancellation
# goes through `future.cancel()` and `loop.call_soon()`.
# We then use this special attribute in the reader callback to
# exit *immediately* without doing any cleanup/rescheduling.
self._read_fut.__asyncio_cancelled_on_pause__ = True
self._read_fut.cancel() self._read_fut.cancel()
self._read_fut = None self._read_fut = None
self._reschedule_on_resume = True self._reschedule_on_resume = True
...@@ -210,7 +227,14 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -210,7 +227,14 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
if not keep_open: if not keep_open:
self.close() self.close()
def _loop_reading__data_received(self, fut=None): def _loop_reading(self, fut=None):
self._loop_reading_cb(fut)
def _loop_reading__data_received(self, fut):
if (fut is not None and
getattr(fut, '__asyncio_cancelled_on_pause__', False)):
return
if self._paused: if self._paused:
self._reschedule_on_resume = True self._reschedule_on_resume = True
return return
...@@ -253,14 +277,18 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -253,14 +277,18 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
if not self._closing: if not self._closing:
raise raise
else: else:
self._read_fut.add_done_callback(self._loop_reading) self._read_fut.add_done_callback(self._loop_reading__data_received)
finally: finally:
if data: if data:
self._protocol.data_received(data) self._protocol.data_received(data)
elif data == b'': elif data == b'':
self._loop_reading__on_eof() self._loop_reading__on_eof()
def _loop_reading__get_buffer(self, fut=None): def _loop_reading__get_buffer(self, fut):
if (fut is not None and
getattr(fut, '__asyncio_cancelled_on_pause__', False)):
return
if self._paused: if self._paused:
self._reschedule_on_resume = True self._reschedule_on_resume = True
return return
...@@ -310,7 +338,9 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -310,7 +338,9 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
return return
try: try:
buf = self._protocol.get_buffer() buf = self._protocol.get_buffer(-1)
if not len(buf):
raise RuntimeError('get_buffer() returned an empty buffer')
except Exception as exc: except Exception as exc:
self._fatal_error( self._fatal_error(
exc, 'Fatal error: protocol.get_buffer() call failed.') exc, 'Fatal error: protocol.get_buffer() call failed.')
...@@ -319,7 +349,7 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, ...@@ -319,7 +349,7 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
try: try:
# schedule a new read # schedule a new read
self._read_fut = self._loop._proactor.recv_into(self._sock, buf) self._read_fut = self._loop._proactor.recv_into(self._sock, buf)
self._read_fut.add_done_callback(self._loop_reading) self._read_fut.add_done_callback(self._loop_reading__get_buffer)
except ConnectionAbortedError as exc: except ConnectionAbortedError as exc:
if not self._closing: if not self._closing:
self._fatal_error(exc, 'Fatal read error on pipe transport') self._fatal_error(exc, 'Fatal read error on pipe transport')
......
...@@ -130,11 +130,15 @@ class BufferedProtocol(BaseProtocol): ...@@ -130,11 +130,15 @@ class BufferedProtocol(BaseProtocol):
* CL: connection_lost() * CL: connection_lost()
""" """
def get_buffer(self): def get_buffer(self, sizehint):
"""Called to allocate a new receive buffer. """Called to allocate a new receive buffer.
*sizehint* is a recommended minimal size for the returned
buffer. When set to -1, the buffer size can be arbitrary.
Must return an object that implements the Must return an object that implements the
:ref:`buffer protocol <bufferobjects>`. :ref:`buffer protocol <bufferobjects>`.
It is an error to return a zero-sized buffer.
""" """
def buffer_updated(self, nbytes): def buffer_updated(self, nbytes):
...@@ -185,7 +189,3 @@ class SubprocessProtocol(BaseProtocol): ...@@ -185,7 +189,3 @@ class SubprocessProtocol(BaseProtocol):
def process_exited(self): def process_exited(self):
"""Called when subprocess has exited.""" """Called when subprocess has exited."""
def _is_buffered_protocol(proto):
return hasattr(proto, 'get_buffer') and not hasattr(proto, 'data_received')
...@@ -597,8 +597,10 @@ class _SelectorTransport(transports._FlowControlMixin, ...@@ -597,8 +597,10 @@ class _SelectorTransport(transports._FlowControlMixin,
self._extra['peername'] = None self._extra['peername'] = None
self._sock = sock self._sock = sock
self._sock_fd = sock.fileno() self._sock_fd = sock.fileno()
self._protocol = protocol
self._protocol_connected = True self._protocol_connected = False
self.set_protocol(protocol)
self._server = server self._server = server
self._buffer = self._buffer_factory() self._buffer = self._buffer_factory()
self._conn_lost = 0 # Set when call to connection_lost scheduled. self._conn_lost = 0 # Set when call to connection_lost scheduled.
...@@ -640,6 +642,7 @@ class _SelectorTransport(transports._FlowControlMixin, ...@@ -640,6 +642,7 @@ class _SelectorTransport(transports._FlowControlMixin,
def set_protocol(self, protocol): def set_protocol(self, protocol):
self._protocol = protocol self._protocol = protocol
self._protocol_connected = True
def get_protocol(self): def get_protocol(self):
return self._protocol return self._protocol
...@@ -721,11 +724,7 @@ class _SelectorSocketTransport(_SelectorTransport): ...@@ -721,11 +724,7 @@ class _SelectorSocketTransport(_SelectorTransport):
def __init__(self, loop, sock, protocol, waiter=None, def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None): extra=None, server=None):
if protocols._is_buffered_protocol(protocol): self._read_ready_cb = None
self._read_ready = self._read_ready__get_buffer
else:
self._read_ready = self._read_ready__data_received
super().__init__(loop, sock, protocol, extra, server) super().__init__(loop, sock, protocol, extra, server)
self._eof = False self._eof = False
self._paused = False self._paused = False
...@@ -745,6 +744,14 @@ class _SelectorSocketTransport(_SelectorTransport): ...@@ -745,6 +744,14 @@ class _SelectorSocketTransport(_SelectorTransport):
self._loop.call_soon(futures._set_result_unless_cancelled, self._loop.call_soon(futures._set_result_unless_cancelled,
waiter, None) waiter, None)
def set_protocol(self, protocol):
if isinstance(protocol, protocols.BufferedProtocol):
self._read_ready_cb = self._read_ready__get_buffer
else:
self._read_ready_cb = self._read_ready__data_received
super().set_protocol(protocol)
def is_reading(self): def is_reading(self):
return not self._paused and not self._closing return not self._paused and not self._closing
...@@ -764,12 +771,17 @@ class _SelectorSocketTransport(_SelectorTransport): ...@@ -764,12 +771,17 @@ class _SelectorSocketTransport(_SelectorTransport):
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r resumes reading", self) logger.debug("%r resumes reading", self)
def _read_ready(self):
self._read_ready_cb()
def _read_ready__get_buffer(self): def _read_ready__get_buffer(self):
if self._conn_lost: if self._conn_lost:
return return
try: try:
buf = self._protocol.get_buffer() buf = self._protocol.get_buffer(-1)
if not len(buf):
raise RuntimeError('get_buffer() returned an empty buffer')
except Exception as exc: except Exception as exc:
self._fatal_error( self._fatal_error(
exc, 'Fatal error: protocol.get_buffer() call failed.') exc, 'Fatal error: protocol.get_buffer() call failed.')
......
...@@ -441,6 +441,8 @@ class SSLProtocol(protocols.Protocol): ...@@ -441,6 +441,8 @@ class SSLProtocol(protocols.Protocol):
self._waiter = waiter self._waiter = waiter
self._loop = loop self._loop = loop
self._app_protocol = app_protocol self._app_protocol = app_protocol
self._app_protocol_is_buffer = \
isinstance(app_protocol, protocols.BufferedProtocol)
self._app_transport = _SSLProtocolTransport(self._loop, self) self._app_transport = _SSLProtocolTransport(self._loop, self)
# _SSLPipe instance (None until the connection is made) # _SSLPipe instance (None until the connection is made)
self._sslpipe = None self._sslpipe = None
...@@ -522,7 +524,16 @@ class SSLProtocol(protocols.Protocol): ...@@ -522,7 +524,16 @@ class SSLProtocol(protocols.Protocol):
for chunk in appdata: for chunk in appdata:
if chunk: if chunk:
self._app_protocol.data_received(chunk) try:
if self._app_protocol_is_buffer:
_feed_data_to_bufferred_proto(
self._app_protocol, chunk)
else:
self._app_protocol.data_received(chunk)
except Exception as ex:
self._fatal_error(
ex, 'application protocol failed to receive SSL data')
return
else: else:
self._start_shutdown() self._start_shutdown()
break break
...@@ -709,3 +720,22 @@ class SSLProtocol(protocols.Protocol): ...@@ -709,3 +720,22 @@ class SSLProtocol(protocols.Protocol):
self._transport.abort() self._transport.abort()
finally: finally:
self._finalize() self._finalize()
def _feed_data_to_bufferred_proto(proto, data):
data_len = len(data)
while data_len:
buf = proto.get_buffer(data_len)
buf_len = len(buf)
if not buf_len:
raise RuntimeError('get_buffer() returned an empty buffer')
if buf_len >= data_len:
buf[:data_len] = data
proto.buffer_updated(data_len)
return
else:
buf[:buf_len] = data[:buf_len]
proto.buffer_updated(buf_len)
data = data[buf_len:]
data_len = len(data)
...@@ -20,6 +20,7 @@ from . import coroutines ...@@ -20,6 +20,7 @@ from . import coroutines
from . import events from . import events
from . import futures from . import futures
from . import selector_events from . import selector_events
from . import tasks
from . import transports from . import transports
from .log import logger from .log import logger
...@@ -308,6 +309,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -308,6 +309,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
ssl, backlog, ssl_handshake_timeout) ssl, backlog, ssl_handshake_timeout)
if start_serving: if start_serving:
server._start_serving() server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
# go through.
await tasks.sleep(0, loop=self)
return server return server
......
...@@ -9,7 +9,7 @@ class ReceiveStuffProto(asyncio.BufferedProtocol): ...@@ -9,7 +9,7 @@ class ReceiveStuffProto(asyncio.BufferedProtocol):
self.cb = cb self.cb = cb
self.con_lost_fut = con_lost_fut self.con_lost_fut = con_lost_fut
def get_buffer(self): def get_buffer(self, sizehint):
self.buffer = bytearray(100) self.buffer = bytearray(100)
return self.buffer return self.buffer
......
...@@ -2095,7 +2095,7 @@ class SubprocessTestsMixin: ...@@ -2095,7 +2095,7 @@ class SubprocessTestsMixin:
class SendfileBase: class SendfileBase:
DATA = b"12345abcde" * 16 * 1024 # 160 KiB DATA = b"12345abcde" * 64 * 1024 # 64 KiB (don't use smaller sizes)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -2452,7 +2452,7 @@ class SendfileMixin(SendfileBase): ...@@ -2452,7 +2452,7 @@ class SendfileMixin(SendfileBase):
self.assertEqual(srv_proto.data, self.DATA) self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA)) self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_close_peer_in_middle_of_receiving(self): def test_sendfile_close_peer_in_the_middle_of_receiving(self):
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError): with self.assertRaises(ConnectionError):
self.run_loop( self.run_loop(
...@@ -2465,7 +2465,7 @@ class SendfileMixin(SendfileBase): ...@@ -2465,7 +2465,7 @@ class SendfileMixin(SendfileBase):
self.file.tell()) self.file.tell())
self.assertTrue(cli_proto.transport.is_closing()) self.assertTrue(cli_proto.transport.is_closing())
def test_sendfile_fallback_close_peer_in_middle_of_receiving(self): def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
def sendfile_native(transp, file, offset, count): def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError # to raise SendfileNotAvailableError
......
...@@ -465,8 +465,8 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase): ...@@ -465,8 +465,8 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase):
self.loop._proactor = self.proactor self.loop._proactor = self.proactor
self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol) self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol)
self.buf = mock.Mock() self.buf = bytearray(1)
self.protocol.get_buffer.side_effect = lambda: self.buf self.protocol.get_buffer.side_effect = lambda hint: self.buf
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
...@@ -505,6 +505,64 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase): ...@@ -505,6 +505,64 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase):
self.assertTrue(self.protocol.get_buffer.called) self.assertTrue(self.protocol.get_buffer.called)
self.assertFalse(self.protocol.buffer_updated.called) self.assertFalse(self.protocol.buffer_updated.called)
def test_get_buffer_zerosized(self):
transport = self.socket_transport()
transport._fatal_error = mock.Mock()
self.loop.call_exception_handler = mock.Mock()
self.protocol.get_buffer.side_effect = lambda hint: bytearray(0)
transport._loop_reading()
self.assertTrue(transport._fatal_error.called)
self.assertTrue(self.protocol.get_buffer.called)
self.assertFalse(self.protocol.buffer_updated.called)
def test_proto_type_switch(self):
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
tr = self.socket_transport()
res = asyncio.Future(loop=self.loop)
res.set_result(b'data')
tr = self.socket_transport()
tr._read_fut = res
tr._loop_reading(res)
self.loop._proactor.recv.assert_called_with(self.sock, 32768)
self.protocol.data_received.assert_called_with(b'data')
# switch protocol to a BufferedProtocol
buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
buf = bytearray(4)
buf_proto.get_buffer.side_effect = lambda hint: buf
tr.set_protocol(buf_proto)
test_utils.run_briefly(self.loop)
res = asyncio.Future(loop=self.loop)
res.set_result(4)
tr._read_fut = res
tr._loop_reading(res)
self.loop._proactor.recv_into.assert_called_with(self.sock, buf)
buf_proto.buffer_updated.assert_called_with(4)
def test_proto_buf_switch(self):
tr = self.socket_transport()
test_utils.run_briefly(self.loop)
self.protocol.get_buffer.assert_called_with(-1)
# switch protocol to *another* BufferedProtocol
buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
buf = bytearray(4)
buf_proto.get_buffer.side_effect = lambda hint: buf
tr._read_fut.done.side_effect = lambda: False
tr.set_protocol(buf_proto)
self.assertFalse(buf_proto.get_buffer.called)
test_utils.run_briefly(self.loop)
buf_proto.get_buffer.assert_called_with(-1)
def test_buffer_updated_error(self): def test_buffer_updated_error(self):
transport = self.socket_transport() transport = self.socket_transport()
transport._fatal_error = mock.Mock() transport._fatal_error = mock.Mock()
......
...@@ -772,7 +772,8 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): ...@@ -772,7 +772,8 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
accept2_mock.return_value = None accept2_mock.return_value = None
with mock_obj(self.loop, 'create_task') as task_mock: with mock_obj(self.loop, 'create_task') as task_mock:
task_mock.return_value = None task_mock.return_value = None
self.loop._accept_connection(mock.Mock(), sock, backlog=backlog) self.loop._accept_connection(
mock.Mock(), sock, backlog=backlog)
self.assertEqual(sock.accept.call_count, backlog) self.assertEqual(sock.accept.call_count, backlog)
...@@ -1285,8 +1286,8 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase): ...@@ -1285,8 +1286,8 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
self.loop = self.new_test_loop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol) self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol)
self.buf = mock.Mock() self.buf = bytearray(1)
self.protocol.get_buffer.side_effect = lambda: self.buf self.protocol.get_buffer.side_effect = lambda hint: self.buf
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
self.sock_fd = self.sock.fileno.return_value = 7 self.sock_fd = self.sock.fileno.return_value = 7
...@@ -1319,6 +1320,42 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase): ...@@ -1319,6 +1320,42 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
self.assertTrue(self.protocol.get_buffer.called) self.assertTrue(self.protocol.get_buffer.called)
self.assertFalse(self.protocol.buffer_updated.called) self.assertFalse(self.protocol.buffer_updated.called)
def test_get_buffer_zerosized(self):
transport = self.socket_transport()
transport._fatal_error = mock.Mock()
self.loop.call_exception_handler = mock.Mock()
self.protocol.get_buffer.side_effect = lambda hint: bytearray(0)
transport._read_ready()
self.assertTrue(transport._fatal_error.called)
self.assertTrue(self.protocol.get_buffer.called)
self.assertFalse(self.protocol.buffer_updated.called)
def test_proto_type_switch(self):
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
transport = self.socket_transport()
self.sock.recv.return_value = b'data'
transport._read_ready()
self.protocol.data_received.assert_called_with(b'data')
# switch protocol to a BufferedProtocol
buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
buf = bytearray(4)
buf_proto.get_buffer.side_effect = lambda hint: buf
transport.set_protocol(buf_proto)
self.sock.recv_into.return_value = 10
transport._read_ready()
buf_proto.get_buffer.assert_called_with(-1)
buf_proto.buffer_updated.assert_called_with(10)
def test_buffer_updated_error(self): def test_buffer_updated_error(self):
transport = self.socket_transport() transport = self.socket_transport()
transport._fatal_error = mock.Mock() transport._fatal_error = mock.Mock()
...@@ -1354,7 +1391,7 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase): ...@@ -1354,7 +1391,7 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
self.sock.recv_into.return_value = 10 self.sock.recv_into.return_value = 10
transport._read_ready() transport._read_ready()
self.protocol.get_buffer.assert_called_with() self.protocol.get_buffer.assert_called_with(-1)
self.protocol.buffer_updated.assert_called_with(10) self.protocol.buffer_updated.assert_called_with(10)
def test_read_ready_eof(self): def test_read_ready_eof(self):
......
"""Tests for asyncio/sslproto.py.""" """Tests for asyncio/sslproto.py."""
import os
import logging import logging
import time import socket
import unittest import unittest
from unittest import mock from unittest import mock
try: try:
...@@ -185,17 +184,67 @@ class SslProtoHandshakeTests(test_utils.TestCase): ...@@ -185,17 +184,67 @@ class SslProtoHandshakeTests(test_utils.TestCase):
class BaseStartTLS(func_tests.FunctionalTestCaseMixin): class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
PAYLOAD_SIZE = 1024 * 100
TIMEOUT = 60
def new_loop(self): def new_loop(self):
raise NotImplementedError raise NotImplementedError
def test_start_tls_client_1(self): def test_buf_feed_data(self):
HELLO_MSG = b'1' * 1024 * 1024
class Proto(asyncio.BufferedProtocol):
def __init__(self, bufsize, usemv):
self.buf = bytearray(bufsize)
self.mv = memoryview(self.buf)
self.data = b''
self.usemv = usemv
def get_buffer(self, sizehint):
if self.usemv:
return self.mv
else:
return self.buf
def buffer_updated(self, nsize):
if self.usemv:
self.data += self.mv[:nsize]
else:
self.data += self.buf[:nsize]
for usemv in [False, True]:
proto = Proto(1, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234')
proto = Proto(4, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234')
proto = Proto(100, usemv)
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(0, usemv)
with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
sslproto._feed_data_to_bufferred_proto(proto, b'12345')
def test_start_tls_client_reg_proto_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext() server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext() client_context = test_utils.simple_client_sslcontext()
def serve(sock): def serve(sock):
sock.settimeout(5) sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG)) data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG))
...@@ -205,6 +254,8 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): ...@@ -205,6 +254,8 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.sendall(b'O') sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG)) data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR)
sock.close() sock.close()
class ClientProto(asyncio.Protocol): class ClientProto(asyncio.Protocol):
...@@ -246,17 +297,80 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): ...@@ -246,17 +297,80 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
self.loop.run_until_complete( self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10)) asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
def test_start_tls_client_buf_proto_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.start_tls(server_context, server_side=True)
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ClientProto(asyncio.BufferedProtocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
self.buf = bytearray(1)
def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def get_buffer(self, sizehint):
return self.buf
def buffer_updated(self, nsize):
assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize]))
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
self.assertEqual(await on_data, b'O')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
with self.tcp_server(serve) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
loop=self.loop, timeout=self.TIMEOUT))
def test_start_tls_server_1(self): def test_start_tls_server_1(self):
HELLO_MSG = b'1' * 1024 * 1024 HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext() server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext() client_context = test_utils.simple_client_sslcontext()
# TODO: fix TLSv1.3 support
client_context.options |= ssl.OP_NO_TLSv1_3
def client(sock, addr): def client(sock, addr):
time.sleep(0.5) sock.settimeout(self.TIMEOUT)
sock.settimeout(5)
sock.connect(addr) sock.connect(addr)
data = sock.recv_all(len(HELLO_MSG)) data = sock.recv_all(len(HELLO_MSG))
...@@ -264,12 +378,15 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): ...@@ -264,12 +378,15 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.start_tls(client_context) sock.start_tls(client_context)
sock.sendall(HELLO_MSG) sock.sendall(HELLO_MSG)
sock.shutdown(socket.SHUT_RDWR)
sock.close() sock.close()
class ServerProto(asyncio.Protocol): class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_eof): def __init__(self, on_con, on_eof, on_con_lost):
self.on_con = on_con self.on_con = on_con
self.on_eof = on_eof self.on_eof = on_eof
self.on_con_lost = on_con_lost
self.data = b'' self.data = b''
def connection_made(self, tr): def connection_made(self, tr):
...@@ -281,7 +398,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): ...@@ -281,7 +398,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
def eof_received(self): def eof_received(self):
self.on_eof.set_result(1) self.on_eof.set_result(1)
async def main(): def connection_lost(self, exc):
if exc is None:
self.on_con_lost.set_result(None)
else:
self.on_con_lost.set_exception(exc)
async def main(proto, on_con, on_eof, on_con_lost):
tr = await on_con tr = await on_con
tr.write(HELLO_MSG) tr.write(HELLO_MSG)
...@@ -292,24 +415,29 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): ...@@ -292,24 +415,29 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
server_side=True) server_side=True)
await on_eof await on_eof
await on_con_lost
self.assertEqual(proto.data, HELLO_MSG) self.assertEqual(proto.data, HELLO_MSG)
new_tr.close() new_tr.close()
server.close() async def run_main():
await server.wait_closed() on_con = self.loop.create_future()
on_eof = self.loop.create_future()
on_con_lost = self.loop.create_future()
proto = ServerProto(on_con, on_eof, on_con_lost)
on_con = self.loop.create_future() server = await self.loop.create_server(
on_eof = self.loop.create_future() lambda: proto, '127.0.0.1', 0)
proto = ServerProto(on_con, on_eof) addr = server.sockets[0].getsockname()
server = self.loop.run_until_complete( with self.tcp_client(lambda sock: client(sock, addr)):
self.loop.create_server( await asyncio.wait_for(
lambda: proto, '127.0.0.1', 0)) main(proto, on_con, on_eof, on_con_lost),
addr = server.sockets[0].getsockname() loop=self.loop, timeout=self.TIMEOUT)
with self.tcp_client(lambda sock: client(sock, addr)): server.close()
self.loop.run_until_complete( await server.wait_closed()
asyncio.wait_for(main(), loop=self.loop, timeout=10))
self.loop.run_until_complete(run_main())
def test_start_tls_wrong_args(self): def test_start_tls_wrong_args(self):
async def main(): async def main():
...@@ -332,7 +460,6 @@ class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase): ...@@ -332,7 +460,6 @@ class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
@unittest.skipIf(os.environ.get('APPVEYOR'), 'XXX: issue 32458')
class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase): class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
def new_loop(self): def new_loop(self):
......
Fix transport.set_protocol() to support switching between asyncio.Protocol
and asyncio.BufferedProtocol. Fix loop.start_tls() to work with
asyncio.BufferedProtocols.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment