Commit e5a0261c authored by Grégory Wisniewski's avatar Grégory Wisniewski

Request a TID only if not supplied by the ZODB.

Master transactions objects are instanciated during the finish phase only,
which means that any transaction known by the master is being committed.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2237 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent ffcdf197
...@@ -564,16 +564,14 @@ class Application(object): ...@@ -564,16 +564,14 @@ class Application(object):
return return
if self.local_var.txn is not None: if self.local_var.txn is not None:
raise NeoException, 'local_var is not clean in tpc_begin' raise NeoException, 'local_var is not clean in tpc_begin'
# ask the primary master to start a transaction, if no tid is supplied, # use the given TID or request a new one to the master
# the master will supply us one. Otherwise the requested tid will be self.local_var.tid = tid
# used if possible. if tid is None:
self.local_var.tid = None self._askPrimary(Packets.AskBeginTransaction())
self._askPrimary(Packets.AskBeginTransaction(tid)) if self.local_var.tid is None:
if self.local_var.tid is None: raise NEOStorageError('tpc_begin failed')
raise NEOStorageError('tpc_begin failed')
self.local_var.txn = transaction self.local_var.txn = transaction
@profiler_decorator @profiler_decorator
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
"""Store object.""" """Store object."""
......
...@@ -196,7 +196,7 @@ class EventHandler(object): ...@@ -196,7 +196,7 @@ class EventHandler(object):
def commitTransaction(self, conn, tid): def commitTransaction(self, conn, tid):
raise UnexpectedPacketError raise UnexpectedPacketError
def askBeginTransaction(self, conn, tid): def askBeginTransaction(self, conn):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerBeginTransaction(self, conn, tid): def answerBeginTransaction(self, conn, tid):
......
...@@ -48,12 +48,14 @@ class ClientServiceHandler(MasterHandler): ...@@ -48,12 +48,14 @@ class ClientServiceHandler(MasterHandler):
conn.answer(Packets.AnswerNodeInformation()) conn.answer(Packets.AnswerNodeInformation())
def abortTransaction(self, conn, tid): def abortTransaction(self, conn, tid):
self.app.tm.remove(tid) # nothing to remove.
pass
def askBeginTransaction(self, conn, tid): def askBeginTransaction(self, conn):
node = self.app.nm.getByUUID(conn.getUUID()) """
tid = self.app.tm.begin(node, tid) A client request a TID, nothing is kept about it until the finish.
conn.answer(Packets.AnswerBeginTransaction(tid)) """
conn.answer(Packets.AnswerBeginTransaction(self.app.tm.begin()))
def askNewOIDs(self, conn, num_oids): def askNewOIDs(self, conn, num_oids):
conn.answer(Packets.AnswerNewOIDs(self.app.tm.getNextOIDList(num_oids))) conn.answer(Packets.AnswerNewOIDs(self.app.tm.getNextOIDList(num_oids)))
...@@ -61,10 +63,6 @@ class ClientServiceHandler(MasterHandler): ...@@ -61,10 +63,6 @@ class ClientServiceHandler(MasterHandler):
def askFinishTransaction(self, conn, tid, oid_list): def askFinishTransaction(self, conn, tid, oid_list):
app = self.app app = self.app
# If the given transaction ID is later than the last TID, the peer
# is crazy.
if tid > self.app.tm.getLastTID():
raise ProtocolError('TID too big')
# Collect partitions related to this transaction. # Collect partitions related to this transaction.
getPartition = app.pt.getPartition getPartition = app.pt.getPartition
...@@ -91,5 +89,6 @@ class ClientServiceHandler(MasterHandler): ...@@ -91,5 +89,6 @@ class ClientServiceHandler(MasterHandler):
node.ask(p, timeout=60) node.ask(p, timeout=60)
used_uuid_set.add(node.getUUID()) used_uuid_set.add(node.getUUID())
app.tm.prepare(tid, oid_list, used_uuid_set, conn.getPeerId()) node = self.app.nm.getByUUID(conn.getUUID())
app.tm.prepare(node, tid, oid_list, used_uuid_set, conn.getPeerId())
...@@ -43,11 +43,8 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -43,11 +43,8 @@ class StorageServiceHandler(BaseServiceHandler):
self.app.outdateAndBroadcastPartition() self.app.outdateAndBroadcastPartition()
uuid = conn.getUUID() uuid = conn.getUUID()
for tid, transaction in self.app.tm.items(): for tid, transaction in self.app.tm.items():
# If this transaction was not "prepared" (see askFinishTransaction) # if a transaction is known, this means that it's being committed
# there is nothing to cleanup on it (it doesn't have the list of if transaction.forget(uuid):
# involved storage nodes yet). As such transaction would be detected
# as locked, we must also prevent _afterLock from being called.
if transaction.isPrepared() and transaction.forget(uuid):
self._afterLock(tid) self._afterLock(tid)
def askLastIDs(self, conn): def askLastIDs(self, conn):
......
...@@ -26,15 +26,16 @@ class Transaction(object): ...@@ -26,15 +26,16 @@ class Transaction(object):
A pending transaction A pending transaction
""" """
_prepared = False def __init__(self, node, tid, oid_list, uuid_list, msg_id):
"""
def __init__(self, node, tid): Prepare the transaction, set OIDs and UUIDs related to it
"""
self._node = node self._node = node
self._tid = tid self._tid = tid
self._oid_list = [] self._oid_list = oid_list
self._msg_id = None self._msg_id = msg_id
# uuid dict hold flag to known who has locked the transaction # uuid dict hold flag to known who has locked the transaction
self._uuid_dict = {} self._uuid_dict = dict.fromkeys(uuid_list, False)
self._birth = time() self._birth = time()
def __repr__(self): def __repr__(self):
...@@ -60,12 +61,6 @@ class Transaction(object): ...@@ -60,12 +61,6 @@ class Transaction(object):
""" """
return self._tid return self._tid
def isPrepared(self):
"""
"""
return self._prepared
def getMessageId(self): def getMessageId(self):
""" """
Returns the packet ID to use in the answer Returns the packet ID to use in the answer
...@@ -85,17 +80,6 @@ class Transaction(object): ...@@ -85,17 +80,6 @@ class Transaction(object):
return list(self._oid_list) return list(self._oid_list)
def prepare(self, oid_list, uuid_list, msg_id):
"""
Prepare the transaction, set OIDs and UUIDs related to it
"""
assert not self._oid_list
assert not self._uuid_dict
self._oid_list = oid_list
self._uuid_dict = dict.fromkeys(uuid_list, False)
self._msg_id = msg_id
self._prepared = True
def forget(self, uuid): def forget(self, uuid):
""" """
Given storage was lost while waiting for its lock, stop waiting Given storage was lost while waiting for its lock, stop waiting
...@@ -239,38 +223,25 @@ class TransactionManager(object): ...@@ -239,38 +223,25 @@ class TransactionManager(object):
""" """
return self._tid_dict.keys() return self._tid_dict.keys()
def begin(self, node, tid): def begin(self):
""" """
Begin a new transaction Generate a new TID
""" """
assert node is not None return self._nextTID()
if tid is not None and tid < self._last_tid:
logging.warn('Transaction began with a decreased TID: %s, ' \
'expected at least %s', tid, self._last_tid)
if tid is None:
# give a TID
tid = self._nextTID()
txn = Transaction(node, tid)
self._tid_dict[tid] = txn
self._node_dict.setdefault(node, {})[tid] = txn
self.setLastTID(tid)
return tid
def prepare(self, tid, oid_list, uuid_list, msg_id): def prepare(self, node, tid, oid_list, uuid_list, msg_id):
""" """
Prepare a transaction to be finished Prepare a transaction to be finished
""" """
assert tid in self._tid_dict, "Transaction not started" self.setLastTID(tid)
txn = self._tid_dict[tid] txn = Transaction(node, tid, oid_list, uuid_list, msg_id)
txn.prepare(oid_list, uuid_list, msg_id) self._tid_dict[tid] = txn
self._node_dict.setdefault(node, {})[tid] = txn
def remove(self, tid): def remove(self, tid):
""" """
Remove a transaction, commited or aborted Remove a transaction, commited or aborted
""" """
if tid not in self._tid_dict:
logging.warn('aborting transaction %s does not exist', dump(tid))
return
node = self._tid_dict[tid].getNode() node = self._tid_dict[tid].getNode()
# remove both mappings, node will be removed in abortFor # remove both mappings, node will be removed in abortFor
del self._tid_dict[tid] del self._tid_dict[tid]
......
...@@ -715,12 +715,6 @@ class AskBeginTransaction(Packet): ...@@ -715,12 +715,6 @@ class AskBeginTransaction(Packet):
""" """
Ask to begin a new transaction. C -> PM. Ask to begin a new transaction. C -> PM.
""" """
def _encode(self, tid):
return _encodeTID(tid)
def _decode(self, body):
(tid, ) = unpack('8s', body)
return (_decodeTID(tid), )
class AnswerBeginTransaction(Packet): class AnswerBeginTransaction(Packet):
""" """
......
...@@ -606,7 +606,6 @@ class ClientApplicationTests(NeoTestBase): ...@@ -606,7 +606,6 @@ class ClientApplicationTests(NeoTestBase):
# will check if there was just one call/packet : # will check if there was just one call/packet :
self.checkNotifyPacket(conn1, Packets.AbortTransaction) self.checkNotifyPacket(conn1, Packets.AbortTransaction)
self.checkNotifyPacket(conn2, Packets.AbortTransaction) self.checkNotifyPacket(conn2, Packets.AbortTransaction)
self.checkNotifyPacket(app.master_conn, Packets.AbortTransaction)
self.assertEquals(app.local_var.tid, None) self.assertEquals(app.local_var.tid, None)
self.assertEquals(app.local_var.txn, None) self.assertEquals(app.local_var.txn, None)
self.assertEquals(app.local_var.data_dict, {}) self.assertEquals(app.local_var.data_dict, {})
...@@ -672,7 +671,6 @@ class ClientApplicationTests(NeoTestBase): ...@@ -672,7 +671,6 @@ class ClientApplicationTests(NeoTestBase):
app.cp = ConnectionPool() app.cp = ConnectionPool()
# abort must be sent to storage 1 and 2 # abort must be sent to storage 1 and 2
app.tpc_abort(txn) app.tpc_abort(txn)
self.checkAbortTransaction(app.master_conn)
self.checkAbortTransaction(conn2) self.checkAbortTransaction(conn2)
self.checkAbortTransaction(conn3) self.checkAbortTransaction(conn3)
...@@ -1040,7 +1038,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -1040,7 +1038,7 @@ class ClientApplicationTests(NeoTestBase):
def _waitMessage_hook(app, conn, msg_id, handler=None): def _waitMessage_hook(app, conn, msg_id, handler=None):
self.test_ok = True self.test_ok = True
_waitMessage_old = Application._waitMessage _waitMessage_old = Application._waitMessage
packet = Packets.AskBeginTransaction(None) packet = Packets.AskBeginTransaction()
packet.setId(0) packet.setId(0)
Application._waitMessage = _waitMessage_hook Application._waitMessage = _waitMessage_hook
try: try:
...@@ -1066,7 +1064,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -1066,7 +1064,7 @@ class ClientApplicationTests(NeoTestBase):
self.test_ok = True self.test_ok = True
_waitMessage_old = Application._waitMessage _waitMessage_old = Application._waitMessage
Application._waitMessage = _waitMessage_hook Application._waitMessage = _waitMessage_hook
packet = Packets.AskBeginTransaction(None) packet = Packets.AskBeginTransaction()
packet.setId(0) packet.setId(0)
try: try:
app._askPrimary(packet) app._askPrimary(packet)
......
...@@ -72,11 +72,8 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -72,11 +72,8 @@ class MasterClientHandlerTests(NeoTestBase):
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
service.askBeginTransaction(conn, None) service.askBeginTransaction(conn)
self.assertTrue(ltid < self.app.tm.getLastTID()) self.assertTrue(ltid < self.app.tm.getLastTID())
self.assertEqual(len(self.app.tm.getPendingList()), 1)
tid = self.app.tm.getPendingList()[0]
self.assertEquals(tid, self.app.tm.getLastTID())
def test_08_askNewOIDs(self): def test_08_askNewOIDs(self):
service = self.service service = self.service
...@@ -97,18 +94,6 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -97,18 +94,6 @@ class MasterClientHandlerTests(NeoTestBase):
def test_09_askFinishTransaction(self): def test_09_askFinishTransaction(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
# give an older tid than the PMN known, must abort
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address)
oid_list = []
upper, lower = unpack('!LL', self.app.tm.getLastTID())
new_tid = pack('!LL', upper, lower + 10)
self.checkProtocolErrorRaised(service.askFinishTransaction, conn,
new_tid, oid_list)
old_node = self.app.nm.getByUUID(uuid)
self.app.nm.remove(old_node)
self.app.pt.dropNode(old_node)
# do the right job # do the right job
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
storage_uuid = self.identifyToMasterNode() storage_uuid = self.identifyToMasterNode()
...@@ -119,7 +104,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -119,7 +104,7 @@ class MasterClientHandlerTests(NeoTestBase):
'getPartition': 0, 'getPartition': 0,
'getCellList': [Mock({'getUUID': storage_uuid})], 'getCellList': [Mock({'getUUID': storage_uuid})],
}) })
service.askBeginTransaction(conn, None) service.askBeginTransaction(conn)
oid_list = [] oid_list = []
tid = self.app.tm.getLastTID() tid = self.app.tm.getLastTID()
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
...@@ -169,17 +154,12 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -169,17 +154,12 @@ class MasterClientHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
self.service.askBeginTransaction(conn, None)
self.service.askBeginTransaction(conn, None)
self.service.askBeginTransaction(conn, None)
self.assertEquals(self.app.nm.getByUUID(client_uuid).getState(), self.assertEquals(self.app.nm.getByUUID(client_uuid).getState(),
NodeStates.RUNNING) NodeStates.RUNNING)
self.assertEquals(len(self.app.tm.getPendingList()), 3)
method(conn) method(conn)
# node must be have been remove, and no more transaction must remains # node must be have been remove, and no more transaction must remains
self.assertEquals(self.app.nm.getByUUID(client_uuid), None) self.assertEquals(self.app.nm.getByUUID(client_uuid), None)
self.assertEquals(lptid, self.app.pt.getID()) self.assertEquals(lptid, self.app.pt.getID())
self.assertFalse(self.app.tm.hasPending())
def test_15_peerBroken(self): def test_15_peerBroken(self):
self.__testWithMethod(self.service.peerBroken, NodeStates.BROKEN) self.__testWithMethod(self.service.peerBroken, NodeStates.BROKEN)
......
...@@ -104,8 +104,8 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -104,8 +104,8 @@ class MasterStorageHandlerTests(NeoTestBase):
oid_list = self.getOID(), self.getOID() oid_list = self.getOID(), self.getOID()
msg_id = 1 msg_id = 1
# register a transaction # register a transaction
tid = self.app.tm.begin(client_1, None) tid = self.app.tm.begin()
self.app.tm.prepare(tid, oid_list, uuid_list, msg_id) self.app.tm.prepare(client_1, tid, oid_list, uuid_list, msg_id)
self.assertTrue(tid in self.app.tm) self.assertTrue(tid in self.app.tm)
# the first storage acknowledge the lock # the first storage acknowledge the lock
self.service.answerInformationLocked(storage_conn_1, tid) self.service.answerInformationLocked(storage_conn_1, tid)
...@@ -148,9 +148,13 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -148,9 +148,13 @@ class MasterStorageHandlerTests(NeoTestBase):
# create some transaction # create some transaction
node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port=self.client_port) port=self.client_port)
self.client_handler.askBeginTransaction(conn, None) def create_transaction(index):
self.client_handler.askBeginTransaction(conn, None) tid = self.getNextTID()
self.client_handler.askBeginTransaction(conn, None) oid_list = [self.getOID(index)]
self.app.tm.prepare(node, tid, oid_list, [node.getUUID()], index)
create_transaction(1)
create_transaction(2)
create_transaction(3)
conn = self.getFakeConnection(node.getUUID(), self.storage_address) conn = self.getFakeConnection(node.getUUID(), self.storage_address)
service.askUnfinishedTransactions(conn) service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn) packet = self.checkAnswerUnfinishedTransactions(conn)
...@@ -214,17 +218,14 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -214,17 +218,14 @@ class MasterStorageHandlerTests(NeoTestBase):
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
tm.begin(client1, tid1) tm.prepare(client1, tid1, oid_list, [node1.getUUID(), node2.getUUID()], msg_id_1)
tm.prepare(tid1, oid_list, [node1.getUUID(), node2.getUUID()], msg_id_1)
tm.lock(tid1, node2.getUUID()) tm.lock(tid1, node2.getUUID())
# Transaction 2: 2 storage nodes involved, one will die # Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2 msg_id_2 = 2
tm.begin(client2, tid2) tm.prepare(client2, tid2, oid_list, [node1.getUUID(), node2.getUUID()], msg_id_2)
tm.prepare(tid2, oid_list, [node1.getUUID(), node2.getUUID()], msg_id_2)
# Transaction 3: 1 storage node involved, which won't die # Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3 msg_id_3 = 3
tm.begin(client3, tid3) tm.prepare(client3, tid3, oid_list, [node2.getUUID(), ], msg_id_3)
tm.prepare(tid3, oid_list, [node2.getUUID(), ], msg_id_3)
# Assert initial state # Assert initial state
self.checkNoPacketSent(cconn1) self.checkNoPacketSent(cconn1)
......
...@@ -37,13 +37,13 @@ class testTransactionManager(NeoTestBase): ...@@ -37,13 +37,13 @@ class testTransactionManager(NeoTestBase):
# test data # test data
node = Mock({'__repr__': 'Node'}) node = Mock({'__repr__': 'Node'})
tid = self.makeTID(1) tid = self.makeTID(1)
oid_list = (oid1, oid2) = (self.makeOID(1), self.makeOID(2)) oid_list = (oid1, oid2) = [self.makeOID(1), self.makeOID(2)]
uuid_list = (uuid1, uuid2) = (self.makeUUID(1), self.makeUUID(2)) uuid_list = (uuid1, uuid2) = [self.makeUUID(1), self.makeUUID(2)]
msg_id = 1 msg_id = 1
# create transaction object # create transaction object
txn = Transaction(node, tid) txn = Transaction(node, tid, oid_list, uuid_list, msg_id)
self.assertEqual(txn.getUUIDList(), []) self.assertEqual(txn.getUUIDList(), uuid_list)
txn.prepare(oid_list, uuid_list, msg_id) self.assertEqual(txn.getOIDList(), oid_list)
# lock nodes one by one # lock nodes one by one
self.assertFalse(txn.lock(uuid1)) self.assertFalse(txn.lock(uuid1))
self.assertTrue(txn.lock(uuid2)) self.assertTrue(txn.lock(uuid2))
...@@ -61,14 +61,15 @@ class testTransactionManager(NeoTestBase): ...@@ -61,14 +61,15 @@ class testTransactionManager(NeoTestBase):
self.assertFalse(txnman.hasPending()) self.assertFalse(txnman.hasPending())
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.getPendingList(), [])
# begin the transaction # begin the transaction
tid = txnman.begin(node, None) tid = txnman.begin()
self.assertTrue(tid is not None) self.assertTrue(tid is not None)
self.assertFalse(txnman.hasPending())
self.assertEqual(len(txnman.getPendingList()), 0)
# prepare the transaction
txnman.prepare(node, tid, oid_list, uuid_list, msg_id)
self.assertTrue(txnman.hasPending()) self.assertTrue(txnman.hasPending())
self.assertEqual(len(txnman.getPendingList()), 1)
self.assertEqual(txnman.getPendingList()[0], tid) self.assertEqual(txnman.getPendingList()[0], tid)
self.assertEqual(txnman[tid].getTID(), tid) self.assertEqual(txnman[tid].getTID(), tid)
# prepare the transaction
txnman.prepare(tid, oid_list, uuid_list, msg_id)
txn = txnman[tid] txn = txnman[tid]
self.assertEqual(txn.getUUIDList(), list(uuid_list)) self.assertEqual(txn.getUUIDList(), list(uuid_list))
self.assertEqual(txn.getOIDList(), list(oid_list)) self.assertEqual(txn.getOIDList(), list(oid_list))
...@@ -82,12 +83,19 @@ class testTransactionManager(NeoTestBase): ...@@ -82,12 +83,19 @@ class testTransactionManager(NeoTestBase):
def testAbortFor(self): def testAbortFor(self):
node1 = Mock({'__hash__': 1}) node1 = Mock({'__hash__': 1})
node2 = Mock({'__hash__': 2}) node2 = Mock({'__hash__': 2})
oid_list = [self.makeOID(1), ]
storage_1_uuid = self.makeUUID(1)
storage_2_uuid = self.makeUUID(2)
txnman = TransactionManager() txnman = TransactionManager()
# register 4 transactions made by two nodes # register 4 transactions made by two nodes
tid11 = txnman.begin(node1, None) tid11 = txnman.begin()
tid12 = txnman.begin(node1, None) txnman.prepare(node1, tid11, oid_list, [storage_1_uuid], 1)
tid21 = txnman.begin(node2, None) tid12 = txnman.begin()
tid22 = txnman.begin(node2, None) txnman.prepare(node1, tid12, oid_list, [storage_1_uuid], 2)
tid21 = txnman.begin()
txnman.prepare(node2, tid21, oid_list, [storage_2_uuid], 3)
tid22 = txnman.begin()
txnman.prepare(node2, tid22, oid_list, [storage_2_uuid], 4)
self.assertTrue(tid11 < tid12 < tid21 < tid22) self.assertTrue(tid11 < tid12 < tid21 < tid22)
self.assertEqual(len(txnman.getPendingList()), 4) self.assertEqual(len(txnman.getPendingList()), 4)
# abort transactions of one node # abort transactions of one node
...@@ -120,7 +128,7 @@ class testTransactionManager(NeoTestBase): ...@@ -120,7 +128,7 @@ class testTransactionManager(NeoTestBase):
self.assertEqual(txnman.getLastTID(), None) self.assertEqual(txnman.getLastTID(), None)
# first transaction # first transaction
node1 = Mock({'__hash__': 1}) node1 = Mock({'__hash__': 1})
tid1 = txnman.begin(node1, None) tid1 = txnman.begin()
self.assertTrue(tid1 is not None) self.assertTrue(tid1 is not None)
self.assertEqual(txnman.getLastTID(), tid1) self.assertEqual(txnman.getLastTID(), tid1)
# set a new last TID # set a new last TID
...@@ -130,7 +138,7 @@ class testTransactionManager(NeoTestBase): ...@@ -130,7 +138,7 @@ class testTransactionManager(NeoTestBase):
self.assertTrue(ntid > tid1) self.assertTrue(ntid > tid1)
# new trancation # new trancation
node2 = Mock({'__hash__': 2}) node2 = Mock({'__hash__': 2})
tid2 = txnman.begin(node2, None) tid2 = txnman.begin()
self.assertTrue(tid2 is not None) self.assertTrue(tid2 is not None)
self.assertTrue(tid2 > ntid > tid1) self.assertTrue(tid2 > ntid > tid1)
...@@ -140,26 +148,23 @@ class testTransactionManager(NeoTestBase): ...@@ -140,26 +148,23 @@ class testTransactionManager(NeoTestBase):
client3 = Mock({'__hash__': 3}) client3 = Mock({'__hash__': 3})
storage_1_uuid = self.makeUUID(1) storage_1_uuid = self.makeUUID(1)
storage_2_uuid = self.makeUUID(2) storage_2_uuid = self.makeUUID(2)
tid1 = self.makeTID(1)
tid2 = self.makeTID(2)
tid3 = self.makeTID(3)
oid_list = [self.makeOID(1), ] oid_list = [self.makeOID(1), ]
tm = TransactionManager() tm = TransactionManager()
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
tm.begin(client1, tid1) tid1 = tm.begin()
tm.prepare(tid1, oid_list, [storage_1_uuid, storage_2_uuid], msg_id_1) tm.prepare(client1, tid1, oid_list, [storage_1_uuid, storage_2_uuid], msg_id_1)
tm.lock(tid1, storage_2_uuid) tm.lock(tid1, storage_2_uuid)
# Transaction 2: 2 storage nodes involved, one will die # Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2 msg_id_2 = 2
tm.begin(client2, tid2) tid2 = tm.begin()
tm.prepare(tid2, oid_list, [storage_1_uuid, storage_2_uuid], msg_id_2) tm.prepare(client2, tid2, oid_list, [storage_1_uuid, storage_2_uuid], msg_id_2)
# Transaction 3: 1 storage node involved, which won't die # Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3 msg_id_3 = 3
tm.begin(client3, tid3) tid3 = tm.begin()
tm.prepare(tid3, oid_list, [storage_2_uuid, ], msg_id_3) tm.prepare(client3, tid3, oid_list, [storage_2_uuid, ], msg_id_3)
t1 = tm[tid1] t1 = tm[tid1]
t2 = tm[tid2] t2 = tm[tid2]
......
...@@ -852,7 +852,7 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -852,7 +852,7 @@ class HandlerSwitcherTests(NeoTestBase):
return packet return packet
def _makeRequest(self, msg_id): def _makeRequest(self, msg_id):
packet = Packets.AskBeginTransaction(self.getNextTID()) packet = Packets.AskBeginTransaction()
packet.setId(msg_id) packet.setId(msg_id)
return packet return packet
......
...@@ -238,14 +238,8 @@ class ProtocolTests(NeoTestBase): ...@@ -238,14 +238,8 @@ class ProtocolTests(NeoTestBase):
def test_32_askBeginTransaction(self): def test_32_askBeginTransaction(self):
# try with an invalid TID, None must be returned p = Packets.AskBeginTransaction()
tid = INVALID_TID self.assertEqual(p.decode(), ())
p = Packets.AskBeginTransaction(tid)
self.assertEqual(p.decode(), (None, ))
# and with another TID
tid = '\1' * 8
p = Packets.AskBeginTransaction(tid)
self.assertEqual(p.decode(), (tid, ))
def test_33_answerBeginTransaction(self): def test_33_answerBeginTransaction(self):
tid = self.getNextTID() tid = self.getNextTID()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment