mysqldb.py 38.7 KB
Newer Older
Aurel's avatar
Aurel committed
1
#
2
# Copyright (C) 2006-2019  Nexedi SA
3
#
Aurel's avatar
Aurel committed
4 5 6 7
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
8
#
Aurel's avatar
Aurel committed
9 10 11 12 13 14
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
Aurel's avatar
Aurel committed
16

17
from binascii import a2b_hex
18
from collections import OrderedDict
19
from functools import wraps
Yoshinori Okuji's avatar
Yoshinori Okuji committed
20
import MySQLdb
21 22
from MySQLdb import DataError, IntegrityError, \
    OperationalError, ProgrammingError
Yoshinori Okuji's avatar
Yoshinori Okuji committed
23
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
24
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
Julien Muchembled's avatar
Julien Muchembled committed
25 26 27
# BBB: the following 2 constants were added to mysqlclient 1.3.8
DROP_LAST_PARTITION = 1508
SAME_NAME_PARTITION = 1517
28
from array import array
29
from hashlib import sha1
30
import os
31
import re
32
import string
33
import struct
34
import sys
35
import time
Yoshinori Okuji's avatar
Yoshinori Okuji committed
36

37
from . import LOG_QUERIES, DatabaseFailure
38
from .manager import DatabaseManager, splitOIDField
39
from neo.lib import logging, util
40
from neo.lib.interfaces import implements
41
from neo.lib.protocol import ZERO_OID, ZERO_TID, ZERO_HASH
Yoshinori Okuji's avatar
Yoshinori Okuji committed
42

43

44 45 46 47 48 49 50 51 52 53 54 55 56 57
class MysqlError(DatabaseFailure):

    def __init__(self, exc, query=None):
        self.exc = exc
        self.query = query

    code = property(lambda self: self.exc.args[0])

    def __str__(self):
        msg = 'MySQL error %s: %s' % self.exc.args
        return msg if self.query is None else '%s\nQuery: %s' % (
            msg, getPrintableQuery(self.query[:1000]))


58 59 60 61
def getPrintableQuery(query, max=70):
    return ''.join(c if c in string.printable and c not in '\t\x0b\x0c\r'
        else '\\x%02x' % ord(c) for c in query)

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
def auto_reconnect(wrapped):
    def wrapper(self, *args):
        # Try 3 times at most. When it fails too often for the same
        # query then the disconnection is likely caused by this query.
        # We don't want to enter into an infinite loop.
        retry = 2
        while 1:
            try:
                return wrapped(self, *args)
            except OperationalError as m:
                # IDEA: Is it safe to retry in case of DISK_FULL ?
                # XXX:  However, this would another case of failure that would
                #       be unnoticed by other nodes (ADMIN & MASTER). When
                #       there are replicas, it may be preferred to not retry.
                if (self._active
                    or SERVER_GONE_ERROR != m.args[0] != SERVER_LOST
                    or not retry):
79 80 81
                    if self.LOCK:
                        raise MysqlError(m, *args)
                    raise # caught upper for secondary connections
82 83 84 85 86 87
                logging.info('the MySQL server is gone; reconnecting')
                assert not self._deferred
                self.close()
                retry -= 1
    return wraps(wrapped)(wrapper)

88

89
@implements
90 91
class MySQLDatabaseManager(DatabaseManager):
    """This class manages a database on MySQL."""
Yoshinori Okuji's avatar
Yoshinori Okuji committed
92

93
    VERSION = 3
94
    ENGINES = "InnoDB", "RocksDB", "TokuDB"
95
    _engine = ENGINES[0] # default engine
96

97 98
    _use_partition = False

99 100
    _max_allowed_packet = 32769 * 1024

101 102
    def _parse(self, database):
        """ Get the database credentials (username, password, database) """
103
        # expected pattern : [user[:password]@]database[(~|.|/)unix_socket]
104
        self.user, self.passwd, self.db, self.socket = re.match(
105
            '(?:([^:]+)(?::(.*))?@)?([^~./]+)(.+)?$', database).groups()
Yoshinori Okuji's avatar
Yoshinori Okuji committed
106

107
    def _close(self):
108 109 110 111 112 113 114 115 116
        try:
            conn = self.__dict__.pop('conn')
        except KeyError:
            return
        conn.close()

    def __getattr__(self, attr):
        if attr == 'conn':
            self._tryConnect()
117
        return super(MySQLDatabaseManager, self).__getattr__(attr)
118

119
    def _tryConnect(self):
120 121 122 123 124
        kwd = {'db' : self.db}
        if self.user:
            kwd['user'] = self.user
            if self.passwd is not None:
                kwd['passwd'] = self.passwd
125
        if self.socket:
126
            kwd['unix_socket'] = os.path.expanduser(self.socket)
127
        logging.info('connecting to MySQL on the database %s with user %s',
Yoshinori Okuji's avatar
Yoshinori Okuji committed
128
                     self.db, self.user)
129
        self._active = 0
130 131 132 133
        if self._wait < 0:
            timeout_at = None
        else:
            timeout_at = time.time() + self._wait
134
        last = None
135 136 137
        while True:
            try:
                self.conn = MySQLdb.connect(**kwd)
138
                break
139 140
            except Exception as e:
                if None is not timeout_at <= time.time():
141
                    raise
142 143 144 145 146 147 148
                e = str(e)
                if last == e:
                    log = logging.debug
                else:
                    last = e
                    log = logging.exception
                log('Connection to MySQL failed, retrying.')
149
                time.sleep(1)
150
        self._config = {}
151 152 153 154
        conn = self.conn
        conn.autocommit(False)
        conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
        conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
155 156 157 158 159 160 161 162 163 164 165
        def query(sql):
            conn.query(sql)
            r = conn.store_result()
            return r.fetch_row(r.num_rows())
        if self.LOCK:
            (locked,), = query("SELECT GET_LOCK('%s.%s', 0)"
                % (self.db, self.LOCK))
            if not locked:
                sys.exit(self.LOCKED)
        (name, value), = query(
            "SHOW VARIABLES WHERE variable_name='max_allowed_packet'")
166 167 168 169 170
        if int(value) < self._max_allowed_packet:
            raise DatabaseFailure("Global variable %r is too small."
                " Minimal value must be %uk."
                % (name, self._max_allowed_packet // 1024))
        self._max_allowed_packet = int(value)
171 172 173 174 175 176 177
        try:
            self._dedup = bool(query(
                "SHOW INDEX FROM data WHERE key_name='hash'"))
        except ProgrammingError as e:
            if e.args[0] != NO_SUCH_TABLE:
                raise
            self._dedup = None
178 179 180 181
        if not self.LOCK:
            # Prevent automatic reconnection for secondary connections.
            self._active = 1
            self._commit = self.conn.commit
182

183 184
    _connect = auto_reconnect(_tryConnect)

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    def autoReconnect(self, f):
        assert self._active and not self.LOCK
        @auto_reconnect
        def try_once(self):
            if self._active:
                try:
                    f()
                finally:
                    self._active = 0
                return True
        while not try_once(self):
            # Avoid reconnecting too often.
            # Since this is used to wrap an arbitrary long process and
            # not just a single query, we can't limit the number of retries.
            time.sleep(5)
            self._connect()

202
    def _commit(self):
203
        # XXX: Should we translate OperationalError into MysqlError ?
204
        self.conn.commit()
205
        self._active = 0
Yoshinori Okuji's avatar
Yoshinori Okuji committed
206

207
    @auto_reconnect
Yoshinori Okuji's avatar
Yoshinori Okuji committed
208 209
    def query(self, query):
        """Query data from a database."""
210 211 212
        if LOG_QUERIES:
            logging.debug('querying %s...',
                getPrintableQuery(query.split('\n', 1)[0][:70]))
213 214 215 216 217 218 219 220
        conn = self.conn
        conn.query(query)
        if query.startswith("SELECT "):
            r = conn.store_result()
            return tuple([
                tuple([d.tostring() if isinstance(d, array) else d
                      for d in row])
                for row in r.fetch_row(r.num_rows())])
221 222 223 224
        r = query.split(None, 1)[0]
        if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"):
            self._active = 1
        else:
225
            assert r in ("ALTER", "CREATE", "DROP"), query
226

227 228
    @property
    def escape(self):
229
        """Escape special characters in a string."""
230
        return self.conn.escape_string
231

232 233 234 235 236
    def _getDevPath(self):
        # BBB: MySQL is moving to Performance Schema.
        return self.query("SELECT * FROM information_schema.global_variables"
                          " WHERE variable_name='datadir'")[0][1]

237
    def erase(self):
238 239
        self.query("DROP TABLE IF EXISTS"
            " config, pt, trans, obj, data, bigdata, ttrans, tobj")
240

241 242 243
    def nonempty(self, table):
        try:
            return bool(self.query("SELECT 1 FROM %s LIMIT 1" % table))
Julien Muchembled's avatar
Julien Muchembled committed
244 245
        except ProgrammingError as e:
            if e.args[0] != NO_SUCH_TABLE:
246 247
                raise

248 249 250 251 252 253 254 255 256 257 258 259 260
    def _alterTable(self, schema_dict, table, select="*"):
        q = self.query
        new = 'new_' + table
        if self.nonempty(table) is None:
            if self.nonempty(new) is None:
                return
        else:
            q("DROP TABLE IF EXISTS " + new)
            q(schema_dict.pop(table) % new
              + " SELECT %s FROM %s" % (select, table))
            q("DROP TABLE " + table)
        q("ALTER TABLE %s RENAME TO %s" % (new, table))

261
    def _migrate1(self, _):
262 263 264
        self._checkNoUnfinishedTransactions()
        self.query("DROP TABLE IF EXISTS ttrans")

265
    def _migrate2(self, schema_dict):
266
        self._alterTable(schema_dict, 'obj')
267

268 269 270 271 272 273 274 275
    def _migrate3(self, schema_dict):
        self._alterTable(schema_dict, 'pt', "rid as `partition`, nid,"
            " CASE state"
            " WHEN 0 THEN -1"  # UP_TO_DATE
            " WHEN 2 THEN -2"  # FEEDING
            " ELSE 1-state"
            " END as tid")

276 277 278 279 280 281
    # Let's wait for a more important change to clean up,
    # so that users can still downgrade.
    if 0:
      def _migrate4(self, schema_dict):
        self._setConfiguration('partitions', None)

282
    def _setup(self, dedup=False):
283
        self._config.clear()
284
        q = self.query
285
        p = engine = self._engine
286
        schema_dict = OrderedDict()
287

288 289 290 291 292 293
        # The table "config" stores configuration
        # parameters which affect the persistent data.
        schema_dict['config'] = """CREATE TABLE %s (
                  name VARBINARY(255) NOT NULL PRIMARY KEY,
                  value VARBINARY(255) NULL
              ) ENGINE=""" + engine
294 295

        # The table "pt" stores a partition table.
296
        schema_dict['pt'] = """CREATE TABLE %s (
297
                 `partition` SMALLINT UNSIGNED NOT NULL,
298
                 nid INT NOT NULL,
299 300
                 tid BIGINT NOT NULL,
                 PRIMARY KEY (`partition`, nid)
301
             ) ENGINE=""" + engine
302

303 304 305
        if self._use_partition:
            p += """ PARTITION BY LIST (`partition`) (
                PARTITION dummy VALUES IN (NULL))"""
306

307 308 309 310 311 312
        if engine == "RocksDB":
            cf = lambda name, rev=False: " COMMENT '%scf_neo_%s'" % (
                'rev:' if rev else '', name)
        else:
            cf = lambda *_: ''

313
        # The table "trans" stores information on committed transactions.
314
        schema_dict['trans'] =  """CREATE TABLE %s (
315
                 `partition` SMALLINT UNSIGNED NOT NULL,
316
                 tid BIGINT UNSIGNED NOT NULL,
317
                 packed BOOLEAN NOT NULL,
318 319
                 oids MEDIUMBLOB NOT NULL,
                 user BLOB NOT NULL,
Yoshinori Okuji's avatar
Yoshinori Okuji committed
320
                 description BLOB NOT NULL,
321
                 ext BLOB NOT NULL,
322
                 ttid BIGINT UNSIGNED NOT NULL,
323 324
                 PRIMARY KEY (`partition`, tid){}
             ) ENGINE={}""".format(cf('append_meta'), p)
325

326
        # The table "obj" stores committed object metadata.
327
        schema_dict['obj'] = """CREATE TABLE %s (
328
                 `partition` SMALLINT UNSIGNED NOT NULL,
329
                 oid BIGINT UNSIGNED NOT NULL,
330
                 tid BIGINT UNSIGNED NOT NULL,
331
                 data_id BIGINT UNSIGNED NULL,
332
                 value_tid BIGINT UNSIGNED NULL,
333 334 335 336 337
                 PRIMARY KEY (`partition`, oid, tid){},
                 KEY tid (`partition`, tid, oid){},
                 KEY (data_id){}
             ) ENGINE={}""".format(cf('obj_pk', True),
                 cf('append_meta'), cf('append_meta'), p)
338 339 340

        if engine == "TokuDB":
            engine += " compression='tokudb_uncompressed'"
341

342
        # The table "data" stores object data.
343
        # We'd like to have partial index on 'hash' column (e.g. hash(4))
344
        # but 'UNIQUE' constraint would not work as expected.
345 346
        schema_dict['data'] = """CREATE TABLE %s (
                 id BIGINT UNSIGNED NOT NULL,
347
                 hash BINARY(20) NOT NULL,
348
                 compression TINYINT UNSIGNED NULL,
349 350 351 352 353
                 value MEDIUMBLOB NOT NULL,
                 PRIMARY KEY (id){}{}
             ) ENGINE={}""".format(cf('append'), """,
                 UNIQUE (hash, compression)""" + cf('no_comp') if dedup else "",
                 engine)
354

355
        schema_dict['bigdata'] = """CREATE TABLE %s (
356 357 358 359
                 id INT UNSIGNED NOT NULL AUTO_INCREMENT,
                 value MEDIUMBLOB NOT NULL,
                 PRIMARY KEY (id){}
             ) ENGINE={}""".format(cf('append'), p)
360

361
        # The table "ttrans" stores information on uncommitted transactions.
362
        schema_dict['ttrans'] = """CREATE TABLE %s (
363
                 `partition` SMALLINT UNSIGNED NOT NULL,
364
                 tid BIGINT UNSIGNED,
365
                 packed BOOLEAN NOT NULL,
366 367
                 oids MEDIUMBLOB NOT NULL,
                 user BLOB NOT NULL,
Yoshinori Okuji's avatar
Yoshinori Okuji committed
368
                 description BLOB NOT NULL,
369
                 ext BLOB NOT NULL,
370 371 372
                 ttid BIGINT UNSIGNED NOT NULL,
                 PRIMARY KEY (ttid){}
             ) ENGINE={}""".format(cf('no_comp'), p)
373

374
        # The table "tobj" stores uncommitted object metadata.
375
        schema_dict['tobj'] = """CREATE TABLE %s (
376
                 `partition` SMALLINT UNSIGNED NOT NULL,
377
                 oid BIGINT UNSIGNED NOT NULL,
378
                 tid BIGINT UNSIGNED NOT NULL,
379
                 data_id BIGINT UNSIGNED NULL,
380
                 value_tid BIGINT UNSIGNED NULL,
381 382
                 PRIMARY KEY (tid, oid){}
             ) ENGINE={}""".format(cf('no_comp'), p)
383 384 385 386 387 388 389 390 391

        if self.nonempty('config') is None:
            q(schema_dict.pop('config') % 'config')
            self._setConfiguration('version', self.VERSION)
        else:
            self.migrate(schema_dict)

        for table, schema in schema_dict.iteritems():
            q(schema % ('IF NOT EXISTS ' + table))
392

393 394 395
        if self._dedup is None:
            self._dedup = dedup

396 397
        self._uncommitted_data.update(q("SELECT data_id, count(*)"
            " FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id"))
398

399 400
    def getConfiguration(self, key):
        try:
401 402 403 404 405 406 407 408 409 410
            return self._config[key]
        except KeyError:
            sql_key = self.escape(str(key))
            try:
                r = self.query("SELECT value FROM config WHERE name = '%s'"
                               % sql_key)[0][0]
            except IndexError:
                r = None
            self._config[key] = r
            return r
411

412 413 414
    def _setConfiguration(self, key, value):
        q = self.query
        e = self.escape
415
        self._config[key] = value
416
        k = e(str(key))
417
        if value is None:
418 419 420 421 422 423
            q("DELETE FROM config WHERE name = '%s'" % k)
            return
        value = str(value)
        sql = "REPLACE INTO config VALUES ('%s', '%s')" % (k, e(value))
        try:
            q(sql)
Julien Muchembled's avatar
Julien Muchembled committed
424 425
        except DataError as e:
            if e.args[0] != DATA_TOO_LONG or len(value) < 256 or key != "zodb":
426 427 428
                raise
            q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value))
            q(sql)
429

430 431 432
    def _getMaxPartition(self):
        return self.query("SELECT MAX(`partition`) FROM pt")[0][0]

433
    def _getPartitionTable(self):
434
        return self.query("SELECT * FROM pt")
435

436 437 438 439 440 441 442 443 444
    def _getLastTID(self, partition, max_tid=None):
        x = "WHERE `partition`=%s" % partition
        if max_tid:
            x += " AND tid<=%s" % max_tid
        (tid,), = self.query(
            "SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)" + x)
        return tid

    def _getLastIDs(self, partition):
445
        q = self.query
446 447 448 449
        x = "WHERE `partition`=%s" % partition
        (oid,), = q("SELECT MAX(oid) FROM obj FORCE INDEX (PRIMARY)" + x)
        (tid,), = q("SELECT MAX(tid) FROM obj FORCE INDEX (tid)" + x)
        return tid, oid
450

451 452 453 454
    def _getDataLastId(self, partition):
        return self.query("SELECT MAX(id) FROM data WHERE %s <= id AND id < %s"
            % (partition << 48, (partition + 1) << 48))[0][0]

455
    def _getUnfinishedTIDDict(self):
456
        q = self.query
457 458 459 460 461 462 463 464
        return q("SELECT ttid, tid FROM ttrans"), (ttid
            for ttid, in q("SELECT DISTINCT tid FROM tobj"))

    def getFinalTID(self, ttid):
        ttid = util.u64(ttid)
        # MariaDB is smart enough to realize that 'ttid' is constant.
        r = self.query("SELECT tid FROM trans"
            " WHERE `partition`=%s AND tid>=ttid AND ttid=%s LIMIT 1"
465
            % (self._getReadablePartition(ttid), ttid))
466 467
        if r:
            return util.p64(r[0][0])
468

469 470
    def getLastObjectTID(self, oid):
        oid = util.u64(oid)
471
        r = self.query("SELECT tid FROM obj FORCE INDEX(PRIMARY)"
472
                       " WHERE `partition`=%d AND oid=%d"
473
                       " ORDER BY tid DESC LIMIT 1"
474
                       % (self._getReadablePartition(oid), oid))
475 476
        return util.p64(r[0][0]) if r else None

477 478
    def _getNextTID(self, *args): # partition, oid, tid
        r = self.query("SELECT tid FROM obj"
479
                       " FORCE INDEX(PRIMARY)"
480
                       " WHERE `partition`=%d AND oid=%d AND tid>%d"
481 482 483
                       " ORDER BY tid LIMIT 1" % args)
        return r[0][0] if r else None

484
    def _getObject(self, oid, tid=None, before_tid=None):
485
        q = self.query
486
        partition = self._getReadablePartition(oid)
487
        sql = ('SELECT tid, compression, data.hash, value, value_tid'
488
               ' FROM obj FORCE INDEX(PRIMARY)'
489
               ' LEFT JOIN data ON (obj.data_id = data.id)'
490
               ' WHERE `partition` = %d AND oid = %d') % (partition, oid)
491
        if before_tid is not None:
492
            sql += ' AND tid < %d ORDER BY tid DESC LIMIT 1' % before_tid
493 494
        elif tid is not None:
            sql += ' AND tid = %d' % tid
495
        else:
496
            # XXX I want to express "HAVING tid = MAX(tid)", but
497
            # MySQL does not use an index for a HAVING clause!
498
            sql += ' ORDER BY tid DESC LIMIT 1'
499 500 501 502 503
        r = q(sql)
        try:
            serial, compression, checksum, data, value_serial = r[0]
        except IndexError:
            return None
504 505 506
        if compression and compression & 0x80:
            compression &= 0x7f
            data = ''.join(self._bigData(data))
507 508
        return (serial, self._getNextTID(partition, oid, serial),
                compression, checksum, data, value_serial)
509

510
    def _changePartitionTable(self, cell_list, reset=False):
511
        offset_list = []
512 513
        q = self.query
        if reset:
514
            q("DELETE FROM pt")
515
        for offset, nid, tid in cell_list:
516 517
            # TODO: this logic should move out of database manager
            # add 'dropCells(cell_list)' to API and use one query
518 519
            if tid is None:
                q("DELETE FROM pt WHERE `partition` = %d AND nid = %d"
520
                  % (offset, nid))
521 522 523
            else:
                offset_list.append(offset)
                q("INSERT INTO pt VALUES (%d, %d, %d)"
524 525
                  " ON DUPLICATE KEY UPDATE tid = %d"
                  % (offset, nid, tid, tid))
526 527 528 529
        if self._use_partition:
            for offset in offset_list:
                add = """ALTER TABLE %%s ADD PARTITION (
                    PARTITION p%u VALUES IN (%u))""" % (offset, offset)
530
                for table in 'trans', 'obj':
531
                    try:
532 533 534
                        self.query(add % table)
                    except MysqlError as e:
                        if e.code != SAME_NAME_PARTITION:
535
                            raise
536

537
    def dropPartitions(self, offset_list):
538 539 540 541 542
        q = self.query
        # XXX: these queries are inefficient (execution time increase with
        # row count, although we use indexes) when there are rows to
        # delete. It should be done as an idle task, by chunks.
        for partition in offset_list:
543
            where = " WHERE `partition`=%d" % partition
544
            data_id_list = [x for x, in
545
                q("SELECT DISTINCT data_id FROM obj FORCE INDEX(tid)"
546
                  "%s AND data_id IS NOT NULL" % where)]
547 548 549 550
            if not self._use_partition:
                q("DELETE FROM obj" + where)
                q("DELETE FROM trans" + where)
            self._pruneData(data_id_list)
551 552 553 554 555
        if self._use_partition:
            drop = "ALTER TABLE %s DROP PARTITION" + \
                ','.join(' p%u' % i for i in offset_list)
            for table in 'trans', 'obj':
                try:
556 557 558
                    self.query(drop % table)
                except MysqlError as e:
                    if e.code != DROP_LAST_PARTITION:
559
                        raise
560

561
    def _getUnfinishedDataIdList(self):
562 563
        return [x for x, in self.query(
            "SELECT data_id FROM tobj WHERE data_id IS NOT NULL")]
564 565 566 567

    def dropPartitionsTemporary(self, offset_list=None):
        where = "" if offset_list is None else \
            " WHERE `partition` IN (%s)" % ','.join(map(str, offset_list))
568
        q = self.query
569 570
        q("DELETE FROM tobj" + where)
        q("DELETE FROM ttrans" + where)
571

572
    def storeTransaction(self, tid, object_list, transaction, temporary = True):
573
        e = self.escape
574 575
        u64 = util.u64
        tid = u64(tid)
576 577 578 579 580 581
        if temporary:
            obj_table = 'tobj'
            trans_table = 'ttrans'
        else:
            obj_table = 'obj'
            trans_table = 'trans'
582
        q = self.query
583 584 585
        sql = ["REPLACE INTO %s VALUES " % obj_table]
        values_max = self._max_allowed_packet - len(sql[0])
        values_size = 0
586 587 588 589 590 591
        for oid, data_id, value_serial in object_list:
            oid = u64(oid)
            partition = self._getPartition(oid)
            if value_serial:
                value_serial = u64(value_serial)
                (data_id,), = q("SELECT data_id FROM obj"
592
                    " WHERE `partition`=%d AND oid=%d AND tid=%d"
593 594
                    % (partition, oid, value_serial))
                if temporary:
595
                    self.holdData(data_id)
596 597
            else:
                value_serial = 'NULL'
598
            value = "(%s,%s,%s,%s,%s)," % (
599 600 601
                partition, oid, tid,
                'NULL' if data_id is None else data_id,
                value_serial)
602 603 604 605 606 607 608 609 610 611 612 613
            values_size += len(value)
            # actually: max_values < values_size + EXTRA - len(final comma)
            # (test_max_allowed_packet checks that EXTRA == 2)
            if values_max <= values_size:
                sql[-1] = sql[-1][:-1] # remove final comma
                q(''.join(sql))
                del sql[1:]
                values_size = len(value)
            sql.append(value)
        if values_size:
            sql[-1] = value[:-1] # remove final comma
            q(''.join(sql))
614 615 616 617
        if transaction:
            oid_list, user, desc, ext, packed, ttid = transaction
            partition = self._getPartition(tid)
            assert packed in (0, 1)
618 619 620
            q("REPLACE INTO %s VALUES (%s,%s,%s,'%s','%s','%s','%s',%s)" % (
                trans_table, partition, 'NULL' if temporary else tid, packed,
                e(''.join(oid_list)), e(user), e(desc), e(ext), u64(ttid)))
621

622 623 624
    _structLL = struct.Struct(">LL")
    _unpackLL = _structLL.unpack

625 626 627 628 629
    def getOrphanList(self):
        return [x for x, in self.query(
            "SELECT id FROM data LEFT JOIN obj ON (id=data_id)"
            " WHERE data_id IS NULL")]

630 631 632
    def _pruneData(self, data_id_list):
        data_id_list = set(data_id_list).difference(self._uncommitted_data)
        if data_id_list:
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
            q = self.query
            id_list = []
            bigid_list = []
            for id, value in q("SELECT id, IF(compression < 128, NULL, value)"
                               " FROM data LEFT JOIN obj ON (id = data_id)"
                               " WHERE id IN (%s) AND data_id IS NULL"
                               % ",".join(map(str, data_id_list))):
                id_list.append(str(id))
                if value:
                    bigdata_id, length = self._unpackLL(value)
                    bigid_list += xrange(bigdata_id,
                                         bigdata_id + (length + 0x7fffff >> 23))
            if id_list:
                q("DELETE FROM data WHERE id IN (%s)" % ",".join(id_list))
                if bigid_list:
                    q("DELETE FROM bigdata WHERE id IN (%s)"
                      % ",".join(map(str, bigid_list)))
650 651
                return len(id_list)
        return 0
652 653 654 655 656 657 658

    def _bigData(self, value):
        bigdata_id, length = self._unpackLL(value)
        q = self.query
        return (q("SELECT value FROM bigdata WHERE id=%s" % i)[0][0]
            for i in xrange(bigdata_id,
                            bigdata_id + (length + 0x7fffff >> 23)))
659

660
    def storeData(self, checksum, oid, data, compression, _pack=_structLL.pack):
661 662
        e = self.escape
        checksum = e(checksum)
663 664 665
        if 0x1000000 <= len(data): # 16M (MEDIUMBLOB limit)
            compression |= 0x80
            q = self.query
666 667 668 669 670 671 672 673 674 675 676
            if self._dedup:
                for r, d in q("SELECT id, value FROM data"
                              " WHERE hash='%s' AND compression=%s"
                              % (checksum, compression)):
                    i = 0
                    for d in self._bigData(d):
                        j = i + len(d)
                        if data[i:j] != d:
                            raise IntegrityError(DUP_ENTRY)
                        i = j
                    if j != len(data):
677
                        raise IntegrityError(DUP_ENTRY)
678
                    return r
679 680 681 682 683 684 685 686 687
            i = 'NULL'
            length = len(data)
            for j in xrange(0, length, 0x800000): # 8M
                q("INSERT INTO bigdata VALUES (%s, '%s')"
                  % (i, e(data[j:j+0x800000])))
                if not j:
                    i = bigdata_id = self.conn.insert_id()
                i += 1
            data = _pack(bigdata_id, length)
688 689
        p = self._getPartition(util.u64(oid))
        r = self._data_last_ids[p]
690
        try:
691 692
            self.query("INSERT INTO data VALUES (%s, '%s', %d, '%s')" %
                       (r, checksum, compression,  e(data)))
Julien Muchembled's avatar
Julien Muchembled committed
693 694
        except IntegrityError as e:
            if e.args[0] == DUP_ENTRY:
695 696 697 698
                (r, d), = self.query("SELECT id, value FROM data"
                                     " WHERE hash='%s' AND compression=%s"
                                     % (checksum, compression))
                if d == data:
699 700
                    return r
            raise
701 702
        self._data_last_ids[p] = r + 1
        return r
703

704 705 706 707 708 709 710 711 712
    def loadData(self, data_id):
        compression, hash, value = self.query(
            "SELECT compression, hash, value FROM data where id=%s"
            % data_id)[0]
        if compression and compression & 0x80:
            compression &= 0x7f
            data = ''.join(self._bigData(data))
        return compression, hash, value

713 714
    del _structLL

715
    def _getDataTID(self, oid, tid=None, before_tid=None):
716
        sql = ('SELECT tid, value_tid FROM obj FORCE INDEX(PRIMARY)'
717
               ' WHERE `partition` = %d AND oid = %d'
718
              ) % (self._getReadablePartition(oid), oid)
719
        if tid is not None:
720
            sql += ' AND tid = %d' % tid
721
        elif before_tid is not None:
722
            sql += ' AND tid < %d ORDER BY tid DESC LIMIT 1' % before_tid
723
        else:
724
            # XXX I want to express "HAVING tid = MAX(tid)", but
725
            # MySQL does not use an index for a HAVING clause!
726
            sql += ' ORDER BY tid DESC LIMIT 1'
727
        r = self.query(sql)
728
        return r[0] if r else (None, None)
729

730 731 732 733 734 735
    def lockTransaction(self, tid, ttid):
        u64 = util.u64
        self.query("UPDATE ttrans SET tid=%d WHERE ttid=%d LIMIT 1"
                   % (u64(tid), u64(ttid)))
        self.commit()

736
    def unlockTransaction(self, tid, ttid, trans, obj):
737
        q = self.query
738 739
        u64 = util.u64
        tid = u64(tid)
740 741 742 743 744
        if trans:
            q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
            q("DELETE FROM ttrans WHERE tid=%d" % tid)
            if not obj:
                return
745
        sql = " FROM tobj WHERE tid=%d" % u64(ttid)
746 747
        data_id_list = [x for x, in q("SELECT data_id%s AND data_id IS NOT NULL"
                                      % sql)]
748 749 750
        q("INSERT INTO obj SELECT `partition`, oid, %d, data_id, value_tid %s"
          % (tid, sql))
        q("DELETE" + sql)
751
        self.releaseData(data_id_list)
752

753 754
    def abortTransaction(self, ttid):
        ttid = util.u64(ttid)
755
        q = self.query
756
        q("DELETE FROM tobj WHERE tid=%s" % ttid)
757 758 759 760 761 762
        q("DELETE FROM ttrans WHERE ttid=%s" % ttid)

    def deleteTransaction(self, tid):
        tid = util.u64(tid)
        self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" %
            (self._getPartition(tid), tid))
763

764 765
    def deleteObject(self, oid, serial=None):
        u64 = util.u64
766
        oid = u64(oid)
767
        sql = " FROM obj WHERE `partition`=%d AND oid=%d" \
768 769
            % (self._getPartition(oid), oid)
        if serial:
770
            sql += ' AND tid=%d' % u64(serial)
771
        q = self.query
772 773
        data_id_list = [x for x, in q(
            "SELECT DISTINCT data_id%s AND data_id IS NOT NULL" % sql)]
774 775
        q("DELETE" + sql)
        self._pruneData(data_id_list)
776

777
    def _deleteRange(self, partition, min_tid=None, max_tid=None):
778
        sql = " WHERE `partition`=%d" % partition
779 780 781 782
        if min_tid is not None:
            sql += " AND %d < tid" % min_tid
        if max_tid is not None:
            sql += " AND tid <= %d" % max_tid
783
        q = self.query
784
        q("DELETE FROM trans" + sql)
785
        sql = " FROM obj" + sql
786 787
        data_id_list = [x for x, in q(
            "SELECT DISTINCT data_id%s AND data_id IS NOT NULL" % sql)]
788 789
        q("DELETE" + sql)
        self._pruneData(data_id_list)
790

791
    def getTransaction(self, tid, all = False):
792
        tid = util.u64(tid)
793 794
        q = self.query
        r = q("SELECT oids, user, description, ext, packed, ttid"
795
              " FROM trans WHERE `partition` = %d AND tid = %d"
796
              % (self._getReadablePartition(tid), tid))
797
        if not r and all:
798
            r = q("SELECT oids, user, description, ext, packed, ttid"
799
                  " FROM ttrans WHERE tid = %d" % tid)
800
        if r:
801
            oids, user, desc, ext, packed, ttid = r[0]
802
            oid_list = splitOIDField(tid, oids)
803
            return oid_list, user, desc, ext, bool(packed), util.p64(ttid)
804

805
    def getObjectHistory(self, oid, offset, length):
806
        # FIXME: This method doesn't take client's current transaction id as
807 808
        # parameter, which means it can return transactions in the future of
        # client's transaction.
809
        oid = util.u64(oid)
810
        p64 = util.p64
811 812
        r = self.query("SELECT tid, IF(compression < 128, LENGTH(value),"
            "  CAST(CONV(HEX(SUBSTR(value, 5, 4)), 16, 10) AS INT))"
813
            " FROM obj FORCE INDEX(PRIMARY)"
814
            " LEFT JOIN data ON (obj.data_id = data.id)"
815 816
            " WHERE `partition` = %d AND oid = %d AND tid >= %d"
            " ORDER BY tid DESC LIMIT %d, %d" %
817 818
            (self._getReadablePartition(oid), oid,
             self._getPackTID(), offset, length))
819
        if r:
820
            return [(p64(tid), length or 0) for tid, length in r]
821

822 823 824
    def _fetchObject(self, oid, tid):
        r = self.query(
            'SELECT tid, compression, data.hash, value, value_tid'
825
            ' FROM obj FORCE INDEX(PRIMARY)'
826 827 828 829 830 831 832 833
            ' LEFT JOIN data ON (obj.data_id = data.id)'
            ' WHERE `partition` = %d AND oid = %d AND tid = %d'
            % (self._getReadablePartition(oid), oid, tid))
        if r:
            r = r[0]
            compression = r[1]
            if compression and compression & 0x80:
                return (r[0], compression & 0x7f, r[2],
834
                    ''.join(self._bigData(r[3])), r[4])
835 836
            return r

837 838
    def getReplicationObjectList(self, min_tid, max_tid, length, partition,
            min_oid):
839
        u64 = util.u64
840
        p64 = util.p64
841
        min_tid = u64(min_tid)
842
        r = self.query('SELECT tid, oid FROM obj FORCE INDEX(tid)'
843
                       ' WHERE `partition` = %d AND tid <= %d'
844 845
                       ' AND (tid = %d AND %d <= oid OR %d < tid)'
                       ' ORDER BY tid ASC, oid ASC LIMIT %d' % (
846 847
            partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))
        return [(p64(serial), p64(oid)) for serial, oid in r]
848

849 850 851 852 853
    def _getTIDList(self, offset, length, partition_list):
        return (t[0] for t in self.query(
            "SELECT tid FROM trans WHERE `partition` in (%s)"
            " ORDER BY tid DESC LIMIT %d,%d"
            % (','.join(map(str, partition_list)), offset, length)))
854

855
    def getReplicationTIDList(self, min_tid, max_tid, length, partition):
856 857 858 859
        u64 = util.u64
        p64 = util.p64
        min_tid = u64(min_tid)
        max_tid = u64(max_tid)
860
        r = self.query("""SELECT tid FROM trans
861
                    WHERE `partition` = %(partition)d
862
                    AND tid >= %(min_tid)d AND tid <= %(max_tid)d
863
                    ORDER BY tid ASC LIMIT %(length)d""" % {
864
            'partition': partition,
865 866
            'min_tid': min_tid,
            'max_tid': max_tid,
867 868
            'length': length,
        })
869
        return [p64(t[0]) for t in r]
870

871
    def _updatePackFuture(self, oid, orig_serial, max_serial):
872 873 874 875 876 877
        q = self.query
        # Before deleting this objects revision, see if there is any
        # transaction referencing its value at max_serial or above.
        # If there is, copy value to the first future transaction. Any further
        # reference is just updated to point to the new data location.
        value_serial = None
878
        kw = {
879
          'partition': self._getReadablePartition(oid),
880
          'oid': oid,
881 882 883
          'orig_tid': orig_serial,
          'max_tid': max_serial,
          'new_tid': 'NULL',
884 885
        }
        for kw['table'] in 'obj', 'tobj':
886
            for kw['tid'], in q('SELECT tid FROM %(table)s'
887
                  ' WHERE `partition`=%(partition)d AND oid=%(oid)d'
888 889 890
                  ' AND tid>=%(max_tid)d AND value_tid=%(orig_tid)d'
                  ' ORDER BY tid ASC' % kw):
                q('UPDATE %(table)s SET value_tid=%(new_tid)s'
891
                  ' WHERE `partition`=%(partition)d AND oid=%(oid)d'
892
                  ' AND tid=%(tid)d' % kw)
893
                if value_serial is None:
894
                    # First found, mark its serial for future reference.
895
                    kw['new_tid'] = value_serial = kw['tid']
896
        return value_serial
897 898 899

    def pack(self, tid, updateObjectDataForPack):
        # TODO: unit test (along with updatePackFuture)
900
        p64 = util.p64
901 902
        tid = util.u64(tid)
        updatePackFuture = self._updatePackFuture
903
        getPartition = self._getReadablePartition
904 905 906
        q = self.query
        self._setPackTID(tid)
        for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
907
                                        " FROM obj FORCE INDEX(PRIMARY)"
908
                                        " WHERE tid <= %d GROUP BY oid"
909 910
                                        % tid):
            partition = getPartition(oid)
911
            if q("SELECT 1 FROM obj WHERE `partition` = %d"
912 913 914 915 916 917 918
                 " AND oid = %d AND tid = %d AND data_id IS NULL"
                 % (partition, oid, max_serial)):
                max_serial += 1
            elif not count:
                continue
            # There are things to delete for this object
            data_id_set = set()
919
            sql = ' FROM obj WHERE `partition`=%d AND oid=%d' \
920 921 922 923 924 925 926 927 928 929 930 931
                ' AND tid<%d' % (partition, oid, max_serial)
            for serial, data_id in q('SELECT tid, data_id' + sql):
                data_id_set.add(data_id)
                new_serial = updatePackFuture(oid, serial, max_serial)
                if new_serial:
                    new_serial = p64(new_serial)
                updateObjectDataForPack(p64(oid), p64(serial),
                                        new_serial, data_id)
            q('DELETE' + sql)
            data_id_set.discard(None)
            self._pruneData(data_id_set)
        self.commit()
932

933
    def checkTIDRange(self, partition, length, min_tid, max_tid):
934
        count, tid_checksum, max_tid = self.query(
935
            """SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
936
               FROM (SELECT tid FROM trans
937
                     WHERE `partition` = %(partition)s
938 939
                       AND tid >= %(min_tid)d
                       AND tid <= %(max_tid)d
940
                     ORDER BY tid ASC %(limit)s) AS t""" % {
941 942 943
            'partition': partition,
            'min_tid': util.u64(min_tid),
            'max_tid': util.u64(max_tid),
944
            'limit': '' if length is None else 'LIMIT %u' % length,
945
        })[0]
946 947 948
        if count:
            return count, a2b_hex(tid_checksum), util.p64(max_tid)
        return 0, ZERO_HASH, ZERO_TID
949

950
    def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
951
        u64 = util.u64
952 953
        # We don't ask MySQL to compute everything (like in checkTIDRange)
        # because it's difficult to get the last serial _for the last oid_.
Julien Muchembled's avatar
Julien Muchembled committed
954
        # We would need a function (that could be named 'LAST') that returns the
955 956
        # last grouped value, instead of the greatest one.
        r = self.query(
957
            """SELECT tid, oid
958
               FROM obj FORCE INDEX(tid)
959
               WHERE `partition` = %(partition)s
960
                 AND tid <= %(max_tid)d
961 962 963
                 AND (tid > %(min_tid)d OR
                      tid = %(min_tid)d AND oid >= %(min_oid)d)
               ORDER BY tid, oid %(limit)s""" % {
964
            'min_oid': u64(min_oid),
965
            'min_tid': u64(min_tid),
966
            'max_tid': u64(max_tid),
967
            'limit': '' if length is None else 'LIMIT %u' % length,
968 969 970 971 972
            'partition': partition,
        })
        if r:
            p64 = util.p64
            return (len(r),
973
                    sha1(','.join(str(x[0]) for x in r)).digest(),
974
                    p64(r[-1][0]),
975
                    sha1(','.join(str(x[1]) for x in r)).digest(),
976
                    p64(r[-1][1]))
977
        return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
978 979 980 981 982 983 984 985 986 987 988 989 990

    def _cmdline(self):
        for x in ('u', self.user), ('p', self.passwd), ('S', self.socket):
            if x[1]:
                yield '-%s%s' % x
        yield self.db

    def dump(self):
        import subprocess
        cmd = ['mysqldump', '--compact', '--hex-blob']
        cmd += self._cmdline()
        return subprocess.check_output(cmd)

991
    def _restore(self, sql):
992 993 994 995 996 997 998 999
        import subprocess
        cmd = ['mysql']
        cmd += self._cmdline()
        p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
        p.communicate(sql)
        retcode = p.wait()
        if retcode:
            raise subprocess.CalledProcessError(retcode, cmd)