Commit 005cb829 authored by Julien Muchembled's avatar Julien Muchembled

Guess last oid from 'obj' tables instead of saving if in 'config'

This fixes a bug in backup mode where 'loid' was not updated.
parent 1aab7d19
...@@ -24,6 +24,7 @@ SQL commands to migrate each storage from NEO 0.10.x:: ...@@ -24,6 +24,7 @@ SQL commands to migrate each storage from NEO 0.10.x::
UPDATE pt, uuid SET pt.uuid=uuid.new, state=state-1 WHERE pt.old=uuid.old; UPDATE pt, uuid SET pt.uuid=uuid.new, state=state-1 WHERE pt.old=uuid.old;
ALTER TABLE pt DROP old, ADD PRIMARY KEY (rid, uuid); ALTER TABLE pt DROP old, ADD PRIMARY KEY (rid, uuid);
UPDATE config, uuid SET config.value=uuid.new WHERE config.name='uuid' AND uuid.old=config.value; UPDATE config, uuid SET config.value=uuid.new WHERE config.name='uuid' AND uuid.old=config.value;
DELETE FROM config WHERE name='loid';
NEO 0.10 NEO 0.10
======== ========
......
...@@ -26,7 +26,7 @@ except ImportError: ...@@ -26,7 +26,7 @@ except ImportError:
pass pass
# The protocol version (major, minor). # The protocol version (major, minor).
PROTOCOL_VERSION = (10, 1) PROTOCOL_VERSION = (11, 1)
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -1192,14 +1192,6 @@ class ClusterState(Packet): ...@@ -1192,14 +1192,6 @@ class ClusterState(Packet):
PEnum('state', ClusterStates), PEnum('state', ClusterStates),
) )
class NotifyLastOID(Packet):
"""
Notify last OID generated
"""
_fmt = PStruct('notify_last_oid',
POID('last_oid'),
)
class ObjectUndoSerial(Packet): class ObjectUndoSerial(Packet):
""" """
Ask storage the serial where object data is when undoing given transaction, Ask storage the serial where object data is when undoing given transaction,
...@@ -1682,8 +1674,6 @@ class Packets(dict): ...@@ -1682,8 +1674,6 @@ class Packets(dict):
ClusterInformation) ClusterInformation)
AskClusterState, AnswerClusterState = register( AskClusterState, AnswerClusterState = register(
ClusterState) ClusterState)
NotifyLastOID = register(
NotifyLastOID)
AskObjectUndoSerial, AnswerObjectUndoSerial = register( AskObjectUndoSerial, AnswerObjectUndoSerial = register(
ObjectUndoSerial) ObjectUndoSerial)
AskHasLock, AnswerHasLock = register( AskHasLock, AnswerHasLock = register(
......
...@@ -267,13 +267,6 @@ class Application(object): ...@@ -267,13 +267,6 @@ class Application(object):
if selector(node): if selector(node):
node.notify(packet) node.notify(packet)
def broadcastLastOID(self):
oid = self.tm.getLastOID()
logging.debug('Broadcast last OID to storages : %s', dump(oid))
packet = Packets.NotifyLastOID(oid)
for node in self.nm.getStorageList(only_identified=True):
node.notify(packet)
def provideService(self): def provideService(self):
""" """
This is the normal mode for a primary master node. Handle transactions This is the normal mode for a primary master node. Handle transactions
......
...@@ -55,9 +55,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -55,9 +55,7 @@ class ClientServiceHandler(MasterHandler):
conn.answer(Packets.AnswerBeginTransaction(app.tm.begin(node, tid))) conn.answer(Packets.AnswerBeginTransaction(app.tm.begin(node, tid)))
def askNewOIDs(self, conn, num_oids): def askNewOIDs(self, conn, num_oids):
app = self.app conn.answer(Packets.AnswerNewOIDs(self.app.tm.getNextOIDList(num_oids)))
conn.answer(Packets.AnswerNewOIDs(app.tm.getNextOIDList(num_oids)))
app.broadcastLastOID()
def askFinishTransaction(self, conn, ttid, oid_list): def askFinishTransaction(self, conn, ttid, oid_list):
app = self.app app = self.app
...@@ -87,10 +85,6 @@ class ClientServiceHandler(MasterHandler): ...@@ -87,10 +85,6 @@ class ClientServiceHandler(MasterHandler):
tid = app.tm.prepare(ttid, partitions, oid_list, usable_uuid_set, tid = app.tm.prepare(ttid, partitions, oid_list, usable_uuid_set,
peer_id) peer_id)
# check if greater and foreign OID was stored
if app.tm.updateLastOID(oid_list):
app.broadcastLastOID()
# Request locking data. # Request locking data.
# build a new set as we may not send the message to all nodes as some # build a new set as we may not send the message to all nodes as some
# might be not reachable at that time # might be not reachable at that time
......
...@@ -257,17 +257,6 @@ class TransactionManager(object): ...@@ -257,17 +257,6 @@ class TransactionManager(object):
self._last_oid = oid_list[-1] self._last_oid = oid_list[-1]
return oid_list return oid_list
def updateLastOID(self, oid_list):
"""
Updates the last oid with the max of those supplied if greater than
the current known, returns True if changed
"""
max_oid = oid_list and max(oid_list) or None # oid_list might be empty
if max_oid > self._last_oid:
self._last_oid = max_oid
return True
return False
def setLastOID(self, oid): def setLastOID(self, oid):
self._last_oid = max(oid, self._last_oid) self._last_oid = max(oid, self._last_oid)
...@@ -391,6 +380,9 @@ class TransactionManager(object): ...@@ -391,6 +380,9 @@ class TransactionManager(object):
logging.debug('Finish TXN %s for %s (was %s)', logging.debug('Finish TXN %s for %s (was %s)',
dump(tid), node, dump(ttid)) dump(tid), node, dump(ttid))
txn.prepare(tid, oid_list, uuid_list, msg_id) txn.prepare(tid, oid_list, uuid_list, msg_id)
# check if greater and foreign OID was stored
if oid_list:
self.setLastOID(max(oid_list))
return tid return tid
def remove(self, uuid, ttid): def remove(self, uuid, ttid):
......
...@@ -190,12 +190,6 @@ class VerificationManager(BaseServiceHandler): ...@@ -190,12 +190,6 @@ class VerificationManager(BaseServiceHandler):
return uuid_set return uuid_set
def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
# FIXME: this packet should not allowed here, the master already
# accepted the current partition table end IDs. As there were manually
# approved during recovery, there is no need to check them here.
raise RuntimeError
def answerUnfinishedTransactions(self, conn, max_tid, tid_list): def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
uuid = conn.getUUID() uuid = conn.getUUID()
logging.info('got unfinished transactions %s from %r', logging.info('got unfinished transactions %s from %r',
......
...@@ -171,18 +171,6 @@ class DatabaseManager(object): ...@@ -171,18 +171,6 @@ class DatabaseManager(object):
ptid = str(ptid) ptid = str(ptid)
self.setConfiguration('ptid', ptid) self.setConfiguration('ptid', ptid)
def getLastOID(self):
"""
Returns the last OID used
"""
return util.bin(self.getConfiguration('loid'))
def setLastOID(self, loid):
"""
Set the last OID used
"""
self.setConfiguration('loid', util.dump(loid))
def getBackupTID(self): def getBackupTID(self):
return util.bin(self.getConfiguration('backup_tid')) return util.bin(self.getConfiguration('backup_tid'))
...@@ -195,18 +183,18 @@ class DatabaseManager(object): ...@@ -195,18 +183,18 @@ class DatabaseManager(object):
node, and a cell state.""" node, and a cell state."""
raise NotImplementedError raise NotImplementedError
def _getLastTIDs(self, all=True): def _getLastIDs(self, all=True):
raise NotImplementedError raise NotImplementedError
def getLastTIDs(self, all=True): def getLastIDs(self, all=True):
trans, obj = self._getLastTIDs() trans, obj, oid = self._getLastIDs()
if trans: if trans:
tid = max(trans.itervalues()) tid = max(trans.itervalues())
if obj: if obj:
tid = max(tid, max(obj.itervalues())) tid = max(tid, max(obj.itervalues()))
else: else:
tid = max(obj.itervalues()) if obj else None tid = max(obj.itervalues()) if obj else None
return tid, trans, obj return tid, trans, obj, oid
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
"""Return a list of unfinished transaction's IDs.""" """Return a list of unfinished transaction's IDs."""
......
...@@ -269,7 +269,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -269,7 +269,7 @@ class MySQLDatabaseManager(DatabaseManager):
def getPartitionTable(self): def getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getLastTIDs(self, all=True): def _getLastIDs(self, all=True):
p64 = util.p64 p64 = util.p64
with self as q: with self as q:
trans = dict((partition, p64(tid)) trans = dict((partition, p64(tid))
...@@ -278,14 +278,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -278,14 +278,18 @@ class MySQLDatabaseManager(DatabaseManager):
obj = dict((partition, p64(tid)) obj = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM obj GROUP BY partition")) " FROM obj GROUP BY partition"))
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t")[0][0]
if all: if all:
tid = q("SELECT MAX(tid) FROM ttrans")[0][0] tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if tid is not None: if tid is not None:
trans[None] = p64(tid) trans[None] = p64(tid)
tid = q("SELECT MAX(tid) FROM tobj")[0][0] tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj")[0]
if tid is not None: if tid is not None:
obj[None] = p64(tid) obj[None] = p64(tid)
return trans, obj if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
tid_set = set() tid_set = set()
......
...@@ -224,7 +224,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -224,7 +224,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def getPartitionTable(self): def getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getLastTIDs(self, all=True): def _getLastIDs(self, all=True):
p64 = util.p64 p64 = util.p64
with self as q: with self as q:
trans = dict((partition, p64(tid)) trans = dict((partition, p64(tid))
...@@ -233,14 +233,18 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -233,14 +233,18 @@ class SQLiteDatabaseManager(DatabaseManager):
obj = dict((partition, p64(tid)) obj = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q("SELECT partition, MAX(tid)"
" FROM obj GROUP BY partition")) " FROM obj GROUP BY partition"))
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t").next()[0]
if all: if all:
tid = q("SELECT MAX(tid) FROM ttrans").fetchone()[0] tid = q("SELECT MAX(tid) FROM ttrans").next()[0]
if tid is not None: if tid is not None:
trans[None] = p64(tid) trans[None] = p64(tid)
tid = q("SELECT MAX(tid) FROM tobj").fetchone()[0] tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj").next()
if tid is not None: if tid is not None:
obj[None] = p64(tid) obj[None] = p64(tid)
return trans, obj if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
p64 = util.p64 p64 = util.p64
......
...@@ -35,9 +35,6 @@ class BaseMasterHandler(EventHandler): ...@@ -35,9 +35,6 @@ class BaseMasterHandler(EventHandler):
def notifyClusterInformation(self, conn, state): def notifyClusterInformation(self, conn, state):
self.app.changeClusterState(state) self.app.changeClusterState(state)
def notifyLastOID(self, conn, oid):
self.app.dm.setLastOID(oid)
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
"""Store information on nodes, only if this is sent by a primary """Store information on nodes, only if this is sent by a primary
master node.""" master node."""
......
...@@ -47,7 +47,6 @@ class InitializationHandler(BaseMasterHandler): ...@@ -47,7 +47,6 @@ class InitializationHandler(BaseMasterHandler):
app.dm.setPartitionTable(ptid, cell_list) app.dm.setPartitionTable(ptid, cell_list)
def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid): def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
self.app.dm.setLastOID(loid)
self.app.dm.setBackupTID(backup_tid) self.app.dm.setBackupTID(backup_tid)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
......
...@@ -25,9 +25,10 @@ class VerificationHandler(BaseMasterHandler): ...@@ -25,9 +25,10 @@ class VerificationHandler(BaseMasterHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
ltid, _, _, loid = app.dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs( conn.answer(Packets.AnswerLastIDs(
app.dm.getLastOID(), loid,
app.dm.getLastTIDs()[0], ltid,
app.pt.getID(), app.pt.getID(),
app.dm.getBackupTID())) app.dm.getBackupTID()))
......
...@@ -125,7 +125,7 @@ class Replicator(object): ...@@ -125,7 +125,7 @@ class Replicator(object):
self.replicate_dict = {} self.replicate_dict = {}
self.source_dict = {} self.source_dict = {}
self.ttid_set = set() self.ttid_set = set()
last_tid, last_trans_dict, last_obj_dict = app.dm.getLastTIDs() last_tid, last_trans_dict, last_obj_dict, _ = app.dm.getLastIDs()
backup_tid = app.dm.getBackupTID() backup_tid = app.dm.getBackupTID()
if backup_tid and last_tid < backup_tid: if backup_tid and last_tid < backup_tid:
last_tid = backup_tid last_tid = backup_tid
...@@ -154,7 +154,7 @@ class Replicator(object): ...@@ -154,7 +154,7 @@ class Replicator(object):
abort = False abort = False
added_list = [] added_list = []
app = self.app app = self.app
last_tid, last_trans_dict, last_obj_dict = app.dm.getLastTIDs() last_tid, last_trans_dict, last_obj_dict, _ = app.dm.getLastIDs()
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
if uuid == app.uuid: if uuid == app.uuid:
if state in (CellStates.DISCARDED, CellStates.CORRUPTED): if state in (CellStates.DISCARDED, CellStates.CORRUPTED):
......
...@@ -102,9 +102,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -102,9 +102,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
service.askNewOIDs(conn, 1) service.askNewOIDs(conn, 1)
self.assertTrue(self.app.tm.getLastOID() > oid1) self.assertTrue(self.app.tm.getLastOID() > oid1)
for node in self.app.nm.getStorageList():
conn = node.getConnection()
self.assertEqual(self.checkNotifyLastOID(conn, decode=True), (oid2,))
def test_09_askFinishTransaction(self): def test_09_askFinishTransaction(self):
service = self.service service = self.service
......
...@@ -126,13 +126,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -126,13 +126,6 @@ class StorageDBTests(NeoUnitTestBase):
result = db.getPartitionTable() result = db.getPartitionTable()
self.assertEqual(set(result), set([cell1, cell2])) self.assertEqual(set(result), set([cell1, cell2]))
def test_getLastOID(self):
db = self.getDB()
oid1 = self.getOID(1)
db.setLastOID(oid1)
result1 = db.getLastOID()
self.assertEqual(result1, oid1)
def getOIDs(self, count): def getOIDs(self, count):
return map(self.getOID, xrange(count)) return map(self.getOID, xrange(count))
...@@ -153,22 +146,23 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -153,22 +146,23 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2)) self.assertEqual(set(list1), set(list2))
def test_getLastTIDs(self): def test_getLastIDs(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
txn, objs = self.getTransaction([oid1, oid2]) txn, objs = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs, txn, False) self.db.storeTransaction(tid1, objs, txn, False)
self.db.storeTransaction(tid2, objs, txn, False) self.db.storeTransaction(tid2, objs, txn, False)
self.assertEqual(self.db.getLastTIDs(), (tid2, {0: tid2}, {0: tid2})) self.assertEqual(self.db.getLastIDs(),
(tid2, {0: tid2}, {0: tid2}, oid2))
self.db.storeTransaction(tid3, objs, txn) self.db.storeTransaction(tid3, objs, txn)
tids = {0: tid2, None: tid3} tids = {0: tid2, None: tid3}
self.assertEqual(self.db.getLastTIDs(), (tid3, tids, tids)) self.assertEqual(self.db.getLastIDs(), (tid3, tids, tids, oid2))
self.db.storeTransaction(tid4, objs, None) self.db.storeTransaction(tid4, objs, None)
self.assertEqual(self.db.getLastTIDs(), self.assertEqual(self.db.getLastIDs(),
(tid4, tids, {0: tid2, None: tid4})) (tid4, tids, {0: tid2, None: tid4}, oid2))
self.db.finishTransaction(tid3) self.db.finishTransaction(tid3)
self.assertEqual(self.db.getLastTIDs(), self.assertEqual(self.db.getLastIDs(),
(tid4, {0: tid3}, {0: tid3, None: tid4})) (tid4, {0: tid3}, {0: tid3, None: tid4}, oid2))
def test_getUnfinishedTIDList(self): def test_getUnfinishedTIDList(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
...@@ -320,7 +314,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -320,7 +314,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# nothing in database # nothing in database
self.assertEqual(self.db.getLastTIDs(), (None, {}, {})) self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None))
self.assertEqual(self.db.getUnfinishedTIDList(), []) self.assertEqual(self.db.getUnfinishedTIDList(), [])
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment