ssl.py 19.5 KB
Newer Older
1 2 3
# Wrapper module for _ssl, providing some additional facilities
# implemented in Python.  Written by Bill Janssen.

4
"""This module provides some more Pythonic support for SSL.
5 6 7

Object types:

8
  SSLSocket -- subtype of socket.socket which does SSL over the socket
9 10 11

Exceptions:

12
  SSLError -- exception raised for I/O errors
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

Functions:

  cert_time_to_seconds -- convert time string used for certificate
                          notBefore and notAfter functions to integer
                          seconds past the Epoch (the time values
                          returned from time.time())

  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
                          by the server running on HOST at port PORT.  No
                          validation of the certificate is performed.

Integer constants:

SSL_ERROR_ZERO_RETURN
SSL_ERROR_WANT_READ
SSL_ERROR_WANT_WRITE
SSL_ERROR_WANT_X509_LOOKUP
SSL_ERROR_SYSCALL
SSL_ERROR_SSL
SSL_ERROR_WANT_CONNECT

SSL_ERROR_EOF
SSL_ERROR_INVALID_ERROR_CODE

The following group define certificate requirements that one side is
allowing/requiring from the other side:

CERT_NONE - no certificates from the other side are required (or will
            be looked at if provided)
CERT_OPTIONAL - certificates are not required, but if provided will be
                validated, and if validation fails, the connection will
                also fail
CERT_REQUIRED - certificates are required, and will be validated, and
                if validation fails, the connection will also fail

The following constants identify various SSL protocol variants:

PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLSv1
"""

57
import textwrap
58
import re
59 60

import _ssl             # if we can't import it, let the error propagate
61

62
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
63
from _ssl import _SSLContext, SSLError
64
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
65
from _ssl import OP_ALL, OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_TLSv1
66
from _ssl import RAND_status, RAND_egd, RAND_add
67 68 69 70 71 72 73 74 75 76 77
from _ssl import (
    SSL_ERROR_ZERO_RETURN,
    SSL_ERROR_WANT_READ,
    SSL_ERROR_WANT_WRITE,
    SSL_ERROR_WANT_X509_LOOKUP,
    SSL_ERROR_SYSCALL,
    SSL_ERROR_SSL,
    SSL_ERROR_WANT_CONNECT,
    SSL_ERROR_EOF,
    SSL_ERROR_INVALID_ERROR_CODE,
    )
78
from _ssl import HAS_SNI
79
from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
80 81
from _ssl import _OPENSSL_API_VERSION

82 83 84 85 86 87 88 89 90 91 92
_PROTOCOL_NAMES = {
    PROTOCOL_TLSv1: "TLSv1",
    PROTOCOL_SSLv23: "SSLv23",
    PROTOCOL_SSLv3: "SSLv3",
}
try:
    from _ssl import PROTOCOL_SSLv2
except ImportError:
    pass
else:
    _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2"
93 94

from socket import getnameinfo as _getnameinfo
Bill Janssen's avatar
Bill Janssen committed
95
from socket import error as socket_error
96
from socket import socket, AF_INET, SOCK_STREAM
97
import base64        # for DER-to-PEM translation
98
import traceback
99
import errno
100 101


102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
class CertificateError(ValueError):
    pass


def _dnsname_to_pat(dn):
    pats = []
    for frag in dn.split(r'.'):
        if frag == '*':
            # When '*' is a fragment by itself, it matches a non-empty dotless
            # fragment.
            pats.append('[^.]+')
        else:
            # Otherwise, '*' matches any dotless fragment.
            frag = re.escape(frag)
            pats.append(frag.replace(r'\*', '[^.]*'))
    return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)


def match_hostname(cert, hostname):
    """Verify that *cert* (in decoded format as returned by
    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 rules
    are mostly followed, but IP addresses are not accepted for *hostname*.

    CertificateError is raised on failure. On success, the function
    returns nothing.
    """
    if not cert:
        raise ValueError("empty or no certificate")
    dnsnames = []
    san = cert.get('subjectAltName', ())
    for key, value in san:
        if key == 'DNS':
            if _dnsname_to_pat(value).match(hostname):
                return
            dnsnames.append(value)
137 138 139
    if not dnsnames:
        # The subject is only checked when there is no dNSName entry
        # in subjectAltName
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        for sub in cert.get('subject', ()):
            for key, value in sub:
                # XXX according to RFC 2818, the most specific Common Name
                # must be used.
                if key == 'commonName':
                    if _dnsname_to_pat(value).match(hostname):
                        return
                    dnsnames.append(value)
    if len(dnsnames) > 1:
        raise CertificateError("hostname %r "
            "doesn't match either of %s"
            % (hostname, ', '.join(map(repr, dnsnames))))
    elif len(dnsnames) == 1:
        raise CertificateError("hostname %r "
            "doesn't match %r"
            % (hostname, dnsnames[0]))
    else:
        raise CertificateError("no appropriate commonName or "
            "subjectAltName fields were found")


161 162 163 164 165 166 167 168 169 170 171 172 173 174
class SSLContext(_SSLContext):
    """An SSLContext holds various SSL-related configuration options and
    data, such as certificates and possibly a private key."""

    __slots__ = ('protocol',)

    def __new__(cls, protocol, *args, **kwargs):
        return _SSLContext.__new__(cls, protocol)

    def __init__(self, protocol):
        self.protocol = protocol

    def wrap_socket(self, sock, server_side=False,
                    do_handshake_on_connect=True,
175 176
                    suppress_ragged_eofs=True,
                    server_hostname=None):
177 178 179
        return SSLSocket(sock=sock, server_side=server_side,
                         do_handshake_on_connect=do_handshake_on_connect,
                         suppress_ragged_eofs=suppress_ragged_eofs,
180
                         server_hostname=server_hostname,
181 182 183 184
                         _context=self)


class SSLSocket(socket):
185 186 187 188
    """This class implements a subtype of socket.socket that wraps
    the underlying OS socket in an SSL context when necessary, and
    provides read and write methods over that channel."""

Bill Janssen's avatar
Bill Janssen committed
189
    def __init__(self, sock=None, keyfile=None, certfile=None,
190
                 server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
191 192 193
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                 do_handshake_on_connect=True,
                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
194
                 suppress_ragged_eofs=True, ciphers=None,
195
                 server_hostname=None,
196
                 _context=None):
Bill Janssen's avatar
Bill Janssen committed
197

198 199 200
        if _context:
            self.context = _context
        else:
201 202 203
            if server_side and not certfile:
                raise ValueError("certfile must be specified for server-side "
                                 "operations")
204 205
            if keyfile and not certfile:
                raise ValueError("certfile must be specified")
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
            if certfile and not keyfile:
                keyfile = certfile
            self.context = SSLContext(ssl_version)
            self.context.verify_mode = cert_reqs
            if ca_certs:
                self.context.load_verify_locations(ca_certs)
            if certfile:
                self.context.load_cert_chain(certfile, keyfile)
            if ciphers:
                self.context.set_ciphers(ciphers)
            self.keyfile = keyfile
            self.certfile = certfile
            self.cert_reqs = cert_reqs
            self.ssl_version = ssl_version
            self.ca_certs = ca_certs
            self.ciphers = ciphers
222 223 224
        if server_side and server_hostname:
            raise ValueError("server_hostname can only be specified "
                             "in client mode")
225
        self.server_side = server_side
226
        self.server_hostname = server_hostname
227 228
        self.do_handshake_on_connect = do_handshake_on_connect
        self.suppress_ragged_eofs = suppress_ragged_eofs
229
        connected = False
Bill Janssen's avatar
Bill Janssen committed
230
        if sock is not None:
231 232 233 234
            socket.__init__(self,
                            family=sock.family,
                            type=sock.type,
                            proto=sock.proto,
235
                            fileno=sock.fileno())
236
            self.settimeout(sock.gettimeout())
237 238 239 240 241 242 243 244
            # see if it's connected
            try:
                sock.getpeername()
            except socket_error as e:
                if e.errno != errno.ENOTCONN:
                    raise
            else:
                connected = True
245
            sock.detach()
Bill Janssen's avatar
Bill Janssen committed
246 247 248 249 250
        elif fileno is not None:
            socket.__init__(self, fileno=fileno)
        else:
            socket.__init__(self, family=family, type=type, proto=proto)

251 252
        self._closed = False
        self._sslobj = None
253
        self._connected = connected
254 255
        if connected:
            # create the SSL object
Bill Janssen's avatar
Bill Janssen committed
256
            try:
257 258
                self._sslobj = self.context._wrap_socket(self, server_side,
                                                         server_hostname)
Bill Janssen's avatar
Bill Janssen committed
259
                if do_handshake_on_connect:
260 261 262 263
                    timeout = self.gettimeout()
                    if timeout == 0.0:
                        # non-blocking
                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
Bill Janssen's avatar
Bill Janssen committed
264
                    self.do_handshake()
265

Bill Janssen's avatar
Bill Janssen committed
266 267 268 269
            except socket_error as x:
                self.close()
                raise x

270 271 272 273
    def dup(self):
        raise NotImplemented("Can't dup() %s instances" %
                             self.__class__.__name__)

Bill Janssen's avatar
Bill Janssen committed
274 275 276
    def _checkClosed(self, msg=None):
        # raise an exception here if you wish to check for spurious closes
        pass
277

278
    def read(self, len=0, buffer=None):
279 280 281
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""

Bill Janssen's avatar
Bill Janssen committed
282 283
        self._checkClosed()
        try:
284 285
            if buffer is not None:
                v = self._sslobj.read(len, buffer)
Bill Janssen's avatar
Bill Janssen committed
286
            else:
287 288
                v = self._sslobj.read(len or 1024)
            return v
Bill Janssen's avatar
Bill Janssen committed
289 290
        except SSLError as x:
            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
291
                if buffer is not None:
292 293 294
                    return 0
                else:
                    return b''
Bill Janssen's avatar
Bill Janssen committed
295 296
            else:
                raise
297 298

    def write(self, data):
299 300 301
        """Write DATA to the underlying SSL channel.  Returns
        number of bytes of DATA actually transmitted."""

Bill Janssen's avatar
Bill Janssen committed
302
        self._checkClosed()
303 304
        return self._sslobj.write(data)

305
    def getpeercert(self, binary_form=False):
306 307 308 309 310
        """Returns a formatted version of the data in the
        certificate provided by the other end of the SSL channel.
        Return None if no certificate was provided, {} if a
        certificate was provided, but not validated."""

Bill Janssen's avatar
Bill Janssen committed
311
        self._checkClosed()
312 313
        return self._sslobj.peer_certificate(binary_form)

314
    def cipher(self):
Bill Janssen's avatar
Bill Janssen committed
315
        self._checkClosed()
316 317 318 319
        if not self._sslobj:
            return None
        else:
            return self._sslobj.cipher()
320

321
    def send(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
322
        self._checkClosed()
323 324 325 326 327
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to send() on %s" %
                    self.__class__)
Bill Janssen's avatar
Bill Janssen committed
328 329 330 331 332 333 334 335 336 337 338 339
            while True:
                try:
                    v = self._sslobj.write(data)
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        return 0
                    elif x.args[0] == SSL_ERROR_WANT_WRITE:
                        return 0
                    else:
                        raise
                else:
                    return v
340 341
        else:
            return socket.send(self, data, flags)
342

343
    def sendto(self, data, flags_or_addr, addr=None):
Bill Janssen's avatar
Bill Janssen committed
344
        self._checkClosed()
345
        if self._sslobj:
346
            raise ValueError("sendto not allowed on instances of %s" %
347
                             self.__class__)
348 349
        elif addr is None:
            return socket.sendto(self, data, flags_or_addr)
350
        else:
351
            return socket.sendto(self, data, flags_or_addr, addr)
352

353
    def sendall(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
354
        self._checkClosed()
355
        if self._sslobj:
356 357 358 359
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to sendall() on %s" %
                    self.__class__)
Bill Janssen's avatar
Bill Janssen committed
360 361 362 363 364 365
            amount = len(data)
            count = 0
            while (count < amount):
                v = self.send(data[count:])
                count += v
            return amount
366 367
        else:
            return socket.sendall(self, data, flags)
368

369
    def recv(self, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
370
        self._checkClosed()
371 372 373
        if self._sslobj:
            if flags != 0:
                raise ValueError(
374 375 376
                    "non-zero flags not allowed in calls to recv() on %s" %
                    self.__class__)
            return self.read(buflen)
377 378
        else:
            return socket.recv(self, buflen, flags)
379

380
    def recv_into(self, buffer, nbytes=None, flags=0):
Bill Janssen's avatar
Bill Janssen committed
381 382 383 384 385 386 387 388
        self._checkClosed()
        if buffer and (nbytes is None):
            nbytes = len(buffer)
        elif nbytes is None:
            nbytes = 1024
        if self._sslobj:
            if flags != 0:
                raise ValueError(
389 390
                  "non-zero flags not allowed in calls to recv_into() on %s" %
                  self.__class__)
391
            return self.read(nbytes, buffer)
Bill Janssen's avatar
Bill Janssen committed
392 393 394
        else:
            return socket.recv_into(self, buffer, nbytes, flags)

395
    def recvfrom(self, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
396
        self._checkClosed()
397
        if self._sslobj:
398
            raise ValueError("recvfrom not allowed on instances of %s" %
399 400
                             self.__class__)
        else:
401
            return socket.recvfrom(self, buflen, flags)
402

403 404 405 406 407 408 409 410
    def recvfrom_into(self, buffer, nbytes=None, flags=0):
        self._checkClosed()
        if self._sslobj:
            raise ValueError("recvfrom_into not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.recvfrom_into(self, buffer, nbytes, flags)

411
    def pending(self):
Bill Janssen's avatar
Bill Janssen committed
412 413 414 415 416 417
        self._checkClosed()
        if self._sslobj:
            return self._sslobj.pending()
        else:
            return 0

418
    def shutdown(self, how):
Bill Janssen's avatar
Bill Janssen committed
419
        self._checkClosed()
420
        self._sslobj = None
421
        socket.shutdown(self, how)
422

423
    def unwrap(self):
424 425 426 427 428 429 430
        if self._sslobj:
            s = self._sslobj.shutdown()
            self._sslobj = None
            return s
        else:
            raise ValueError("No SSL wrapper around " + str(self))

431
    def _real_close(self):
432
        self._sslobj = None
Bill Janssen's avatar
Bill Janssen committed
433
        # self._closed = True
434
        socket._real_close(self)
Bill Janssen's avatar
Bill Janssen committed
435

436
    def do_handshake(self, block=False):
Bill Janssen's avatar
Bill Janssen committed
437 438
        """Perform a TLS/SSL handshake."""

439
        timeout = self.gettimeout()
Bill Janssen's avatar
Bill Janssen committed
440
        try:
441 442
            if timeout == 0.0 and block:
                self.settimeout(None)
Bill Janssen's avatar
Bill Janssen committed
443
            self._sslobj.do_handshake()
444 445
        finally:
            self.settimeout(timeout)
446

447
    def _real_connect(self, addr, connect_ex):
448 449
        if self.server_side:
            raise ValueError("can't connect in server-side mode")
450 451
        # Here we assume that the socket is client-side, and not
        # connected at the time of the call.  We connect it, then wrap it.
452
        if self._connected:
453
            raise ValueError("attempt to connect already-connected SSLSocket!")
454
        self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
455
        try:
456 457
            if connect_ex:
                rc = socket.connect_ex(self, addr)
458
            else:
459 460 461 462 463 464 465 466 467 468
                rc = None
                socket.connect(self, addr)
            if not rc:
                if self.do_handshake_on_connect:
                    self.do_handshake()
                self._connected = True
            return rc
        except socket_error:
            self._sslobj = None
            raise
469 470 471 472 473 474 475 476 477 478

    def connect(self, addr):
        """Connects to remote ADDR, and then wraps the connection in
        an SSL channel."""
        self._real_connect(addr, False)

    def connect_ex(self, addr):
        """Connects to remote ADDR, and then wraps the connection in
        an SSL channel."""
        return self._real_connect(addr, True)
479 480

    def accept(self):
481 482 483 484 485
        """Accepts a new connection from a remote client, and returns
        a tuple containing that new connection wrapped with a server-side
        SSL channel, and the address of the remote client."""

        newsock, addr = socket.accept(self)
Bill Janssen's avatar
Bill Janssen committed
486 487 488
        return (SSLSocket(sock=newsock,
                          keyfile=self.keyfile, certfile=self.certfile,
                          server_side=True,
489 490
                          cert_reqs=self.cert_reqs,
                          ssl_version=self.ssl_version,
Bill Janssen's avatar
Bill Janssen committed
491
                          ca_certs=self.ca_certs,
492
                          ciphers=self.ciphers,
493 494
                          do_handshake_on_connect=
                              self.do_handshake_on_connect),
Bill Janssen's avatar
Bill Janssen committed
495
                addr)
496

497
    def __del__(self):
498
        # sys.stderr.write("__del__ on %s\n" % repr(self))
499 500
        self._real_close()

501

502 503
def wrap_socket(sock, keyfile=None, certfile=None,
                server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
504
                ssl_version=PROTOCOL_SSLv23, ca_certs=None,
505
                do_handshake_on_connect=True,
506
                suppress_ragged_eofs=True, ciphers=None):
507

Bill Janssen's avatar
Bill Janssen committed
508
    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
509
                     server_side=server_side, cert_reqs=cert_reqs,
Bill Janssen's avatar
Bill Janssen committed
510
                     ssl_version=ssl_version, ca_certs=ca_certs,
511
                     do_handshake_on_connect=do_handshake_on_connect,
512 513
                     suppress_ragged_eofs=suppress_ragged_eofs,
                     ciphers=ciphers)
514

515 516 517
# some utility functions

def cert_time_to_seconds(cert_time):
518 519 520 521
    """Takes a date-time string in standard ASN1_print form
    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
    a Python time value in seconds past the epoch."""

522 523 524
    import time
    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))

525 526 527 528 529 530 531
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"

def DER_cert_to_PEM_cert(der_cert_bytes):
    """Takes a certificate in binary DER format and returns the
    PEM version of it as a string."""

Bill Janssen's avatar
Bill Janssen committed
532 533 534 535
    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
    return (PEM_HEADER + '\n' +
            textwrap.fill(f, 64) + '\n' +
            PEM_FOOTER + '\n')
536 537 538 539 540 541 542 543 544 545 546 547

def PEM_cert_to_DER_cert(pem_cert_string):
    """Takes a certificate in ASCII PEM format and returns the
    DER-encoded version of it as a byte sequence"""

    if not pem_cert_string.startswith(PEM_HEADER):
        raise ValueError("Invalid PEM encoding; must start with %s"
                         % PEM_HEADER)
    if not pem_cert_string.strip().endswith(PEM_FOOTER):
        raise ValueError("Invalid PEM encoding; must end with %s"
                         % PEM_FOOTER)
    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
Georg Brandl's avatar
Georg Brandl committed
548
    return base64.decodebytes(d.encode('ASCII', 'strict'))
549

550
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
    """Retrieve the certificate from the server at the specified address,
    and return it as a PEM-encoded string.
    If 'ca_certs' is specified, validate the server cert against it.
    If 'ssl_version' is specified, use it in the connection attempt."""

    host, port = addr
    if (ca_certs is not None):
        cert_reqs = CERT_REQUIRED
    else:
        cert_reqs = CERT_NONE
    s = wrap_socket(socket(), ssl_version=ssl_version,
                    cert_reqs=cert_reqs, ca_certs=ca_certs)
    s.connect(addr)
    dercert = s.getpeercert(True)
    s.close()
    return DER_cert_to_PEM_cert(dercert)

568
def get_protocol_name(protocol_code):
569
    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')