Commit 7ee7ff4e authored by Julien Muchembled's avatar Julien Muchembled

master: do never abort a prepared transaction

This fixes the following crash (for example when a client disconnects during
tpc_finish):

Traceback (most recent call last):
  ...
  File "neo/master/handlers/storage.py", line 68, in answerInformationLocked
    self.app.tm.lock(ttid, conn.getUUID())
  File "neo/master/transactions.py", line 338, in lock
    if self._ttid_dict[ttid].lock(uuid) and self._queue[0][1] == ttid:
IndexError: list index out of range
parent 7aecdada
...@@ -540,8 +540,6 @@ class Application(BaseApplication): ...@@ -540,8 +540,6 @@ class Application(BaseApplication):
if node is not None and node.isConnected(): if node is not None and node.isConnected():
node.getConnection().notify(notify_finished) node.getConnection().notify(notify_finished)
# remove transaction from manager
self.tm.remove(transaction_node.getUUID(), ttid)
assert self.last_transaction < tid, (self.last_transaction, tid) assert self.last_transaction < tid, (self.last_transaction, tid)
self.setLastTransaction(tid) self.setLastTransaction(tid)
......
...@@ -26,7 +26,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -26,7 +26,7 @@ class ClientServiceHandler(MasterHandler):
if app.listening_conn: # if running if app.listening_conn: # if running
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
assert node is not None assert node is not None
app.tm.abortFor(node) app.tm.clientLost(node)
node.setState(NodeStates.DOWN) node.setState(NodeStates.DOWN)
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
app.nm.remove(node) app.nm.remove(node)
...@@ -100,5 +100,5 @@ class ClientServiceHandler(MasterHandler): ...@@ -100,5 +100,5 @@ class ClientServiceHandler(MasterHandler):
conn.answer(Packets.AnswerPack(False)) conn.answer(Packets.AnswerPack(False))
def abortTransaction(self, conn, tid): def abortTransaction(self, conn, tid):
self.app.tm.remove(conn.getUUID(), tid) self.app.tm.abort(tid, conn.getUUID())
...@@ -39,7 +39,7 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -39,7 +39,7 @@ class StorageServiceHandler(BaseServiceHandler):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
super(StorageServiceHandler, self).connectionLost(conn, new_state) super(StorageServiceHandler, self).connectionLost(conn, new_state)
app.tm.forget(conn.getUUID()) app.tm.storageLost(conn.getUUID())
if (app.getClusterState() == ClusterStates.BACKINGUP if (app.getClusterState() == ClusterStates.BACKINGUP
# Also check if we're exiting, because backup_app is not usable # Also check if we're exiting, because backup_app is not usable
# in this case. Maybe cluster state should be set to something # in this case. Maybe cluster state should be set to something
...@@ -61,10 +61,6 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -61,10 +61,6 @@ class StorageServiceHandler(BaseServiceHandler):
conn.answer(p) conn.answer(p)
def answerInformationLocked(self, conn, ttid): def answerInformationLocked(self, conn, ttid):
tm = self.app.tm
if ttid not in tm:
raise ProtocolError('Unknown transaction')
# transaction locked on this storage node
self.app.tm.lock(ttid, conn.getUUID()) self.app.tm.lock(ttid, conn.getUUID())
def notifyPartitionCorrupted(self, conn, partition, cell_list): def notifyPartitionCorrupted(self, conn, partition, cell_list):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import deque
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from neo.lib import logging from neo.lib import logging
...@@ -121,7 +122,7 @@ class Transaction(object): ...@@ -121,7 +122,7 @@ class Transaction(object):
self._lock_wait_uuid_set = set(uuid_list) self._lock_wait_uuid_set = set(uuid_list)
self._prepared = True self._prepared = True
def forget(self, uuid): def storageLost(self, uuid):
""" """
Given storage was lost while waiting for its lock, stop waiting Given storage was lost while waiting for its lock, stop waiting
for it. for it.
...@@ -136,6 +137,14 @@ class Transaction(object): ...@@ -136,6 +137,14 @@ class Transaction(object):
return self.locked() return self.locked()
return False return False
def clientLost(self, node):
if self._node is node:
if self._prepared:
self._node = None # orphan
else:
return True # abort
return False
def lock(self, uuid): def lock(self, uuid):
""" """
Define that a node has locked the transaction Define that a node has locked the transaction
...@@ -163,19 +172,26 @@ class TransactionManager(object): ...@@ -163,19 +172,26 @@ class TransactionManager(object):
def reset(self): def reset(self):
# ttid -> transaction # ttid -> transaction
self._ttid_dict = {} self._ttid_dict = {}
# node -> transactions mapping
self._node_dict = {}
self._last_oid = ZERO_OID self._last_oid = ZERO_OID
self._last_tid = ZERO_TID self._last_tid = ZERO_TID
# queue filled with ttids pointing to transactions with increasing tids # queue filled with ttids pointing to transactions with increasing tids
self._queue = [] self._queue = deque()
def __getitem__(self, ttid): def __getitem__(self, ttid):
""" """
Return the transaction object for this TID Return the transaction object for this TID
""" """
# XXX: used by unit tests only try:
return self._ttid_dict[ttid] return self._ttid_dict[ttid]
except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid))
def __delitem__(self, ttid):
try:
self._queue.remove(ttid)
except ValueError:
pass
del self._ttid_dict[ttid]
def __contains__(self, ttid): def __contains__(self, ttid):
""" """
...@@ -272,61 +288,44 @@ class TransactionManager(object): ...@@ -272,61 +288,44 @@ class TransactionManager(object):
""" """
if tid is None: if tid is None:
# No TID requested, generate a temporary one # No TID requested, generate a temporary one
ttid = self._nextTID() tid = self._nextTID()
else: else:
# Use of specific TID requested, queue it immediately and update # Use of specific TID requested, queue it immediately and update
# last TID. # last TID.
self._queue.append((node.getUUID(), tid)) self._queue.append(tid)
self.setLastTID(tid) self.setLastTID(tid)
ttid = tid txn = self._ttid_dict[tid] = Transaction(node, tid)
txn = Transaction(node, ttid)
self._ttid_dict[ttid] = txn
self._node_dict.setdefault(node, {})[ttid] = txn
logging.debug('Begin %s', txn) logging.debug('Begin %s', txn)
return ttid return tid
def prepare(self, ttid, divisor, oid_list, uuid_list, msg_id): def prepare(self, ttid, divisor, oid_list, uuid_list, msg_id):
""" """
Prepare a transaction to be finished Prepare a transaction to be finished
""" """
# XXX: not efficient but the list should be often small txn = self[ttid]
try: # maybe not the fastest but _queue should be often small
txn = self._ttid_dict[ttid] if ttid in self._queue:
except KeyError: tid = ttid
raise ProtocolError("unknown ttid %s" % dump(ttid))
node = txn.getNode()
for _, tid in self._queue:
if ttid == tid:
break
else: else:
tid = self._nextTID(ttid, divisor) tid = self._nextTID(ttid, divisor)
self._queue.append((node.getUUID(), ttid)) self._queue.append(ttid)
logging.debug('Finish TXN %s for %s (was %s)', logging.debug('Finish TXN %s for %s (was %s)',
dump(tid), node, dump(ttid)) dump(tid), txn.getNode(), dump(ttid))
txn.prepare(tid, oid_list, uuid_list, msg_id) txn.prepare(tid, oid_list, uuid_list, msg_id)
# check if greater and foreign OID was stored # check if greater and foreign OID was stored
if oid_list: if oid_list:
self.setLastOID(max(oid_list)) self.setLastOID(max(oid_list))
return tid return tid
def remove(self, uuid, ttid): def abort(self, ttid, uuid):
""" """
Remove a transaction, commited or aborted Abort a transaction
""" """
logging.debug('Remove TXN %s', dump(ttid)) logging.debug('Abort TXN %s for %s', dump(ttid), uuid_str(uuid))
try: if self[ttid].isPrepared():
# only in case of an import: raise ProtocolError("commit already requested for ttid %s"
self._queue.remove((uuid, ttid)) % dump(ttid))
except ValueError: del self[ttid]
# finish might not have been started
pass
ttid_dict = self._ttid_dict
if ttid in ttid_dict:
txn = ttid_dict[ttid]
node = txn.getNode()
# ...and tried to finish
del ttid_dict[ttid]
del self._node_dict[node][ttid]
def lock(self, ttid, uuid): def lock(self, ttid, uuid):
""" """
...@@ -335,19 +334,20 @@ class TransactionManager(object): ...@@ -335,19 +334,20 @@ class TransactionManager(object):
instanciation time. instanciation time.
""" """
logging.debug('Lock TXN %s for %s', dump(ttid), uuid_str(uuid)) logging.debug('Lock TXN %s for %s', dump(ttid), uuid_str(uuid))
if self._ttid_dict[ttid].lock(uuid) and self._queue[0][1] == ttid: if self[ttid].lock(uuid) and self._queue[0] == ttid:
# all storage are locked and we unlock the commit queue # all storage are locked and we unlock the commit queue
self._unlockPending() self._unlockPending()
def forget(self, uuid): def storageLost(self, uuid):
""" """
A storage node has been lost, don't expect a reply from it for A storage node has been lost, don't expect a reply from it for
current transactions current transactions
""" """
unlock = False unlock = False
for ttid, txn in self._ttid_dict.iteritems(): for ttid, txn in self._ttid_dict.iteritems():
if txn.forget(uuid) and self._queue[0][1] == ttid: if txn.storageLost(uuid) and self._queue[0] == ttid:
unlock = True unlock = True
# do not break: we must call forget() on all transactions
if unlock: if unlock:
self._unlockPending() self._unlockPending()
...@@ -359,41 +359,20 @@ class TransactionManager(object): ...@@ -359,41 +359,20 @@ class TransactionManager(object):
is required is when some storages are already busy by other tasks. is required is when some storages are already busy by other tasks.
""" """
queue = self._queue queue = self._queue
pop = queue.pop self._on_commit(self._ttid_dict.pop(queue.popleft()))
insert = queue.insert
on_commit = self._on_commit
ttid_dict = self._ttid_dict
while queue: while queue:
uuid, ttid = pop(0) ttid = queue[0]
txn = ttid_dict[ttid] txn = self._ttid_dict[ttid]
if txn.locked(): if not txn.locked():
on_commit(txn)
else:
insert(0, (uuid, ttid))
break break
del queue[0], self._ttid_dict[ttid]
self._on_commit(txn)
def abortFor(self, node): def clientLost(self, node):
"""
Abort pending transactions initiated by a node
"""
# BUG: As soon as we have started to lock a transaction,
# we should complete it even if the client is lost.
# Of course, we won't reply to the FinishTransaction
# finish but we'll send invalidations to all clients.
logging.debug('Abort TXN for %s', node)
uuid = node.getUUID() uuid = node.getUUID()
# XXX: this loop is usefull only during an import for txn in self._ttid_dict.values():
for nuuid, ntid in list(self._queue): if txn.clientLost(node):
if nuuid == uuid: del self[txn.getTTID()]
self._queue.remove((uuid, ntid))
if node in self._node_dict:
# remove transactions
remove = self.remove
for ttid in self._node_dict[node].keys():
if not self._ttid_dict[ttid].isPrepared():
remove(uuid, ttid)
# discard node entry
del self._node_dict[node]
def log(self): def log(self):
logging.info('Transactions:') logging.info('Transactions:')
......
...@@ -130,7 +130,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -130,7 +130,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.assertFalse(self.app.isStorageReady(storage_uuid)) self.assertFalse(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, (), ()) service.askFinishTransaction(conn, ttid, (), ())
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
self.app.tm.abortFor(self.app.nm.getByUUID(client_uuid))
# ...but AskLockInformation is sent if it is ready # ...but AskLockInformation is sent if it is ready
self.app.setStorageReady(storage_uuid) self.app.setStorageReady(storage_uuid)
self.assertTrue(self.app.isStorageReady(storage_uuid)) self.assertTrue(self.app.isStorageReady(storage_uuid))
......
...@@ -85,41 +85,15 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -85,41 +85,15 @@ class testTransactionManager(NeoUnitTestBase):
self.assertEqual(len(callback.getNamedCalls('__call__')), 0) self.assertEqual(len(callback.getNamedCalls('__call__')), 0)
txnman.lock(ttid, uuid2) txnman.lock(ttid, uuid2)
self.assertEqual(len(callback.getNamedCalls('__call__')), 1) self.assertEqual(len(callback.getNamedCalls('__call__')), 1)
# transaction finished
txnman.remove(client_uuid, ttid)
self.assertEqual(txnman.registerForNotification(uuid1), []) self.assertEqual(txnman.registerForNotification(uuid1), [])
def testAbortFor(self): def test_storageLost(self):
oid_list = [self.makeOID(1), ]
storage_1_uuid, node1 = self.makeNode(NodeTypes.STORAGE)
storage_2_uuid, node2 = self.makeNode(NodeTypes.STORAGE)
client_uuid, client = self.makeNode(NodeTypes.CLIENT)
txnman = TransactionManager(lambda tid, txn: None)
# register 4 transactions made by two nodes
self.assertEqual(txnman.registerForNotification(storage_1_uuid), [])
ttid1 = txnman.begin(client)
tid1 = txnman.prepare(ttid1, 1, oid_list, [storage_1_uuid], 1)
self.assertEqual(txnman.registerForNotification(storage_1_uuid), [ttid1])
# abort transactions of another node, transaction stays
txnman.abortFor(node2)
self.assertEqual(txnman.registerForNotification(storage_1_uuid), [ttid1])
# abort transactions of requesting node, transaction is not removed
# because the transaction is prepared and must remains until the end of
# the 2PC
txnman.abortFor(node1)
self.assertEqual(txnman.registerForNotification(storage_1_uuid), [ttid1])
self.assertTrue(txnman.hasPending())
# ...and the lock is available
txnman.begin(client, self.getNextTID())
def test_forget(self):
client1 = Mock({'__hash__': 1}) client1 = Mock({'__hash__': 1})
client2 = Mock({'__hash__': 2}) client2 = Mock({'__hash__': 2})
client3 = Mock({'__hash__': 3}) client3 = Mock({'__hash__': 3})
storage_1_uuid = self.getStorageUUID() storage_1_uuid = self.getStorageUUID()
storage_2_uuid = self.getStorageUUID() storage_2_uuid = self.getStorageUUID()
oid_list = [self.makeOID(1), ] oid_list = [self.makeOID(1), ]
client_uuid = self.getClientUUID()
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(lambda tid, txn: None)
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
...@@ -133,9 +107,9 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -133,9 +107,9 @@ class testTransactionManager(NeoUnitTestBase):
self.assertFalse(t1.locked()) self.assertFalse(t1.locked())
# Storage 1 dies: # Storage 1 dies:
# t1 is over # t1 is over
self.assertTrue(t1.forget(storage_1_uuid)) self.assertTrue(t1.storageLost(storage_1_uuid))
self.assertEqual(t1.getUUIDList(), [storage_2_uuid]) self.assertEqual(t1.getUUIDList(), [storage_2_uuid])
tm.remove(client_uuid, tid1) del tm[ttid1]
# Transaction 2: 2 storage nodes involved, one will die # Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2 msg_id_2 = 2
...@@ -146,10 +120,10 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -146,10 +120,10 @@ class testTransactionManager(NeoUnitTestBase):
self.assertFalse(t2.locked()) self.assertFalse(t2.locked())
# Storage 1 dies: # Storage 1 dies:
# t2 still waits for storage 2 # t2 still waits for storage 2
self.assertFalse(t2.forget(storage_1_uuid)) self.assertFalse(t2.storageLost(storage_1_uuid))
self.assertEqual(t2.getUUIDList(), [storage_2_uuid]) self.assertEqual(t2.getUUIDList(), [storage_2_uuid])
self.assertTrue(t2.lock(storage_2_uuid)) self.assertTrue(t2.lock(storage_2_uuid))
tm.remove(client_uuid, tid2) del tm[ttid2]
# Transaction 3: 1 storage node involved, which won't die # Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3 msg_id_3 = 3
...@@ -160,10 +134,10 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -160,10 +134,10 @@ class testTransactionManager(NeoUnitTestBase):
self.assertFalse(t3.locked()) self.assertFalse(t3.locked())
# Storage 1 dies: # Storage 1 dies:
# t3 doesn't care # t3 doesn't care
self.assertFalse(t3.forget(storage_1_uuid)) self.assertFalse(t3.storageLost(storage_1_uuid))
self.assertEqual(t3.getUUIDList(), [storage_2_uuid]) self.assertEqual(t3.getUUIDList(), [storage_2_uuid])
self.assertTrue(t3.lock(storage_2_uuid)) self.assertTrue(t3.lock(storage_2_uuid))
tm.remove(client_uuid, tid3) del tm[ttid3]
def testTIDUtils(self): def testTIDUtils(self):
""" """
...@@ -204,13 +178,13 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -204,13 +178,13 @@ class testTransactionManager(NeoUnitTestBase):
ttid2 = self.getNextTID() ttid2 = self.getNextTID()
tid1 = tm.begin(client, ttid1) tid1 = tm.begin(client, ttid1)
self.assertEqual(tid1, ttid1) self.assertEqual(tid1, ttid1)
tm.remove(client_uuid, tid1) del tm[ttid1]
# Without a requested TID, lock spans from prepare to remove only # Without a requested TID, lock spans from prepare to remove only
ttid3 = tm.begin(client) ttid3 = tm.begin(client)
ttid4 = tm.begin(client) # Doesn't raise ttid4 = tm.begin(client) # Doesn't raise
node = Mock({'getUUID': client_uuid, '__hash__': 0}) node = Mock({'getUUID': client_uuid, '__hash__': 0})
tid4 = tm.prepare(ttid4, 1, [], [], 0) tid4 = tm.prepare(ttid4, 1, [], [], 0)
tm.remove(client_uuid, tid4) del tm[ttid4]
tm.prepare(ttid3, 1, [], [], 0) tm.prepare(ttid3, 1, [], [], 0)
def testClientDisconectsAfterBegin(self): def testClientDisconectsAfterBegin(self):
...@@ -219,7 +193,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -219,7 +193,7 @@ class testTransactionManager(NeoUnitTestBase):
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tm.begin(node1, tid1) tm.begin(node1, tid1)
tm.abortFor(node1) tm.clientLost(node1)
self.assertTrue(tid1 not in tm) self.assertTrue(tid1 not in tm)
def testUnlockPending(self): def testUnlockPending(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment