Commit 52c5d862 authored by Julien Muchembled's avatar Julien Muchembled

storage: fix severe performance issue by committing backend only at key moments

parent 4741e38e
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import ZERO_TID from neo.lib.protocol import ZERO_TID
class CreationUndone(Exception): class CreationUndone(Exception):
...@@ -28,7 +27,6 @@ class DatabaseManager(object): ...@@ -28,7 +27,6 @@ class DatabaseManager(object):
""" """
Initialize the object. Initialize the object.
""" """
self._under_transaction = False
self._wait = wait self._wait = wait
self._parse(database) self._parse(database)
...@@ -50,34 +48,9 @@ class DatabaseManager(object): ...@@ -50,34 +48,9 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def __enter__(self):
"""
Begin a transaction
"""
if self._under_transaction:
raise DatabaseFailure('A transaction has already begun')
r = self.begin()
self._under_transaction = True
return r
def __exit__(self, exc_type, exc_value, tb):
if not self._under_transaction:
raise DatabaseFailure('The transaction has not begun')
self._under_transaction = False
if exc_type is None:
self.commit()
else:
self.rollback()
def begin(self):
pass
def commit(self): def commit(self):
pass pass
def rollback(self):
pass
def _getPartition(self, oid_or_tid): def _getPartition(self, oid_or_tid):
return oid_or_tid % self.getNumPartitions() return oid_or_tid % self.getNumPartitions()
...@@ -91,11 +64,8 @@ class DatabaseManager(object): ...@@ -91,11 +64,8 @@ class DatabaseManager(object):
""" """
Set a configuration value Set a configuration value
""" """
if self._under_transaction:
self._setConfiguration(key, value)
else:
with self:
self._setConfiguration(key, value) self._setConfiguration(key, value)
self.commit()
def _setConfiguration(self, key, value): def _setConfiguration(self, key, value):
raise NotImplementedError raise NotImplementedError
...@@ -344,8 +314,8 @@ class DatabaseManager(object): ...@@ -344,8 +314,8 @@ class DatabaseManager(object):
else: else:
del refcount[data_id] del refcount[data_id]
if prune: if prune:
with self:
self._pruneData(data_id_list) self._pruneData(data_id_list)
self.commit()
__getDataTID = set() __getDataTID = set()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
...@@ -465,11 +435,11 @@ class DatabaseManager(object): ...@@ -465,11 +435,11 @@ class DatabaseManager(object):
def truncate(self, tid): def truncate(self, tid):
assert tid not in (None, ZERO_TID), tid assert tid not in (None, ZERO_TID), tid
with self:
assert self.getBackupTID() assert self.getBackupTID()
self.setBackupTID(tid) self.setBackupTID(tid)
for partition in xrange(self.getNumPartitions()): for partition in xrange(self.getNumPartitions()):
self._deleteRange(partition, tid) self._deleteRange(partition, tid)
self.commit()
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
......
...@@ -93,23 +93,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -93,23 +93,10 @@ class MySQLDatabaseManager(DatabaseManager):
self.conn.query("SET SESSION group_concat_max_len = -1") self.conn.query("SET SESSION group_concat_max_len = -1")
self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
def begin(self):
q = self.query
q("BEGIN")
return q
if LOG_QUERIES:
def commit(self): def commit(self):
logging.debug('committing...') logging.debug('committing...')
self.conn.commit() self.conn.commit()
def rollback(self):
logging.debug('aborting...')
self.conn.rollback()
else:
commit = property(lambda self: self.conn.commit)
rollback = property(lambda self: self.conn.rollback)
def query(self, query): def query(self, query):
"""Query data from a database.""" """Query data from a database."""
conn = self.conn conn = self.conn
...@@ -271,7 +258,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -271,7 +258,7 @@ class MySQLDatabaseManager(DatabaseManager):
def _getLastIDs(self, all=True): def _getLastIDs(self, all=True):
p64 = util.p64 p64 = util.p64
with self as q: q = self.query
trans = dict((partition, p64(tid)) trans = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM trans GROUP BY partition")) " FROM trans GROUP BY partition"))
...@@ -292,23 +279,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -292,23 +279,17 @@ class MySQLDatabaseManager(DatabaseManager):
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
tid_set = set() p64 = util.p64
with self as q: return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
r = q("""SELECT tid FROM ttrans""") " UNION SELECT tid FROM tobj")]
tid_set.update((util.p64(t[0]) for t in r))
r = q("""SELECT tid FROM tobj""")
tid_set.update((util.p64(t[0]) for t in r))
return list(tid_set)
def objectPresent(self, oid, tid, all = True): def objectPresent(self, oid, tid, all = True):
oid = util.u64(oid) oid = util.u64(oid)
tid = util.u64(tid) tid = util.u64(tid)
partition = self._getPartition(oid) q = self.query
with self as q: return q("SELECT 1 FROM obj WHERE partition=%d AND oid=%d AND tid=%d"
return q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND " % (self._getPartition(oid), oid, tid)) or all and \
"tid=%d" % (partition, oid, tid)) or all and \ q("SELECT 1 FROM tobj WHERE tid=%d AND oid=%d" % (tid, oid))
q("SELECT oid FROM tobj WHERE tid=%d AND oid=%d"
% (tid, oid))
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
...@@ -339,19 +320,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -339,19 +320,19 @@ class MySQLDatabaseManager(DatabaseManager):
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
offset_list = [] offset_list = []
with self as q: q = self.query
if reset: if reset:
q("""TRUNCATE pt""") q("TRUNCATE pt")
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
# TODO: this logic should move out of database manager # TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query # add 'dropCells(cell_list)' to API and use one query
if state == CellStates.DISCARDED: if state == CellStates.DISCARDED:
q("""DELETE FROM pt WHERE rid = %d AND uuid = %d""" q("DELETE FROM pt WHERE rid = %d AND uuid = %d"
% (offset, uuid)) % (offset, uuid))
else: else:
offset_list.append(offset) offset_list.append(offset)
q("""INSERT INTO pt VALUES (%d, %d, %d) q("INSERT INTO pt VALUES (%d, %d, %d)"
ON DUPLICATE KEY UPDATE state = %d""" \ " ON DUPLICATE KEY UPDATE state = %d"
% (offset, uuid, state, state)) % (offset, uuid, state, state))
self.setPTID(ptid) self.setPTID(ptid)
if self._use_partition: if self._use_partition:
...@@ -372,7 +353,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -372,7 +353,7 @@ class MySQLDatabaseManager(DatabaseManager):
self.doSetPartitionTable(ptid, cell_list, True) self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
with self as q: q = self.query
# XXX: these queries are inefficient (execution time increase with # XXX: these queries are inefficient (execution time increase with
# row count, although we use indexes) when there are rows to # row count, although we use indexes) when there are rows to
# delete. It should be done as an idle task, by chunks. # delete. It should be done as an idle task, by chunks.
...@@ -395,25 +376,23 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -395,25 +376,23 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
def dropUnfinishedData(self): def dropUnfinishedData(self):
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x] data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("""TRUNCATE tobj""") q("TRUNCATE tobj")
q("""TRUNCATE ttrans""") q("TRUNCATE ttrans")
self.unlockData(data_id_list, True) self.unlockData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary = True): def storeTransaction(self, tid, object_list, transaction, temporary = True):
e = self.escape e = self.escape
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
if temporary: if temporary:
obj_table = 'tobj' obj_table = 'tobj'
trans_table = 'ttrans' trans_table = 'ttrans'
else: else:
obj_table = 'obj' obj_table = 'obj'
trans_table = 'trans' trans_table = 'trans'
q = self.query
with self as q:
for oid, data_id, value_serial in object_list: for oid, data_id, value_serial in object_list:
oid = u64(oid) oid = u64(oid)
partition = self._getPartition(oid) partition = self._getPartition(oid)
...@@ -428,7 +407,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -428,7 +407,6 @@ class MySQLDatabaseManager(DatabaseManager):
value_serial = 'NULL' value_serial = 'NULL'
q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" % (obj_table, q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" % (obj_table,
partition, oid, tid, data_id or 'NULL', value_serial)) partition, oid, tid, data_id or 'NULL', value_serial))
if transaction: if transaction:
oid_list, user, desc, ext, packed, ttid = transaction oid_list, user, desc, ext, packed, ttid = transaction
partition = self._getPartition(tid) partition = self._getPartition(tid)
...@@ -436,6 +414,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -436,6 +414,8 @@ class MySQLDatabaseManager(DatabaseManager):
q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % ( q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % (
trans_table, partition, tid, packed, e(''.join(oid_list)), trans_table, partition, tid, packed, e(''.join(oid_list)),
e(user), e(desc), e(ext), u64(ttid))) e(user), e(desc), e(ext), u64(ttid)))
if temporary:
self.commit()
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
...@@ -448,20 +428,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -448,20 +428,17 @@ class MySQLDatabaseManager(DatabaseManager):
def _storeData(self, checksum, data, compression): def _storeData(self, checksum, data, compression):
e = self.escape e = self.escape
checksum = e(checksum) checksum = e(checksum)
with self as q:
try: try:
q("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" % self.query("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" %
(checksum, compression, e(data))) (checksum, compression, e(data)))
except IntegrityError, (code, _): except IntegrityError, (code, _):
if code != DUP_ENTRY: if code == DUP_ENTRY:
raise (r, c, d), = self.query("SELECT id, compression, value"
(r, c, d), = q("SELECT id, compression, value"
" FROM data WHERE hash='%s'" % checksum) " FROM data WHERE hash='%s'" % checksum)
if c != compression or d != data: if c == compression and d == data:
raise
else:
r = self.conn.insert_id()
return r return r
raise
return self.conn.insert_id()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
sql = ('SELECT tid, data_id, value_tid FROM obj' sql = ('SELECT tid, data_id, value_tid FROM obj'
...@@ -486,7 +463,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -486,7 +463,6 @@ class MySQLDatabaseManager(DatabaseManager):
def finishTransaction(self, tid): def finishTransaction(self, tid):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
with self as q:
sql = " FROM tobj WHERE tid=%d" % tid sql = " FROM tobj WHERE tid=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("INSERT INTO obj SELECT *" + sql) q("INSERT INTO obj SELECT *" + sql)
...@@ -494,12 +470,13 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -494,12 +470,13 @@ class MySQLDatabaseManager(DatabaseManager):
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid) q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.unlockData(data_id_list) self.unlockData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
sql = " FROM tobj WHERE tid=%d" % tid sql = " FROM tobj WHERE tid=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.unlockData(data_id_list) self.unlockData(data_id_list)
...@@ -525,7 +502,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -525,7 +502,7 @@ class MySQLDatabaseManager(DatabaseManager):
% (self._getPartition(oid), oid) % (self._getPartition(oid), oid)
if serial: if serial:
sql += ' AND tid=%d' % u64(serial) sql += ' AND tid=%d' % u64(serial)
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x]
q("DELETE" + sql) q("DELETE" + sql)
self._pruneData(data_id_list) self._pruneData(data_id_list)
...@@ -545,7 +522,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -545,7 +522,7 @@ class MySQLDatabaseManager(DatabaseManager):
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
tid = util.u64(tid) tid = util.u64(tid)
with self as q: q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE partition = %d AND tid = %d" " FROM trans WHERE partition = %d AND tid = %d"
% (self._getPartition(tid), tid)) % (self._getPartition(tid), tid))
...@@ -665,10 +642,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -665,10 +642,10 @@ class MySQLDatabaseManager(DatabaseManager):
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, ' for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
'MAX(tid) FROM obj WHERE tid <= %d GROUP BY oid' " FROM obj WHERE tid <= %d GROUP BY oid"
% tid): % tid):
partition = getPartition(oid) partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition = %d" if q("SELECT 1 FROM obj WHERE partition = %d"
...@@ -691,6 +668,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -691,6 +668,7 @@ class MySQLDatabaseManager(DatabaseManager):
q('DELETE' + sql) q('DELETE' + sql)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
self.commit()
def checkTIDRange(self, partition, length, min_tid, max_tid): def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tid_checksum, max_tid = self.query( count, tid_checksum, max_tid = self.query(
......
...@@ -76,23 +76,13 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -76,23 +76,13 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, isolation_level=None, self.conn = sqlite3.connect(self.db, check_same_thread=False)
check_same_thread=False)
def begin(self):
q = self.query
retry_if_locked(q, "BEGIN IMMEDIATE")
return q
if LOG_QUERIES:
def commit(self): def commit(self):
logging.debug('committing...') logging.debug('committing...')
retry_if_locked(self.conn.commit) retry_if_locked(self.conn.commit)
def rollback(self): if LOG_QUERIES:
logging.debug('aborting...')
self.conn.rollback()
def query(self, query): def query(self, query):
printable_char_list = [] printable_char_list = []
for c in query.split('\n', 1)[0][:70]: for c in query.split('\n', 1)[0][:70]:
...@@ -102,10 +92,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -102,10 +92,7 @@ class SQLiteDatabaseManager(DatabaseManager):
logging.debug('querying %s...', ''.join(printable_char_list)) logging.debug('querying %s...', ''.join(printable_char_list))
return self.conn.execute(query) return self.conn.execute(query)
else: else:
rollback = property(lambda self: self.conn.rollback)
query = property(lambda self: self.conn.execute) query = property(lambda self: self.conn.execute)
def commit(self):
retry_if_locked(self.conn.commit)
def setup(self, reset = 0): def setup(self, reset = 0):
self._config.clear() self._config.clear()
...@@ -226,7 +213,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -226,7 +213,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _getLastIDs(self, all=True): def _getLastIDs(self, all=True):
p64 = util.p64 p64 = util.p64
with self as q: q = self.query
trans = dict((partition, p64(tid)) trans = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM trans GROUP BY partition")) " FROM trans GROUP BY partition"))
...@@ -248,22 +235,17 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -248,22 +235,17 @@ class SQLiteDatabaseManager(DatabaseManager):
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
p64 = util.p64 p64 = util.p64
tid_set = set() return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
with self as q: " UNION SELECT tid FROM tobj")]
tid_set.update((p64(t[0]) for t in q("SELECT tid FROM ttrans")))
tid_set.update((p64(t[0]) for t in q("SELECT tid FROM tobj")))
return list(tid_set)
def objectPresent(self, oid, tid, all=True): def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid) oid = util.u64(oid)
tid = util.u64(tid) tid = util.u64(tid)
with self as q: q = self.query
r = q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?", return q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?",
(self._getPartition(oid), oid, tid)).fetchone() (self._getPartition(oid), oid, tid)).fetchone() or all and \
if not r and all: q("SELECT 1 FROM tobj WHERE tid=? AND oid=?",
r = q("SELECT 1 FROM tobj WHERE tid=? AND oid=?",
(tid, oid)).fetchone() (tid, oid)).fetchone()
return bool(r)
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
...@@ -292,7 +274,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -292,7 +274,7 @@ class SQLiteDatabaseManager(DatabaseManager):
return serial, r and r[0], compression, checksum, data, value_serial return serial, r and r[0], compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
with self as q: q = self.query
if reset: if reset:
q("DELETE FROM pt") q("DELETE FROM pt")
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
...@@ -316,17 +298,17 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -316,17 +298,17 @@ class SQLiteDatabaseManager(DatabaseManager):
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
where = " WHERE partition=?" where = " WHERE partition=?"
with self as q: q = self.query
for partition in offset_list: for partition in offset_list:
args = partition,
data_id_list = [x for x, in data_id_list = [x for x, in
q("SELECT DISTINCT data_id FROM obj" + where, q("SELECT DISTINCT data_id FROM obj" + where, args) if x]
(partition,)) if x] q("DELETE FROM obj" + where, args)
q("DELETE FROM obj" + where, (partition,)) q("DELETE FROM trans" + where, args)
q("DELETE FROM trans" + where, (partition,))
self._pruneData(data_id_list) self._pruneData(data_id_list)
def dropUnfinishedData(self): def dropUnfinishedData(self):
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x] data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("DELETE FROM tobj") q("DELETE FROM tobj")
q("DELETE FROM ttrans") q("DELETE FROM ttrans")
...@@ -337,7 +319,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -337,7 +319,7 @@ class SQLiteDatabaseManager(DatabaseManager):
tid = u64(tid) tid = u64(tid)
T = 't' if temporary else '' T = 't' if temporary else ''
obj_sql = "INSERT OR FAIL INTO %sobj VALUES (?,?,?,?,?)" % T obj_sql = "INSERT OR FAIL INTO %sobj VALUES (?,?,?,?,?)" % T
with self as q: q = self.query
for oid, data_id, value_serial in object_list: for oid, data_id, value_serial in object_list:
oid = u64(oid) oid = u64(oid)
partition = self._getPartition(oid) partition = self._getPartition(oid)
...@@ -360,7 +342,6 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -360,7 +342,6 @@ class SQLiteDatabaseManager(DatabaseManager):
if r == (data_id, value_serial): if r == (data_id, value_serial):
continue continue
raise raise
if transaction: if transaction:
oid_list, user, desc, ext, packed, ttid = transaction oid_list, user, desc, ext, packed, ttid = transaction
partition = self._getPartition(tid) partition = self._getPartition(tid)
...@@ -368,6 +349,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -368,6 +349,8 @@ class SQLiteDatabaseManager(DatabaseManager):
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T, q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T,
(partition, tid, packed, buffer(''.join(oid_list)), (partition, tid, packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext), u64(ttid))) buffer(user), buffer(desc), buffer(ext), u64(ttid)))
if temporary:
self.commit()
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
...@@ -381,13 +364,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -381,13 +364,12 @@ class SQLiteDatabaseManager(DatabaseManager):
def _storeData(self, checksum, data, compression): def _storeData(self, checksum, data, compression):
H = buffer(checksum) H = buffer(checksum)
with self as q:
try: try:
return q("INSERT INTO data VALUES (NULL,?,?,?)", return self.query("INSERT INTO data VALUES (NULL,?,?,?)",
(H, compression, buffer(data))).lastrowid (H, compression, buffer(data))).lastrowid
except sqlite3.IntegrityError, e: except sqlite3.IntegrityError, e:
if e.args[0] == 'column hash is not unique': if e.args[0] == 'column hash is not unique':
(r, c, d), = q("SELECT id, compression, value" (r, c, d), = self.query("SELECT id, compression, value"
" FROM data WHERE hash=?", (H,)) " FROM data WHERE hash=?", (H,))
if c == compression and str(d) == data: if c == compression and str(d) == data:
return r return r
...@@ -415,21 +397,21 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -415,21 +397,21 @@ class SQLiteDatabaseManager(DatabaseManager):
def finishTransaction(self, tid): def finishTransaction(self, tid):
args = util.u64(tid), args = util.u64(tid),
with self as q: q = self.query
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args) q("INSERT OR FAIL INTO obj SELECT *" + sql, args)
q("DELETE FROM tobj WHERE tid=?", args) q("DELETE FROM tobj WHERE tid=?", args)
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args)
args)
q("DELETE FROM ttrans WHERE tid=?", args) q("DELETE FROM ttrans WHERE tid=?", args)
self.unlockData(data_id_list) self.unlockData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x]
self.unlockData(data_id_list) self.unlockData(data_id_list)
...@@ -455,7 +437,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -455,7 +437,7 @@ class SQLiteDatabaseManager(DatabaseManager):
if serial: if serial:
sql += " AND tid=?" sql += " AND tid=?"
args.append(util.u64(serial)) args.append(util.u64(serial))
with self as q: q = self.query
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args) data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args)
if x] if x]
q("DELETE" + sql, args) q("DELETE" + sql, args)
...@@ -480,7 +462,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -480,7 +462,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def getTransaction(self, tid, all=False): def getTransaction(self, tid, all=False):
tid = util.u64(tid) tid = util.u64(tid)
with self as q: q = self.query
r = q("SELECT oids, user, description, ext, packed, ttid" r = q("SELECT oids, user, description, ext, packed, ttid"
" FROM trans WHERE partition=? AND tid=?", " FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid)).fetchone() (self._getPartition(tid), tid)).fetchone()
...@@ -515,8 +497,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -515,8 +497,7 @@ class SQLiteDatabaseManager(DatabaseManager):
pack_tid = self._getPackTID() pack_tid = self._getPackTID()
result = [] result = []
append = result.append append = result.append
with self as q: for serial, length, value_serial in self.query("""\
for serial, length, value_serial in q("""\
SELECT tid, LENGTH(value), value_tid SELECT tid, LENGTH(value), value_tid
FROM obj LEFT JOIN data ON obj.data_id = data.id FROM obj LEFT JOIN data ON obj.data_id = data.id
WHERE partition=? AND oid=? AND tid>=? WHERE partition=? AND oid=? AND tid>=?
...@@ -587,10 +568,10 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -587,10 +568,10 @@ class SQLiteDatabaseManager(DatabaseManager):
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getPartition
with self as q: q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid," for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
" MAX(tid) FROM obj WHERE tid<=? GROUP BY oid", " FROM obj WHERE tid<=? GROUP BY oid",
(tid,)): (tid,)):
partition = getPartition(oid) partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition=?" if q("SELECT 1 FROM obj WHERE partition=?"
...@@ -613,6 +594,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -613,6 +594,7 @@ class SQLiteDatabaseManager(DatabaseManager):
q("DELETE" + sql, args) q("DELETE" + sql, args)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
self.commit()
def checkTIDRange(self, partition, length, min_tid, max_tid): def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tids, max_tid = self.query("""\ count, tids, max_tid = self.query("""\
......
...@@ -98,6 +98,7 @@ class StorageOperationHandler(EventHandler): ...@@ -98,6 +98,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid_list in object_dict.iteritems(): for serial, oid_list in object_dict.iteritems():
for oid in oid_list: for oid in oid_list:
deleteObject(oid, serial) deleteObject(oid, serial)
self.app.dm.commit()
assert not pack_tid, "TODO" assert not pack_tid, "TODO"
if next_tid: if next_tid:
self.app.replicator.fetchObjects(next_tid, next_oid) self.app.replicator.fetchObjects(next_tid, next_oid)
......
...@@ -20,7 +20,6 @@ from mock import Mock ...@@ -20,7 +20,6 @@ from mock import Mock
from neo.lib.util import add64, dump, p64, u64 from neo.lib.util import add64, dump, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.exception import DatabaseFailure
class StorageDBTests(NeoUnitTestBase): class StorageDBTests(NeoUnitTestBase):
...@@ -93,29 +92,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -93,29 +92,6 @@ class StorageDBTests(NeoUnitTestBase):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1)) self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1))
def test_transaction(self):
db = self.getDB()
x = []
class DB(db.__class__):
begin = lambda self: x.append('begin')
commit = lambda self: x.append('commit')
rollback = lambda self: x.append('rollback')
db.__class__ = DB
with db:
self.assertEqual(x.pop(), 'begin')
self.assertEqual(x.pop(), 'commit')
try:
with db:
self.assertEqual(x.pop(), 'begin')
with db:
self.fail()
self.fail()
except DatabaseFailure:
pass
self.assertEqual(x.pop(), 'rollback')
self.assertRaises(DatabaseFailure, db.__exit__, None, None, None)
self.assertFalse(x)
def test_getPartitionTable(self): def test_getPartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1) ptid = self.getPTID(1)
......
...@@ -300,8 +300,8 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -300,8 +300,8 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
pass pass
def switchTables(self): def switchTables(self):
with self.dm as q: q = self.dm.query
for table in ('trans', 'obj'): for table in 'trans', 'obj':
q('ALTER TABLE %s RENAME TO tmp' % table) q('ALTER TABLE %s RENAME TO tmp' % table)
q('ALTER TABLE t%s RENAME TO %s' % (table, table)) q('ALTER TABLE t%s RENAME TO %s' % (table, table))
q('ALTER TABLE tmp RENAME TO t%s' % table) q('ALTER TABLE tmp RENAME TO t%s' % table)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment