ssl.py 13.8 KB
Newer Older
1 2 3
# Wrapper module for _ssl, providing some additional facilities
# implemented in Python.  Written by Bill Janssen.

4
"""This module provides some more Pythonic support for SSL.
5 6 7

Object types:

8
  SSLSocket -- subtype of socket.socket which does SSL over the socket
9 10 11

Exceptions:

12
  SSLError -- exception raised for I/O errors
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

Functions:

  cert_time_to_seconds -- convert time string used for certificate
                          notBefore and notAfter functions to integer
                          seconds past the Epoch (the time values
                          returned from time.time())

  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
                          by the server running on HOST at port PORT.  No
                          validation of the certificate is performed.

Integer constants:

SSL_ERROR_ZERO_RETURN
SSL_ERROR_WANT_READ
SSL_ERROR_WANT_WRITE
SSL_ERROR_WANT_X509_LOOKUP
SSL_ERROR_SYSCALL
SSL_ERROR_SSL
SSL_ERROR_WANT_CONNECT

SSL_ERROR_EOF
SSL_ERROR_INVALID_ERROR_CODE

The following group define certificate requirements that one side is
allowing/requiring from the other side:

CERT_NONE - no certificates from the other side are required (or will
            be looked at if provided)
CERT_OPTIONAL - certificates are not required, but if provided will be
                validated, and if validation fails, the connection will
                also fail
CERT_REQUIRED - certificates are required, and will be validated, and
                if validation fails, the connection will also fail

The following constants identify various SSL protocol variants:

PROTOCOL_SSLv2
PROTOCOL_SSLv3
PROTOCOL_SSLv23
PROTOCOL_TLSv1
"""

57
import os, sys, textwrap
58 59

import _ssl             # if we can't import it, let the error propagate
60 61

from _ssl import SSLError
62
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
63 64
from _ssl import (PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23,
                  PROTOCOL_TLSv1)
65
from _ssl import RAND_status, RAND_egd, RAND_add
66 67 68 69 70 71 72 73 74 75 76
from _ssl import (
    SSL_ERROR_ZERO_RETURN,
    SSL_ERROR_WANT_READ,
    SSL_ERROR_WANT_WRITE,
    SSL_ERROR_WANT_X509_LOOKUP,
    SSL_ERROR_SYSCALL,
    SSL_ERROR_SSL,
    SSL_ERROR_WANT_CONNECT,
    SSL_ERROR_EOF,
    SSL_ERROR_INVALID_ERROR_CODE,
    )
77

Bill Janssen's avatar
Bill Janssen committed
78
from socket import socket, AF_INET, SOCK_STREAM, error
79
from socket import getnameinfo as _getnameinfo
Bill Janssen's avatar
Bill Janssen committed
80
from socket import error as socket_error
81
from socket import dup as _dup
82
import base64        # for DER-to-PEM translation
83

84
class SSLSocket(socket):
85

86 87 88 89
    """This class implements a subtype of socket.socket that wraps
    the underlying OS socket in an SSL context when necessary, and
    provides read and write methods over that channel."""

Bill Janssen's avatar
Bill Janssen committed
90
    def __init__(self, sock=None, keyfile=None, certfile=None,
91
                 server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
92 93 94 95 96 97 98 99 100 101
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                 do_handshake_on_connect=True,
                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
                 suppress_ragged_eofs=True):

        self._base = None

        if sock is not None:
            # copied this code from socket.accept()
            fd = sock.fileno()
102 103 104 105 106
            nfd = _dup(fd)
            socket.__init__(self, family=sock.family, type=sock.type,
                            proto=sock.proto, fileno=nfd)
            sock.close()
            sock = None
Bill Janssen's avatar
Bill Janssen committed
107 108 109 110 111 112 113
        elif fileno is not None:
            socket.__init__(self, fileno=fileno)
        else:
            socket.__init__(self, family=family, type=type, proto=proto)

        self._closed = False

114 115
        if certfile and not keyfile:
            keyfile = certfile
116 117 118 119 120 121
        # see if it's connected
        try:
            socket.getpeername(self)
        except:
            # no, no connection yet
            self._sslobj = None
122
        else:
123
            # yes, create the SSL object
Bill Janssen's avatar
Bill Janssen committed
124 125 126 127 128 129 130 131 132 133 134
            try:
                self._sslobj = _ssl.sslwrap(self, server_side,
                                            keyfile, certfile,
                                            cert_reqs, ssl_version, ca_certs)
                if do_handshake_on_connect:
                    self.do_handshake()
            except socket_error as x:
                self.close()
                raise x

        self._base = sock
135 136 137 138 139
        self.keyfile = keyfile
        self.certfile = certfile
        self.cert_reqs = cert_reqs
        self.ssl_version = ssl_version
        self.ca_certs = ca_certs
Bill Janssen's avatar
Bill Janssen committed
140 141 142
        self.do_handshake_on_connect = do_handshake_on_connect
        self.suppress_ragged_eofs = suppress_ragged_eofs

143 144 145 146
    def dup(self):
        raise NotImplemented("Can't dup() %s instances" %
                             self.__class__.__name__)

Bill Janssen's avatar
Bill Janssen committed
147 148 149
    def _checkClosed(self, msg=None):
        # raise an exception here if you wish to check for spurious closes
        pass
150

Bill Janssen's avatar
Bill Janssen committed
151
    def read(self, len=1024, buffer=None):
152 153 154
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""

Bill Janssen's avatar
Bill Janssen committed
155 156 157 158 159 160 161 162 163 164 165
        self._checkClosed()
        try:
            if buffer:
                return self._sslobj.read(buffer, len)
            else:
                return self._sslobj.read(len)
        except SSLError as x:
            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
                return b''
            else:
                raise
166 167

    def write(self, data):
168 169 170
        """Write DATA to the underlying SSL channel.  Returns
        number of bytes of DATA actually transmitted."""

Bill Janssen's avatar
Bill Janssen committed
171
        self._checkClosed()
172 173
        return self._sslobj.write(data)

174
    def getpeercert(self, binary_form=False):
175 176 177 178 179
        """Returns a formatted version of the data in the
        certificate provided by the other end of the SSL channel.
        Return None if no certificate was provided, {} if a
        certificate was provided, but not validated."""

Bill Janssen's avatar
Bill Janssen committed
180
        self._checkClosed()
181 182
        return self._sslobj.peer_certificate(binary_form)

183
    def cipher(self):
Bill Janssen's avatar
Bill Janssen committed
184
        self._checkClosed()
185 186 187 188
        if not self._sslobj:
            return None
        else:
            return self._sslobj.cipher()
189

190
    def send(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
191
        self._checkClosed()
192 193 194 195 196
        if self._sslobj:
            if flags != 0:
                raise ValueError(
                    "non-zero flags not allowed in calls to send() on %s" %
                    self.__class__)
Bill Janssen's avatar
Bill Janssen committed
197 198 199 200 201 202 203 204 205 206 207 208
            while True:
                try:
                    v = self._sslobj.write(data)
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        return 0
                    elif x.args[0] == SSL_ERROR_WANT_WRITE:
                        return 0
                    else:
                        raise
                else:
                    return v
209 210
        else:
            return socket.send(self, data, flags)
211

212
    def send_to(self, data, addr, flags=0):
Bill Janssen's avatar
Bill Janssen committed
213
        self._checkClosed()
214 215 216 217 218
        if self._sslobj:
            raise ValueError("send_to not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.send_to(self, data, addr, flags)
219

220
    def sendall(self, data, flags=0):
Bill Janssen's avatar
Bill Janssen committed
221
        self._checkClosed()
222
        if self._sslobj:
Bill Janssen's avatar
Bill Janssen committed
223 224 225 226 227 228
            amount = len(data)
            count = 0
            while (count < amount):
                v = self.send(data[count:])
                count += v
            return amount
229 230
        else:
            return socket.sendall(self, data, flags)
231

232
    def recv(self, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
233
        self._checkClosed()
234 235 236
        if self._sslobj:
            if flags != 0:
                raise ValueError(
237 238
                  "non-zero flags not allowed in calls to recv_into() on %s" %
                  self.__class__)
Bill Janssen's avatar
Bill Janssen committed
239 240 241 242 243 244 245 246
            while True:
                try:
                    return self.read(buflen)
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        continue
                    else:
                        raise x
247 248
        else:
            return socket.recv(self, buflen, flags)
249

250
    def recv_into(self, buffer, nbytes=None, flags=0):
Bill Janssen's avatar
Bill Janssen committed
251 252 253 254 255 256 257 258
        self._checkClosed()
        if buffer and (nbytes is None):
            nbytes = len(buffer)
        elif nbytes is None:
            nbytes = 1024
        if self._sslobj:
            if flags != 0:
                raise ValueError(
259 260
                  "non-zero flags not allowed in calls to recv_into() on %s" %
                  self.__class__)
Bill Janssen's avatar
Bill Janssen committed
261 262 263 264 265 266 267 268 269 270 271 272 273
            while True:
                try:
                    v = self.read(nbytes, buffer)
                    sys.stdout.flush()
                    return v
                except SSLError as x:
                    if x.args[0] == SSL_ERROR_WANT_READ:
                        continue
                    else:
                        raise x
        else:
            return socket.recv_into(self, buffer, nbytes, flags)

274
    def recv_from(self, addr, buflen=1024, flags=0):
Bill Janssen's avatar
Bill Janssen committed
275
        self._checkClosed()
276 277 278 279 280 281
        if self._sslobj:
            raise ValueError("recv_from not allowed on instances of %s" %
                             self.__class__)
        else:
            return socket.recv_from(self, addr, buflen, flags)

282
    def pending(self):
Bill Janssen's avatar
Bill Janssen committed
283 284 285 286 287 288
        self._checkClosed()
        if self._sslobj:
            return self._sslobj.pending()
        else:
            return 0

289
    def shutdown(self, how):
Bill Janssen's avatar
Bill Janssen committed
290
        self._checkClosed()
291
        self._sslobj = None
292
        socket.shutdown(self, how)
293

294
    def _real_close(self):
295
        self._sslobj = None
Bill Janssen's avatar
Bill Janssen committed
296 297 298 299 300
        # self._closed = True
        if self._base:
            self._base.close()
        socket._real_close(self)

301
    def do_handshake(self):
Bill Janssen's avatar
Bill Janssen committed
302 303 304 305 306 307 308
        """Perform a TLS/SSL handshake."""

        try:
            self._sslobj.do_handshake()
        except:
            self._sslobj = None
            raise
309 310

    def connect(self, addr):
311 312 313
        """Connects to remote ADDR, and then wraps the connection in
        an SSL channel."""

314 315
        # Here we assume that the socket is client-side, and not
        # connected at the time of the call.  We connect it, then wrap it.
316
        if self._sslobj:
317
            raise ValueError("attempt to connect already-connected SSLSocket!")
318
        socket.connect(self, addr)
Bill Janssen's avatar
Bill Janssen committed
319
        self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
320 321
                                    self.cert_reqs, self.ssl_version,
                                    self.ca_certs)
Bill Janssen's avatar
Bill Janssen committed
322 323
        if self.do_handshake_on_connect:
            self.do_handshake()
324 325

    def accept(self):
326 327 328 329 330
        """Accepts a new connection from a remote client, and returns
        a tuple containing that new connection wrapped with a server-side
        SSL channel, and the address of the remote client."""

        newsock, addr = socket.accept(self)
Bill Janssen's avatar
Bill Janssen committed
331 332 333
        return (SSLSocket(sock=newsock,
                          keyfile=self.keyfile, certfile=self.certfile,
                          server_side=True,
334 335
                          cert_reqs=self.cert_reqs,
                          ssl_version=self.ssl_version,
Bill Janssen's avatar
Bill Janssen committed
336
                          ca_certs=self.ca_certs,
337 338
                          do_handshake_on_connect=
                              self.do_handshake_on_connect),
Bill Janssen's avatar
Bill Janssen committed
339
                addr)
340 341 342 343


def wrap_socket(sock, keyfile=None, certfile=None,
                server_side=False, cert_reqs=CERT_NONE,
Bill Janssen's avatar
Bill Janssen committed
344 345
                ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                do_handshake_on_connect=True):
346

Bill Janssen's avatar
Bill Janssen committed
347
    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
348
                     server_side=server_side, cert_reqs=cert_reqs,
Bill Janssen's avatar
Bill Janssen committed
349 350
                     ssl_version=ssl_version, ca_certs=ca_certs,
                     do_handshake_on_connect=do_handshake_on_connect)
351

352 353 354
# some utility functions

def cert_time_to_seconds(cert_time):
355 356 357 358
    """Takes a date-time string in standard ASN1_print form
    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
    a Python time value in seconds past the epoch."""

359 360 361
    import time
    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))

362 363 364 365 366 367 368
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"

def DER_cert_to_PEM_cert(der_cert_bytes):
    """Takes a certificate in binary DER format and returns the
    PEM version of it as a string."""

Bill Janssen's avatar
Bill Janssen committed
369 370 371 372
    f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict')
    return (PEM_HEADER + '\n' +
            textwrap.fill(f, 64) + '\n' +
            PEM_FOOTER + '\n')
373 374 375 376 377 378 379 380 381 382 383 384

def PEM_cert_to_DER_cert(pem_cert_string):
    """Takes a certificate in ASCII PEM format and returns the
    DER-encoded version of it as a byte sequence"""

    if not pem_cert_string.startswith(PEM_HEADER):
        raise ValueError("Invalid PEM encoding; must start with %s"
                         % PEM_HEADER)
    if not pem_cert_string.strip().endswith(PEM_FOOTER):
        raise ValueError("Invalid PEM encoding; must end with %s"
                         % PEM_FOOTER)
    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
Bill Janssen's avatar
Bill Janssen committed
385
    return base64.decodestring(d.encode('ASCII', 'strict'))
386

387
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
    """Retrieve the certificate from the server at the specified address,
    and return it as a PEM-encoded string.
    If 'ca_certs' is specified, validate the server cert against it.
    If 'ssl_version' is specified, use it in the connection attempt."""

    host, port = addr
    if (ca_certs is not None):
        cert_reqs = CERT_REQUIRED
    else:
        cert_reqs = CERT_NONE
    s = wrap_socket(socket(), ssl_version=ssl_version,
                    cert_reqs=cert_reqs, ca_certs=ca_certs)
    s.connect(addr)
    dercert = s.getpeercert(True)
    s.close()
    return DER_cert_to_PEM_cert(dercert)

405
def get_protocol_name(protocol_code):
406 407 408 409 410 411 412 413 414 415
    if protocol_code == PROTOCOL_TLSv1:
        return "TLSv1"
    elif protocol_code == PROTOCOL_SSLv23:
        return "SSLv23"
    elif protocol_code == PROTOCOL_SSLv2:
        return "SSLv2"
    elif protocol_code == PROTOCOL_SSLv3:
        return "SSLv3"
    else:
        return "<unknown>"