registry.py 29.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
"""
Authenticated communication:

  handshake (hello):
    C->S: CN
    S->C: X = Encrypt(CN)(secret), Sign(CA)(X)

  call:
    C->S: CN, ..., HMAC(secret+1)(path_info?query_string)
    S->C: result, HMAC(secret+2)(result)

  secret+1 = SHA1(secret) to protect from replay attacks

  HMAC in custom header, base64-encoded

  To prevent anyone from breaking an existing session,
  keep 2 secrets for each client:
  - the last one that was really used by the client (!hello)
  - the one of the last handshake (hello)
"""
21
import base64, hmac, hashlib, httplib, inspect, json, logging
22
import mailbox, os, platform, random, select, smtplib, socket, sqlite3
23
import string, sys, threading, time, weakref, zlib
24
from collections import defaultdict, deque
25
from datetime import datetime
26 27
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText
28
from operator import itemgetter
29
from OpenSSL import crypto
30
from urllib import splittype, splithost, unquote, urlencode
31
from . import ctl, tunnel, utils, version, x509
32 33

HMAC_HEADER = "Re6stHMAC"
34
RENEW_PERIOD = 30 * 86400
35
GRACE_PERIOD = 100 * 86400
Killian Lufau's avatar
Killian Lufau committed
36
BABEL_HMAC = 'babel_hmac0', 'babel_hmac1', 'babel_hmac2'
37

38 39
def rpc(f):
    args, varargs, varkw, defaults = inspect.getargspec(f)
40 41 42 43 44 45
    assert not (varargs or varkw), f
    if not defaults:
	defaults = ()
    i = len(args) - len(defaults)
    f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:i]
        + map("%s=%r".__mod__, zip(args[i:], defaults))))
46
    return f
47

48 49 50 51
def rpc_private(f):
    f._private = None
    return rpc(f)

52

53 54 55 56
class HTTPError(Exception):
    pass


57 58
class RegistryServer(object):

59
    peers = 0, ()
60
    cert_duration = 365 * 86400
61

Joanne Hugé's avatar
Joanne Hugé committed
62 63 64
    def _geoiplookup(self, ip):
        raise HTTPError(httplib.BAD_REQUEST)

65 66 67 68
    def __init__(self, config):
        self.config = config
        self.lock = threading.Lock()
        self.sessions = {}
69
        self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
70 71

        # Database initializing
72 73
        db_dir = os.path.dirname(self.config.db)
        db_dir and utils.makedirs(db_dir)
74 75
        self.db = sqlite3.connect(self.config.db, isolation_level=None,
                                                  check_same_thread=False)
76
        self.db.text_factory = str
77 78 79 80
        utils.sqliteCreateTable(self.db, "config",
                "name TEXT PRIMARY KEY NOT NULL",
                "value")
        self.prefix = self.getConfig("prefix", None)
81
        self.version = str(self.getConfig("version", "\0")) # BBB: blob
82 83 84 85 86
        utils.sqliteCreateTable(self.db, "token",
                "token TEXT PRIMARY KEY NOT NULL",
                "email TEXT NOT NULL",
                "prefix_len INTEGER NOT NULL",
                "date INTEGER NOT NULL")
87
        utils.sqliteCreateTable(self.db, "cert",
88 89
                "prefix TEXT PRIMARY KEY NOT NULL",
                "email TEXT",
90 91
                "cert TEXT")
        self.db.execute("INSERT OR IGNORE INTO cert VALUES ('',null,null)")
92
        utils.sqliteCreateTable(self.db, "crl",
93 94 95
                "serial INTEGER PRIMARY KEY NOT NULL",
                # Expiration date of revoked certificate.
                # TODO: purge rows with dates in the past.
96
                "date INTEGER NOT NULL")
97

98
        self.cert = x509.Cert(self.config.ca, self.config.key)
99
        # Get vpn network prefix
100
        self.network = self.cert.network
101 102
        logging.info("Network: %s/%u", utils.ipFromBin(self.network),
                                       len(self.network))
103
        self.email = self.cert.ca.get_subject().emailAddress
104 105

        self.peers_lock = threading.Lock()
106
        self.ctl = ctl.Babel(os.path.join(config.run, 'babeld.sock'),
107 108
            weakref.proxy(self), self.network)

Joanne Hugé's avatar
Joanne Hugé committed
109 110 111 112 113 114 115 116 117 118 119 120 121 122
        db = os.getenv('GEOIP2_MMDB')
        if db:
            from geoip2 import database, errors
            country = database.Reader(db).country
            def geoiplookup(ip):
                try:
                    return country(ip).country.iso_code.encode()
                except errors.AddressNotFoundError:
                    return
            self._geoiplookup = geoiplookup
        elif self.config.same_country:
            sys.exit("Can not respect 'same_country' network configuration"
                     " (GEOIP2_MMDB not set)")

123
        self.onTimeout()
124 125 126
        if self.prefix:
            with self.db:
                self.updateNetworkConfig()
Killian Lufau's avatar
Killian Lufau committed
127 128
        else:
            self.newHMAC(0)
129

130 131 132 133 134 135 136 137 138
    def getConfig(self, name, *default):
        r, = next(self.db.execute(
            "SELECT value FROM config WHERE name=?", (name,)), default)
        return r

    def setConfig(self, *name_value):
        self.db.execute("INSERT OR REPLACE INTO config VALUES (?, ?)",
                        name_value)

139
    def updateNetworkConfig(self, _it0=itemgetter(0)):
140
        kw = {
141
            'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125',
142 143
            'crl': map(_it0, self.db.execute(
                "SELECT serial FROM crl ORDER BY serial")),
144 145 146
            'protocol': version.protocol,
            'registry_prefix': self.prefix,
        }
147 148
        if self.config.ipv4:
            kw['ipv4'], kw['ipv4_sublen'] = self.config.ipv4
149 150
        if self.config.same_country:
            kw['same_country'] = self.config.same_country
151 152
        for x in ('client_count', 'encrypt', 'hello',
                  'max_clients', 'min_protocol', 'tunnel_refresh'):
153 154 155
            kw[x] = getattr(self.config, x)
        config = json.dumps(kw, sort_keys=True)
        if config != self.getConfig('last_config', None):
156
            self.increaseVersion()
157 158 159
            # BBB: Use buffer because of http://bugs.python.org/issue13676
            #      on Python 2.6
            self.setConfig('version', buffer(self.version))
160 161
            self.setConfig('last_config', config)
            self.sendto(self.prefix, 0)
Julien Muchembled's avatar
Julien Muchembled committed
162
        # The following entry lists values that are base64-encoded.
163 164
        kw[''] = 'version',
        kw['version'] = self.version.encode('base64')
Killian Lufau's avatar
Killian Lufau committed
165
        self.network_config = kw
166

167 168 169
    def increaseVersion(self):
        x = utils.packInteger(1 + utils.unpackInteger(self.version)[0])
        self.version = x + self.cert.sign(x)
170

171 172 173 174 175 176 177 178 179 180 181 182 183 184
    def sendto(self, prefix, code):
        self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT))

    def recv(self, code):
        try:
            prefix, msg = self.sock.recv(1<<16).split('\0', 1)
            int(prefix, 2)
        except ValueError:
            pass
        else:
            if msg and ord(msg[0]) == code:
                return prefix, msg[1:]
        return None, None

185
    def select(self, r, w, t):
186 187
        if self.timeout:
            t.append((self.timeout, self.onTimeout))
188 189 190

    def request_dump(self):
        assert self.peers_lock.locked()
191 192
        def abort():
            raise ctl.BabelException
193
        self._wait_dump = True
194
        for _ in 0, 1:
195 196 197
            self.ctl.request_dump()
            try:
                while self._wait_dump:
198
                    args = {}, {}, ((time.time() + 5, abort),)
199 200 201
                    self.ctl.select(*args)
                    utils.select(*args)
                break
202
            except ctl.BabelException:
203 204 205 206
                self.ctl.reset()

    def babel_dump(self):
        self._wait_dump = False
207

208 209 210 211 212 213 214 215 216
    def iterCert(self):
        for prefix, email, cert in self.db.execute(
                "SELECT * FROM cert WHERE cert IS NOT NULL"):
            try:
                yield (crypto.load_certificate(crypto.FILETYPE_PEM, cert),
                       prefix, email)
            except crypto.Error:
                pass

217
    def onTimeout(self):
218
        # XXX: Because we use threads to process requests, the statements
219
        #      'self.timeout = 1' below have no effect as long as the
220 221 222 223 224
        #      'select' call does not return. Ideally, we should interrupt it.
        logging.info("Checking if there's any old entry in the database ...")
        not_after = None
        old = time.time() - GRACE_PERIOD
        q =  self.db.execute
225
        with self.lock, self.db:
226 227 228 229 230 231
            q("BEGIN")
            for token, x in q("SELECT token, date FROM token"):
                if x <= old:
                    q("DELETE FROM token WHERE token=?", (token,))
                elif not_after is None or x < not_after:
                    not_after = x
232
            for cert, prefix, email in self.iterCert():
233
                x = x509.notAfter(cert)
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
                if x <= old:
                    if prefix == self.prefix:
                        logging.critical("Refuse to delete certificate"
                                         " of main node: wrong clock ?")
                        sys.exit(1)
                    logging.info("Delete %s: %s (invalid since %s)",
                        "certificate requested by '%s'" % email
                        if email else "anonymous certificate",
                        ", ".join("%s=%s" % x for x in
                                  cert.get_subject().get_components()),
                        datetime.utcfromtimestamp(x).isoformat())
                    q("UPDATE cert SET email=null, cert=null WHERE prefix=?",
                      (prefix,))
                elif not_after is None or x < not_after:
                    not_after = x
            # TODO: reduce 'cert' table by merging free slots
250 251
            #       (IOW, do the contrary of newPrefix)
            self.timeout = not_after and not_after + GRACE_PERIOD
252

253
    def handle_request(self, request, method, kw):
254
        m = getattr(self, method)
255
        if hasattr(m, '_private'):
256
            authorized_origin =  self.config.authorized_origin
257
            x_forwarded_for = request.headers.get('X-Forwarded-For')
258 259
            if request.client_address[0] not in authorized_origin or \
               x_forwarded_for and x_forwarded_for not in authorized_origin:
260
                return request.send_error(httplib.FORBIDDEN)
261 262 263 264 265
        key = m.getcallargs(**kw).get('cn')
        if key:
            h = base64.b64decode(request.headers[HMAC_HEADER])
            with self.lock:
                session = self.sessions[key]
266
                for key, protocol in session:
267 268 269 270 271
                    if h == hmac.HMAC(key, request.path, hashlib.sha1).digest():
                        break
                else:
                    raise Exception("Wrong HMAC")
                key = hashlib.sha1(key).digest()
272
                session[:] = (hashlib.sha1(key).digest(), protocol),
273 274 275 276 277 278 279 280 281
        else:
            logging.info("%s%s: %s, %s",
                method,
                '(' + utils.ipFromBin(x509.networkFromCa(self.cert.ca)
                                      + kw["client_prefix"])
                + ')' if method == 'hello' else '',
                request.headers.get("X-Forwarded-For") or
                request.headers.get("host"),
                request.headers.get("user-agent"))
282 283
        try:
            result = m(**kw)
284 285
        except HTTPError, e:
            return request.send_error(*e.args)
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
        except:
            logging.warning(request.requestline, exc_info=1)
            return request.send_error(httplib.INTERNAL_SERVER_ERROR)
        if result:
            request.send_response(httplib.OK)
            request.send_header("Content-Length", str(len(result)))
        else:
            request.send_response(httplib.NO_CONTENT)
        if key:
            request.send_header(HMAC_HEADER, base64.b64encode(
                hmac.HMAC(key, result, hashlib.sha1).digest()))
        request.end_headers()
        if result:
            request.wfile.write(result)

301 302 303 304
    def getPeerProtocol(self, cn):
        session, = self.sessions[cn]
        return session[1]

305
    @rpc
306
    def hello(self, client_prefix, protocol='1'):
307
        with self.lock:
308
            cert = self.getCert(client_prefix)
309
            key = utils.newHmacSecret()
310
            self.sessions.setdefault(client_prefix, [])[1:] = (key, int(protocol)),
311 312
        key = x509.encrypt(cert, key)
        sign = self.cert.sign(key)
313 314 315
        assert len(key) == len(sign)
        return key + sign

316
    def getCert(self, client_prefix):
317
        assert self.lock.locked()
318 319
        return self.db.execute("SELECT cert FROM cert"
                               " WHERE prefix=? AND cert IS NOT NULL",
320 321
                               (client_prefix,)).next()[0]

322 323 324 325 326 327 328 329 330 331 332 333 334 335
    @rpc_private
    def isToken(self, token):
        with self.lock:
            if self.db.execute("SELECT 1 FROM token WHERE token = ?",
                               (token,)).fetchone():
                return "1"

    @rpc_private
    def deleteToken(self, token):
        with self.lock:
            self.db.execute("DELETE FROM token WHERE token = ?", (token,))

    @rpc_private
    def addToken(self, email, token):
336 337 338
        prefix_len = self.config.prefix_length
        if not prefix_len:
            raise HTTPError(httplib.FORBIDDEN)
339
        request = token is None
340 341 342
        with self.lock:
            while True:
                # Generating token
343 344
                if request:
                    token = ''.join(random.sample(string.ascii_lowercase, 8))
345
                args = token, email, prefix_len, int(time.time())
346 347 348 349 350
                # Updating database
                try:
                    self.db.execute("INSERT INTO token VALUES (?,?,?,?)", args)
                    break
                except sqlite3.IntegrityError:
351 352
                    if not request:
                        raise HTTPError(httplib.CONFLICT)
353
            self.timeout = 1
354 355 356 357 358 359 360 361 362
        if request:
            return token

    @rpc
    def requestToken(self, email):
        if not self.config.mailhost:
            raise HTTPError(httplib.FORBIDDEN)

        token = self.addToken(email, None)
363 364 365 366 367

        # Creating and sending email
        msg = MIMEText('Hello, your token to join re6st network is: %s\n'
                       % token)
        msg['Subject'] = '[re6stnet] Token Request'
368 369
        if self.email:
            msg['From'] = self.email
370 371 372 373 374 375 376 377 378 379 380
        msg['To'] = email
        if os.path.isabs(self.config.mailhost) or \
           os.path.isfile(self.config.mailhost):
            with self.lock:
                m = mailbox.mbox(self.config.mailhost)
                try:
                    m.add(msg)
                finally:
                    m.close()
        else:
            s = smtplib.SMTP(self.config.mailhost)
381 382 383 384
            if self.config.smtp_starttls:
                s.starttls()
            if self.config.smtp_user:
                s.login(self.config.smtp_user, self.config.smtp_pwd)
385
            s.sendmail(self.email, email, msg.as_string())
386 387
            s.quit()

388
    def newPrefix(self, prefix_len):
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
        max_len = 128 - len(self.network)
        assert 0 < prefix_len <= max_len
        try:
            prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null
                                         ORDER BY length(prefix) DESC""", (prefix_len,)).next()
        except StopIteration:
            logging.error('No more free /%u prefix available', prefix_len)
            raise
        while len(prefix) < prefix_len:
            self.db.execute("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
            prefix += '0'
            self.db.execute("INSERT INTO cert VALUES (?,null,null)", (prefix,))
        if len(prefix) < max_len or '1' in prefix:
            return prefix
        self.db.execute("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
404
        return self.newPrefix(prefix_len)
405

406
    @rpc
407 408 409 410
    def requestCertificate(self, token, req):
        req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req)
        with self.lock:
            with self.db:
411
                if token:
412 413
                    if not self.config.prefix_length:
                        raise HTTPError(httplib.FORBIDDEN)
414 415 416 417 418 419
                    try:
                        token, email, prefix_len, _ = self.db.execute(
                            "SELECT * FROM token WHERE token = ?",
                            (token,)).next()
                    except StopIteration:
                        return
420 421
                    self.db.execute("DELETE FROM token WHERE token = ?",
                                    (token,))
422 423 424
                else:
                    prefix_len = self.config.anonymous_prefix_length
                    if not prefix_len:
425
                        raise HTTPError(httplib.FORBIDDEN)
426
                    email = None
427
                prefix = self.newPrefix(prefix_len)
428 429
                self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?",
                                (email, prefix))
430 431
                if self.prefix is None:
                    self.prefix = prefix
432
                    self.setConfig('prefix', prefix)
433
                    self.updateNetworkConfig()
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
                subject = req.get_subject()
                subject.serialNumber = str(self.getSubjectSerial())
                return self.createCertificate(prefix, subject, req.get_pubkey())

    def getSubjectSerial(self):
        # Smallest unique number, for IPv4 support.
        serials = []
        for x in self.iterCert():
            serial = x[0].get_subject().serialNumber
            if serial:
                serials.append(int(serial))
        serials.sort()
        for serial, x in enumerate(serials):
            if serial != x:
                return serial
        return len(serials)

    def createCertificate(self, client_prefix, subject, pubkey, not_after=None):
452 453
        cert = crypto.X509()
        cert.gmtime_adj_notBefore(0)
454 455 456 457
        if not_after:
            cert.set_notAfter(not_after)
        else:
            cert.gmtime_adj_notAfter(self.cert_duration)
458
        cert.set_issuer(self.cert.ca.get_subject())
459 460 461
        subject.CN = "%u/%u" % (int(client_prefix, 2), len(client_prefix))
        cert.set_subject(subject)
        cert.set_pubkey(pubkey)
462 463
        # Certificate serial, for revocation support. Contrary to
        # subject serial, it does not need to be as small as possible.
464 465
        serial = 1 + self.getConfig('serial', 0)
        self.setConfig('serial', serial)
466
        cert.set_serial_number(serial)
467
        cert.sign(self.cert.key, 'sha512')
468 469 470
        cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
        self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?",
                        (cert, client_prefix))
471
        self.timeout = 1
472 473
        return cert

474
    @rpc
475 476
    def renewCertificate(self, cn):
        with self.lock:
477
            with self.db as db:
478
                pem = self.getCert(cn)
479
                cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
480
                if x509.notAfter(cert) - RENEW_PERIOD < time.time():
481
                    not_after = None
482 483
                elif db.execute("SELECT count(*) FROM crl WHERE serial=?",
                                (cert.get_serial_number(),)).fetchone()[0]:
484
                    not_after = cert.get_notAfter()
485 486
                else:
                    return pem
487 488
                return self.createCertificate(cn,
                    cert.get_subject(), cert.get_pubkey(), not_after)
489

490
    @rpc
491
    def getCa(self):
492
        return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca)
493

494 495 496 497 498
    @rpc
    def getDh(self, cn):
        with open(self.config.dh) as f:
            return f.read()

499
    @rpc
500
    def getNetworkConfig(self, cn):
Killian Lufau's avatar
Killian Lufau committed
501 502 503 504 505 506 507 508
        with self.lock:
            cert = self.getCert(cn)
            config = self.network_config.copy()
            hmac = [self.getConfig(k, None) for k in BABEL_HMAC]
            for i, v in enumerate(v for v in hmac if v is not None):
                config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \
                    v and x509.encrypt(cert, v).encode('base64')
        return zlib.compress(json.dumps(config))
509

510 511 512 513 514 515 516 517 518 519 520 521 522 523
    def _queryAddress(self, peer):
        self.sendto(peer, 1)
        s = self.sock,
        timeout = 3
        end = timeout + time.time()
        # Loop because there may be answers from previous requests.
        while select.select(s, (), (), timeout)[0]:
            prefix, msg = self.recv(1)
            if prefix == peer:
                return msg
            timeout = max(0, end - time.time())
        logging.info("Timeout while querying address for %s/%s",
                     int(peer, 2), len(peer))

Joanne Hugé's avatar
Joanne Hugé committed
524 525 526 527
    @rpc
    def getCountry(self, cn, address):
        return self._geoiplookup(address)

528
    @rpc
529
    def getBootstrapPeer(self, cn):
530
        with self.peers_lock:
531
            age, peers = self.peers
532
            if age < time.time() or not peers:
533
                self.request_dump()
534 535 536 537 538
                peers = [prefix
                    for neigh_routes in self.ctl.neighbours.itervalues()
                    for prefix in neigh_routes[1]
                    if prefix]
                peers.append(self.prefix)
539
                random.shuffle(peers)
540
                self.peers = time.time() + 60, peers
541 542 543 544 545 546
            peer = peers.pop()
            if peer == cn:
                # Very unlikely (e.g. peer restarted with empty cache),
                # so don't bother looping over above code
                # (in case 'peers' is empty).
                peer = self.prefix
547
        with self.lock:
548 549
            msg = self._queryAddress(peer)
            if msg is None:
550
                return
551 552 553 554
            # Remove country for old nodes
            if self.getPeerProtocol(cn) < 7:
                msg = ';'.join(','.join(a.split(',')[:3])
                               for a in msg.split(';'))
555
            cert = self.getCert(cn)
556
        msg = "%s %s" % (peer, msg)
557
        logging.info("Sending bootstrap peer: %s", msg)
558
        return x509.encrypt(cert, msg)
559

560
    @rpc_private
561
    def revoke(self, cn_or_serial):
562
        with self.lock, self.db:
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
            q = self.db.execute
            try:
                serial = int(cn_or_serial)
            except ValueError:
                prefix = utils.binFromSubnet(cn_or_serial)
                cert = self.getCert(prefix)
                q("UPDATE cert SET email=null, cert=null WHERE prefix=?",
                  (prefix,))
                cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
                serial = cert.get_serial_number()
                self.sessions.pop(prefix, None)
            else:
                cert, = (cert for cert, prefix, email in self.iterCert()
                              if cert.get_serial_number() == serial)
            not_after = x509.notAfter(cert)
            if time.time() < not_after:
                q("INSERT INTO crl VALUES (?,?)", (serial, not_after))
                self.updateNetworkConfig()

Killian Lufau's avatar
Killian Lufau committed
582 583 584 585 586 587 588 589 590 591
    def newHMAC(self, i, key=None):
       if key is None:
          key = buffer(os.urandom(16))
       self.setConfig(BABEL_HMAC[i], key)

    def delHMAC(self, i):
       self.db.execute("DELETE FROM config WHERE name=?", (BABEL_HMAC[i],))

    @rpc_private
    def updateHMAC(self):
592
        with self.lock, self.db:
Killian Lufau's avatar
Killian Lufau committed
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
            hmac = [self.getConfig(BABEL_HMAC[i], None) for i in (0,1,2)]
            if hmac[0]:
                if hmac[1]:
                    self.newHMAC(2, hmac[0])
                    self.delHMAC(0)
                else:
                    self.newHMAC(1)
            elif hmac[1]:
                self.newHMAC(0, hmac[1])
                self.delHMAC(1)
                self.delHMAC(2)
            else:
                # Initialization of HMAC on the network
                self.newHMAC(1)
                self.newHMAC(2, '')
608
            self.increaseVersion()
Killian Lufau's avatar
Killian Lufau committed
609 610 611 612
            self.setConfig('version', buffer(self.version))
            self.network_config['version']  = self.version.encode('base64')
        self.sendto(self.prefix, 0)

613 614
    @rpc_private
    def getNodePrefix(self, email):
615
        with self.lock, self.db:
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
            try:
                cert, = self.db.execute("SELECT cert FROM cert WHERE email = ?",
                                        (email,)).next()
            except StopIteration:
                return
        certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
        return x509.subnetFromCert(certificate)

    @rpc_private
    def getIPv6Address(self, email):
        cn = self.getNodePrefix(email)
        if cn:
            return utils.ipFromBin(
                x509.networkFromCa(self.cert.ca)
                + utils.binFromSubnet(cn))

    @rpc_private
    def getIPv4Information(self, email):
        peer = self.getNodePrefix(email)
        if peer:
            peer = utils.binFromSubnet(peer)
            with self.peers_lock:
                self.request_dump()
                for neigh_routes in self.ctl.neighbours.itervalues():
                    for prefix in neigh_routes[1]:
                        if prefix == peer:
                            break
                else:
                    return
            logging.info("%s %s", email, peer)
            with self.lock:
                msg = self._queryAddress(peer)
            if msg:
                return msg.split(',')[0]

651
    @rpc_private
652 653 654
    def versions(self):
        with self.peers_lock:
            self.request_dump()
655
            peers = {prefix
656 657
                for neigh_routes in self.ctl.neighbours.itervalues()
                for prefix in neigh_routes[1]
658
                if prefix}
659 660 661 662 663 664 665
        peers.add(self.prefix)
        peer_dict = {}
        s = self.sock,
        with self.lock:
            while True:
                r, w, _ = select.select(s, s if peers else (), (), 3)
                if r:
666 667 668
                    prefix, ver = self.recv(4)
                    if prefix:
                        peer_dict[prefix] = ver
669
                if w:
670 671 672
                    prefix = peers.pop()
                    peer_dict[prefix] = None
                    self.sendto(prefix, 4)
673 674
                elif not r:
                    break
675
        return json.dumps(peer_dict)
676

677
    @rpc_private
678
    def topology(self):
679 680 681 682
        p = lambda p: '%s/%s' % (int(p, 2), len(p))
        peers = deque((p(self.prefix),))
        graph = defaultdict(set)
        s = self.sock,
683 684
        with self.lock:
            while True:
685
                r, w, _ = select.select(s, s if peers else (), (), 3)
686
                if r:
687 688 689 690 691 692 693 694 695 696 697 698 699 700
                    prefix, x = self.recv(5)
                    if prefix and x:
                        prefix = p(prefix)
                        x = x.split()
                        try:
                            n = int(x.pop(0))
                        except ValueError:
                            continue
                        if n <= len(x) and prefix not in x:
                            graph[prefix].update(x[:n])
                            peers += set(x).difference(graph)
                            for x in x[n:]:
                                graph[x].add(prefix)
                            graph[''].add(prefix)
701
                if w:
702
                    self.sendto(utils.binFromSubnet(peers.popleft()), 5)
703 704
                elif not r:
                    break
705
        return json.dumps({k: list(v) for k, v in graph.iteritems()})
706 707 708 709 710


class RegistryClient(object):

    _hmac = None
711
    user_agent = "re6stnet/%s, %s" % (version.version, platform.platform())
712

713 714
    def __init__(self, url, cert=None, auto_close=True):
        self.cert = cert
715 716 717 718 719
        self.auto_close = auto_close
        scheme, host = splittype(url)
        host, path = splithost(host)
        self._conn = dict(http=httplib.HTTPConnection,
                          https=httplib.HTTPSConnection,
720
                          )[scheme](unquote(host), timeout=60)
721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
        self._path = path.rstrip('/')

    def __getattr__(self, name):
        getcallargs = getattr(RegistryServer, name).getcallargs
        def rpc(*args, **kw):
            kw = getcallargs(*args, **kw)
            query = '/' + name
            if kw:
                if any(type(v) is not str for v in kw.itervalues()):
                    raise TypeError
                query += '?' + urlencode(kw)
            url = self._path + query
            client_prefix = kw.get('cn')
            retry = True
            try:
                while retry:
                    if client_prefix:
                        key = self._hmac
                        if not key:
                            retry = False
741
                            h = self.hello(client_prefix, str(version.protocol))
742
                            n = len(h) // 2
743 744
                            self.cert.verify(h[n:], h[:n])
                            key = self.cert.decrypt(h[:n])
745 746 747 748 749 750
                        h = hmac.HMAC(key, query, hashlib.sha1).digest()
                        key = hashlib.sha1(key).digest()
                        self._hmac = hashlib.sha1(key).digest()
                    else:
                        retry = False
                    self._conn.putrequest('GET', url, skip_accept_encoding=1)
751
                    self._conn.putheader('User-Agent', self.user_agent)
752 753 754 755 756
                    if client_prefix:
                        self._conn.putheader(HMAC_HEADER, base64.b64encode(h))
                    self._conn.endheaders()
                    response = self._conn.getresponse()
                    body = response.read()
757 758 759 760 761 762 763 764 765 766 767 768 769
                    if response.status in (httplib.OK, httplib.NO_CONTENT):
                        if (not client_prefix or
                                hmac.HMAC(key, body, hashlib.sha1).digest() ==
                                base64.b64decode(response.msg[HMAC_HEADER])):
                            if self.auto_close and name != 'hello':
                                self._conn.close()
                            return body
                    elif response.status == httplib.FORBIDDEN:
                        # XXX: We should improve error handling, while making
                        #      sure re6st nodes don't crash on temporary errors.
                        #      This is currently good enough for re6st-conf, to
                        #      inform the user when registration is disabled.
                        raise HTTPError(response.status, response.reason)
770 771
                    if client_prefix:
                        self._hmac = None
772 773
            except HTTPError:
                raise
774 775 776 777 778 779 780 781
            except Exception:
                logging.info(url, exc_info=1)
            else:
                logging.info('%s\nUnexpected response %s %s',
                             url, response.status, response.reason)
            self._conn.close()
        setattr(self, name, rpc)
        return rpc