Commit f7686c1f authored by Neil Aspinall's avatar Neil Aspinall Committed by Andrew Svetlov

bpo-29970: Add timeout for SSL handshake in asyncio

10 seconds by default.
parent 4b965930
......@@ -261,7 +261,7 @@ Tasks
Creating connections
--------------------
.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None)
.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=10.0)
Create a streaming transport connection to a given Internet *host* and
*port*: socket family :py:data:`~socket.AF_INET` or
......@@ -325,6 +325,13 @@ Creating connections
to bind the socket to locally. The *local_host* and *local_port*
are looked up using getaddrinfo(), similarly to *host* and *port*.
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
to wait for the SSL handshake to complete before aborting the connection.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.5
On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
......@@ -386,7 +393,7 @@ Creating connections
:ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples.
.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None)
.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=10.0)
Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket
type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket
......@@ -404,6 +411,10 @@ Creating connections
Availability: UNIX.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.7
The *path* parameter can now be a :class:`~pathlib.Path` object.
......@@ -412,7 +423,7 @@ Creating connections
Creating listening connections
------------------------------
.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None)
.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=10.0)
Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to
*host* and *port*.
......@@ -456,6 +467,13 @@ Creating listening connections
set this flag when being created. This option is not supported on
Windows.
* *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
for the SSL handshake to complete before aborting the connection.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.5
On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
......@@ -470,7 +488,7 @@ Creating listening connections
The *host* parameter can now be a sequence of strings.
.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None)
.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=10.0)
Similar to :meth:`AbstractEventLoop.create_server`, but specific to the
socket family :py:data:`~socket.AF_UNIX`.
......@@ -481,11 +499,15 @@ Creating listening connections
Availability: UNIX.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionchanged:: 3.7
The *path* parameter can now be a :class:`~pathlib.Path` object.
.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None)
.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=10.0)
Handle an accepted connection.
......@@ -500,8 +522,15 @@ Creating listening connections
* *ssl* can be set to an :class:`~ssl.SSLContext` to enable SSL over the
accepted connections.
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection.
When completed it returns a ``(transport, protocol)`` pair.
.. versionadded:: 3.7
The *ssl_handshake_timeout* parameter.
.. versionadded:: 3.5.3
......
......@@ -29,6 +29,7 @@ import sys
import warnings
import weakref
from . import constants
from . import coroutines
from . import events
from . import futures
......@@ -275,9 +276,11 @@ class BaseEventLoop(events.AbstractEventLoop):
"""Create socket transport."""
raise NotImplementedError
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None):
def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Create SSL transport."""
raise NotImplementedError
......@@ -635,10 +638,12 @@ class BaseEventLoop(events.AbstractEventLoop):
return await self.run_in_executor(
None, socket.getnameinfo, sockaddr, flags)
async def create_connection(self, protocol_factory, host=None, port=None,
*, ssl=None, family=0,
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None):
async def create_connection(
self, protocol_factory, host=None, port=None,
*, ssl=None, family=0,
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Connect to a TCP server.
Create a streaming transport connection to a given Internet host and
......@@ -751,7 +756,8 @@ class BaseEventLoop(events.AbstractEventLoop):
f'A Stream Socket was expected, got {sock!r}')
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
......@@ -760,8 +766,10 @@ class BaseEventLoop(events.AbstractEventLoop):
sock, host, port, transport, protocol)
return transport, protocol
async def _create_connection_transport(self, sock, protocol_factory, ssl,
server_hostname, server_side=False):
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
sock.setblocking(False)
......@@ -771,7 +779,8 @@ class BaseEventLoop(events.AbstractEventLoop):
sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname)
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)
......@@ -929,15 +938,17 @@ class BaseEventLoop(events.AbstractEventLoop):
raise OSError(f'getaddrinfo({host!r}) returned empty list')
return infos
async def create_server(self, protocol_factory, host=None, port=None,
*,
family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE,
sock=None,
backlog=100,
ssl=None,
reuse_address=None,
reuse_port=None):
async def create_server(
self, protocol_factory, host=None, port=None,
*,
family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE,
sock=None,
backlog=100,
ssl=None,
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Create a TCP server.
The host parameter can be a string, in that case the TCP server is
......@@ -1026,13 +1037,16 @@ class BaseEventLoop(events.AbstractEventLoop):
for sock in sockets:
sock.listen(backlog)
sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server, backlog)
self._start_serving(protocol_factory, sock, ssl, server, backlog,
ssl_handshake_timeout)
if self._debug:
logger.info("%r is serving", server)
return server
async def connect_accepted_socket(self, protocol_factory, sock,
*, ssl=None):
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""Handle an accepted connection.
This is used by servers that accept connections outside of
......@@ -1045,7 +1059,8 @@ class BaseEventLoop(events.AbstractEventLoop):
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True)
sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
......
......@@ -8,3 +8,6 @@ ACCEPT_RETRY_DELAY = 1
# The larger the number, the slower the operation in debug mode
# (see extract_stack() in format_helpers.py).
DEBUG_STACK_DEPTH = 10
# Number of seconds to wait for SSL handshake to complete
SSL_HANDSHAKE_TIMEOUT = 10.0
......@@ -250,16 +250,20 @@ class AbstractEventLoop:
async def getnameinfo(self, sockaddr, flags=0):
raise NotImplementedError
async def create_connection(self, protocol_factory, host=None, port=None,
*, ssl=None, family=0, proto=0,
flags=0, sock=None, local_addr=None,
server_hostname=None):
raise NotImplementedError
async def create_server(self, protocol_factory, host=None, port=None,
*, family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None):
async def create_connection(
self, protocol_factory, host=None, port=None,
*, ssl=None, family=0, proto=0,
flags=0, sock=None, local_addr=None,
server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
raise NotImplementedError
async def create_server(
self, protocol_factory, host=None, port=None,
*, family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""A coroutine which creates a TCP server bound to host and port.
The return value is a Server object which can be used to stop
......@@ -294,16 +298,25 @@ class AbstractEventLoop:
the same port as other existing endpoints are bound to, so long as
they all set this flag when being created. This option is not
supported on Windows.
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for completion of the SSL handshake before aborting the
connection. Default is 10s, longer timeouts may increase vulnerability
to DoS attacks (see https://support.f5.com/csp/article/K13834)
"""
raise NotImplementedError
async def create_unix_connection(self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None):
async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
raise NotImplementedError
async def create_unix_server(self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None):
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
"""A coroutine which creates a UNIX Domain Socket server.
The return value is a Server object, which can be used to stop
......@@ -320,6 +333,9 @@ class AbstractEventLoop:
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 10s).
"""
raise NotImplementedError
......
......@@ -389,11 +389,15 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
return _ProactorSocketTransport(self, sock, protocol, waiter,
extra, server)
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None):
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
server_side, server_hostname)
def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
......@@ -486,7 +490,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self._csock.send(b'\0')
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100):
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
def loop(f=None):
try:
......@@ -499,7 +504,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
if sslcontext is not None:
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server)
extra={'peername': addr}, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
else:
self._make_socket_transport(
conn, protocol,
......
......@@ -70,11 +70,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
return _SelectorSocketTransport(self, sock, protocol, waiter,
extra, server)
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None):
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
server_side, server_hostname)
def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
......@@ -143,12 +147,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
exc_info=True)
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100):
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog)
protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout)
def _accept_connection(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100):
def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
......@@ -179,17 +187,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
backlog)
backlog, ssl_handshake_timeout)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server)
protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout)
self.create_task(accept)
async def _accept_connection2(self, protocol_factory, conn, extra,
sslcontext=None, server=None):
async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
protocol = None
transport = None
try:
......@@ -198,7 +209,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
if sslcontext:
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server)
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
......
......@@ -6,6 +6,7 @@ except ImportError: # pragma: no cover
ssl = None
from . import base_events
from . import constants
from . import protocols
from . import transports
from .log import logger
......@@ -400,7 +401,8 @@ class SSLProtocol(protocols.Protocol):
def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None,
call_connection_made=True):
call_connection_made=True,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
if ssl is None:
raise RuntimeError('stdlib ssl module not available')
......@@ -434,6 +436,7 @@ class SSLProtocol(protocols.Protocol):
# transport, ex: SelectorSocketTransport
self._transport = None
self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
......@@ -561,9 +564,18 @@ class SSLProtocol(protocols.Protocol):
# the SSL handshake
self._write_backlog.append((b'', 1))
self._loop.call_soon(self._process_write_backlog)
self._handshake_timeout_handle = \
self._loop.call_later(self._ssl_handshake_timeout,
self._check_handshake_timeout)
def _check_handshake_timeout(self):
if self._in_handshake is True:
logger.warning("%r stalled during handshake", self)
self._abort()
def _on_handshake_complete(self, handshake_exc):
self._in_handshake = False
self._handshake_timeout_handle.cancel()
sslobj = self._sslpipe.ssl_object
try:
......
......@@ -192,9 +192,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode)
async def create_unix_connection(self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None):
async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
......@@ -228,11 +230,14 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
sock.setblocking(False)
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
return transport, protocol
async def create_unix_server(self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None):
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
......@@ -283,7 +288,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
server = base_events.Server(self, [sock])
sock.listen(backlog)
sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server)
self._start_serving(protocol_factory, sock, ssl, server,
ssl_handshake_timeout=ssl_handshake_timeout)
return server
......
......@@ -1301,34 +1301,45 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY
handshake_timeout = object()
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org')
server_hostname='python.org',
ssl_handshake_timeout=handshake_timeout)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com')
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com')
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
server_hostname='')
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='',
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='')
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='',
ssl_handshake_timeout=handshake_timeout)
def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
......@@ -1687,7 +1698,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
mock.ANY,
MyProto, sock, None, None, mock.ANY)
MyProto, sock, None, None, mock.ANY, mock.ANY)
def test_call_coroutine(self):
@asyncio.coroutine
......
......@@ -11,6 +11,7 @@ except ImportError:
import asyncio
from asyncio import log
from asyncio import sslproto
from asyncio import tasks
from test.test_asyncio import utils as test_utils
......@@ -25,7 +26,8 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def ssl_protocol(self, waiter=None):
sslcontext = test_utils.dummy_ssl_context()
app_proto = asyncio.Protocol()
proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter)
proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
ssl_handshake_timeout=0.1)
self.assertIs(proto._app_transport.get_protocol(), app_proto)
self.addCleanup(proto._app_transport.close)
return proto
......@@ -63,6 +65,16 @@ class SslProtoHandshakeTests(test_utils.TestCase):
with test_utils.disable_logger():
self.loop.run_until_complete(handshake_fut)
def test_handshake_timeout(self):
# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
ssl_proto = self.ssl_protocol()
transport = self.connection_made(ssl_proto)
with test_utils.disable_logger():
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
self.assertTrue(transport.abort.called)
def test_eof_received_waiter(self):
waiter = asyncio.Future(loop=self.loop)
ssl_proto = self.ssl_protocol(waiter)
......
......@@ -63,6 +63,7 @@ Jeffrey Armstrong
Jason Asbahr
David Ascher
Ammar Askar
Neil Aspinall
Chris AtLee
Aymeric Augustin
Cathy Avery
......
Abort asyncio SSLProtocol connection if handshake not complete within 10s
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment