Commit b1fdf47f authored by Antoine Pitrou's avatar Antoine Pitrou

Issue #21965: Add support for in-memory SSL to the ssl module.

Patch by Geert Jansen.
parent 414e15a8
...@@ -803,6 +803,29 @@ the specification of normal, OS-level sockets. See especially the ...@@ -803,6 +803,29 @@ the specification of normal, OS-level sockets. See especially the
SSL sockets also have the following additional methods and attributes: SSL sockets also have the following additional methods and attributes:
.. method:: SSLSocket.read(len=0, buffer=None)
Read up to *len* bytes of data from the SSL socket and return the result as
a ``bytes`` instance. If *buffer* is specified, then read into the buffer
instead, and return the number of bytes read.
.. method:: SSLSocket.write(buf)
Write *buf* to the SSL socket and return the number of bytes written. The
*buf* argument must be an object supporting the buffer interface.
.. note::
The :meth:`~SSLSocket.read` and :meth:`~SSLSocket.write` methods are the
low-level methods that read and write unencrypted, application-level data
and and decrypt/encrypt it to encrypted, wire-level data. These methods
require an active SSL connection, i.e. the handshake was completed and
:meth:`SSLSocket.unwrap` was not called.
Normally you should use the socket API methods like
:meth:`~socket.socket.recv` and :meth:`~socket.socket.send` instead of these
methods.
.. method:: SSLSocket.do_handshake() .. method:: SSLSocket.do_handshake()
Perform the SSL setup handshake. Perform the SSL setup handshake.
...@@ -935,6 +958,11 @@ SSL sockets also have the following additional methods and attributes: ...@@ -935,6 +958,11 @@ SSL sockets also have the following additional methods and attributes:
.. versionadded:: 3.5 .. versionadded:: 3.5
.. method:: SSLSocket.pending()
Returns the number of already decrypted bytes available for read, pending on
the connection.
.. attribute:: SSLSocket.context .. attribute:: SSLSocket.context
The :class:`SSLContext` object this SSL socket is tied to. If the SSL The :class:`SSLContext` object this SSL socket is tied to. If the SSL
...@@ -944,6 +972,22 @@ SSL sockets also have the following additional methods and attributes: ...@@ -944,6 +972,22 @@ SSL sockets also have the following additional methods and attributes:
.. versionadded:: 3.2 .. versionadded:: 3.2
.. attribute:: SSLSocket.server_side
A boolean which is ``True`` for server-side sockets and ``False`` for
client-side sockets.
.. versionadded:: 3.5
.. attribute:: SSLSocket.server_hostname
A ``bytes`` instance containing the ``'idna'`` encoded version of the
hostname specified in the *server_hostname* argument in
:meth:`SSLContext.wrap_socket`. If no *server_hostname* was specified, this
attribute will be ``None``.
.. versionadded:: 3.5
SSL Contexts SSL Contexts
------------ ------------
...@@ -1670,6 +1714,130 @@ thus several things you need to be aware of: ...@@ -1670,6 +1714,130 @@ thus several things you need to be aware of:
select.select([], [sock], []) select.select([], [sock], [])
Memory BIO Support
------------------
.. versionadded:: 3.5
Ever since the SSL module was introduced in Python 2.6, the :class:`SSLSocket`
class has provided two related but distinct areas of functionality:
- SSL protocol handling
- Network IO
The network IO API is identical to that provided by :class:`socket.socket`,
from which :class:`SSLSocket` also inherits. This allows an SSL socket to be
used as a drop-in replacement for a regular socket, making it very easy to add
SSL support to an existing application.
Combining SSL protocol handling and network IO usually works well, but there
are some cases where it doesn't. An example is async IO frameworks that want to
use a different IO multiplexing model than the "select/poll on a file
descriptor" (readiness based) model that is assumed by :class:`socket.socket`
and by the internal OpenSSL socket IO routines. This is mostly relevant for
platforms like Windows where this model is not efficient. For this purpose, a
reduced scope variant of :class:`SSLSocket` called :class:`SSLObject` is
provided.
.. class:: SSLObject
A reduced-scope variant of :class:`SSLSocket` representing an SSL protocol
instance that does not contain any network IO methods.
The following methods are available from :class:`SSLSocket`:
- :attr:`~SSLSocket.context`
- :attr:`~SSLSocket.server_side`
- :attr:`~SSLSocket.server_hostname`
- :meth:`~SSLSocket.read`
- :meth:`~SSLSocket.write`
- :meth:`~SSLSocket.getpeercert`
- :meth:`~SSLSocket.selected_npn_protocol`
- :meth:`~SSLSocket.cipher`
- :meth:`~SSLSocket.compression`
- :meth:`~SSLSocket.pending`
- :meth:`~SSLSocket.do_handshake`
- :meth:`~SSLSocket.unwrap`
- :meth:`~SSLSocket.get_channel_binding`
An SSLObject communicates with the outside world using memory buffers. The
class :class:`MemoryBIO` provides a memory buffer that can be used for this
purpose. It wraps an OpenSSL memory BIO (Basic IO) object:
.. class:: MemoryBIO
A memory buffer that can be used to pass data between Python and an SSL
protocol instance.
.. attribute:: MemoryBIO.pending
Return the number of bytes currently in the memory buffer.
.. attribute:: MemoryBIO.eof
A boolean indicating whether the memory BIO is current at the end-of-file
position.
.. method:: MemoryBIO.read(n=-1)
Read up to *n* bytes from the memory buffer. If *n* is not specified or
negative, all bytes are returned.
.. method:: MemoryBIO.write(buf)
Write the bytes from *buf* to the memory BIO. The *buf* argument must be an
object supporting the buffer protocol.
The return value is the number of bytes written, which is always equal to
the length of *buf*.
.. method:: MemoryBIO.write_eof()
Write an EOF marker to the memory BIO. After this method has been called, it
is illegal to call :meth:`~MemoryBIO.write`. The attribute :attr:`eof` will
become true after all data currently in the buffer has been read.
An :class:`SSLObject` instance can be created using the
:meth:`~SSLContext.wrap_bio` method. This method will create the
:class:`SSLObject` instance and bind it to a pair of BIOs. The *incoming* BIO
is used to pass data from Python to the SSL protocol instance, while the
*outgoing* BIO is used to pass data the other way around.
.. method:: SSLContext.wrap_bio(incoming, outgoing, server_side=False, \
server_hostname=None)
Create a new :class:`SSLObject` instance by wrapping the BIO objects
*incoming* and *outgoing*. The SSL routines will read input data from the
incoming BIO and write data to the outgoing BIO.
The *server_side* and *server_hostname* parameters have the same meaning as
in :meth:`SSLContext.wrap_socket`.
Some notes related to the use of :class:`SSLObject`:
- All IO on an :class:`SSLObject` is non-blocking. This means that for example
:meth:`~SSLSocket.read` will raise an :exc:`SSLWantReadError` if it needs
more data than the incoming BIO has available.
- There is no module-level ``wrap_bio`` call like there is for
:meth:`~SSLContext.wrap_socket`. An :class:`SSLObject` is always created via
an :class:`SSLContext`.
- There is no *do_handshake_on_connect* machinery. You must always manually
call :meth:`~SSLSocket.do_handshake` to start the handshake.
- There is no handling of *suppress_ragged_eofs*. All end-of-file conditions
that are in violation of the protocol are reported via the :exc:`SSLEOFError`
exception.
- The method :meth:`~SSLSocket.unwrap` call does not return anything, unlike
for an SSL socket where it returns the underlying socket.
- The *server_name_callback* callback passed to
:meth:`SSLContext.set_servername_callback` will get an :class:`SSLObject`
instance instead of a :class:`SSLSocket` instance as its first parameter.
.. _ssl-security: .. _ssl-security:
Security considerations Security considerations
......
...@@ -97,7 +97,7 @@ from enum import Enum as _Enum, IntEnum as _IntEnum ...@@ -97,7 +97,7 @@ from enum import Enum as _Enum, IntEnum as _IntEnum
import _ssl # if we can't import it, let the error propagate import _ssl # if we can't import it, let the error propagate
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
from _ssl import _SSLContext from _ssl import _SSLContext, MemoryBIO
from _ssl import ( from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError, SSLSyscallError, SSLEOFError,
...@@ -352,6 +352,12 @@ class SSLContext(_SSLContext): ...@@ -352,6 +352,12 @@ class SSLContext(_SSLContext):
server_hostname=server_hostname, server_hostname=server_hostname,
_context=self) _context=self)
def wrap_bio(self, incoming, outgoing, server_side=False,
server_hostname=None):
sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
server_hostname=server_hostname)
return SSLObject(sslobj)
def set_npn_protocols(self, npn_protocols): def set_npn_protocols(self, npn_protocols):
protos = bytearray() protos = bytearray()
for protocol in npn_protocols: for protocol in npn_protocols:
...@@ -469,6 +475,129 @@ def _create_stdlib_context(protocol=PROTOCOL_SSLv23, *, cert_reqs=None, ...@@ -469,6 +475,129 @@ def _create_stdlib_context(protocol=PROTOCOL_SSLv23, *, cert_reqs=None,
return context return context
class SSLObject:
"""This class implements an interface on top of a low-level SSL object as
implemented by OpenSSL. This object captures the state of an SSL connection
but does not provide any network IO itself. IO needs to be performed
through separate "BIO" objects which are OpenSSL's IO abstraction layer.
This class does not have a public constructor. Instances are returned by
``SSLContext.wrap_bio``. This class is typically used by framework authors
that want to implement asynchronous IO for SSL through memory buffers.
When compared to ``SSLSocket``, this object lacks the following features:
* Any form of network IO incluging methods such as ``recv`` and ``send``.
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
"""
def __init__(self, sslobj, owner=None):
self._sslobj = sslobj
# Note: _sslobj takes a weak reference to owner
self._sslobj.owner = owner or self
@property
def context(self):
"""The SSLContext that is currently in use."""
return self._sslobj.context
@context.setter
def context(self, ctx):
self._sslobj.context = ctx
@property
def server_side(self):
"""Whether this is a server-side socket."""
return self._sslobj.server_side
@property
def server_hostname(self):
"""The currently set server hostname (for SNI), or ``None`` if no
server hostame is set."""
return self._sslobj.server_hostname
def read(self, len=0, buffer=None):
"""Read up to 'len' bytes from the SSL object and return them.
If 'buffer' is provided, read into this buffer and return the number of
bytes read.
"""
if buffer is not None:
v = self._sslobj.read(len, buffer)
else:
v = self._sslobj.read(len or 1024)
return v
def write(self, data):
"""Write 'data' to the SSL object and return the number of bytes
written.
The 'data' argument must support the buffer interface.
"""
return self._sslobj.write(data)
def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the certificate provided
by the other end of the SSL channel.
Return None if no certificate was provided, {} if a certificate was
provided, but not validated.
"""
return self._sslobj.peer_certificate(binary_form)
def selected_npn_protocol(self):
"""Return the currently selected NPN protocol as a string, or ``None``
if a next protocol was not negotiated or if NPN is not supported by one
of the peers."""
if _ssl.HAS_NPN:
return self._sslobj.selected_npn_protocol()
def cipher(self):
"""Return the currently selected cipher as a 3-tuple ``(name,
ssl_version, secret_bits)``."""
return self._sslobj.cipher()
def compression(self):
"""Return the current compression algorithm in use, or ``None`` if
compression was not negotiated or not supported by one of the peers."""
return self._sslobj.compression()
def pending(self):
"""Return the number of bytes that can be read immediately."""
return self._sslobj.pending()
def do_handshake(self, block=False):
"""Start the SSL/TLS handshake."""
self._sslobj.do_handshake()
if self.context.check_hostname:
if not self.server_hostname:
raise ValueError("check_hostname needs server_hostname "
"argument")
match_hostname(self.getpeercert(), self.server_hostname)
def unwrap(self):
"""Start the SSL shutdown handshake."""
return self._sslobj.shutdown()
def get_channel_binding(self, cb_type="tls-unique"):
"""Get channel binding data for current connection. Raise ValueError
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake)."""
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError(
"{0} channel binding type not implemented"
.format(cb_type))
return self._sslobj.tls_unique_cb()
def version(self):
"""Return a string identifying the protocol version used by the
current SSL channel. """
return self._sslobj.version()
class SSLSocket(socket): class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps """This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and the underlying OS socket in an SSL context when necessary, and
...@@ -556,8 +685,9 @@ class SSLSocket(socket): ...@@ -556,8 +685,9 @@ class SSLSocket(socket):
if connected: if connected:
# create the SSL object # create the SSL object
try: try:
self._sslobj = self._context._wrap_socket(self, server_side, sslobj = self._context._wrap_socket(self, server_side,
server_hostname) server_hostname)
self._sslobj = SSLObject(sslobj, owner=self)
if do_handshake_on_connect: if do_handshake_on_connect:
timeout = self.gettimeout() timeout = self.gettimeout()
if timeout == 0.0: if timeout == 0.0:
...@@ -602,11 +732,7 @@ class SSLSocket(socket): ...@@ -602,11 +732,7 @@ class SSLSocket(socket):
if not self._sslobj: if not self._sslobj:
raise ValueError("Read on closed or unwrapped SSL socket.") raise ValueError("Read on closed or unwrapped SSL socket.")
try: try:
if buffer is not None: return self._sslobj.read(len, buffer)
v = self._sslobj.read(len, buffer)
else:
v = self._sslobj.read(len or 1024)
return v
except SSLError as x: except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None: if buffer is not None:
...@@ -633,7 +759,7 @@ class SSLSocket(socket): ...@@ -633,7 +759,7 @@ class SSLSocket(socket):
self._checkClosed() self._checkClosed()
self._check_connected() self._check_connected()
return self._sslobj.peer_certificate(binary_form) return self._sslobj.getpeercert(binary_form)
def selected_npn_protocol(self): def selected_npn_protocol(self):
self._checkClosed() self._checkClosed()
...@@ -773,7 +899,7 @@ class SSLSocket(socket): ...@@ -773,7 +899,7 @@ class SSLSocket(socket):
def unwrap(self): def unwrap(self):
if self._sslobj: if self._sslobj:
s = self._sslobj.shutdown() s = self._sslobj.unwrap()
self._sslobj = None self._sslobj = None
return s return s
else: else:
...@@ -794,12 +920,6 @@ class SSLSocket(socket): ...@@ -794,12 +920,6 @@ class SSLSocket(socket):
finally: finally:
self.settimeout(timeout) self.settimeout(timeout)
if self.context.check_hostname:
if not self.server_hostname:
raise ValueError("check_hostname needs server_hostname "
"argument")
match_hostname(self.getpeercert(), self.server_hostname)
def _real_connect(self, addr, connect_ex): def _real_connect(self, addr, connect_ex):
if self.server_side: if self.server_side:
raise ValueError("can't connect in server-side mode") raise ValueError("can't connect in server-side mode")
...@@ -807,7 +927,8 @@ class SSLSocket(socket): ...@@ -807,7 +927,8 @@ class SSLSocket(socket):
# connected at the time of the call. We connect it, then wrap it. # connected at the time of the call. We connect it, then wrap it.
if self._connected: if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!") raise ValueError("attempt to connect already-connected SSLSocket!")
self._sslobj = self.context._wrap_socket(self, False, self.server_hostname) sslobj = self.context._wrap_socket(self, False, self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self)
try: try:
if connect_ex: if connect_ex:
rc = socket.connect_ex(self, addr) rc = socket.connect_ex(self, addr)
...@@ -850,15 +971,9 @@ class SSLSocket(socket): ...@@ -850,15 +971,9 @@ class SSLSocket(socket):
if the requested `cb_type` is not supported. Return bytes of the data if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake). or None if the data is not available (e.g. before the handshake).
""" """
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError(
"{0} channel binding type not implemented"
.format(cb_type))
if self._sslobj is None: if self._sslobj is None:
return None return None
return self._sslobj.tls_unique_cb() return self._sslobj.get_channel_binding(cb_type)
def version(self): def version(self):
""" """
......
...@@ -518,9 +518,14 @@ class BasicSocketTests(unittest.TestCase): ...@@ -518,9 +518,14 @@ class BasicSocketTests(unittest.TestCase):
def test_unknown_channel_binding(self): def test_unknown_channel_binding(self):
# should raise ValueError for unknown type # should raise ValueError for unknown type
s = socket.socket(socket.AF_INET) s = socket.socket(socket.AF_INET)
with ssl.wrap_socket(s) as ss: s.bind(('127.0.0.1', 0))
s.listen()
c = socket.socket(socket.AF_INET)
c.connect(s.getsockname())
with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ss.get_channel_binding("unknown-type") ss.get_channel_binding("unknown-type")
s.close()
@unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
"'tls-unique' channel binding not available") "'tls-unique' channel binding not available")
...@@ -1247,6 +1252,69 @@ class SSLErrorTests(unittest.TestCase): ...@@ -1247,6 +1252,69 @@ class SSLErrorTests(unittest.TestCase):
self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ) self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
class MemoryBIOTests(unittest.TestCase):
def test_read_write(self):
bio = ssl.MemoryBIO()
bio.write(b'foo')
self.assertEqual(bio.read(), b'foo')
self.assertEqual(bio.read(), b'')
bio.write(b'foo')
bio.write(b'bar')
self.assertEqual(bio.read(), b'foobar')
self.assertEqual(bio.read(), b'')
bio.write(b'baz')
self.assertEqual(bio.read(2), b'ba')
self.assertEqual(bio.read(1), b'z')
self.assertEqual(bio.read(1), b'')
def test_eof(self):
bio = ssl.MemoryBIO()
self.assertFalse(bio.eof)
self.assertEqual(bio.read(), b'')
self.assertFalse(bio.eof)
bio.write(b'foo')
self.assertFalse(bio.eof)
bio.write_eof()
self.assertFalse(bio.eof)
self.assertEqual(bio.read(2), b'fo')
self.assertFalse(bio.eof)
self.assertEqual(bio.read(1), b'o')
self.assertTrue(bio.eof)
self.assertEqual(bio.read(), b'')
self.assertTrue(bio.eof)
def test_pending(self):
bio = ssl.MemoryBIO()
self.assertEqual(bio.pending, 0)
bio.write(b'foo')
self.assertEqual(bio.pending, 3)
for i in range(3):
bio.read(1)
self.assertEqual(bio.pending, 3-i-1)
for i in range(3):
bio.write(b'x')
self.assertEqual(bio.pending, i+1)
bio.read()
self.assertEqual(bio.pending, 0)
def test_buffer_types(self):
bio = ssl.MemoryBIO()
bio.write(b'foo')
self.assertEqual(bio.read(), b'foo')
bio.write(bytearray(b'bar'))
self.assertEqual(bio.read(), b'bar')
bio.write(memoryview(b'baz'))
self.assertEqual(bio.read(), b'baz')
def test_error_types(self):
bio = ssl.MemoryBIO()
self.assertRaises(TypeError, bio.write, 'foo')
self.assertRaises(TypeError, bio.write, None)
self.assertRaises(TypeError, bio.write, True)
self.assertRaises(TypeError, bio.write, 1)
class NetworkedTests(unittest.TestCase): class NetworkedTests(unittest.TestCase):
def test_connect(self): def test_connect(self):
...@@ -1577,6 +1645,95 @@ class NetworkedTests(unittest.TestCase): ...@@ -1577,6 +1645,95 @@ class NetworkedTests(unittest.TestCase):
self.assertIs(ss.context, ctx2) self.assertIs(ss.context, ctx2)
self.assertIs(ss._sslobj.context, ctx2) self.assertIs(ss._sslobj.context, ctx2)
class NetworkedBIOTests(unittest.TestCase):
def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
# A simple IO loop. Call func(*args) depending on the error we get
# (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
timeout = kwargs.get('timeout', 10)
count = 0
while True:
errno = None
count += 1
try:
ret = func(*args)
except ssl.SSLError as e:
# Note that we get a spurious -1/SSL_ERROR_SYSCALL for
# non-blocking IO. The SSL_shutdown manpage hints at this.
# It *should* be safe to just ignore SYS_ERROR_SYSCALL because
# with a Memory BIO there's no syscalls (for IO at least).
if e.errno not in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE,
ssl.SSL_ERROR_SYSCALL):
raise
errno = e.errno
# Get any data from the outgoing BIO irrespective of any error, and
# send it to the socket.
buf = outgoing.read()
sock.sendall(buf)
# If there's no error, we're done. For WANT_READ, we need to get
# data from the socket and put it in the incoming BIO.
if errno is None:
break
elif errno == ssl.SSL_ERROR_WANT_READ:
buf = sock.recv(32768)
if buf:
incoming.write(buf)
else:
incoming.write_eof()
if support.verbose:
sys.stdout.write("Needed %d calls to complete %s().\n"
% (count, func.__name__))
return ret
def test_handshake(self):
with support.transient_internet("svn.python.org"):
sock = socket.socket(socket.AF_INET)
sock.connect(("svn.python.org", 443))
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT)
if ssl.HAS_SNI:
ctx.check_hostname = True
sslobj = ctx.wrap_bio(incoming, outgoing, False, 'svn.python.org')
else:
ctx.check_hostname = False
sslobj = ctx.wrap_bio(incoming, outgoing, False)
self.assertIs(sslobj._sslobj.owner, sslobj)
self.assertIsNone(sslobj.cipher())
self.assertRaises(ValueError, sslobj.getpeercert)
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
self.assertTrue(sslobj.cipher())
self.assertTrue(sslobj.getpeercert())
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertTrue(sslobj.get_channel_binding('tls-unique'))
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
sock.close()
def test_read_write_data(self):
with support.transient_internet("svn.python.org"):
sock = socket.socket(socket.AF_INET)
sock.connect(("svn.python.org", 443))
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_NONE
sslobj = ctx.wrap_bio(incoming, outgoing, False)
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
req = b'GET / HTTP/1.0\r\n\r\n'
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
self.assertEqual(buf[:5], b'HTTP/')
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
sock.close()
try: try:
import threading import threading
except ImportError: except ImportError:
...@@ -3061,10 +3218,11 @@ def test_main(verbose=False): ...@@ -3061,10 +3218,11 @@ def test_main(verbose=False):
if not os.path.exists(filename): if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename) raise support.TestFailed("Can't read certificate file %r" % filename)
tests = [ContextTests, BasicSocketTests, SSLErrorTests] tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests]
if support.is_resource_enabled('network'): if support.is_resource_enabled('network'):
tests.append(NetworkedTests) tests.append(NetworkedTests)
tests.append(NetworkedBIOTests)
if _have_threads: if _have_threads:
thread_info = support.threading_setup() thread_info = support.threading_setup()
......
...@@ -166,6 +166,9 @@ Core and Builtins ...@@ -166,6 +166,9 @@ Core and Builtins
Library Library
------- -------
- Issue #21965: Add support for in-memory SSL to the ssl module. Patch
by Geert Jansen.
- Issue #21173: Fix len() on a WeakKeyDictionary when .clear() was called - Issue #21173: Fix len() on a WeakKeyDictionary when .clear() was called
with an iterator alive. with an iterator alive.
......
...@@ -64,6 +64,7 @@ static PySocketModule_APIObject PySocketModule; ...@@ -64,6 +64,7 @@ static PySocketModule_APIObject PySocketModule;
#include "openssl/ssl.h" #include "openssl/ssl.h"
#include "openssl/err.h" #include "openssl/err.h"
#include "openssl/rand.h" #include "openssl/rand.h"
#include "openssl/bio.h"
/* SSL error object */ /* SSL error object */
static PyObject *PySSLErrorObject; static PyObject *PySSLErrorObject;
...@@ -226,10 +227,19 @@ typedef struct { ...@@ -226,10 +227,19 @@ typedef struct {
char shutdown_seen_zero; char shutdown_seen_zero;
char handshake_done; char handshake_done;
enum py_ssl_server_or_client socket_type; enum py_ssl_server_or_client socket_type;
PyObject *owner; /* Python level "owner" passed to servername callback */
PyObject *server_hostname;
} PySSLSocket; } PySSLSocket;
typedef struct {
PyObject_HEAD
BIO *bio;
int eof_written;
} PySSLMemoryBIO;
static PyTypeObject PySSLContext_Type; static PyTypeObject PySSLContext_Type;
static PyTypeObject PySSLSocket_Type; static PyTypeObject PySSLSocket_Type;
static PyTypeObject PySSLMemoryBIO_Type;
static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args); static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args);
static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args); static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args);
...@@ -240,6 +250,7 @@ static PyObject *PySSL_cipher(PySSLSocket *self); ...@@ -240,6 +250,7 @@ static PyObject *PySSL_cipher(PySSLSocket *self);
#define PySSLContext_Check(v) (Py_TYPE(v) == &PySSLContext_Type) #define PySSLContext_Check(v) (Py_TYPE(v) == &PySSLContext_Type)
#define PySSLSocket_Check(v) (Py_TYPE(v) == &PySSLSocket_Type) #define PySSLSocket_Check(v) (Py_TYPE(v) == &PySSLSocket_Type)
#define PySSLMemoryBIO_Check(v) (Py_TYPE(v) == &PySSLMemoryBIO_Type)
typedef enum { typedef enum {
SOCKET_IS_NONBLOCKING, SOCKET_IS_NONBLOCKING,
...@@ -254,6 +265,9 @@ typedef enum { ...@@ -254,6 +265,9 @@ typedef enum {
#define ERRSTR1(x,y,z) (x ":" y ": " z) #define ERRSTR1(x,y,z) (x ":" y ": " z)
#define ERRSTR(x) ERRSTR1("_ssl.c", Py_STRINGIFY(__LINE__), x) #define ERRSTR(x) ERRSTR1("_ssl.c", Py_STRINGIFY(__LINE__), x)
/* Get the socket from a PySSLSocket, if it has one */
#define GET_SOCKET(obj) ((obj)->Socket ? \
(PySocketSockObject *) PyWeakref_GetObject((obj)->Socket) : NULL)
/* /*
* SSL errors. * SSL errors.
...@@ -417,13 +431,12 @@ PySSL_SetError(PySSLSocket *obj, int ret, char *filename, int lineno) ...@@ -417,13 +431,12 @@ PySSL_SetError(PySSLSocket *obj, int ret, char *filename, int lineno)
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
{ {
if (e == 0) { if (e == 0) {
PySocketSockObject *s PySocketSockObject *s = GET_SOCKET(obj);
= (PySocketSockObject *) PyWeakref_GetObject(obj->Socket);
if (ret == 0 || (((PyObject *)s) == Py_None)) { if (ret == 0 || (((PyObject *)s) == Py_None)) {
p = PY_SSL_ERROR_EOF; p = PY_SSL_ERROR_EOF;
type = PySSLEOFErrorObject; type = PySSLEOFErrorObject;
errstr = "EOF occurred in violation of protocol"; errstr = "EOF occurred in violation of protocol";
} else if (ret == -1) { } else if (s && ret == -1) {
/* underlying BIO reported an I/O error */ /* underlying BIO reported an I/O error */
Py_INCREF(s); Py_INCREF(s);
ERR_clear_error(); ERR_clear_error();
...@@ -477,10 +490,12 @@ _setSSLError (char *errstr, int errcode, char *filename, int lineno) { ...@@ -477,10 +490,12 @@ _setSSLError (char *errstr, int errcode, char *filename, int lineno) {
static PySSLSocket * static PySSLSocket *
newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
enum py_ssl_server_or_client socket_type, enum py_ssl_server_or_client socket_type,
char *server_hostname) char *server_hostname,
PySSLMemoryBIO *inbio, PySSLMemoryBIO *outbio)
{ {
PySSLSocket *self; PySSLSocket *self;
SSL_CTX *ctx = sslctx->ctx; SSL_CTX *ctx = sslctx->ctx;
PyObject *hostname;
long mode; long mode;
self = PyObject_New(PySSLSocket, &PySSLSocket_Type); self = PyObject_New(PySSLSocket, &PySSLSocket_Type);
...@@ -493,6 +508,18 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, ...@@ -493,6 +508,18 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
self->ctx = sslctx; self->ctx = sslctx;
self->shutdown_seen_zero = 0; self->shutdown_seen_zero = 0;
self->handshake_done = 0; self->handshake_done = 0;
self->owner = NULL;
if (server_hostname != NULL) {
hostname = PyUnicode_Decode(server_hostname, strlen(server_hostname),
"idna", "strict");
if (hostname == NULL) {
Py_DECREF(self);
return NULL;
}
self->server_hostname = hostname;
} else
self->server_hostname = NULL;
Py_INCREF(sslctx); Py_INCREF(sslctx);
/* Make sure the SSL error state is initialized */ /* Make sure the SSL error state is initialized */
...@@ -502,8 +529,17 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, ...@@ -502,8 +529,17 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
PySSL_BEGIN_ALLOW_THREADS PySSL_BEGIN_ALLOW_THREADS
self->ssl = SSL_new(ctx); self->ssl = SSL_new(ctx);
PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS
SSL_set_app_data(self->ssl,self); SSL_set_app_data(self->ssl, self);
if (sock) {
SSL_set_fd(self->ssl, Py_SAFE_DOWNCAST(sock->sock_fd, SOCKET_T, int)); SSL_set_fd(self->ssl, Py_SAFE_DOWNCAST(sock->sock_fd, SOCKET_T, int));
} else {
/* BIOs are reference counted and SSL_set_bio borrows our reference.
* To prevent a double free in memory_bio_dealloc() we need to take an
* extra reference here. */
CRYPTO_add(&inbio->bio->references, 1, CRYPTO_LOCK_BIO);
CRYPTO_add(&outbio->bio->references, 1, CRYPTO_LOCK_BIO);
SSL_set_bio(self->ssl, inbio->bio, outbio->bio);
}
mode = SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER; mode = SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER;
#ifdef SSL_MODE_AUTO_RETRY #ifdef SSL_MODE_AUTO_RETRY
mode |= SSL_MODE_AUTO_RETRY; mode |= SSL_MODE_AUTO_RETRY;
...@@ -518,7 +554,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, ...@@ -518,7 +554,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
/* If the socket is in non-blocking mode or timeout mode, set the BIO /* If the socket is in non-blocking mode or timeout mode, set the BIO
* to non-blocking mode (blocking is the default) * to non-blocking mode (blocking is the default)
*/ */
if (sock->sock_timeout >= 0.0) { if (sock && sock->sock_timeout >= 0.0) {
BIO_set_nbio(SSL_get_rbio(self->ssl), 1); BIO_set_nbio(SSL_get_rbio(self->ssl), 1);
BIO_set_nbio(SSL_get_wbio(self->ssl), 1); BIO_set_nbio(SSL_get_wbio(self->ssl), 1);
} }
...@@ -531,11 +567,14 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, ...@@ -531,11 +567,14 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS
self->socket_type = socket_type; self->socket_type = socket_type;
if (sock != NULL) {
self->Socket = PyWeakref_NewRef((PyObject *) sock, NULL); self->Socket = PyWeakref_NewRef((PyObject *) sock, NULL);
if (self->Socket == NULL) { if (self->Socket == NULL) {
Py_DECREF(self); Py_DECREF(self);
Py_XDECREF(self->server_hostname);
return NULL; return NULL;
} }
}
return self; return self;
} }
...@@ -546,9 +585,9 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self) ...@@ -546,9 +585,9 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self)
int ret; int ret;
int err; int err;
int sockstate, nonblocking; int sockstate, nonblocking;
PySocketSockObject *sock PySocketSockObject *sock = GET_SOCKET(self);
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
if (sock) {
if (((PyObject*)sock) == Py_None) { if (((PyObject*)sock) == Py_None) {
_setSSLError("Underlying socket connection gone", _setSSLError("Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
...@@ -560,6 +599,7 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self) ...@@ -560,6 +599,7 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self)
nonblocking = (sock->sock_timeout >= 0.0); nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
}
/* Actually negotiate SSL connection */ /* Actually negotiate SSL connection */
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */ /* XXX If SSL_do_handshake() returns 0, it's also a failure. */
...@@ -593,7 +633,7 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self) ...@@ -593,7 +633,7 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self)
break; break;
} }
} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
Py_DECREF(sock); Py_XDECREF(sock);
if (ret < 1) if (ret < 1)
return PySSL_SetError(self, ret, __FILE__, __LINE__); return PySSL_SetError(self, ret, __FILE__, __LINE__);
...@@ -608,7 +648,7 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self) ...@@ -608,7 +648,7 @@ static PyObject *PySSL_SSLdo_handshake(PySSLSocket *self)
return Py_None; return Py_None;
error: error:
Py_DECREF(sock); Py_XDECREF(sock);
return NULL; return NULL;
} }
...@@ -1483,6 +1523,54 @@ on the SSLContext to change the certificate information associated with the\n\ ...@@ -1483,6 +1523,54 @@ on the SSLContext to change the certificate information associated with the\n\
SSLSocket before the cryptographic exchange handshake messages\n"); SSLSocket before the cryptographic exchange handshake messages\n");
static PyObject *
PySSL_get_server_side(PySSLSocket *self, void *c)
{
return PyBool_FromLong(self->socket_type == PY_SSL_SERVER);
}
PyDoc_STRVAR(PySSL_get_server_side_doc,
"Whether this is a server-side socket.");
static PyObject *
PySSL_get_server_hostname(PySSLSocket *self, void *c)
{
if (self->server_hostname == NULL)
Py_RETURN_NONE;
Py_INCREF(self->server_hostname);
return self->server_hostname;
}
PyDoc_STRVAR(PySSL_get_server_hostname_doc,
"The currently set server hostname (for SNI).");
static PyObject *
PySSL_get_owner(PySSLSocket *self, void *c)
{
PyObject *owner;
if (self->owner == NULL)
Py_RETURN_NONE;
owner = PyWeakref_GetObject(self->owner);
Py_INCREF(owner);
return owner;
}
static int
PySSL_set_owner(PySSLSocket *self, PyObject *value, void *c)
{
Py_XDECREF(self->owner);
self->owner = PyWeakref_NewRef(value, NULL);
if (self->owner == NULL)
return -1;
return 0;
}
PyDoc_STRVAR(PySSL_get_owner_doc,
"The Python-level owner of this object.\
Passed as \"self\" in servername callback.");
static void PySSL_dealloc(PySSLSocket *self) static void PySSL_dealloc(PySSLSocket *self)
{ {
...@@ -1492,6 +1580,8 @@ static void PySSL_dealloc(PySSLSocket *self) ...@@ -1492,6 +1580,8 @@ static void PySSL_dealloc(PySSLSocket *self)
SSL_free(self->ssl); SSL_free(self->ssl);
Py_XDECREF(self->Socket); Py_XDECREF(self->Socket);
Py_XDECREF(self->ctx); Py_XDECREF(self->ctx);
Py_XDECREF(self->server_hostname);
Py_XDECREF(self->owner);
PyObject_Del(self); PyObject_Del(self);
} }
...@@ -1508,10 +1598,10 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing) ...@@ -1508,10 +1598,10 @@ check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing)
int rc; int rc;
/* Nothing to do unless we're in timeout mode (not non-blocking) */ /* Nothing to do unless we're in timeout mode (not non-blocking) */
if (s->sock_timeout < 0.0) if ((s == NULL) || (s->sock_timeout == 0.0))
return SOCKET_IS_BLOCKING;
else if (s->sock_timeout == 0.0)
return SOCKET_IS_NONBLOCKING; return SOCKET_IS_NONBLOCKING;
else if (s->sock_timeout < 0.0)
return SOCKET_IS_BLOCKING;
/* Guard against closed socket */ /* Guard against closed socket */
if (s->sock_fd < 0) if (s->sock_fd < 0)
...@@ -1572,18 +1662,19 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args) ...@@ -1572,18 +1662,19 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args)
int sockstate; int sockstate;
int err; int err;
int nonblocking; int nonblocking;
PySocketSockObject *sock PySocketSockObject *sock = GET_SOCKET(self);
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
if (sock != NULL) {
if (((PyObject*)sock) == Py_None) { if (((PyObject*)sock) == Py_None) {
_setSSLError("Underlying socket connection gone", _setSSLError("Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL; return NULL;
} }
Py_INCREF(sock); Py_INCREF(sock);
}
if (!PyArg_ParseTuple(args, "y*:write", &buf)) { if (!PyArg_ParseTuple(args, "y*:write", &buf)) {
Py_DECREF(sock); Py_XDECREF(sock);
return NULL; return NULL;
} }
...@@ -1593,10 +1684,12 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args) ...@@ -1593,10 +1684,12 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args)
goto error; goto error;
} }
if (sock != NULL) {
/* just in case the blocking state of the socket has been changed */ /* just in case the blocking state of the socket has been changed */
nonblocking = (sock->sock_timeout >= 0.0); nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
}
sockstate = check_socket_and_wait_for_timeout(sock, 1); sockstate = check_socket_and_wait_for_timeout(sock, 1);
if (sockstate == SOCKET_HAS_TIMED_OUT) { if (sockstate == SOCKET_HAS_TIMED_OUT) {
...@@ -1640,7 +1733,7 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args) ...@@ -1640,7 +1733,7 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args)
} }
} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE);
Py_DECREF(sock); Py_XDECREF(sock);
PyBuffer_Release(&buf); PyBuffer_Release(&buf);
if (len > 0) if (len > 0)
return PyLong_FromLong(len); return PyLong_FromLong(len);
...@@ -1648,7 +1741,7 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args) ...@@ -1648,7 +1741,7 @@ static PyObject *PySSL_SSLwrite(PySSLSocket *self, PyObject *args)
return PySSL_SetError(self, len, __FILE__, __LINE__); return PySSL_SetError(self, len, __FILE__, __LINE__);
error: error:
Py_DECREF(sock); Py_XDECREF(sock);
PyBuffer_Release(&buf); PyBuffer_Release(&buf);
return NULL; return NULL;
} }
...@@ -1688,15 +1781,16 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args) ...@@ -1688,15 +1781,16 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args)
int sockstate; int sockstate;
int err; int err;
int nonblocking; int nonblocking;
PySocketSockObject *sock PySocketSockObject *sock = GET_SOCKET(self);
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
if (sock != NULL) {
if (((PyObject*)sock) == Py_None) { if (((PyObject*)sock) == Py_None) {
_setSSLError("Underlying socket connection gone", _setSSLError("Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL; return NULL;
} }
Py_INCREF(sock); Py_INCREF(sock);
}
buf.obj = NULL; buf.obj = NULL;
buf.buf = NULL; buf.buf = NULL;
...@@ -1722,10 +1816,12 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args) ...@@ -1722,10 +1816,12 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args)
} }
} }
if (sock != NULL) {
/* just in case the blocking state of the socket has been changed */ /* just in case the blocking state of the socket has been changed */
nonblocking = (sock->sock_timeout >= 0.0); nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
}
/* first check if there are bytes ready to be read */ /* first check if there are bytes ready to be read */
PySSL_BEGIN_ALLOW_THREADS PySSL_BEGIN_ALLOW_THREADS
...@@ -1781,7 +1877,7 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args) ...@@ -1781,7 +1877,7 @@ static PyObject *PySSL_SSLread(PySSLSocket *self, PyObject *args)
} }
done: done:
Py_DECREF(sock); Py_XDECREF(sock);
if (!buf_passed) { if (!buf_passed) {
_PyBytes_Resize(&dest, count); _PyBytes_Resize(&dest, count);
return dest; return dest;
...@@ -1792,7 +1888,7 @@ done: ...@@ -1792,7 +1888,7 @@ done:
} }
error: error:
Py_DECREF(sock); Py_XDECREF(sock);
if (!buf_passed) if (!buf_passed)
Py_XDECREF(dest); Py_XDECREF(dest);
else else
...@@ -1809,9 +1905,9 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self) ...@@ -1809,9 +1905,9 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self)
{ {
int err, ssl_err, sockstate, nonblocking; int err, ssl_err, sockstate, nonblocking;
int zeros = 0; int zeros = 0;
PySocketSockObject *sock PySocketSockObject *sock = GET_SOCKET(self);
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
if (sock != NULL) {
/* Guard against closed socket */ /* Guard against closed socket */
if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) { if ((((PyObject*)sock) == Py_None) || (sock->sock_fd < 0)) {
_setSSLError("Underlying socket connection gone", _setSSLError("Underlying socket connection gone",
...@@ -1824,6 +1920,7 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self) ...@@ -1824,6 +1920,7 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self)
nonblocking = (sock->sock_timeout >= 0.0); nonblocking = (sock->sock_timeout >= 0.0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
}
while (1) { while (1) {
PySSL_BEGIN_ALLOW_THREADS PySSL_BEGIN_ALLOW_THREADS
...@@ -1881,15 +1978,17 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self) ...@@ -1881,15 +1978,17 @@ static PyObject *PySSL_SSLshutdown(PySSLSocket *self)
} }
if (err < 0) { if (err < 0) {
Py_DECREF(sock); Py_XDECREF(sock);
return PySSL_SetError(self, err, __FILE__, __LINE__); return PySSL_SetError(self, err, __FILE__, __LINE__);
} }
else if (sock)
/* It's already INCREF'ed */ /* It's already INCREF'ed */
return (PyObject *) sock; return (PyObject *) sock;
else
Py_RETURN_NONE;
error: error:
Py_DECREF(sock); Py_XDECREF(sock);
return NULL; return NULL;
} }
...@@ -1937,6 +2036,12 @@ If the TLS handshake is not yet complete, None is returned"); ...@@ -1937,6 +2036,12 @@ If the TLS handshake is not yet complete, None is returned");
static PyGetSetDef ssl_getsetlist[] = { static PyGetSetDef ssl_getsetlist[] = {
{"context", (getter) PySSL_get_context, {"context", (getter) PySSL_get_context,
(setter) PySSL_set_context, PySSL_set_context_doc}, (setter) PySSL_set_context, PySSL_set_context_doc},
{"server_side", (getter) PySSL_get_server_side, NULL,
PySSL_get_server_side_doc},
{"server_hostname", (getter) PySSL_get_server_hostname, NULL,
PySSL_get_server_hostname_doc},
{"owner", (getter) PySSL_get_owner, (setter) PySSL_set_owner,
PySSL_get_owner_doc},
{NULL}, /* sentinel */ {NULL}, /* sentinel */
}; };
...@@ -2825,13 +2930,48 @@ context_wrap_socket(PySSLContext *self, PyObject *args, PyObject *kwds) ...@@ -2825,13 +2930,48 @@ context_wrap_socket(PySSLContext *self, PyObject *args, PyObject *kwds)
#endif #endif
} }
res = (PyObject *) newPySSLSocket(self, sock, server_side, res = (PyObject *) newPySSLSocket(self, sock, server_side, hostname,
hostname); NULL, NULL);
if (hostname != NULL) if (hostname != NULL)
PyMem_Free(hostname); PyMem_Free(hostname);
return res; return res;
} }
static PyObject *
context_wrap_bio(PySSLContext *self, PyObject *args, PyObject *kwds)
{
char *kwlist[] = {"incoming", "outgoing", "server_side",
"server_hostname", NULL};
int server_side;
char *hostname = NULL;
PyObject *hostname_obj = Py_None, *res;
PySSLMemoryBIO *incoming, *outgoing;
/* server_hostname is either None (or absent), or to be encoded
using the idna encoding. */
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O!i|O:_wrap_bio", kwlist,
&PySSLMemoryBIO_Type, &incoming,
&PySSLMemoryBIO_Type, &outgoing,
&server_side, &hostname_obj))
return NULL;
if (hostname_obj != Py_None) {
#if HAVE_SNI
if (!PyArg_Parse(hostname_obj, "et", "idna", &hostname))
return NULL;
#else
PyErr_SetString(PyExc_ValueError, "server_hostname is not supported "
"by your OpenSSL library");
return NULL;
#endif
}
res = (PyObject *) newPySSLSocket(self, NULL, server_side, hostname,
incoming, outgoing);
PyMem_Free(hostname);
return res;
}
static PyObject * static PyObject *
session_stats(PySSLContext *self, PyObject *unused) session_stats(PySSLContext *self, PyObject *unused)
{ {
...@@ -2938,11 +3078,25 @@ _servername_callback(SSL *s, int *al, void *args) ...@@ -2938,11 +3078,25 @@ _servername_callback(SSL *s, int *al, void *args)
ssl = SSL_get_app_data(s); ssl = SSL_get_app_data(s);
assert(PySSLSocket_Check(ssl)); assert(PySSLSocket_Check(ssl));
/* The servername callback expects a argument that represents the current
* SSL connection and that has a .context attribute that can be changed to
* identify the requested hostname. Since the official API is the Python
* level API we want to pass the callback a Python level object rather than
* a _ssl.SSLSocket instance. If there's an "owner" (typically an
* SSLObject) that will be passed. Otherwise if there's a socket then that
* will be passed. If both do not exist only then the C-level object is
* passed. */
if (ssl->owner)
ssl_socket = PyWeakref_GetObject(ssl->owner);
else if (ssl->Socket)
ssl_socket = PyWeakref_GetObject(ssl->Socket); ssl_socket = PyWeakref_GetObject(ssl->Socket);
else
ssl_socket = (PyObject *) ssl;
Py_INCREF(ssl_socket); Py_INCREF(ssl_socket);
if (ssl_socket == Py_None) { if (ssl_socket == Py_None)
goto error; goto error;
}
if (servername == NULL) { if (servername == NULL) {
result = PyObject_CallFunctionObjArgs(ssl_ctx->set_hostname, ssl_socket, result = PyObject_CallFunctionObjArgs(ssl_ctx->set_hostname, ssl_socket,
...@@ -3171,6 +3325,8 @@ static PyGetSetDef context_getsetlist[] = { ...@@ -3171,6 +3325,8 @@ static PyGetSetDef context_getsetlist[] = {
static struct PyMethodDef context_methods[] = { static struct PyMethodDef context_methods[] = {
{"_wrap_socket", (PyCFunction) context_wrap_socket, {"_wrap_socket", (PyCFunction) context_wrap_socket,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"_wrap_bio", (PyCFunction) context_wrap_bio,
METH_VARARGS | METH_KEYWORDS, NULL},
{"set_ciphers", (PyCFunction) set_ciphers, {"set_ciphers", (PyCFunction) set_ciphers,
METH_VARARGS, NULL}, METH_VARARGS, NULL},
{"_set_npn_protocols", (PyCFunction) _set_npn_protocols, {"_set_npn_protocols", (PyCFunction) _set_npn_protocols,
...@@ -3240,6 +3396,225 @@ static PyTypeObject PySSLContext_Type = { ...@@ -3240,6 +3396,225 @@ static PyTypeObject PySSLContext_Type = {
}; };
/*
* MemoryBIO objects
*/
static PyObject *
memory_bio_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
char *kwlist[] = {NULL};
BIO *bio;
PySSLMemoryBIO *self;
if (!PyArg_ParseTupleAndKeywords(args, kwds, ":MemoryBIO", kwlist))
return NULL;
bio = BIO_new(BIO_s_mem());
if (bio == NULL) {
PyErr_SetString(PySSLErrorObject,
"failed to allocate BIO");
return NULL;
}
/* Since our BIO is non-blocking an empty read() does not indicate EOF,
* just that no data is currently available. The SSL routines should retry
* the read, which we can achieve by calling BIO_set_retry_read(). */
BIO_set_retry_read(bio);
BIO_set_mem_eof_return(bio, -1);
assert(type != NULL && type->tp_alloc != NULL);
self = (PySSLMemoryBIO *) type->tp_alloc(type, 0);
if (self == NULL) {
BIO_free(bio);
return NULL;
}
self->bio = bio;
self->eof_written = 0;
return (PyObject *) self;
}
static void
memory_bio_dealloc(PySSLMemoryBIO *self)
{
BIO_free(self->bio);
Py_TYPE(self)->tp_free(self);
}
static PyObject *
memory_bio_get_pending(PySSLMemoryBIO *self, void *c)
{
return PyLong_FromLong(BIO_ctrl_pending(self->bio));
}
PyDoc_STRVAR(PySSL_memory_bio_pending_doc,
"The number of bytes pending in the memory BIO.");
static PyObject *
memory_bio_get_eof(PySSLMemoryBIO *self, void *c)
{
return PyBool_FromLong((BIO_ctrl_pending(self->bio) == 0)
&& self->eof_written);
}
PyDoc_STRVAR(PySSL_memory_bio_eof_doc,
"Whether the memory BIO is at EOF.");
static PyObject *
memory_bio_read(PySSLMemoryBIO *self, PyObject *args)
{
int len = -1, avail, nbytes;
PyObject *result;
if (!PyArg_ParseTuple(args, "|i:read", &len))
return NULL;
avail = BIO_ctrl_pending(self->bio);
if ((len < 0) || (len > avail))
len = avail;
result = PyBytes_FromStringAndSize(NULL, len);
if ((result == NULL) || (len == 0))
return result;
nbytes = BIO_read(self->bio, PyBytes_AS_STRING(result), len);
/* There should never be any short reads but check anyway. */
if ((nbytes < len) && (_PyBytes_Resize(&result, len) < 0)) {
Py_DECREF(result);
return NULL;
}
return result;
}
PyDoc_STRVAR(PySSL_memory_bio_read_doc,
"read([len]) -> bytes\n\
\n\
Read up to len bytes from the memory BIO.\n\
\n\
If len is not specified, read the entire buffer.\n\
If the return value is an empty bytes instance, this means either\n\
EOF or that no data is available. Use the \"eof\" property to\n\
distinguish between the two.");
static PyObject *
memory_bio_write(PySSLMemoryBIO *self, PyObject *args)
{
Py_buffer buf;
int nbytes;
if (!PyArg_ParseTuple(args, "y*:write", &buf))
return NULL;
if (buf.len > INT_MAX) {
PyErr_Format(PyExc_OverflowError,
"string longer than %d bytes", INT_MAX);
goto error;
}
if (self->eof_written) {
PyErr_SetString(PySSLErrorObject,
"cannot write() after write_eof()");
goto error;
}
nbytes = BIO_write(self->bio, buf.buf, buf.len);
if (nbytes < 0) {
_setSSLError(NULL, 0, __FILE__, __LINE__);
goto error;
}
PyBuffer_Release(&buf);
return PyLong_FromLong(nbytes);
error:
PyBuffer_Release(&buf);
return NULL;
}
PyDoc_STRVAR(PySSL_memory_bio_write_doc,
"write(b) -> len\n\
\n\
Writes the bytes b into the memory BIO. Returns the number\n\
of bytes written.");
static PyObject *
memory_bio_write_eof(PySSLMemoryBIO *self, PyObject *args)
{
self->eof_written = 1;
/* After an EOF is written, a zero return from read() should be a real EOF
* i.e. it should not be retried. Clear the SHOULD_RETRY flag. */
BIO_clear_retry_flags(self->bio);
BIO_set_mem_eof_return(self->bio, 0);
Py_RETURN_NONE;
}
PyDoc_STRVAR(PySSL_memory_bio_write_eof_doc,
"write_eof()\n\
\n\
Write an EOF marker to the memory BIO.\n\
When all data has been read, the \"eof\" property will be True.");
static PyGetSetDef memory_bio_getsetlist[] = {
{"pending", (getter) memory_bio_get_pending, NULL,
PySSL_memory_bio_pending_doc},
{"eof", (getter) memory_bio_get_eof, NULL,
PySSL_memory_bio_eof_doc},
{NULL}, /* sentinel */
};
static struct PyMethodDef memory_bio_methods[] = {
{"read", (PyCFunction) memory_bio_read,
METH_VARARGS, PySSL_memory_bio_read_doc},
{"write", (PyCFunction) memory_bio_write,
METH_VARARGS, PySSL_memory_bio_write_doc},
{"write_eof", (PyCFunction) memory_bio_write_eof,
METH_NOARGS, PySSL_memory_bio_write_eof_doc},
{NULL, NULL} /* sentinel */
};
static PyTypeObject PySSLMemoryBIO_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
"_ssl.MemoryBIO", /*tp_name*/
sizeof(PySSLMemoryBIO), /*tp_basicsize*/
0, /*tp_itemsize*/
(destructor)memory_bio_dealloc, /*tp_dealloc*/
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
0, /*tp_reserved*/
0, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
0, /*tp_as_mapping*/
0, /*tp_hash*/
0, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT, /*tp_flags*/
0, /*tp_doc*/
0, /*tp_traverse*/
0, /*tp_clear*/
0, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
0, /*tp_iter*/
0, /*tp_iternext*/
memory_bio_methods, /*tp_methods*/
0, /*tp_members*/
memory_bio_getsetlist, /*tp_getset*/
0, /*tp_base*/
0, /*tp_dict*/
0, /*tp_descr_get*/
0, /*tp_descr_set*/
0, /*tp_dictoffset*/
0, /*tp_init*/
0, /*tp_alloc*/
memory_bio_new, /*tp_new*/
};
#ifdef HAVE_OPENSSL_RAND #ifdef HAVE_OPENSSL_RAND
...@@ -3927,6 +4302,8 @@ PyInit__ssl(void) ...@@ -3927,6 +4302,8 @@ PyInit__ssl(void)
return NULL; return NULL;
if (PyType_Ready(&PySSLSocket_Type) < 0) if (PyType_Ready(&PySSLSocket_Type) < 0)
return NULL; return NULL;
if (PyType_Ready(&PySSLMemoryBIO_Type) < 0)
return NULL;
m = PyModule_Create(&_sslmodule); m = PyModule_Create(&_sslmodule);
if (m == NULL) if (m == NULL)
...@@ -3990,6 +4367,9 @@ PyInit__ssl(void) ...@@ -3990,6 +4367,9 @@ PyInit__ssl(void)
if (PyDict_SetItemString(d, "_SSLSocket", if (PyDict_SetItemString(d, "_SSLSocket",
(PyObject *)&PySSLSocket_Type) != 0) (PyObject *)&PySSLSocket_Type) != 0)
return NULL; return NULL;
if (PyDict_SetItemString(d, "MemoryBIO",
(PyObject *)&PySSLMemoryBIO_Type) != 0)
return NULL;
PyModule_AddIntConstant(m, "SSL_ERROR_ZERO_RETURN", PyModule_AddIntConstant(m, "SSL_ERROR_ZERO_RETURN",
PY_SSL_ERROR_ZERO_RETURN); PY_SSL_ERROR_ZERO_RETURN);
PyModule_AddIntConstant(m, "SSL_ERROR_WANT_READ", PyModule_AddIntConstant(m, "SSL_ERROR_WANT_READ",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment