cache.py 10.9 KB
Newer Older
1
import json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
2
from itertools import chain
3
from .registry import RegistryClient
4
from . import utils, version, x509
5

6
class Cache(object):
Guillaume Bury's avatar
Guillaume Bury committed
7

8 9
    def __init__(self, db_path, registry, cert, db_size=200):
        self._prefix = cert.prefix
10
        self._db_size = db_size
11 12
        self._decrypt = cert.decrypt
        self._registry = RegistryClient(registry, cert)
13

14
        logging.info('Initialize cache ...')
15 16 17 18 19 20
        try:
            self._db = self._open(db_path)
        except sqlite3.OperationalError:
            logging.exception("Start with empty cache")
            os.rename(db_path, db_path + '.bak')
            self._db = self._open(db_path)
21 22 23
        q = self._db.execute
        q('ATTACH DATABASE ":memory:" AS volatile')
        q("""CREATE TABLE volatile.stat (
24
            peer TEXT PRIMARY KEY NOT NULL,
25 26
            try INTEGER NOT NULL DEFAULT 0)""")
        q("CREATE INDEX volatile.stat_try ON stat(try)")
27
        q("INSERT INTO volatile.stat (peer)"
28
          " SELECT prefix FROM peer WHERE prefix != ''")
29
        self._db.commit()
30
        self._loadConfig(self._selectConfig(q))
31
        try:
32 33
            cert.verifyVersion(self.version)
        except (AttributeError, x509.VerifyError):
34
            retry = 1
35 36 37 38 39 40 41 42 43 44
            while not self.updateConfig():
                time.sleep(retry)
                retry = min(60, retry * 2)
        else:
            if (# re6stnet upgraded after being unused  for a long time.
                self.protocol < version.protocol
                # Always query the registry at startup in case we were down
                # when it tried to send us new parameters.
                or self._prefix == self.registry_prefix):
                self.updateConfig()
45
        self.next_renew = cert.maybeRenew(self._registry, self.crl)
46 47 48 49 50 51
        if version.protocol < self.min_protocol:
            logging.critical("Your version of re6stnet is too old."
                             " Please update.")
            sys.exit(1)
        self.warnProtocol()
        logging.info("Cache initialized.")
52

53 54 55 56 57 58 59 60 61 62 63 64 65
    def _open(self, path):
        db = sqlite3.connect(path, isolation_level=None)
        db.text_factory = str
        db.execute("PRAGMA synchronous = OFF")
        db.execute("PRAGMA journal_mode = MEMORY")
        utils.sqliteCreateTable(db, "peer",
            "prefix TEXT PRIMARY KEY NOT NULL",
            "address TEXT NOT NULL")
        utils.sqliteCreateTable(db, "config",
            "name TEXT PRIMARY KEY NOT NULL",
            "value")
        return db

66 67 68 69 70
    @staticmethod
    def _selectConfig(execute): # BBB: blob
        return ((k, str(v) if type(v) is buffer else v)
            for k, v in execute("SELECT * FROM config"))

71 72 73
    def _loadConfig(self, config):
        cls = self.__class__
        logging.debug("Loading network parameters:")
74
        self.crl = self.same_country = ()
75
        for k, v in config:
76 77 78 79 80 81 82 83
            if k == 'crl': # BBB
                k = 'crl:json'
            if k.endswith(':json'):
                k = k[:-5]
                v = json.loads(v)
                if k == 'crl':
                    v = set(v)
            if hasattr(cls, k):
84 85
                continue
            setattr(self, k, v)
86 87 88 89 90 91
            logging.debug("- %s: %r", k, v)

    def updateConfig(self):
        logging.info("Getting new network parameters from registry...")
        try:
            # TODO: When possible, the registry should be queried via the re6st.
92
            x = json.loads(zlib.decompress(
93
                self._registry.getNetworkConfig(self._prefix)))
94 95 96 97
            base64 = x.pop('', ())
            config = {}
            for k, v in x.iteritems():
                k = str(k)
Killian Lufau's avatar
Killian Lufau committed
98 99 100 101
                if k.startswith('babel_hmac'):
                    if v:
                        v = self._decrypt(v.decode('base64'))
                elif k in base64:
102 103 104 105 106 107 108
                    v = v.decode('base64')
                elif type(v) is unicode:
                    v = str(v)
                elif isinstance(v, (list, dict)):
                    k += ':json'
                    v = json.dumps(v)
                config[k] = v
109 110 111 112 113 114 115 116 117 118 119 120 121
        except socket.error, e:
            logging.warning(e)
            return
        except Exception:
            # Even if the response is authenticated, a mistake on the registry
            # should not kill the whole network in a few seconds.
            logging.exception("buggy registry ?")
            return
        # XXX: check version ?
        self.delay_restart = config.pop("delay_restart", 0)
        old = {}
        with self._db as db:
            remove = []
122
            for k, v in self._selectConfig(db.execute):
123 124 125 126
                if k in config:
                    old[k] = v
                    continue
                try:
127
                    delattr(self, k[:-5] if k.endswith(':json') else k)
128 129 130 131 132
                except AttributeError:
                    pass
                remove.append(k)
            db.execute("DELETE FROM config WHERE name in ('%s')"
                       % "','".join(remove))
133 134
            # BBB: Use buffer because of http://bugs.python.org/issue13676
            #      on Python 2.6
135
            db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)",
Killian Lufau's avatar
Killian Lufau committed
136 137
                           ((k, buffer(v) if k in base64 or
                             k.startswith('babel_hmac') else v)
138
                            for k, v in config.iteritems()))
139
        self._loadConfig(config.iteritems())
140
        return [k[:-5] if k.endswith(':json') else k
Killian Lufau's avatar
Killian Lufau committed
141 142 143
                for k in chain(remove, (k
                    for k, v in config.iteritems()
                    if k not in old or old[k] != v))]
144 145 146 147 148 149

    def warnProtocol(self):
        if version.protocol < self.protocol:
            logging.warning("There's a new version of re6stnet:"
                            " you should update.")

150
    def getDh(self, path):
151 152 153 154
        # We'd like to do a full check here but
        #   from OpenSSL import SSL
        #   SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
        # segfaults if file is corrupted.
155 156 157 158 159
        if not os.path.exists(path):
            retry = 1
            while True:
                try:
                    dh = self._registry.getDh(self._prefix)
160 161 162 163 164 165 166 167 168 169
                    if dh:
                        break
                    e = None
                except socket.error:
                    e = sys.exc_info()
                logging.warning(
                    "Failed to get DH parameters from the registry."
                    " Will retry in %s seconds", retry, exc_info=e)
                time.sleep(retry)
                retry = min(60, retry * 2)
170 171 172
            with open(path, "wb") as f:
                f.write(dh)

173 174 175 176 177 178 179 180 181
    def log(self):
        if logging.getLogger().isEnabledFor(5):
            logging.trace("Cache:")
            for prefix, address, _try in self._db.execute(
                    "SELECT peer.*, try FROM peer, volatile.stat"
                    " WHERE prefix=peer ORDER BY prefix"):
                logging.trace("- %s: %s%s", prefix, address,
                              ' (blacklisted)' if _try else '')

182 183 184 185 186 187 188 189 190 191 192 193 194
    def cacheMinimize(self, size):
        with self._db:
            self._cacheMinimize(size)

    def _cacheMinimize(self, size):
        a = self._db.execute(
            "SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1",
            (size,)).fetchall()
        if a:
            q = self._db.executemany
            q("DELETE FROM peer WHERE prefix IN (?)", a)
            q("DELETE FROM volatile.stat WHERE peer IN (?)", a)

195 196 197 198 199 200 201 202 203 204 205 206 207
    def connecting(self, prefix, connecting):
        self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?",
                         (connecting, prefix))

    def resetConnecting(self):
        self._db.execute("UPDATE volatile.stat SET try=0")

    def getAddress(self, prefix):
        r = self._db.execute("SELECT address FROM peer, volatile.stat"
                             " WHERE prefix=? AND prefix=peer AND try=0",
                             (prefix,)).fetchone()
        return r and r[0]

208 209
    @property
    def my_address(self):
210
        for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"):
211 212 213
            return x

    @my_address.setter
214 215 216 217 218 219 220
    def my_address(self, value):
        if value:
            with self._db as db:
                db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)",
                           (value,))
        else:
            del self.my_address
221 222 223 224

    @my_address.deleter
    def my_address(self):
        with self._db as db:
225
            db.execute("DELETE FROM peer WHERE prefix=''")
226

227 228 229 230 231 232 233 234 235 236
    # Exclude our own address from results in case it is there, which may
    # happen if a node change its certificate without clearing the cache.
    # IOW, one should probably always put our own address there.
    _get_peer_sql = "SELECT %s FROM peer, volatile.stat" \
                    " WHERE prefix=peer AND prefix!=? AND try=?"
    def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address"
                                                        + " ORDER BY RANDOM()"):
        return self._db.execute(__sql, (self._prefix, failed))
    def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)"):
        return self._db.execute(__sql, (self._prefix, failed)).next()[0]
237 238

    def getBootstrapPeer(self):
239
        logging.info('Getting Boot peer...')
240
        try:
241
            bootpeer = self._registry.getBootstrapPeer(self._prefix)
242
            prefix, address = self._decrypt(bootpeer).split()
243
        except (socket.error, subprocess.CalledProcessError, ValueError), e:
244 245
            logging.warning('Failed to bootstrap (%s)',
                            e if bootpeer else 'no peer returned')
246
        else:
Julien Muchembled's avatar
Julien Muchembled committed
247 248 249
            if prefix != self._prefix:
                self.addPeer(prefix, address)
                return prefix, address
250 251
            logging.warning('Buggy registry sent us our own address')

252
    def addPeer(self, prefix, address, set_preferred=False):
253 254 255 256 257
        logging.debug('Adding peer %s: %s', prefix, address)
        with self._db:
            q = self._db.execute
            try:
                (a,), = q("SELECT address FROM peer WHERE prefix=?", (prefix,))
258 259 260 261 262 263 264 265 266 267 268
                if set_preferred:
                    preferred = address.split(';')
                    address = a
                else:
                    preferred = a.split(';')
                def key(a):
                    try:
                        return preferred.index(a)
                    except ValueError:
                        return len(preferred)
                address = ';'.join(sorted(address.split(';'), key=key))
269
            except ValueError:
270 271
                self._cacheMinimize(self._db_size)
                a = None
272
            if a != address:
273 274
                q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
            q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))
275 276 277 278 279 280

    def getCountry(self, ip):
        try:
            return self._registry.getCountry(self._prefix, ip)
        except socket.error, e:
            logging.warning('Failed to get country (%s)', ip)