Commit 231b404c authored by Victor Stinner's avatar Victor Stinner

Issue #22560: New SSL implementation based on ssl.MemoryBIO

The new SSL implementation is based on the new ssl.MemoryBIO which is only
available on Python 3.5. On Python 3.4 and older, the legacy SSL implementation
(using SSL_write, SSL_read, etc.) is used. The proactor event loop only
supports the new implementation.

The new asyncio.sslproto module adds _SSLPipe, SSLProtocol and
_SSLProtocolTransport classes. _SSLPipe allows to "wrap" or "unwrap" a socket
(switch between cleartext and SSL/TLS).

Patch written by Antoine Pitrou. sslproto.py is based on gruvi/ssl.py of the
gruvi project written by Geert Jansen.

This change adds SSL support to ProactorEventLoop on Python 3.5 and newer!

It becomes also possible to implement STARTTTLS: switch a cleartext socket to
SSL.
parent 9036e49b
...@@ -11,6 +11,7 @@ import socket ...@@ -11,6 +11,7 @@ import socket
from . import base_events from . import base_events
from . import constants from . import constants
from . import futures from . import futures
from . import sslproto
from . import transports from . import transports
from .log import logger from .log import logger
...@@ -367,6 +368,20 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -367,6 +368,20 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
return _ProactorSocketTransport(self, sock, protocol, waiter, return _ProactorSocketTransport(self, sock, protocol, waiter,
extra, server) extra, server)
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None):
if not sslproto._is_sslproto_available():
raise NotImplementedError("Proactor event loop requires Python 3.5"
" or newer (ssl.MemoryBIO) to support "
"SSL")
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
server_side, server_hostname)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, def _make_duplex_pipe_transport(self, sock, protocol, waiter=None,
extra=None): extra=None):
return _ProactorDuplexPipeTransport(self, return _ProactorDuplexPipeTransport(self,
...@@ -455,9 +470,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -455,9 +470,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
def _write_to_self(self): def _write_to_self(self):
self._csock.send(b'\0') self._csock.send(b'\0')
def _start_serving(self, protocol_factory, sock, ssl=None, server=None): def _start_serving(self, protocol_factory, sock,
if ssl: sslcontext=None, server=None):
raise ValueError('IocpEventLoop is incompatible with SSL.')
def loop(f=None): def loop(f=None):
try: try:
...@@ -467,6 +481,11 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -467,6 +481,11 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
logger.debug("%r got a new connection from %r: %r", logger.debug("%r got a new connection from %r: %r",
server, addr, conn) server, addr, conn)
protocol = protocol_factory() protocol = protocol_factory()
if sslcontext is not None:
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server)
else:
self._make_socket_transport( self._make_socket_transport(
conn, protocol, conn, protocol,
extra={'peername': addr}, server=server) extra={'peername': addr}, server=server)
......
...@@ -10,6 +10,7 @@ import collections ...@@ -10,6 +10,7 @@ import collections
import errno import errno
import functools import functools
import socket import socket
import sys
try: try:
import ssl import ssl
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
...@@ -21,6 +22,7 @@ from . import events ...@@ -21,6 +22,7 @@ from . import events
from . import futures from . import futures
from . import selectors from . import selectors
from . import transports from . import transports
from . import sslproto
from .log import logger from .log import logger
...@@ -58,6 +60,24 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -58,6 +60,24 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None): extra=None, server=None):
if not sslproto._is_sslproto_available():
return self._make_legacy_ssl_transport(
rawsock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
extra=extra, server=server)
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
server_side, server_hostname)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext,
waiter, *,
server_side=False, server_hostname=None,
extra=None, server=None):
# Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used
# on Python 3.4 and older, when ssl.MemoryBIO is not available.
return _SelectorSslTransport( return _SelectorSslTransport(
self, rawsock, protocol, sslcontext, waiter, self, rawsock, protocol, sslcontext, waiter,
server_side, server_hostname, extra, server) server_side, server_hostname, extra, server)
...@@ -508,7 +528,8 @@ class _SelectorTransport(transports._FlowControlMixin, ...@@ -508,7 +528,8 @@ class _SelectorTransport(transports._FlowControlMixin,
def _fatal_error(self, exc, message='Fatal error on transport'): def _fatal_error(self, exc, message='Fatal error on transport'):
# Should be called from exception handler only. # Should be called from exception handler only.
if isinstance(exc, (BrokenPipeError, ConnectionResetError)): if isinstance(exc, (BrokenPipeError,
ConnectionResetError, ConnectionAbortedError)):
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r: %s", self, message, exc_info=True) logger.debug("%r: %s", self, message, exc_info=True)
else: else:
...@@ -683,26 +704,8 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -683,26 +704,8 @@ class _SelectorSslTransport(_SelectorTransport):
if ssl is None: if ssl is None:
raise RuntimeError('stdlib ssl module not available') raise RuntimeError('stdlib ssl module not available')
if server_side:
if not sslcontext:
raise ValueError('Server side ssl needs a valid SSLContext')
else:
if not sslcontext: if not sslcontext:
# Client side may pass ssl=True to use a default sslcontext = sslproto._create_transport_context(server_side, server_hostname)
# context; in that case the sslcontext passed is None.
# The default is secure for client connections.
if hasattr(ssl, 'create_default_context'):
# Python 3.4+: use up-to-date strong settings.
sslcontext = ssl.create_default_context()
if not server_hostname:
sslcontext.check_hostname = False
else:
# Fallback for Python 3.3.
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.set_default_verify_paths()
sslcontext.verify_mode = ssl.CERT_REQUIRED
wrap_kwargs = { wrap_kwargs = {
'server_side': server_side, 'server_side': server_side,
......
This diff is collapsed.
...@@ -434,3 +434,8 @@ def mock_nonblocking_socket(): ...@@ -434,3 +434,8 @@ def mock_nonblocking_socket():
sock = mock.Mock(socket.socket) sock = mock.Mock(socket.socket)
sock.gettimeout.return_value = 0.0 sock.gettimeout.return_value = 0.0
return sock return sock
def force_legacy_ssl_support():
return mock.patch('asyncio.sslproto._is_sslproto_available',
return_value=False)
...@@ -650,6 +650,10 @@ class EventLoopTestsMixin: ...@@ -650,6 +650,10 @@ class EventLoopTestsMixin:
*httpd.address) *httpd.address)
self._test_create_ssl_connection(httpd, create_connection) self._test_create_ssl_connection(httpd, create_connection)
def test_legacy_create_ssl_connection(self):
with test_utils.force_legacy_ssl_support():
self.test_create_ssl_connection()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_ssl_unix_connection(self): def test_create_ssl_unix_connection(self):
...@@ -666,6 +670,10 @@ class EventLoopTestsMixin: ...@@ -666,6 +670,10 @@ class EventLoopTestsMixin:
self._test_create_ssl_connection(httpd, create_connection, self._test_create_ssl_connection(httpd, create_connection,
check_sockname) check_sockname)
def test_legacy_create_ssl_unix_connection(self):
with test_utils.force_legacy_ssl_support():
self.test_create_ssl_unix_connection()
def test_create_connection_local_addr(self): def test_create_connection_local_addr(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
port = support.find_unused_port() port = support.find_unused_port()
...@@ -826,6 +834,10 @@ class EventLoopTestsMixin: ...@@ -826,6 +834,10 @@ class EventLoopTestsMixin:
# stop serving # stop serving
server.close() server.close()
def test_legacy_create_server_ssl(self):
with test_utils.force_legacy_ssl_support():
self.test_create_server_ssl()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl(self): def test_create_unix_server_ssl(self):
...@@ -857,6 +869,10 @@ class EventLoopTestsMixin: ...@@ -857,6 +869,10 @@ class EventLoopTestsMixin:
# stop serving # stop serving
server.close() server.close()
def test_legacy_create_unix_server_ssl(self):
with test_utils.force_legacy_ssl_support():
self.test_create_unix_server_ssl()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl_verify_failed(self): def test_create_server_ssl_verify_failed(self):
proto = MyProto(loop=self.loop) proto = MyProto(loop=self.loop)
...@@ -881,6 +897,10 @@ class EventLoopTestsMixin: ...@@ -881,6 +897,10 @@ class EventLoopTestsMixin:
self.assertIsNone(proto.transport) self.assertIsNone(proto.transport)
server.close() server.close()
def test_legacy_create_server_ssl_verify_failed(self):
with test_utils.force_legacy_ssl_support():
self.test_create_server_ssl_verify_failed()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl_verify_failed(self): def test_create_unix_server_ssl_verify_failed(self):
...@@ -907,6 +927,10 @@ class EventLoopTestsMixin: ...@@ -907,6 +927,10 @@ class EventLoopTestsMixin:
self.assertIsNone(proto.transport) self.assertIsNone(proto.transport)
server.close() server.close()
def test_legacy_create_unix_server_ssl_verify_failed(self):
with test_utils.force_legacy_ssl_support():
self.test_create_unix_server_ssl_verify_failed()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl_match_failed(self): def test_create_server_ssl_match_failed(self):
proto = MyProto(loop=self.loop) proto = MyProto(loop=self.loop)
...@@ -934,6 +958,10 @@ class EventLoopTestsMixin: ...@@ -934,6 +958,10 @@ class EventLoopTestsMixin:
proto.transport.close() proto.transport.close()
server.close() server.close()
def test_legacy_create_server_ssl_match_failed(self):
with test_utils.force_legacy_ssl_support():
self.test_create_server_ssl_match_failed()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_ssl_verified(self): def test_create_unix_server_ssl_verified(self):
...@@ -958,6 +986,11 @@ class EventLoopTestsMixin: ...@@ -958,6 +986,11 @@ class EventLoopTestsMixin:
proto.transport.close() proto.transport.close()
client.close() client.close()
server.close() server.close()
self.loop.run_until_complete(proto.done)
def test_legacy_create_unix_server_ssl_verified(self):
with test_utils.force_legacy_ssl_support():
self.test_create_unix_server_ssl_verified()
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl_verified(self): def test_create_server_ssl_verified(self):
...@@ -982,6 +1015,11 @@ class EventLoopTestsMixin: ...@@ -982,6 +1015,11 @@ class EventLoopTestsMixin:
proto.transport.close() proto.transport.close()
client.close() client.close()
server.close() server.close()
self.loop.run_until_complete(proto.done)
def test_legacy_create_server_ssl_verified(self):
with test_utils.force_legacy_ssl_support():
self.test_create_server_ssl_verified()
def test_create_server_sock(self): def test_create_server_sock(self):
proto = asyncio.Future(loop=self.loop) proto = asyncio.Future(loop=self.loop)
...@@ -1746,20 +1784,20 @@ if sys.platform == 'win32': ...@@ -1746,20 +1784,20 @@ if sys.platform == 'win32':
def create_event_loop(self): def create_event_loop(self):
return asyncio.ProactorEventLoop() return asyncio.ProactorEventLoop()
def test_create_ssl_connection(self): def test_legacy_create_ssl_connection(self):
raise unittest.SkipTest("IocpEventLoop incompatible with SSL") raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
def test_create_server_ssl(self): def test_legacy_create_server_ssl(self):
raise unittest.SkipTest("IocpEventLoop incompatible with SSL") raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
def test_create_server_ssl_verify_failed(self): def test_legacy_create_server_ssl_verify_failed(self):
raise unittest.SkipTest("IocpEventLoop incompatible with SSL") raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
def test_create_server_ssl_match_failed(self): def test_legacy_create_server_ssl_match_failed(self):
raise unittest.SkipTest("IocpEventLoop incompatible with SSL") raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
def test_create_server_ssl_verified(self): def test_legacy_create_server_ssl_verified(self):
raise unittest.SkipTest("IocpEventLoop incompatible with SSL") raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL")
def test_reader_callback(self): def test_reader_callback(self):
raise unittest.SkipTest("IocpEventLoop does not have add_reader()") raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
......
...@@ -59,9 +59,13 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): ...@@ -59,9 +59,13 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
with test_utils.disable_logger(): with test_utils.disable_logger():
transport = self.loop._make_ssl_transport( transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter) m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport) # Sanity check
class_name = transport.__class__.__name__
self.assertIn("ssl", class_name.lower())
self.assertIn("transport", class_name.lower())
@mock.patch('asyncio.selector_events.ssl', None) @mock.patch('asyncio.selector_events.ssl', None)
@mock.patch('asyncio.sslproto.ssl', None)
def test_make_ssl_transport_without_ssl_error(self): def test_make_ssl_transport_without_ssl_error(self):
m = mock.Mock() m = mock.Mock()
self.loop.add_reader = mock.Mock() self.loop.add_reader = mock.Mock()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment