ssl.py 14.7 KB
Newer Older
1 2 3
# Wrapper module for _ssl, providing some additional facilities
# implemented in Python.  Written by Bill Janssen.

4
"""This module provides some more Pythonic support for SSL.
5 6 7

Object types:

8
  SSLSocket -- subtype of socket.socket which does SSL over the socket
9 10 11

Exceptions:

12
  SSLError -- exception raised for I/O errors
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

Functions:

  cert_time_to_seconds -- convert time string used for certificate
                          notBefore and notAfter functions to integer
                          seconds past the Epoch (the time values
                          returned from time.time())

  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
                          by the server running on HOST at port PORT.  No
                          validation of the certificate is performed.

Integer constants:

SSL_ERROR_ZERO_RETURN
SSL_ERROR_WANT_READ
SSL_ERROR_WANT_WRITE
SSL_ERROR_WANT_X509_LOOKUP
SSL_ERROR_SYSCALL
SSL_ERROR_SSL
SSL_ERROR_WANT_CONNECT

SSL_ERROR_EOF
SSL_ERROR_INVALID_ERROR_CODE

The following group define certificate requirements that one side is
allowing/requiring from the other side:

CERT_NONE - no certificates from the other side are required (or will
            be looked at if provided)
CERT_OPTIONAL - certificates are not required, but if provided will be
                validated, and if validation fails, the connection will
                also fail
CERT_REQUIRED - certificates are required, and will be validated, and
                if validation fails, the connection will also fail

The following constants identify various SSL protocol variants:

PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLSv1
"""

57
import textwrap
58 59

import _ssl             # if we can't import it, let the error propagate
60

61
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
62
from _ssl import SSLError
63
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
64 65
from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,
                  PROTOCOL_TLSv1)
66
from _ssl import RAND_status, RAND_egd, RAND_add
67 68 69 70 71 72 73 74 75 76 77
from _ssl import (
    SSL_ERROR_ZERO_RETURN,
    SSL_ERROR_WANT_READ,
    SSL_ERROR_WANT_WRITE,
    SSL_ERROR_WANT_X509_LOOKUP,
    SSL_ERROR_SYSCALL,
    SSL_ERROR_SSL,
    SSL_ERROR_WANT_CONNECT,
    SSL_ERROR_EOF,
    SSL_ERROR_INVALID_ERROR_CODE,
    )
78 79

from socket import getnameinfo as _getnameinfo
Bill Janssen's avatar
Bill Janssen committed
80
from socket import error as socket_error
81
from socket import dup as _dup
82
from socket import socket, AF_INET, SOCK_STREAM
83
import base64        # for DER-to-PEM translation
84
import traceback
85

86
class SSLSocket(socket):
87

88 89 90 91
    """This class implements a subtype of socket.socket that wraps
    the underlying OS socket in an SSL context when necessary, and
    provides read and write methods over that channel."""

Bill Janssen's avatar
Bill Janssen committed
92
    def __init__(self, sock=None, keyfile=None, certfile=None,
93
                 server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
94 95 96
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                 do_handshake_on_connect=True,
                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
97
                 suppress_ragged_eofs=True, ciphers=None):
Bill Janssen's avatar
Bill Janssen committed
98 99

        if sock is not None:
100 101 102 103 104
            socket.__init__(self,
                            family=sock.family,
                            type=sock.type,
                            proto=sock.proto,
                            fileno=_dup(sock.fileno()))
105
            sock.close()
Bill Janssen's avatar
Bill Janssen committed
106 107 108 109 110 111 112
        elif fileno is not None:
            socket.__init__(self, fileno=fileno)
        else:
            socket.__init__(self, family=family, type=type, proto=proto)

        self._closed = False

113 114
        if certfile and not keyfile:
            keyfile = certfile
115 116 117
        # see if it's connected
        try:
            socket.getpeername(self)
Benjamin Peterson's avatar
Benjamin Peterson committed
118
        except socket_error:
119 120
            # no, no connection yet
            self._sslobj = None
121
        else:
122
            # yes, create the SSL object
Bill Janssen's avatar
Bill Janssen committed
123 124 125
            try:
                self._sslobj = _ssl.sslwrap(self, server_side,
                                            keyfile, certfile,
126 127
                                            cert_reqs, ssl_version, ca_certs,
                                            ciphers)
Bill Janssen's avatar
Bill Janssen committed
128
                if do_handshake_on_connect:
129 130 131 132
                    timeout = self.gettimeout()
                    if timeout == 0.0:
                        # non-blocking
                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
Bill Janssen's avatar
Bill Janssen committed
133
                    self.do_handshake()
134

Bill Janssen's avatar
Bill Janssen committed
135 136 137 138
            except socket_error as x:
                self.close()
                raise x

139 140 141 142 143
        self.keyfile = keyfile
        self.certfile = certfile
        self.cert_reqs = cert_reqs
        self.ssl_version = ssl_version
        self.ca_certs = ca_certs
144
        self.ciphers = ciphers
Bill Janssen's avatar
Bill Janssen committed
145 146 147
        self.do_handshake_on_connect = do_handshake_on_connect
        self.suppress_ragged_eofs = suppress_ragged_eofs

148 149 150 151
    def dup(self):
        raise NotImplemented("Can't dup() %s instances" %
                             self.__class__.__name__)

Bill Janssen's avatar
Bill Janssen committed
152 153 154
    def _checkClosed(self, msg=None):
        # raise an exception here if you wish to check for spurious closes
        pass
155

156
    def read(self, len=0, buffer=None):
157 158 159
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""

Bill Janssen's avatar
Bill Janssen committed
160 161 162
        self._checkClosed()
        try:
            if buffer:
163
                v = self._sslobj.read(buffer, len)
Bill Janssen's avatar
Bill Janssen committed
164
            else:
165 166
                v = self._sslobj.read(len or 1024)
            return v
Bill Janssen's avatar
Bill Janssen committed
167 168
        except SSLError as x:
            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
169 170 171 172
                if buffer:
                    return 0
                else:
                    return b''
Bill Janssen's avatar
Bill Janssen committed
173 174
            else:
                raise
175 176

    def write(self, data):
177 178 179
        """Write DATA to the underlying SSL channel.  Returns
        number of bytes of DATA actually transmitted."""

Bill Janssen's avatar
Bill Janssen committed
180
        self._checkClosed()
181 182
        return self._sslobj.write(data)

183
    def getpeercert(self, binary_form=False):
184 185 186 187 188
        """Returns a formatted version of the data in the
        certificate provided by the other end of the SSL channel.
        Return None if no certificate was provided, {} if a
        certificate was provided, but not validated."""

Bill Janssen's avatar
Bill Janssen committed
189
        self._checkClosed()
190 191
        return self._sslobj.peer_certificate(binary_form)

192
    def cipher(self):
Bill Janssen's avatar
Bill Janssen committed
193
        self._checkClosed()
194 195 196 197
        if not self._sslobj:
            return None
        else:
            return self._sslobj.cipher()
198

199
    def send(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
200
        self._checkClosed()
201 202 203 204 205
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to send() on %s" %
                    self.__class__)
Bill Janssen's avatar
Bill Janssen committed
206 207 208 209 210 211 212 213 214 215 216 217
            while True:
                try:
                    v = self._sslobj.write(data)
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        return 0
                    elif x.args[0] == SSL_ERROR_WANT_WRITE:
                        return 0
                    else:
                        raise
                else:
                    return v
218 219
        else:
            return socket.send(self, data, flags)
220

221
    def sendto(self, data, addr, flags=0):
Bill Janssen's avatar
Bill Janssen committed
222
        self._checkClosed()
223
        if self._sslobj:
224
            raise ValueError("sendto not allowed on instances of %s" %
225 226
                             self.__class__)
        else:
227
            return socket.sendto(self, data, addr, flags)
228

229
    def sendall(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
230
        self._checkClosed()
231
        if self._sslobj:
Bill Janssen's avatar
Bill Janssen committed
232 233 234 235 236 237
            amount = len(data)
            count = 0
            while (count < amount):
                v = self.send(data[count:])
                count += v
            return amount
238 239
        else:
            return socket.sendall(self, data, flags)
240

241
    def recv(self, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
242
        self._checkClosed()
243 244 245
        if self._sslobj:
            if flags != 0:
                raise ValueError(
246 247 248
                    "non-zero flags not allowed in calls to recv() on %s" %
                    self.__class__)
            return self.read(buflen)
249 250
        else:
            return socket.recv(self, buflen, flags)
251

252
    def recv_into(self, buffer, nbytes=None, flags=0):
Bill Janssen's avatar
Bill Janssen committed
253 254 255 256 257 258 259 260
        self._checkClosed()
        if buffer and (nbytes is None):
            nbytes = len(buffer)
        elif nbytes is None:
            nbytes = 1024
        if self._sslobj:
            if flags != 0:
                raise ValueError(
261 262
                  "non-zero flags not allowed in calls to recv_into() on %s" %
                  self.__class__)
263
            return self.read(nbytes, buffer)
Bill Janssen's avatar
Bill Janssen committed
264 265 266
        else:
            return socket.recv_into(self, buffer, nbytes, flags)

267
    def recvfrom(self, addr, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
268
        self._checkClosed()
269
        if self._sslobj:
270
            raise ValueError("recvfrom not allowed on instances of %s" %
271 272
                             self.__class__)
        else:
273
            return socket.recvfrom(self, addr, buflen, flags)
274

275 276 277 278 279 280 281 282
    def recvfrom_into(self, buffer, nbytes=None, flags=0):
        self._checkClosed()
        if self._sslobj:
            raise ValueError("recvfrom_into not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.recvfrom_into(self, buffer, nbytes, flags)

283
    def pending(self):
Bill Janssen's avatar
Bill Janssen committed
284 285 286 287 288 289
        self._checkClosed()
        if self._sslobj:
            return self._sslobj.pending()
        else:
            return 0

290
    def shutdown(self, how):
Bill Janssen's avatar
Bill Janssen committed
291
        self._checkClosed()
292
        self._sslobj = None
293
        socket.shutdown(self, how)
294

295
    def unwrap(self):
296 297 298 299 300 301 302
        if self._sslobj:
            s = self._sslobj.shutdown()
            self._sslobj = None
            return s
        else:
            raise ValueError("No SSL wrapper around " + str(self))

303
    def _real_close(self):
304
        self._sslobj = None
Bill Janssen's avatar
Bill Janssen committed
305
        # self._closed = True
306
        socket._real_close(self)
Bill Janssen's avatar
Bill Janssen committed
307

308
    def do_handshake(self, block=False):
Bill Janssen's avatar
Bill Janssen committed
309 310
        """Perform a TLS/SSL handshake."""

311
        timeout = self.gettimeout()
Bill Janssen's avatar
Bill Janssen committed
312
        try:
313 314
            if timeout == 0.0 and block:
                self.settimeout(None)
Bill Janssen's avatar
Bill Janssen committed
315
            self._sslobj.do_handshake()
316 317
        finally:
            self.settimeout(timeout)
318 319

    def connect(self, addr):
320 321 322
        """Connects to remote ADDR, and then wraps the connection in
        an SSL channel."""

323 324
        # Here we assume that the socket is client-side, and not
        # connected at the time of the call.  We connect it, then wrap it.
325
        if self._sslobj:
326
            raise ValueError("attempt to connect already-connected SSLSocket!")
327
        socket.connect(self, addr)
Bill Janssen's avatar
Bill Janssen committed
328
        self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
329
                                    self.cert_reqs, self.ssl_version,
330
                                    self.ca_certs, self.ciphers)
331 332 333 334 335 336
        try:
            if self.do_handshake_on_connect:
                self.do_handshake()
        except:
            self._sslobj = None
            raise
337 338

    def accept(self):
339 340 341 342 343
        """Accepts a new connection from a remote client, and returns
        a tuple containing that new connection wrapped with a server-side
        SSL channel, and the address of the remote client."""

        newsock, addr = socket.accept(self)
Bill Janssen's avatar
Bill Janssen committed
344 345 346
        return (SSLSocket(sock=newsock,
                          keyfile=self.keyfile, certfile=self.certfile,
                          server_side=True,
347 348
                          cert_reqs=self.cert_reqs,
                          ssl_version=self.ssl_version,
Bill Janssen's avatar
Bill Janssen committed
349
                          ca_certs=self.ca_certs,
350
                          ciphers=self.ciphers,
351 352
                          do_handshake_on_connect=
                              self.do_handshake_on_connect),
Bill Janssen's avatar
Bill Janssen committed
353
                addr)
354

355
    def __del__(self):
356
        # sys.stderr.write("__del__ on %s\n" % repr(self))
357 358
        self._real_close()

359

360 361
def wrap_socket(sock, keyfile=None, certfile=None,
                server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
362
                ssl_version=PROTOCOL_SSLv23, ca_certs=None,
363
                do_handshake_on_connect=True,
364
                suppress_ragged_eofs=True, ciphers=None):
365

Bill Janssen's avatar
Bill Janssen committed
366
    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
367
                     server_side=server_side, cert_reqs=cert_reqs,
Bill Janssen's avatar
Bill Janssen committed
368
                     ssl_version=ssl_version, ca_certs=ca_certs,
369
                     do_handshake_on_connect=do_handshake_on_connect,
370 371
                     suppress_ragged_eofs=suppress_ragged_eofs,
                     ciphers=ciphers)
372

373 374 375
# some utility functions

def cert_time_to_seconds(cert_time):
376 377 378 379
    """Takes a date-time string in standard ASN1_print form
    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
    a Python time value in seconds past the epoch."""

380 381 382
    import time
    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))

383 384 385 386 387 388 389
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"

def DER_cert_to_PEM_cert(der_cert_bytes):
    """Takes a certificate in binary DER format and returns the
    PEM version of it as a string."""

Bill Janssen's avatar
Bill Janssen committed
390 391 392 393
    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
    return (PEM_HEADER + '\n' +
            textwrap.fill(f, 64) + '\n' +
            PEM_FOOTER + '\n')
394 395 396 397 398 399 400 401 402 403 404 405

def PEM_cert_to_DER_cert(pem_cert_string):
    """Takes a certificate in ASCII PEM format and returns the
    DER-encoded version of it as a byte sequence"""

    if not pem_cert_string.startswith(PEM_HEADER):
        raise ValueError("Invalid PEM encoding; must start with %s"
                         % PEM_HEADER)
    if not pem_cert_string.strip().endswith(PEM_FOOTER):
        raise ValueError("Invalid PEM encoding; must end with %s"
                         % PEM_FOOTER)
    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
Georg Brandl's avatar
Georg Brandl committed
406
    return base64.decodebytes(d.encode('ASCII', 'strict'))
407

408
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
    """Retrieve the certificate from the server at the specified address,
    and return it as a PEM-encoded string.
    If 'ca_certs' is specified, validate the server cert against it.
    If 'ssl_version' is specified, use it in the connection attempt."""

    host, port = addr
    if (ca_certs is not None):
        cert_reqs = CERT_REQUIRED
    else:
        cert_reqs = CERT_NONE
    s = wrap_socket(socket(), ssl_version=ssl_version,
                    cert_reqs=cert_reqs, ca_certs=ca_certs)
    s.connect(addr)
    dercert = s.getpeercert(True)
    s.close()
    return DER_cert_to_PEM_cert(dercert)

426
def get_protocol_name(protocol_code):
427 428 429 430 431 432 433 434 435 436
    if protocol_code == PROTOCOL_TLSv1:
        return "TLSv1"
    elif protocol_code == PROTOCOL_SSLv23:
        return "SSLv23"
    elif protocol_code == PROTOCOL_SSLv2:
        return "SSLv2"
    elif protocol_code == PROTOCOL_SSLv3:
        return "SSLv3"
    else:
        return "<unknown>"