x509.py 8.95 KB
Newer Older
1 2
# -*- coding: utf-8 -*-
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
3 4 5
from OpenSSL import crypto
from . import utils

6
def newHmacSecret():
7
    return utils.newHmacSecret(int(time.time() * 1000000))
8

9 10 11 12 13 14
def networkFromCa(ca):
    return bin(ca.get_serial_number())[3:]

def subnetFromCert(cert):
    return cert.get_subject().CN

15 16 17
def notBefore(cert):
    return calendar.timegm(time.strptime(cert.get_notBefore(),'%Y%m%d%H%M%SZ'))

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
def notAfter(cert):
    return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))

def openssl(*args):
    return utils.Popen(('openssl',) + args,
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)

def encrypt(cert, data):
    r, w = os.pipe()
    try:
        threading.Thread(target=os.write, args=(w, cert)).start()
        p = openssl('rsautl', '-encrypt', '-certin',
                    '-inkey', '/proc/self/fd/%u' % r)
        out, err = p.communicate(data)
    finally:
        os.close(r)
        os.close(w)
    if p.returncode:
        raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
    return out

41 42 43
def fingerprint(cert, alg='sha1'):
    return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))

44
def maybe_renew(path, cert, info, renew, force=False):
45 46
    from .registry import RENEW_PERIOD
    while True:
47 48 49
        if force:
            force = False
        else:
50 51 52
            next_renew = notAfter(cert) - RENEW_PERIOD
            if time.time() < next_renew:
                return cert, next_renew
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        try:
            pem = renew()
            if not pem or pem == crypto.dump_certificate(
                  crypto.FILETYPE_PEM, cert):
                exc_info = 0
                break
            cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
        except Exception:
            exc_info = 1
            break
        new_path = path + '.new'
        with open(new_path, 'w') as f:
            f.write(pem)
        try:
            s = os.stat(path)
            os.chown(new_path, s.st_uid, s.st_gid)
        except OSError:
            pass
        os.rename(new_path, path)
        logging.info("%s renewed until %s UTC",
            info, time.asctime(time.gmtime(notAfter(cert))))
    logging.error("%s not renewed. Will retry tomorrow.",
                  info, exc_info=exc_info)
    return cert, time.time() + 86400

78

79 80 81
class VerifyError(Exception):
    pass

82 83 84 85
class NewSessionError(Exception):
    pass


86 87 88 89 90 91 92 93 94 95 96 97
class Cert(object):

    def __init__(self, ca, key, cert=None):
        self.ca_path = ca
        self.cert_path = cert
        self.key_path = key
        with open(ca) as f:
            self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
        with open(key) as f:
            self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
        if cert:
            with open(cert) as f:
98
                self.cert = self.loadVerify(f.read())
99 100 101 102 103 104 105 106 107

    @property
    def prefix(self):
        return utils.binFromSubnet(subnetFromCert(self.cert))

    @property
    def network(self):
        return networkFromCa(self.ca)

108 109 110 111
    @property
    def subject_serial(self):
        return int(self.cert.get_subject().serialNumber)

112 113 114 115 116 117
    @property
    def openvpn_args(self):
        return ('--ca', self.ca_path,
                '--cert', self.cert_path,
                '--key', self.key_path)

118
    def maybeRenew(self, registry, crl):
119
        self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
120 121
              "Certificate", lambda: registry.renewCertificate(self.prefix),
              self.cert.get_serial_number() in crl)
122 123 124 125
        self.ca, ca_renew = maybe_renew(self.ca_path, self.ca,
              "CA Certificate", registry.getCa)
        return min(next_renew, ca_renew)

126
    def loadVerify(self, cert, strict=False, type=crypto.FILETYPE_PEM):
127
        try:
128
            r = crypto.load_certificate(type, cert)
129 130
        except crypto.Error:
            raise VerifyError(None, None, 'unable to load certificate')
131 132
        if type != crypto.FILETYPE_PEM:
            cert = crypto.dump_certificate(crypto.FILETYPE_PEM, r)
133 134
        args = ['verify', '-CAfile', self.ca_path]
        if not strict:
135 136
            args += '-attime', str(min(int(time.time()),
                max(notBefore(self.ca), notBefore(r))))
137
        p = openssl(*args)
138
        out, err = p.communicate(cert)
139
        if 1: # BBB: Old OpenSSL could return 0 in case of errors.
140 141 142
          if err is None: # utils.Popen failed with ENOMEM
            raise VerifyError(None, None,
                "error running openssl, assuming cert is invalid")
143 144 145 146
          # BBB: With old versions of openssl, detailed
          #      error is printed to standard output.
          for err in err, out:
            for x in err.splitlines():
147 148 149
                if x.startswith('error '):
                    x, msg = x.split(':', 1)
                    _, code, _, depth, _ = x.split(None, 4)
150
                    raise VerifyError(int(code), int(depth), msg.strip())
151 152
        return r

153
    def verify(self, sign, data):
154
        crypto.verify(self.ca, sign, data, 'sha512')
155 156

    def sign(self, data):
157
        return crypto.sign(self.key, data, 'sha512')
158 159 160 161 162 163 164

    def decrypt(self, data):
        p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
        out, err = p.communicate(data)
        if p.returncode:
            raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
        return out
165

166 167 168 169 170 171 172
    def verifyVersion(self, version):
        try:
            n = 1 + (ord(version[0]) >> 5)
            self.verify(version[n:], version[:n])
        except (IndexError, crypto.Error):
            raise VerifyError(None, None, 'invalid network version')

173 174 175 176 177 178 179

class Peer(object):
    """
    UDP:    A ─────────────────────────────────────────────> B

    hello0:    0, A
               1, fingerprint(B), A
180 181 182
    hello:     2, X = encrypt(B, secret), sign(A, X)
    !hello:    #, type, value, hmac(secret, payload)
               └── payload ──┘
183 184

    new secret > old secret
185
    (timestamp + random bits)
186 187 188 189 190 191 192 193 194

    Reject messages with # smaller or equal than previously processed.

    Yes, we do UDP on purpose. The only drawbacks are:
    - The limited size of packets, but they are big enough for a network
      using 4096-bits RSA keys.
    - hello0 packets (0 & 1) are subject to DoS, because verifying a
      certificate uses much CPU. A solution would be to use TCP until the
      secret is exchanged and continue with UDP.
195 196 197 198 199

    The fingerprint is only used to quickly know if peer's certificate has
    changed. It must be short enough to not exceed packet size when using
    certificates with 4096-bit keys. A weak algorithm is ok as long as there
    is no accidental collision. So SHA-1 looks fine.
200 201 202
    """
    _hello = _last = 0
    _key = newHmacSecret()
203
    serial = None
204
    stop_date = float('inf')
205
    version = ''
206 207 208 209 210 211 212 213

    def __init__(self, prefix):
        self.prefix = prefix

    @property
    def connected(self):
        return self._last is None or time.time() < self._last + 60

214 215
    subject_serial = Cert.subject_serial

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
    def __ne__(self, other):
        raise AssertionError
    __eq__ = __ge__ = __le__ = __ne__

    def __gt__(self, other):
        return self.prefix > (other if type(other) is str else other.prefix)
    def __lt__(self, other):
        return self.prefix < (other if type(other) is str else other.prefix)

    def hello0(self, cert):
        if self._hello < time.time():
            try:
                msg = '\0\0\0\1' + fingerprint(self.cert).digest()
            except AttributeError:
                msg = '\0\0\0\0'
            return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)

    def hello0Sent(self):
        self._hello = time.time() + 60

    def hello(self, cert):
        key = self._key = newHmacSecret()
        h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert),
                    key)
        self._i = self._j = 2
        self._last = 0
        return '\0\0\0\2' + h + cert.sign(h)

    def _hmac(self, msg):
        return hmac.HMAC(self._key, msg, hashlib.sha1).digest()

    def newSession(self, key):
        if key <= self._key:
            raise NewSessionError(self._key, key)
        self._key = key
        self._i = self._j = 2
        self._last = None

    def verify(self, sign, data):
255
        crypto.verify(self.cert, sign, data, 'sha512')
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278

    seqno_struct = struct.Struct("!L")

    def decode(self, msg, _unpack=seqno_struct.unpack):
        seqno, = _unpack(msg[:4])
        if seqno <= 2:
            return seqno, msg[4:]
        i = -utils.HMAC_LEN
        if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
            self._last = None
            self._i = seqno
            return msg[4:i]

    def encode(self, msg, _pack=seqno_struct.pack):
        self._j += 1
        msg = _pack(self._j) + msg
        return msg + self._hmac(msg)

    del seqno_struct

    def sent(self):
        if not self._last:
            self._last = time.time()