Commit a4c06242 authored by Julien Muchembled's avatar Julien Muchembled

Review aborting of transactions

parent dc7a129c
...@@ -595,17 +595,14 @@ class Application(ThreadedApplication): ...@@ -595,17 +595,14 @@ class Application(ThreadedApplication):
txn_context = self._txn_container.pop(transaction) txn_context = self._txn_container.pop(transaction)
if txn_context is None: if txn_context is None:
return return
p = Packets.AbortTransaction(txn_context['ttid'])
# cancel transaction on all those nodes
conns = [self.master_conn]
for uuid in txn_context['involved_nodes']:
node = self.nm.getByUUID(uuid)
if node is not None:
conns.append(self.cp.getConnForNode(node))
for conn in conns:
if conn is not None:
try: try:
conn.notify(p) notify = self.master_conn.notify
except AttributeError:
pass
else:
try:
notify(Packets.AbortTransaction(txn_context['ttid'],
txn_context['involved_nodes']))
except ConnectionClosed: except ConnectionClosed:
pass pass
# We don't need to flush queue, as it won't be reused by future # We don't need to flush queue, as it won't be reused by future
......
...@@ -976,10 +976,11 @@ class StoreObject(Packet): ...@@ -976,10 +976,11 @@ class StoreObject(Packet):
class AbortTransaction(Packet): class AbortTransaction(Packet):
""" """
Abort a transaction. C -> S, PM. Abort a transaction. C -> PM -> S.
""" """
_fmt = PStruct('abort_transaction', _fmt = PStruct('abort_transaction',
PTID('tid'), PTID('tid'),
PFUUIDList, # unused for PM -> S
) )
class StoreTransaction(Packet): class StoreTransaction(Packet):
......
...@@ -585,6 +585,6 @@ class Application(BaseApplication): ...@@ -585,6 +585,6 @@ class Application(BaseApplication):
def isStorageReady(self, uuid): def isStorageReady(self, uuid):
return uuid in self.storage_ready_dict return uuid in self.storage_ready_dict
def getStorageReadySet(self, readiness): def getStorageReadySet(self, readiness=float('inf')):
return {k for k, v in self.storage_ready_dict.iteritems() return {k for k, v in self.storage_ready_dict.iteritems()
if v <= readiness} if v <= readiness}
...@@ -74,6 +74,12 @@ class ClientServiceHandler(MasterHandler): ...@@ -74,6 +74,12 @@ class ClientServiceHandler(MasterHandler):
node.ask(p, timeout=60) node.ask(p, timeout=60)
else: else:
conn.answer(Errors.IncompleteTransaction()) conn.answer(Errors.IncompleteTransaction())
# It's simpler to abort automatically rather than asking the client
# to send a notification on tpc_abort, since it would have keep the
# transaction longer in list of transactions.
# This should happen so rarely that we don't try to minimize the
# number of abort notifications by looking the modified partitions.
self.abortTransaction(conn, ttid, app.getStorageReadySet())
def askFinalTID(self, conn, ttid): def askFinalTID(self, conn, ttid):
tm = self.app.tm tm = self.app.tm
...@@ -102,9 +108,24 @@ class ClientServiceHandler(MasterHandler): ...@@ -102,9 +108,24 @@ class ClientServiceHandler(MasterHandler):
else: else:
conn.answer(Packets.AnswerPack(False)) conn.answer(Packets.AnswerPack(False))
def abortTransaction(self, conn, tid): def abortTransaction(self, conn, tid, uuid_list):
# BUG: The replicator may wait this transaction to be finished. # Consider a failure when the connection between the storage and the
self.app.tm.abort(tid, conn.getUUID()) # client breaks while the answer to the first write is sent back.
# In other words, the client can not know the exact set of nodes that
# know this transaction, and it sends us all nodes it considered for
# writing.
# We must also add those that are waiting for this transaction to be
# finished (returned by tm.abort), because they may have join the
# cluster after that the client started to abort.
app = self.app
involved = app.tm.abort(tid, conn.getUUID())
involved.update(uuid_list)
involved.intersection_update(app.getStorageReadySet())
if involved:
p = Packets.AbortTransaction(tid, ())
getByUUID = app.nm.getByUUID
for involved in involved:
getByUUID(involved).notify(p)
# like ClientServiceHandler but read-only & only for tid <= backup_tid # like ClientServiceHandler but read-only & only for tid <= backup_tid
......
...@@ -393,10 +393,12 @@ class TransactionManager(EventQueue): ...@@ -393,10 +393,12 @@ class TransactionManager(EventQueue):
Abort a transaction Abort a transaction
""" """
logging.debug('Abort TXN %s for %s', dump(ttid), uuid_str(uuid)) logging.debug('Abort TXN %s for %s', dump(ttid), uuid_str(uuid))
if self[ttid].isPrepared(): txn = self[ttid]
if txn.isPrepared():
raise ProtocolError("commit already requested for ttid %s" raise ProtocolError("commit already requested for ttid %s"
% dump(ttid)) % dump(ttid))
del self[ttid] del self[ttid]
return txn._notification_set
def lock(self, ttid, uuid): def lock(self, ttid, uuid):
""" """
......
...@@ -57,9 +57,6 @@ class ClientOperationHandler(EventHandler): ...@@ -57,9 +57,6 @@ class ClientOperationHandler(EventHandler):
compression, checksum, data, data_serial) compression, checksum, data, data_serial)
conn.answer(p) conn.answer(p)
def abortTransaction(self, conn, ttid):
self.app.tm.abort(ttid)
def askStoreTransaction(self, conn, ttid, *txn_info): def askStoreTransaction(self, conn, ttid, *txn_info):
self.app.tm.register(conn, ttid) self.app.tm.register(conn, ttid)
self.app.tm.vote(ttid, txn_info) self.app.tm.vote(ttid, txn_info)
...@@ -201,7 +198,6 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler): ...@@ -201,7 +198,6 @@ class ClientReadOnlyOperationHandler(ClientOperationHandler):
conn.answer(Errors.ReadOnlyAccess( conn.answer(Errors.ReadOnlyAccess(
'read-only access because cluster is in backuping mode')) 'read-only access because cluster is in backuping mode'))
abortTransaction = _readOnly
askStoreTransaction = _readOnly askStoreTransaction = _readOnly
askVoteTransaction = _readOnly askVoteTransaction = _readOnly
askStoreObject = _readOnly askStoreObject = _readOnly
......
...@@ -57,6 +57,10 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -57,6 +57,10 @@ class MasterOperationHandler(BaseMasterHandler):
def notifyUnlockInformation(self, conn, ttid): def notifyUnlockInformation(self, conn, ttid):
self.app.tm.unlock(ttid) self.app.tm.unlock(ttid)
def abortTransaction(self, conn, ttid, _):
self.app.tm.abort(ttid)
self.app.replicator.transactionFinished(ttid)
def askPack(self, conn, tid): def askPack(self, conn, tid):
app = self.app app = self.app
logging.info('Pack started, up to %s...', dump(tid)) logging.info('Pack started, up to %s...', dump(tid))
......
...@@ -135,13 +135,18 @@ class Replicator(object): ...@@ -135,13 +135,18 @@ class Replicator(object):
self.replicate_dict[offset] = max_tid self.replicate_dict[offset] = max_tid
self._nextPartition() self._nextPartition()
def transactionFinished(self, ttid, max_tid): def transactionFinished(self, ttid, max_tid=None):
""" Callback from MasterOperationHandler """ """ Callback from MasterOperationHandler """
try:
self.ttid_set.remove(ttid) self.ttid_set.remove(ttid)
except KeyError:
assert max_tid is None, max_tid
return
min_ttid = min(self.ttid_set) if self.ttid_set else INVALID_TID min_ttid = min(self.ttid_set) if self.ttid_set else INVALID_TID
for offset, p in self.partition_dict.iteritems(): for offset, p in self.partition_dict.iteritems():
if p.max_ttid and p.max_ttid < min_ttid: if p.max_ttid and p.max_ttid < min_ttid:
p.max_ttid = None p.max_ttid = None
if max_tid:
self.replicate_dict[offset] = max_tid self.replicate_dict[offset] = max_tid
self._nextPartition() self._nextPartition()
...@@ -355,7 +360,9 @@ class Replicator(object): ...@@ -355,7 +360,9 @@ class Replicator(object):
p = self.partition_dict[offset] p = self.partition_dict[offset]
p.next_obj = add64(tid, 1) p.next_obj = add64(tid, 1)
self.updateBackupTID() self.updateBackupTID()
if not p.max_ttid: if p.max_ttid:
logging.debug("unfinished transactions: %r", self.ttid_set)
else:
self.app.tm.replicated(offset, tid) self.app.tm.replicated(offset, tid)
logging.debug("partition %u replicated up to %s from %r", logging.debug("partition %u replicated up to %s from %r",
offset, dump(tid), self.current_node) offset, dump(tid), self.current_node)
......
...@@ -324,9 +324,8 @@ class TransactionManager(EventQueue): ...@@ -324,9 +324,8 @@ class TransactionManager(EventQueue):
Note: does not alter persistent content. Note: does not alter persistent content.
""" """
if ttid not in self._transaction_dict: if ttid not in self._transaction_dict:
# the tid may be unknown as the transaction is aborted on every node assert not even_if_locked
# of the partition, even if no data was received (eg. conflict on # See how the master processes AbortTransaction from the client.
# another node)
return return
logging.debug('Abort TXN %s', dump(ttid)) logging.debug('Abort TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
......
...@@ -37,6 +37,7 @@ from time import time ...@@ -37,6 +37,7 @@ from time import time
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
try: try:
from transaction.interfaces import IDataManager
from ZODB.utils import newTid from ZODB.utils import newTid
except ImportError: except ImportError:
pass pass
...@@ -378,6 +379,30 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -378,6 +379,30 @@ class NeoUnitTestBase(NeoTestBase):
return packet return packet
class TransactionalResource(object):
class _sortKey(object):
def __init__(self, last):
self._last = last
def __cmp__(self, other):
assert type(self) is not type(other), other
return 1 if self._last else -1
def __init__(self, txn, last, **kw):
self.sortKey = lambda: self._sortKey(last)
for k in kw:
assert callable(IDataManager.get(k)), k
self.__dict__.update(kw)
txn.get().join(self)
def __getattr__(self, attr):
if callable(IDataManager.get(attr)):
return lambda *_: None
return self.__getattribute__(attr)
class Patch(object): class Patch(object):
""" """
Patch attributes and revert later automatically. Patch attributes and revert later automatically.
......
...@@ -956,13 +956,12 @@ class NEOThreadedTest(NeoTestBase): ...@@ -956,13 +956,12 @@ class NEOThreadedTest(NeoTestBase):
return obj return obj
return unpickler return unpickler
class newThread(threading.Thread): class newPausedThread(threading.Thread):
def __init__(self, func, *args, **kw): def __init__(self, func, *args, **kw):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.__target = func, args, kw self.__target = func, args, kw
self.daemon = True self.daemon = True
self.start()
def run(self): def run(self):
try: try:
...@@ -978,6 +977,12 @@ class NEOThreadedTest(NeoTestBase): ...@@ -978,6 +977,12 @@ class NEOThreadedTest(NeoTestBase):
del self.__exc_info del self.__exc_info
raise etype, value, tb raise etype, value, tb
class newThread(newPausedThread):
def __init__(self, *args, **kw):
NEOThreadedTest.newPausedThread.__init__(self, *args, **kw)
self.start()
def commitWithStorageFailure(self, client, txn): def commitWithStorageFailure(self, client, txn):
with Patch(client, _getFinalTID=lambda *_: None): with Patch(client, _getFinalTID=lambda *_: None):
self.assertRaises(ConnectionClosed, txn.commit) self.assertRaises(ConnectionClosed, txn.commit)
......
...@@ -32,7 +32,7 @@ from neo.lib.connection import ServerConnection, MTClientConnection ...@@ -32,7 +32,7 @@ from neo.lib.connection import ServerConnection, MTClientConnection
from neo.lib.exception import DatabaseFailure, StoppedOperation from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID ZERO_OID, ZERO_TID
from .. import expectedFailure, Patch from .. import expectedFailure, Patch, TransactionalResource
from . import ConnectionFilter, LockLock, NEOThreadedTest, with_cluster from . import ConnectionFilter, LockLock, NEOThreadedTest, with_cluster
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
...@@ -1463,5 +1463,97 @@ class Test(NEOThreadedTest): ...@@ -1463,5 +1463,97 @@ class Test(NEOThreadedTest):
value_list.append(r[x].value) value_list.append(r[x].value)
self.assertEqual(value_list, range(3)) self.assertEqual(value_list, range(3))
@with_cluster(replicas=1, partitions=3, storage_count=3)
def testMasterArbitratingVote(self, cluster):
"""
p\S 1 2 3
0 U U .
1 . U U
2 U . U
With the above setup, check when a client C1 fails to connect to S2
and another C2 fails to connect to S1.
For the first 2 scenarios:
- C1 first votes (the master accepts)
- C2 vote is delayed until C1 decides to finish or abort
"""
def delayAbort(conn, packet):
return isinstance(packet, Packets.AbortTransaction)
def noConnection(jar, storage):
return Patch(jar.db().storage.app.cp,
getConnForNode=lambda orig, node:
None if node.getUUID() == storage.uuid else orig(node))
def c1_vote(txn):
def vote(orig, *args):
result = orig(*args)
ll()
return result
with LockLock() as ll, Patch(cluster.master.tm, vote=vote):
commit2.start()
ll()
if c1_aborts:
raise Exception
pt = [{x.getUUID() for x in x}
for x in cluster.master.pt.partition_list]
cluster.storage_list.sort(key=lambda x:
(x.uuid not in pt[0], x.uuid in pt[1]))
pt = 'UU.|.UU|U.U'
self.assertPartitionTable(cluster, pt)
s1, s2, s3 = cluster.storage_list
t1, c1 = cluster.getTransaction()
with cluster.newClient(1) as db:
t2, c2 = cluster.getTransaction(db)
with noConnection(c1, s2), noConnection(c2, s1):
cluster.client.cp.connection_dict[s2.uuid].close()
self.tic()
for c1_aborts in 0, 1:
# 0: C1 finishes, C2 vote fails
# 1: C1 aborts, C2 finishes
#
# Although we try to modify the same oid, there's no
# conflict because each storage node sees a single
# and different transaction: vote to storages is done
# in parallel, and the master must be involved as an
# arbitrator, which ultimately rejects 1 of the 2
# transactions, preferably before the second phase of
# the commit.
t1.begin(); c1.root()._p_changed = 1
t2.begin(); c2.root()._p_changed = 1
commit2 = self.newPausedThread(t2.commit)
TransactionalResource(t1, 1, tpc_vote=c1_vote)
with ConnectionFilter() as f:
if not c1_aborts:
f.add(delayAbort)
f.delayAskFetchTransactions(lambda _:
f.discard(delayAbort))
try:
t1.commit()
self.assertFalse(c1_aborts)
except Exception:
self.assertTrue(c1_aborts)
try:
commit2.join()
self.assertTrue(c1_aborts)
except NEOStorageError:
self.assertFalse(c1_aborts)
self.tic()
self.assertPartitionTable(cluster,
'OU.|.UU|O.U' if c1_aborts else 'UO.|.OU|U.U')
self.tic()
self.assertPartitionTable(cluster, pt)
# S3 fails while C1 starts to finish
with ConnectionFilter() as f:
f.add(lambda conn, packet: conn.getUUID() == s3.uuid and
isinstance(packet, Packets.AcceptIdentification))
t1.begin(); c1.root()._p_changed = 1
TransactionalResource(t1, 0, tpc_finish=lambda *_:
cluster.master.nm.getByUUID(s3.uuid)
.getConnection().close())
self.assertRaises(NEOStorageError, t1.commit)
self.assertPartitionTable(cluster, 'UU.|.UO|U.O')
self.tic()
self.assertPartitionTable(cluster, pt)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment