ssl.py 16.3 KB
Newer Older
1 2 3
# Wrapper module for _ssl, providing some additional facilities
# implemented in Python.  Written by Bill Janssen.

4
"""This module provides some more Pythonic support for SSL.
5 6 7

Object types:

8
  SSLSocket -- subtype of socket.socket which does SSL over the socket
9 10 11

Exceptions:

12
  SSLError -- exception raised for I/O errors
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

Functions:

  cert_time_to_seconds -- convert time string used for certificate
                          notBefore and notAfter functions to integer
                          seconds past the Epoch (the time values
                          returned from time.time())

  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
                          by the server running on HOST at port PORT.  No
                          validation of the certificate is performed.

Integer constants:

SSL_ERROR_ZERO_RETURN
SSL_ERROR_WANT_READ
SSL_ERROR_WANT_WRITE
SSL_ERROR_WANT_X509_LOOKUP
SSL_ERROR_SYSCALL
SSL_ERROR_SSL
SSL_ERROR_WANT_CONNECT

SSL_ERROR_EOF
SSL_ERROR_INVALID_ERROR_CODE

The following group define certificate requirements that one side is
allowing/requiring from the other side:

CERT_NONE - no certificates from the other side are required (or will
            be looked at if provided)
CERT_OPTIONAL - certificates are not required, but if provided will be
                validated, and if validation fails, the connection will
                also fail
CERT_REQUIRED - certificates are required, and will be validated, and
                if validation fails, the connection will also fail

The following constants identify various SSL protocol variants:

PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLSv1
"""

57
import textwrap
58 59

import _ssl             # if we can't import it, let the error propagate
60

61
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
62
from _ssl import _SSLContext, SSLError
63
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
64 65
from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,
                  PROTOCOL_TLSv1)
66
from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1
67
from _ssl import RAND_status, RAND_egd, RAND_add
68 69 70 71 72 73 74 75 76 77 78
from _ssl import (
    SSL_ERROR_ZERO_RETURN,
    SSL_ERROR_WANT_READ,
    SSL_ERROR_WANT_WRITE,
    SSL_ERROR_WANT_X509_LOOKUP,
    SSL_ERROR_SYSCALL,
    SSL_ERROR_SSL,
    SSL_ERROR_WANT_CONNECT,
    SSL_ERROR_EOF,
    SSL_ERROR_INVALID_ERROR_CODE,
    )
79 80

from socket import getnameinfo as _getnameinfo
Bill Janssen's avatar
Bill Janssen committed
81
from socket import error as socket_error
82
from socket import socket, AF_INET, SOCK_STREAM
83
import base64        # for DER-to-PEM translation
84
import traceback
85
import errno
86 87


88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
class SSLContext(_SSLContext):
    """An SSLContext holds various SSL-related configuration options and
    data, such as certificates and possibly a private key."""

    __slots__ = ('protocol',)

    def __new__(cls, protocol, *args, **kwargs):
        return _SSLContext.__new__(cls, protocol)

    def __init__(self, protocol):
        self.protocol = protocol

    def wrap_socket(self, sock, server_side=False,
                    do_handshake_on_connect=True,
                    suppress_ragged_eofs=True):
        return SSLSocket(sock=sock, server_side=server_side,
                         do_handshake_on_connect=do_handshake_on_connect,
                         suppress_ragged_eofs=suppress_ragged_eofs,
                         _context=self)


class SSLSocket(socket):
110 111 112 113
    """This class implements a subtype of socket.socket that wraps
    the underlying OS socket in an SSL context when necessary, and
    provides read and write methods over that channel."""

Bill Janssen's avatar
Bill Janssen committed
114
    def __init__(self, sock=None, keyfile=None, certfile=None,
115
                 server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
116 117 118
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                 do_handshake_on_connect=True,
                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
119 120
                 suppress_ragged_eofs=True, ciphers=None,
                 _context=None):
Bill Janssen's avatar
Bill Janssen committed
121

122 123 124
        if _context:
            self.context = _context
        else:
125 126 127
            if server_side and not certfile:
                raise ValueError("certfile must be specified for server-side "
                                 "operations")
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
            if certfile and not keyfile:
                keyfile = certfile
            self.context = SSLContext(ssl_version)
            self.context.verify_mode = cert_reqs
            if ca_certs:
                self.context.load_verify_locations(ca_certs)
            if certfile:
                self.context.load_cert_chain(certfile, keyfile)
            if ciphers:
                self.context.set_ciphers(ciphers)
            self.keyfile = keyfile
            self.certfile = certfile
            self.cert_reqs = cert_reqs
            self.ssl_version = ssl_version
            self.ca_certs = ca_certs
            self.ciphers = ciphers
144
        self.server_side = server_side
145 146
        self.do_handshake_on_connect = do_handshake_on_connect
        self.suppress_ragged_eofs = suppress_ragged_eofs
147
        connected = False
Bill Janssen's avatar
Bill Janssen committed
148
        if sock is not None:
149 150 151 152
            socket.__init__(self,
                            family=sock.family,
                            type=sock.type,
                            proto=sock.proto,
153
                            fileno=sock.fileno())
154
            self.settimeout(sock.gettimeout())
155 156 157 158 159 160 161 162
            # see if it's connected
            try:
                sock.getpeername()
            except socket_error as e:
                if e.errno != errno.ENOTCONN:
                    raise
            else:
                connected = True
163
            sock.detach()
Bill Janssen's avatar
Bill Janssen committed
164 165 166 167 168
        elif fileno is not None:
            socket.__init__(self, fileno=fileno)
        else:
            socket.__init__(self, family=family, type=type, proto=proto)

169 170 171 172
        self._closed = False
        self._sslobj = None
        if connected:
            # create the SSL object
Bill Janssen's avatar
Bill Janssen committed
173
            try:
174
                self._sslobj = self.context._wrap_socket(self, server_side)
Bill Janssen's avatar
Bill Janssen committed
175
                if do_handshake_on_connect:
176 177 178 179
                    timeout = self.gettimeout()
                    if timeout == 0.0:
                        # non-blocking
                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
Bill Janssen's avatar
Bill Janssen committed
180
                    self.do_handshake()
181

Bill Janssen's avatar
Bill Janssen committed
182 183 184 185
            except socket_error as x:
                self.close()
                raise x

186 187 188 189
    def dup(self):
        raise NotImplemented("Can't dup() %s instances" %
                             self.__class__.__name__)

Bill Janssen's avatar
Bill Janssen committed
190 191 192
    def _checkClosed(self, msg=None):
        # raise an exception here if you wish to check for spurious closes
        pass
193

194
    def read(self, len=0, buffer=None):
195 196 197
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""

Bill Janssen's avatar
Bill Janssen committed
198 199 200
        self._checkClosed()
        try:
            if buffer:
201
                v = self._sslobj.read(buffer, len)
Bill Janssen's avatar
Bill Janssen committed
202
            else:
203 204
                v = self._sslobj.read(len or 1024)
            return v
Bill Janssen's avatar
Bill Janssen committed
205 206
        except SSLError as x:
            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
207 208 209 210
                if buffer:
                    return 0
                else:
                    return b''
Bill Janssen's avatar
Bill Janssen committed
211 212
            else:
                raise
213 214

    def write(self, data):
215 216 217
        """Write DATA to the underlying SSL channel.  Returns
        number of bytes of DATA actually transmitted."""

Bill Janssen's avatar
Bill Janssen committed
218
        self._checkClosed()
219 220
        return self._sslobj.write(data)

221
    def getpeercert(self, binary_form=False):
222 223 224 225 226
        """Returns a formatted version of the data in the
        certificate provided by the other end of the SSL channel.
        Return None if no certificate was provided, {} if a
        certificate was provided, but not validated."""

Bill Janssen's avatar
Bill Janssen committed
227
        self._checkClosed()
228 229
        return self._sslobj.peer_certificate(binary_form)

230
    def cipher(self):
Bill Janssen's avatar
Bill Janssen committed
231
        self._checkClosed()
232 233 234 235
        if not self._sslobj:
            return None
        else:
            return self._sslobj.cipher()
236

237
    def send(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
238
        self._checkClosed()
239 240 241 242 243
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to send() on %s" %
                    self.__class__)
Bill Janssen's avatar
Bill Janssen committed
244 245 246 247 248 249 250 251 252 253 254 255
            while True:
                try:
                    v = self._sslobj.write(data)
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        return 0
                    elif x.args[0] == SSL_ERROR_WANT_WRITE:
                        return 0
                    else:
                        raise
                else:
                    return v
256 257
        else:
            return socket.send(self, data, flags)
258

259
    def sendto(self, data, addr, flags=0):
Bill Janssen's avatar
Bill Janssen committed
260
        self._checkClosed()
261
        if self._sslobj:
262
            raise ValueError("sendto not allowed on instances of %s" %
263 264
                             self.__class__)
        else:
265
            return socket.sendto(self, data, addr, flags)
266

267
    def sendall(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
268
        self._checkClosed()
269
        if self._sslobj:
270 271 272 273
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to sendall() on %s" %
                    self.__class__)
Bill Janssen's avatar
Bill Janssen committed
274 275 276 277 278 279
            amount = len(data)
            count = 0
            while (count < amount):
                v = self.send(data[count:])
                count += v
            return amount
280 281
        else:
            return socket.sendall(self, data, flags)
282

283
    def recv(self, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
284
        self._checkClosed()
285 286 287
        if self._sslobj:
            if flags != 0:
                raise ValueError(
288 289 290
                    "non-zero flags not allowed in calls to recv() on %s" %
                    self.__class__)
            return self.read(buflen)
291 292
        else:
            return socket.recv(self, buflen, flags)
293

294
    def recv_into(self, buffer, nbytes=None, flags=0):
Bill Janssen's avatar
Bill Janssen committed
295 296 297 298 299 300 301 302
        self._checkClosed()
        if buffer and (nbytes is None):
            nbytes = len(buffer)
        elif nbytes is None:
            nbytes = 1024
        if self._sslobj:
            if flags != 0:
                raise ValueError(
303 304
                  "non-zero flags not allowed in calls to recv_into() on %s" %
                  self.__class__)
305
            return self.read(nbytes, buffer)
Bill Janssen's avatar
Bill Janssen committed
306 307 308
        else:
            return socket.recv_into(self, buffer, nbytes, flags)

309
    def recvfrom(self, addr, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
310
        self._checkClosed()
311
        if self._sslobj:
312
            raise ValueError("recvfrom not allowed on instances of %s" %
313 314
                             self.__class__)
        else:
315
            return socket.recvfrom(self, addr, buflen, flags)
316

317 318 319 320 321 322 323 324
    def recvfrom_into(self, buffer, nbytes=None, flags=0):
        self._checkClosed()
        if self._sslobj:
            raise ValueError("recvfrom_into not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.recvfrom_into(self, buffer, nbytes, flags)

325
    def pending(self):
Bill Janssen's avatar
Bill Janssen committed
326 327 328 329 330 331
        self._checkClosed()
        if self._sslobj:
            return self._sslobj.pending()
        else:
            return 0

332
    def shutdown(self, how):
Bill Janssen's avatar
Bill Janssen committed
333
        self._checkClosed()
334
        self._sslobj = None
335
        socket.shutdown(self, how)
336

337
    def unwrap(self):
338 339 340 341 342 343 344
        if self._sslobj:
            s = self._sslobj.shutdown()
            self._sslobj = None
            return s
        else:
            raise ValueError("No SSL wrapper around " + str(self))

345
    def _real_close(self):
346
        self._sslobj = None
Bill Janssen's avatar
Bill Janssen committed
347
        # self._closed = True
348
        socket._real_close(self)
Bill Janssen's avatar
Bill Janssen committed
349

350
    def do_handshake(self, block=False):
Bill Janssen's avatar
Bill Janssen committed
351 352
        """Perform a TLS/SSL handshake."""

353
        timeout = self.gettimeout()
Bill Janssen's avatar
Bill Janssen committed
354
        try:
355 356
            if timeout == 0.0 and block:
                self.settimeout(None)
Bill Janssen's avatar
Bill Janssen committed
357
            self._sslobj.do_handshake()
358 359
        finally:
            self.settimeout(timeout)
360 361

    def connect(self, addr):
362 363
        """Connects to remote ADDR, and then wraps the connection in
        an SSL channel."""
364 365
        if self.server_side:
            raise ValueError("can't connect in server-side mode")
366 367
        # Here we assume that the socket is client-side, and not
        # connected at the time of the call.  We connect it, then wrap it.
368
        if self._sslobj:
369
            raise ValueError("attempt to connect already-connected SSLSocket!")
370
        socket.connect(self, addr)
371
        self._sslobj = self.context._wrap_socket(self, False)
372 373 374 375 376 377
        try:
            if self.do_handshake_on_connect:
                self.do_handshake()
        except:
            self._sslobj = None
            raise
378 379

    def accept(self):
380 381 382 383 384
        """Accepts a new connection from a remote client, and returns
        a tuple containing that new connection wrapped with a server-side
        SSL channel, and the address of the remote client."""

        newsock, addr = socket.accept(self)
Bill Janssen's avatar
Bill Janssen committed
385 386 387
        return (SSLSocket(sock=newsock,
                          keyfile=self.keyfile, certfile=self.certfile,
                          server_side=True,
388 389
                          cert_reqs=self.cert_reqs,
                          ssl_version=self.ssl_version,
Bill Janssen's avatar
Bill Janssen committed
390
                          ca_certs=self.ca_certs,
391
                          ciphers=self.ciphers,
392 393
                          do_handshake_on_connect=
                              self.do_handshake_on_connect),
Bill Janssen's avatar
Bill Janssen committed
394
                addr)
395

396
    def __del__(self):
397
        # sys.stderr.write("__del__ on %s\n" % repr(self))
398 399
        self._real_close()

400

401 402
def wrap_socket(sock, keyfile=None, certfile=None,
                server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
403
                ssl_version=PROTOCOL_SSLv23, ca_certs=None,
404
                do_handshake_on_connect=True,
405
                suppress_ragged_eofs=True, ciphers=None):
406

Bill Janssen's avatar
Bill Janssen committed
407
    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
408
                     server_side=server_side, cert_reqs=cert_reqs,
Bill Janssen's avatar
Bill Janssen committed
409
                     ssl_version=ssl_version, ca_certs=ca_certs,
410
                     do_handshake_on_connect=do_handshake_on_connect,
411 412
                     suppress_ragged_eofs=suppress_ragged_eofs,
                     ciphers=ciphers)
413

414 415 416
# some utility functions

def cert_time_to_seconds(cert_time):
417 418 419 420
    """Takes a date-time string in standard ASN1_print form
    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
    a Python time value in seconds past the epoch."""

421 422 423
    import time
    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))

424 425 426 427 428 429 430
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"

def DER_cert_to_PEM_cert(der_cert_bytes):
    """Takes a certificate in binary DER format and returns the
    PEM version of it as a string."""

Bill Janssen's avatar
Bill Janssen committed
431 432 433 434
    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
    return (PEM_HEADER + '\n' +
            textwrap.fill(f, 64) + '\n' +
            PEM_FOOTER + '\n')
435 436 437 438 439 440 441 442 443 444 445 446

def PEM_cert_to_DER_cert(pem_cert_string):
    """Takes a certificate in ASCII PEM format and returns the
    DER-encoded version of it as a byte sequence"""

    if not pem_cert_string.startswith(PEM_HEADER):
        raise ValueError("Invalid PEM encoding; must start with %s"
                         % PEM_HEADER)
    if not pem_cert_string.strip().endswith(PEM_FOOTER):
        raise ValueError("Invalid PEM encoding; must end with %s"
                         % PEM_FOOTER)
    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
Georg Brandl's avatar
Georg Brandl committed
447
    return base64.decodebytes(d.encode('ASCII', 'strict'))
448

449
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
    """Retrieve the certificate from the server at the specified address,
    and return it as a PEM-encoded string.
    If 'ca_certs' is specified, validate the server cert against it.
    If 'ssl_version' is specified, use it in the connection attempt."""

    host, port = addr
    if (ca_certs is not None):
        cert_reqs = CERT_REQUIRED
    else:
        cert_reqs = CERT_NONE
    s = wrap_socket(socket(), ssl_version=ssl_version,
                    cert_reqs=cert_reqs, ca_certs=ca_certs)
    s.connect(addr)
    dercert = s.getpeercert(True)
    s.close()
    return DER_cert_to_PEM_cert(dercert)

467
def get_protocol_name(protocol_code):
468 469 470 471 472 473 474 475 476 477
    if protocol_code == PROTOCOL_TLSv1:
        return "TLSv1"
    elif protocol_code == PROTOCOL_SSLv23:
        return "SSLv23"
    elif protocol_code == PROTOCOL_SSLv2:
        return "SSLv2"
    elif protocol_code == PROTOCOL_SSLv3:
        return "SSLv3"
    else:
        return "<unknown>"