# # Copyright (C) 2006-2009 Nexedi SA # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import MySQLdb from MySQLdb import OperationalError from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST from neo import logging from array import array import string from struct import pack, unpack from neo.storage.database import DatabaseManager from neo.exception import DatabaseFailure from neo.protocol import CellStates from neo import util LOG_QUERIES = False def p64(n): return pack('!Q', n) def u64(s): return unpack('!Q', s)[0] class MySQLDatabaseManager(DatabaseManager): """This class manages a database on MySQL.""" def __init__(self, database): super(MySQLDatabaseManager, self).__init__() self.user, self.passwd, self.db = self._parse(database) self.conn = None self._connect() def _parse(self, database): """ Get the database credentials (username, password, database) """ # expected pattern : [user[:password]@]database username = None password = None if '@' in database: (username, database) = database.split('@') if ':' in username: (username, password) = username.split(':') return (username, password, database) # XXX: unused ? def close(self): self.conn.close() def _connect(self): kwd = {'db' : self.db, 'user' : self.user} if self.passwd is not None: kwd['passwd'] = self.passwd logging.info('connecting to MySQL on the database %s with user %s', self.db, self.user) self.conn = MySQLdb.connect(**kwd) self.conn.autocommit(False) def _begin(self): self.query("""BEGIN""") def _commit(self): self.conn.commit() def _rollback(self): self.conn.rollback() def query(self, query): """Query data from a database.""" conn = self.conn try: if LOG_QUERIES: printable_char_list = [] for c in query.split('\n', 1)[0][:70]: if c not in string.printable or c in '\t\x0b\x0c\r': c = '\\x%02x' % ord(c) printable_char_list.append(c) query_part = ''.join(printable_char_list) logging.debug('querying %s...', query_part) conn.query(query) r = conn.store_result() if r is not None: new_r = [] for row in r.fetch_row(r.num_rows()): new_row = [] for d in row: if isinstance(d, array): d = d.tostring() new_row.append(d) new_r.append(tuple(new_row)) r = tuple(new_r) except OperationalError, m: if m[0] in (SERVER_GONE_ERROR, SERVER_LOST): logging.info('the MySQL server is gone; reconnecting') self.connect() return self.query(query) raise DatabaseFailure('MySQL error %d: %s' % (m[0], m[1])) return r def escape(self, s): """Escape special characters in a string.""" return self.conn.escape_string(s) def setup(self, reset = 0): q = self.query if reset: q("""DROP TABLE IF EXISTS config, pt, trans, obj, ttrans, tobj""") # The table "config" stores configuration parameters which affect the # persistent data. q("""CREATE TABLE IF NOT EXISTS config ( name VARBINARY(16) NOT NULL PRIMARY KEY, value VARBINARY(255) NOT NULL ) ENGINE = InnoDB""") # The table "pt" stores a partition table. q("""CREATE TABLE IF NOT EXISTS pt ( rid INT UNSIGNED NOT NULL, uuid CHAR(32) NOT NULL, state TINYINT UNSIGNED NOT NULL, PRIMARY KEY (rid, uuid) ) ENGINE = InnoDB""") # The table "trans" stores information on committed transactions. q("""CREATE TABLE IF NOT EXISTS trans ( tid BIGINT UNSIGNED NOT NULL PRIMARY KEY, oids MEDIUMBLOB NOT NULL, user BLOB NOT NULL, description BLOB NOT NULL, ext BLOB NOT NULL ) ENGINE = InnoDB""") # The table "obj" stores committed object data. q("""CREATE TABLE IF NOT EXISTS obj ( oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NOT NULL, checksum INT UNSIGNED NOT NULL, value MEDIUMBLOB NOT NULL, PRIMARY KEY (oid, serial) ) ENGINE = InnoDB""") # The table "ttrans" stores information on uncommitted transactions. q("""CREATE TABLE IF NOT EXISTS ttrans ( tid BIGINT UNSIGNED NOT NULL, oids MEDIUMBLOB NOT NULL, user BLOB NOT NULL, description BLOB NOT NULL, ext BLOB NOT NULL ) ENGINE = InnoDB""") # The table "tobj" stores uncommitted object data. q("""CREATE TABLE IF NOT EXISTS tobj ( oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NOT NULL, checksum INT UNSIGNED NOT NULL, value MEDIUMBLOB NOT NULL ) ENGINE = InnoDB""") def getConfiguration(self, key): q = self.query e = self.escape key = e(str(key)) r = q("""SELECT value FROM config WHERE name = '%s'""" % key) try: return r[0][0] except IndexError: return None def _setConfiguration(self, key, value): q = self.query e = self.escape key = e(str(key)) value = e(str(value)) q("""REPLACE INTO config VALUES ('%s', '%s')""" % (key, value)) def getPartitionTable(self): q = self.query cell_list = q("""SELECT rid, uuid, state FROM pt""") pt = [] for offset, uuid, state in cell_list: uuid = util.bin(uuid) pt.append((offset, uuid, state)) return pt def getLastTID(self, all = True): # XXX this does not consider serials in obj. # I am not sure if this is really harmful. For safety, # check for tobj only at the moment. The reason why obj is # not tested is that it is too slow to get the max serial # from obj when it has a huge number of objects, because # serial is the second part of the primary key, so the index # is not used in this case. If doing it, it is better to # make another index for serial, but I doubt the cost increase # is worth. q = self.query self.begin() ltid = q("""SELECT MAX(tid) FROM trans""")[0][0] if all: tmp_ltid = q("""SELECT MAX(tid) FROM ttrans""")[0][0] if ltid is None or (tmp_ltid is not None and ltid < tmp_ltid): ltid = tmp_ltid tmp_serial = q("""SELECT MAX(serial) FROM tobj""")[0][0] if ltid is None or (tmp_serial is not None and ltid < tmp_serial): ltid = tmp_serial self.commit() if ltid is not None: ltid = p64(ltid) return ltid def getUnfinishedTIDList(self): q = self.query tid_set = set() self.begin() r = q("""SELECT tid FROM ttrans""") tid_set.update((p64(t[0]) for t in r)) r = q("""SELECT serial FROM tobj""") self.commit() tid_set.update((p64(t[0]) for t in r)) return list(tid_set) def objectPresent(self, oid, tid, all = True): q = self.query oid = u64(oid) tid = u64(tid) self.begin() r = q("""SELECT oid FROM obj WHERE oid = %d AND serial = %d""" \ % (oid, tid)) if not r and all: r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \ % (oid, tid)) self.commit() if r: return True return False def getObject(self, oid, tid = None, before_tid = None): q = self.query oid = u64(oid) if tid is not None: tid = u64(tid) r = q("""SELECT serial, compression, checksum, value FROM obj WHERE oid = %d AND serial = %d""" \ % (oid, tid)) try: serial, compression, checksum, data = r[0] next_serial = None except IndexError: return None elif before_tid is not None: before_tid = u64(before_tid) r = q("""SELECT serial, compression, checksum, value FROM obj WHERE oid = %d AND serial < %d ORDER BY serial DESC LIMIT 1""" \ % (oid, before_tid)) try: serial, compression, checksum, data = r[0] r = q("""SELECT serial FROM obj WHERE oid = %d AND serial >= %d ORDER BY serial LIMIT 1""" \ % (oid, before_tid)) try: next_serial = r[0][0] except IndexError: next_serial = None except IndexError: return None else: # XXX I want to express "HAVING serial = MAX(serial)", but # MySQL does not use an index for a HAVING clause! r = q("""SELECT serial, compression, checksum, value FROM obj WHERE oid = %d ORDER BY serial DESC LIMIT 1""" \ % oid) try: serial, compression, checksum, data = r[0] next_serial = None except IndexError: return None if serial is not None: serial = p64(serial) if next_serial is not None: next_serial = p64(next_serial) return serial, next_serial, compression, checksum, data def doSetPartitionTable(self, ptid, cell_list, reset): q = self.query e = self.escape self.begin() try: if reset: q("""TRUNCATE pt""") for offset, uuid, state in cell_list: uuid = e(util.dump(uuid)) if state == CellStates.DISCARDED: q("""DELETE FROM pt WHERE rid = %d AND uuid = '%s'""" \ % (offset, uuid)) else: q("""INSERT INTO pt VALUES (%d, '%s', %d) ON DUPLICATE KEY UPDATE state = %d""" \ % (offset, uuid, state, state)) self.setPTID(ptid) except: self.rollback() raise self.commit() def changePartitionTable(self, ptid, cell_list): self.doSetPartitionTable(ptid, cell_list, False) def setPartitionTable(self, ptid, cell_list): self.doSetPartitionTable(ptid, cell_list, True) def dropPartition(self, num_partitions, offset): q = self.query self.begin() try: q("""DELETE FROM obj WHERE MOD(oid, %d) = %d""" % (num_partitions, offset)) q("""DELETE FROM trans WHERE MOD(tid, %d) = %d""" % (num_partitions, offset)) except: self.rollback() raise self.commit() def dropUnfinishedData(self): q = self.query self.begin() try: q("""TRUNCATE tobj""") q("""TRUNCATE ttrans""") except: self.rollback() raise self.commit() def storeTransaction(self, tid, object_list, transaction, temporary = True): q = self.query e = self.escape tid = u64(tid) if temporary: obj_table = 'tobj' trans_table = 'ttrans' else: obj_table = 'obj' trans_table = 'trans' self.begin() try: for oid, compression, checksum, data in object_list: oid = u64(oid) data = e(data) q("""REPLACE INTO %s VALUES (%d, %d, %d, %d, '%s')""" \ % (obj_table, oid, tid, compression, checksum, data)) if transaction is not None: oid_list, user, desc, ext = transaction oids = e(''.join(oid_list)) user = e(user) desc = e(desc) ext = e(ext) q("""REPLACE INTO %s VALUES (%d, '%s', '%s', '%s', '%s')""" \ % (trans_table, tid, oids, user, desc, ext)) except: self.rollback() raise self.commit() def finishTransaction(self, tid): q = self.query tid = u64(tid) self.begin() try: q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = %d""" \ % tid) q("""DELETE FROM tobj WHERE serial = %d""" % tid) q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = %d""" % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid) except: self.rollback() raise self.commit() def deleteTransaction(self, tid, all = False): q = self.query tid = u64(tid) self.begin() try: q("""DELETE FROM tobj WHERE serial = %d""" % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid) if all: # Note that this can be very slow. q("""DELETE FROM obj WHERE serial = %d""" % tid) q("""DELETE FROM trans WHERE tid = %d""" % tid) except: self.rollback() raise self.commit() def getTransaction(self, tid, all = False): q = self.query tid = u64(tid) self.begin() r = q("""SELECT oids, user, description, ext FROM trans WHERE tid = %d""" \ % tid) if not r and all: r = q("""SELECT oids, user, description, ext FROM ttrans WHERE tid = %d""" \ % tid) self.commit() if r: oids, user, desc, ext = r[0] if (len(oids) % 8) != 0 or len(oids) == 0: raise DatabaseFailure('invalid oids for tid %x' % tid) oid_list = [] for i in xrange(0, len(oids), 8): oid_list.append(oids[i:i+8]) return oid_list, user, desc, ext return None def getOIDList(self, offset, length, num_partitions, partition_list): q = self.query r = q("""SELECT DISTINCT oid FROM obj WHERE MOD(oid, %d) in (%s) ORDER BY oid DESC LIMIT %d,%d""" \ % (num_partitions, ','.join([str(p) for p in partition_list]), offset, length)) return [p64(t[0]) for t in r] def getObjectHistory(self, oid, offset = 0, length = 1): q = self.query oid = u64(oid) r = q("""SELECT serial, LENGTH(value) FROM obj WHERE oid = %d ORDER BY serial DESC LIMIT %d, %d""" \ % (oid, offset, length)) if r: return [(p64(serial), length) for serial, length in r] return None def getTIDList(self, offset, length, num_partitions, partition_list): q = self.query r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s) ORDER BY tid DESC LIMIT %d,%d""" \ % (num_partitions, ','.join([str(p) for p in partition_list]), offset, length)) return [p64(t[0]) for t in r] def getTIDListPresent(self, tid_list): q = self.query r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \ % ','.join([str(u64(tid)) for tid in tid_list])) return [p64(t[0]) for t in r] def getSerialListPresent(self, oid, serial_list): q = self.query oid = u64(oid) r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \ % (oid, ','.join([str(u64(serial)) for serial in serial_list]))) return [p64(t[0]) for t in r]