Commit 69975d76 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Make replication work with temporary TIDs.

- Storage nodes start to replicate a partition when all transactions that were
pending when the oudated partition was added are committed.
- Transactions are registered by the master from the tpc_begin step.
Signed-off-by: default avatarGrégory <gregory@nexedi.com>

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2649 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 5410225c
...@@ -184,7 +184,7 @@ class EventHandler(object): ...@@ -184,7 +184,7 @@ class EventHandler(object):
def askUnfinishedTransactions(self, conn): def askUnfinishedTransactions(self, conn):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerUnfinishedTransactions(self, conn, tid_list): def answerUnfinishedTransactions(self, conn, max_tid, ttid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askObjectPresent(self, conn, oid, tid): def askObjectPresent(self, conn, oid, tid):
...@@ -229,6 +229,9 @@ class EventHandler(object): ...@@ -229,6 +229,9 @@ class EventHandler(object):
def notifyUnlockInformation(self, conn, ttid): def notifyUnlockInformation(self, conn, ttid):
raise UnexpectedPacketError raise UnexpectedPacketError
def notifyTransactionFinished(self, conn, ttid, max_tid):
raise UnexpectedPacketError
def askStoreObject(self, conn, oid, serial, def askStoreObject(self, conn, oid, serial,
compression, checksum, data, data_serial, ttid, unlock): compression, checksum, data, data_serial, ttid, unlock):
raise UnexpectedPacketError raise UnexpectedPacketError
...@@ -506,6 +509,7 @@ class EventHandler(object): ...@@ -506,6 +509,7 @@ class EventHandler(object):
d[Packets.AnswerLastTransaction] = self.answerLastTransaction d[Packets.AnswerLastTransaction] = self.answerLastTransaction
d[Packets.AskCheckCurrentSerial] = self.askCheckCurrentSerial d[Packets.AskCheckCurrentSerial] = self.askCheckCurrentSerial
d[Packets.AnswerCheckCurrentSerial] = self.answerCheckCurrentSerial d[Packets.AnswerCheckCurrentSerial] = self.answerCheckCurrentSerial
d[Packets.NotifyTransactionFinished] = self.notifyTransactionFinished
return d return d
......
...@@ -661,18 +661,18 @@ class AnswerUnfinishedTransactions(Packet): ...@@ -661,18 +661,18 @@ class AnswerUnfinishedTransactions(Packet):
""" """
Answer unfinished transactions S -> PM. Answer unfinished transactions S -> PM.
""" """
_header_format = '!L' _header_format = '!8sL'
_list_entry_format = '8s' _list_entry_format = '8s'
_list_entry_len = calcsize(_list_entry_format) _list_entry_len = calcsize(_list_entry_format)
def _encode(self, tid_list): def _encode(self, max_tid, tid_list):
body = [pack(self._header_format, len(tid_list))] body = [pack(self._header_format, max_tid, len(tid_list))]
body.extend(tid_list) body.extend(tid_list)
return ''.join(body) return ''.join(body)
def _decode(self, body): def _decode(self, body):
offset = self._header_len offset = self._header_len
(n,) = unpack(self._header_format, body[:offset]) (max_tid, n) = unpack(self._header_format, body[:offset])
tid_list = [] tid_list = []
list_entry_format = self._list_entry_format list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len list_entry_len = self._list_entry_len
...@@ -681,7 +681,7 @@ class AnswerUnfinishedTransactions(Packet): ...@@ -681,7 +681,7 @@ class AnswerUnfinishedTransactions(Packet):
tid = unpack(list_entry_format, body[offset:next_offset])[0] tid = unpack(list_entry_format, body[offset:next_offset])[0]
offset = next_offset offset = next_offset
tid_list.append(tid) tid_list.append(tid)
return (tid_list,) return (max_tid, tid_list)
class AskObjectPresent(Packet): class AskObjectPresent(Packet):
""" """
...@@ -784,6 +784,18 @@ class AskFinishTransaction(Packet): ...@@ -784,6 +784,18 @@ class AskFinishTransaction(Packet):
oid_list.append(oid) oid_list.append(oid)
return (tid, oid_list) return (tid, oid_list)
class NotifyTransactionFinished(Packet):
"""
Notify that a transaction blocking a replication is now finished
M -> S
"""
def _encode(self, ttid, max_tid):
return _encodeTID(ttid) + _encodeTID(max_tid)
def _decode(self, body):
(ttid, max_tid) = unpack('8s8s', body)
return (ttid, max_tid)
class AnswerTransactionFinished(Packet): class AnswerTransactionFinished(Packet):
""" """
Answer when a transaction is finished. PM -> C. Answer when a transaction is finished. PM -> C.
...@@ -2044,6 +2056,10 @@ class PacketRegistry(dict): ...@@ -2044,6 +2056,10 @@ class PacketRegistry(dict):
AskCheckCurrentSerial, AskCheckCurrentSerial,
AnswerCheckCurrentSerial, AnswerCheckCurrentSerial,
) )
NotifyTransactionFinished = register(
0x003E,
NotifyTransactionFinished,
)
# build a "singleton" # build a "singleton"
Packets = PacketRegistry() Packets = PacketRegistry()
......
...@@ -585,6 +585,13 @@ class Application(object): ...@@ -585,6 +585,13 @@ class Application(object):
for storage_uuid in txn.getUUIDList(): for storage_uuid in txn.getUUIDList():
getByUUID(storage_uuid).getConnection().notify(notify_unlock) getByUUID(storage_uuid).getConnection().notify(notify_unlock)
# Notify storage that have replications blocked by this transaction
notify_finished = Packets.NotifyTransactionFinished(ttid, tid)
for storage_uuid in txn.getNotificationUUIDList():
node = getByUUID(storage_uuid)
if node is not None and node.isConnected():
node.getConnection().notify(notify_finished)
# remove transaction from manager # remove transaction from manager
self.tm.remove(transaction_node.getUUID(), ttid) self.tm.remove(transaction_node.getUUID(), ttid)
self.setLastTransaction(tid) self.setLastTransaction(tid)
......
...@@ -51,8 +51,9 @@ class ClientServiceHandler(MasterHandler): ...@@ -51,8 +51,9 @@ class ClientServiceHandler(MasterHandler):
""" """
A client request a TID, nothing is kept about it until the finish. A client request a TID, nothing is kept about it until the finish.
""" """
conn.answer(Packets.AnswerBeginTransaction(self.app.tm.begin( app = self.app
conn.getUUID(), tid))) node = app.nm.getByUUID(conn.getUUID())
conn.answer(Packets.AnswerBeginTransaction(app.tm.begin(node, tid)))
def askNewOIDs(self, conn, num_oids): def askNewOIDs(self, conn, num_oids):
app = self.app app = self.app
...@@ -84,9 +85,8 @@ class ClientServiceHandler(MasterHandler): ...@@ -84,9 +85,8 @@ class ClientServiceHandler(MasterHandler):
usable_uuid_set = set((x.getUUID() for x in identified_node_list)) usable_uuid_set = set((x.getUUID() for x in identified_node_list))
partitions = app.pt.getPartitions() partitions = app.pt.getPartitions()
peer_id = conn.getPeerId() peer_id = conn.getPeerId()
node = app.nm.getByUUID(conn.getUUID()) tid = app.tm.prepare(ttid, partitions, oid_list, usable_uuid_set,
tid = app.tm.prepare(node, ttid, partitions, oid_list, peer_id)
usable_uuid_set, peer_id)
# check if greater and foreign OID was stored # check if greater and foreign OID was stored
if app.tm.updateLastOID(oid_list): if app.tm.updateLastOID(oid_list):
......
...@@ -59,7 +59,10 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -59,7 +59,10 @@ class StorageServiceHandler(BaseServiceHandler):
conn.answer(Packets.AnswerLastIDs(loid, ltid, app.pt.getID())) conn.answer(Packets.AnswerLastIDs(loid, ltid, app.pt.getID()))
def askUnfinishedTransactions(self, conn): def askUnfinishedTransactions(self, conn):
p = Packets.AnswerUnfinishedTransactions(self.app.tm.getPendingList()) tm = self.app.tm
pending_list = tm.registerForNotification(conn.getUUID())
last_tid = tm.getLastTID()
p = Packets.AnswerUnfinishedTransactions(last_tid, pending_list)
conn.answer(p) conn.answer(p)
def answerInformationLocked(self, conn, ttid): def answerInformationLocked(self, conn, ttid):
......
...@@ -91,29 +91,31 @@ class Transaction(object): ...@@ -91,29 +91,31 @@ class Transaction(object):
""" """
A pending transaction A pending transaction
""" """
_tid = None
_msg_id = None
_oid_list = None
_prepared = False
# uuid dict hold flag to known who has locked the transaction
_uuid_set = None
_lock_wait_uuid_set = None
def __init__(self, node, ttid, tid, oid_list, uuid_list, msg_id): def __init__(self, node, ttid):
""" """
Prepare the transaction, set OIDs and UUIDs related to it Prepare the transaction, set OIDs and UUIDs related to it
""" """
self._node = node self._node = node
self._ttid = ttid self._ttid = ttid
self._tid = tid
self._oid_list = oid_list
self._msg_id = msg_id
# uuid dict hold flag to known who has locked the transaction
self._uuid_set = set(uuid_list)
self._lock_wait_uuid_set = set(uuid_list)
self._birth = time() self._birth = time()
self._prepared = False # store storage uuids that must be notified at commit
self._notification_set = set()
def __repr__(self): def __repr__(self):
return "<%s(client=%r, tid=%r, oids=%r, storages=%r, age=%.2fs) at %x>" % ( return "<%s(client=%r, tid=%r, oids=%r, storages=%r, age=%.2fs) at %x>" % (
self.__class__.__name__, self.__class__.__name__,
self._node, self._node,
dump(self._tid), dump(self._tid),
[dump(x) for x in self._oid_list], [dump(x) for x in self._oid_list or ()],
[dump(x) for x in self._uuid_set], [dump(x) for x in self._uuid_set or ()],
time() - self._birth, time() - self._birth,
id(self), id(self),
) )
...@@ -161,6 +163,19 @@ class Transaction(object): ...@@ -161,6 +163,19 @@ class Transaction(object):
""" """
return self._prepared return self._prepared
def registerForNotification(self, uuid):
"""
Register a storage node that requires a notification at commit
"""
self._notification_set.add(uuid)
def getNotificationUUIDList(self):
"""
Returns the list of storage waiting for the transaction to be
finished
"""
return list(self._notification_set)
def prepare(self, tid, oid_list, uuid_list, msg_id): def prepare(self, tid, oid_list, uuid_list, msg_id):
self._tid = tid self._tid = tid
...@@ -332,31 +347,42 @@ class TransactionManager(object): ...@@ -332,31 +347,42 @@ class TransactionManager(object):
""" """
return bool(self._ttid_dict) return bool(self._ttid_dict)
def getPendingList(self): def registerForNotification(self, uuid):
""" """
Return the list of pending transaction IDs Return the list of pending transaction IDs
""" """
return [txn.getTID() for txn in self._ttid_dict.values()] # remember that this node must be notified when pending transactions
# will be finished
for txn in self._ttid_dict.itervalues():
txn.registerForNotification(uuid)
return set(self._ttid_dict.keys())
def begin(self, uuid, tid=None): def begin(self, node, tid=None):
""" """
Generate a new TID Generate a new TID
""" """
if tid is None: if tid is None:
# No TID requested, generate a temporary one # No TID requested, generate a temporary one
tid = self.getTTID() ttid = self.getTTID()
else: else:
# Use of specific TID requested, queue it immediately and update # Use of specific TID requested, queue it immediately and update
# last TID. # last TID.
self._queue.append((uuid, tid)) self._queue.append((node.getUUID(), tid))
self.setLastTID(tid) self.setLastTID(tid)
return tid ttid = tid
txn = Transaction(node, ttid)
self._ttid_dict[ttid] = txn
self._node_dict.setdefault(node, {})[ttid] = txn
neo.lib.logging.debug('Begin %s for %s', txn, node)
return ttid
def prepare(self, node, ttid, divisor, oid_list, uuid_list, msg_id): def prepare(self, ttid, divisor, oid_list, uuid_list, msg_id):
""" """
Prepare a transaction to be finished Prepare a transaction to be finished
""" """
# XXX: not efficient but the list should be often small # XXX: not efficient but the list should be often small
txn = self._ttid_dict[ttid]
node = txn.getNode()
for _, tid in self._queue: for _, tid in self._queue:
if ttid == tid: if ttid == tid:
break break
...@@ -365,9 +391,7 @@ class TransactionManager(object): ...@@ -365,9 +391,7 @@ class TransactionManager(object):
self._queue.append((node.getUUID(), ttid)) self._queue.append((node.getUUID(), ttid))
neo.lib.logging.debug('Finish TXN %s for %s (was %s)', neo.lib.logging.debug('Finish TXN %s for %s (was %s)',
dump(tid), node, dump(ttid)) dump(tid), node, dump(ttid))
txn = Transaction(node, ttid, tid, oid_list, uuid_list, msg_id) txn.prepare(tid, oid_list, uuid_list, msg_id)
self._ttid_dict[ttid] = txn
self._node_dict.setdefault(node, {})[ttid] = txn
return tid return tid
def remove(self, uuid, ttid): def remove(self, uuid, ttid):
...@@ -383,7 +407,6 @@ class TransactionManager(object): ...@@ -383,7 +407,6 @@ class TransactionManager(object):
ttid_dict = self._ttid_dict ttid_dict = self._ttid_dict
if ttid in ttid_dict: if ttid in ttid_dict:
txn = ttid_dict[ttid] txn = ttid_dict[ttid]
tid = txn.getTID()
node = txn.getNode() node = txn.getNode()
# ...and tried to finish # ...and tried to finish
del ttid_dict[ttid] del ttid_dict[ttid]
......
...@@ -195,7 +195,7 @@ class VerificationManager(BaseServiceHandler): ...@@ -195,7 +195,7 @@ class VerificationManager(BaseServiceHandler):
# approved during recovery, there is no need to check them here. # approved during recovery, there is no need to check them here.
pass pass
def answerUnfinishedTransactions(self, conn, tid_list): def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
uuid = conn.getUUID() uuid = conn.getUUID()
neo.lib.logging.info('got unfinished transactions %s from %r', neo.lib.logging.info('got unfinished transactions %s from %r',
[dump(tid) for tid in tid_list], conn) [dump(tid) for tid in tid_list], conn)
......
...@@ -27,8 +27,11 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -27,8 +27,11 @@ class MasterOperationHandler(BaseMasterHandler):
def answerLastIDs(self, conn, loid, ltid, lptid): def answerLastIDs(self, conn, loid, ltid, lptid):
self.app.replicator.setCriticalTID(ltid) self.app.replicator.setCriticalTID(ltid)
def answerUnfinishedTransactions(self, conn, tid_list): def answerUnfinishedTransactions(self, conn, max_tid, ttid_list):
self.app.replicator.setUnfinishedTIDList(tid_list) self.app.replicator.setUnfinishedTIDList(max_tid, ttid_list)
def notifyTransactionFinished(self, conn, ttid, max_tid):
self.app.replicator.transactionFinished(ttid, max_tid)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import neo import neo
from neo.storage.handlers import BaseMasterHandler from neo.storage.handlers import BaseMasterHandler
from neo.lib.protocol import Packets, Errors, ProtocolError from neo.lib.protocol import Packets, Errors, ProtocolError, INVALID_TID
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.exception import OperationFailure from neo.lib.exception import OperationFailure
...@@ -62,7 +62,7 @@ class VerificationHandler(BaseMasterHandler): ...@@ -62,7 +62,7 @@ class VerificationHandler(BaseMasterHandler):
def askUnfinishedTransactions(self, conn): def askUnfinishedTransactions(self, conn):
tid_list = self.app.dm.getUnfinishedTIDList() tid_list = self.app.dm.getUnfinishedTIDList()
conn.answer(Packets.AnswerUnfinishedTransactions(tid_list)) conn.answer(Packets.AnswerUnfinishedTransactions(INVALID_TID, tid_list))
def askTransactionInformation(self, conn, tid): def askTransactionInformation(self, conn, tid):
app = self.app app = self.app
......
...@@ -26,22 +26,26 @@ from neo.lib.util import dump ...@@ -26,22 +26,26 @@ from neo.lib.util import dump
class Partition(object): class Partition(object):
"""This class abstracts the state of a partition.""" """This class abstracts the state of a partition."""
def __init__(self, offset, tid): def __init__(self, offset, max_tid, ttid_list):
self.offset = offset self._offset = offset
if tid is None: self._pending_ttid_list = ttid_list
tid = ZERO_TID # pending upper bound
self.tid = tid self._critical_tid = max_tid
def getOffset(self): def getOffset(self):
return self.offset return self._offset
def getCriticalTID(self): def getCriticalTID(self):
return self.tid return self._critical_tid
def safe(self, min_pending_tid): def transactionFinished(self, ttid, max_tid):
tid = self.tid self._pending_ttid_list.remove(ttid)
return tid is not None and ( assert max_tid is not None
min_pending_tid is None or tid < min_pending_tid) # final upper bound
self._critical_tid = max_tid
def safe(self):
return not self._pending_ttid_list
class Task(object): class Task(object):
""" """
...@@ -115,23 +119,18 @@ class Replicator(object): ...@@ -115,23 +119,18 @@ class Replicator(object):
I ask only non-existing data. """ I ask only non-existing data. """
# new_partition_set # new_partition_set
# outdated partitions for which no critical tid was asked to primary # outdated partitions for which no pending transactions was asked to
# master yet # primary master yet
# critical_tid_list
# outdated partitions for which a critical tid was asked to primary
# master, but not answered so far
# partition_dict # partition_dict
# outdated partitions (with or without a critical tid - if without, it # outdated partitions with pending transaction and temporary critical
# was asked to primary master) # tid
# current_partition # current_partition
# partition being currently synchronised # partition being currently synchronised
# current_connection # current_connection
# connection to a storage node we are replicating from # connection to a storage node we are replicating from
# waiting_for_unfinished_tids # waiting_for_unfinished_tids
# unfinished_tid_list has been asked to primary master node, but it # unfinished tids have been asked to primary master node, but it
# didn't answer yet. # didn't answer yet.
# unfinished_tid_list
# The list of unfinished TIDs known by master node.
# replication_done # replication_done
# False if we know there is something to replicate. # False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if # True when current_partition is replicated, or we don't know yet if
...@@ -140,13 +139,11 @@ class Replicator(object): ...@@ -140,13 +139,11 @@ class Replicator(object):
current_partition = None current_partition = None
current_connection = None current_connection = None
waiting_for_unfinished_tids = False waiting_for_unfinished_tids = False
unfinished_tid_list = None
replication_done = True replication_done = True
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
self.new_partition_set = set() self.new_partition_set = set()
self.critical_tid_list = []
self.partition_dict = {} self.partition_dict = {}
self.task_list = [] self.task_list = []
self.task_dict = {} self.task_dict = {}
...@@ -156,7 +153,6 @@ class Replicator(object): ...@@ -156,7 +153,6 @@ class Replicator(object):
When connection to primary master is lost, stop waiting for unfinished When connection to primary master is lost, stop waiting for unfinished
transactions. transactions.
""" """
self.critical_tid_list = []
self.waiting_for_unfinished_tids = False self.waiting_for_unfinished_tids = False
def storageLost(self): def storageLost(self):
...@@ -182,13 +178,11 @@ class Replicator(object): ...@@ -182,13 +178,11 @@ class Replicator(object):
self.task_dict = {} self.task_dict = {}
self.current_partition = None self.current_partition = None
self.current_connection = None self.current_connection = None
self.unfinished_tid_list = None
self.replication_done = True self.replication_done = True
def pending(self): def pending(self):
"""Return whether there is any pending partition.""" """Return whether there is any pending partition."""
return len(self.partition_dict) or len(self.new_partition_set) \ return len(self.partition_dict) or len(self.new_partition_set)
or self.critical_tid_list
def getCurrentOffset(self): def getCurrentOffset(self):
assert self.current_partition is not None assert self.current_partition is not None
...@@ -205,25 +199,21 @@ class Replicator(object): ...@@ -205,25 +199,21 @@ class Replicator(object):
def isCurrentConnection(self, conn): def isCurrentConnection(self, conn):
return self.current_connection is conn return self.current_connection is conn
def setCriticalTID(self, tid): def setUnfinishedTIDList(self, max_tid, ttid_list):
"""This is a callback from MasterOperationHandler."""
neo.lib.logging.debug('setting critical TID %s to %s', dump(tid),
', '.join([str(p) for p in self.critical_tid_list]))
for offset in self.critical_tid_list:
self.partition_dict[offset] = Partition(offset, tid)
self.critical_tid_list = []
def _askCriticalTID(self):
self.app.master_conn.ask(Packets.AskLastIDs())
self.critical_tid_list.extend(self.new_partition_set)
self.new_partition_set.clear()
def setUnfinishedTIDList(self, tid_list):
"""This is a callback from MasterOperationHandler.""" """This is a callback from MasterOperationHandler."""
neo.lib.logging.debug('setting unfinished TIDs %s', neo.lib.logging.debug('setting unfinished TTIDs %s',
','.join([dump(tid) for tid in tid_list])) ','.join([dump(tid) for tid in ttid_list]))
# all new outdated partition must wait those ttid
new_partition_set = self.new_partition_set
while new_partition_set:
offset = new_partition_set.pop()
self.partition_dict[offset] = Partition(offset, max_tid, ttid_list)
self.waiting_for_unfinished_tids = False self.waiting_for_unfinished_tids = False
self.unfinished_tid_list = tid_list
def transactionFinished(self, ttid, max_tid):
""" Callback from MasterOperationHandler """
partition = self.partition_dict[self.app.pt.getPartition(ttid)]
partition.transactionFinished(ttid, max_tid)
def _askUnfinishedTIDs(self): def _askUnfinishedTIDs(self):
conn = self.app.master_conn conn = self.app.master_conn
...@@ -283,10 +273,6 @@ class Replicator(object): ...@@ -283,10 +273,6 @@ class Replicator(object):
self.current_connection = None self.current_connection = None
def act(self): def act(self):
# If the new partition list is not empty, I must ask a critical
# TID to a primary master node.
if self.new_partition_set:
self._askCriticalTID()
if self.current_partition is not None: if self.current_partition is not None:
# Don't end replication until we have received all expected # Don't end replication until we have received all expected
...@@ -305,24 +291,22 @@ class Replicator(object): ...@@ -305,24 +291,22 @@ class Replicator(object):
neo.lib.logging.debug('waiting for unfinished tids') neo.lib.logging.debug('waiting for unfinished tids')
return return
if self.unfinished_tid_list is None: if self.new_partition_set:
# Ask pending transactions. # Ask pending transactions.
neo.lib.logging.debug('asking unfinished tids') neo.lib.logging.debug('asking unfinished tids')
self._askUnfinishedTIDs() self._askUnfinishedTIDs()
return return
# Try to select something. # Try to select something.
if len(self.unfinished_tid_list):
min_unfinished_tid = min(self.unfinished_tid_list)
else:
min_unfinished_tid = None
for partition in self.partition_dict.values(): for partition in self.partition_dict.values():
if partition.safe(min_unfinished_tid): # XXX: replication could start up to the initial critical tid, that
# is below the pending transactions, then finish when all pending
# transactions are committed.
if partition.safe():
self.current_partition = partition self.current_partition = partition
break break
else: else:
# Not yet. # Not yet.
self.unfinished_tid_list = None
neo.lib.logging.debug('not ready yet') neo.lib.logging.debug('not ready yet')
return return
......
...@@ -283,7 +283,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -283,7 +283,7 @@ class NeoUnitTestBase(NeoTestBase):
def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False): def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False):
""" Check if a notify-packet with the right type is sent """ """ Check if a notify-packet with the right type is sent """
calls = conn.mockGetNamedCalls('notify') calls = conn.mockGetNamedCalls('notify')
self.assertTrue(len(calls) > packet_number) self.assertTrue(len(calls) > packet_number, (len(calls), packet_number))
packet = calls[packet_number].getParam(0) packet = calls[packet_number].getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type) self.assertEquals(packet.getType(), packet_type)
...@@ -324,6 +324,9 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -324,6 +324,9 @@ class NeoUnitTestBase(NeoTestBase):
def checkNotifyUnlockInformation(self, conn, **kw): def checkNotifyUnlockInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw) return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw)
def checkNotifyTransactionFinished(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw)
def checkRequestIdentification(self, conn, **kw): def checkRequestIdentification(self, conn, **kw):
return self.checkAskPacket(conn, Packets.RequestIdentification, **kw) return self.checkAskPacket(conn, Packets.RequestIdentification, **kw)
......
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import time
import unittest import unittest
import transaction import transaction
from persistent import Persistent from persistent import Persistent
from neo.tests.functional import NEOCluster, NEOFunctionalTest from neo.tests.functional import NEOCluster, NEOFunctionalTest
from neo.lib.protocol import ClusterStates, NodeStates from neo.lib.protocol import ClusterStates, NodeStates
from ZODB.tests.StorageTestBase import zodb_pickle
from MySQLdb import ProgrammingError from MySQLdb import ProgrammingError
from MySQLdb.constants.ER import NO_SUCH_TABLE from MySQLdb.constants.ER import NO_SUCH_TABLE
...@@ -522,5 +524,46 @@ class StorageTests(NEOFunctionalTest): ...@@ -522,5 +524,46 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectClusterRecovering() self.neo.expectClusterRecovering()
self.neo.expectOudatedCells(number=10) self.neo.expectOudatedCells(number=10)
def testReplicationBlockedByUnfinished(self):
# start a cluster with 1 of 2 storages and a replica
(started, stopped) = self.__setup(storage_number=2, replicas=1,
pending_number=1, partitions=10)
self.neo.expectRunning(started[0])
self.neo.expectStorageNotKnown(stopped[0])
self.neo.expectOudatedCells(number=0)
self.neo.expectClusterRunning()
self.__populate()
self.neo.expectOudatedCells(number=0)
# start a transaction that will block the end of the replication
db, conn = self.neo.getZODBConnection()
st = conn._storage
t = transaction.Transaction()
t.user = 'user'
t.description = 'desc'
oid = st.new_oid()
rev = '\0' * 8
data = zodb_pickle(PObject(42))
st.tpc_begin(t)
st.store(oid, rev, data, '', t)
# start the oudated storage
stopped[0].start()
self.neo.expectPending(stopped[0])
self.neo.neoctl.enableStorageList([stopped[0].getUUID()])
self.neo.expectRunning(stopped[0])
self.neo.expectClusterRunning()
self.neo.expectAssignedCells(started[0], 10)
self.neo.expectAssignedCells(stopped[0], 10)
# wait a bit, replication must not happen. This hack is required
# because we cannot gather informations directly from the storages
time.sleep(10)
self.neo.expectOudatedCells(number=10)
# finish the transaction, the replication must happen and finish
st.tpc_vote(t)
st.tpc_finish(t)
self.neo.expectOudatedCells(number=0, timeout=10)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -74,11 +74,12 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -74,11 +74,12 @@ class MasterClientHandlerTests(NeoUnitTestBase):
}) })
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
client_node = self.app.nm.getByUUID(client_uuid)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
service.askBeginTransaction(conn, None) service.askBeginTransaction(conn, None)
calls = tm.mockGetNamedCalls('begin') calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_uuid, None) calls[0].checkArgs(client_node, None)
self.checkAnswerBeginTransaction(conn) self.checkAnswerBeginTransaction(conn)
# Client asks for a TID # Client asks for a TID
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
...@@ -86,7 +87,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -86,7 +87,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
service.askBeginTransaction(conn, tid1) service.askBeginTransaction(conn, tid1)
calls = tm.mockGetNamedCalls('begin') calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_uuid, None) calls[0].checkArgs(client_node, None)
args = self.checkAnswerBeginTransaction(conn, decode=True) args = self.checkAnswerBeginTransaction(conn, decode=True)
self.assertEqual(args, (tid1, )) self.assertEqual(args, (tid1, ))
...@@ -142,9 +143,10 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -142,9 +143,10 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.assertTrue(self.app.isStorageReady(storage_uuid)) self.assertTrue(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, oid_list) service.askFinishTransaction(conn, ttid, oid_list)
self.checkAskLockInformation(storage_conn) self.checkAskLockInformation(storage_conn)
self.assertEquals(len(self.app.tm.getPendingList()), 1) self.assertEquals(len(self.app.tm.registerForNotification(storage_uuid)), 1)
txn = self.app.tm[ttid] txn = self.app.tm[ttid]
self.assertEquals(txn.getTID(), self.app.tm.getPendingList()[0]) pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0]
self.assertEquals(ttid, pending_ttid)
self.assertEquals(len(txn.getOIDList()), 0) self.assertEquals(len(txn.getOIDList()), 0)
self.assertEquals(len(txn.getUUIDList()), 1) self.assertEquals(len(txn.getUUIDList()), 1)
......
...@@ -101,8 +101,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -101,8 +101,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
oid_list = self.getOID(), self.getOID() oid_list = self.getOID(), self.getOID()
msg_id = 1 msg_id = 1
# register a transaction # register a transaction
ttid = self.app.tm.begin(client_1.getUUID()) ttid = self.app.tm.begin(client_1)
tid = self.app.tm.prepare(client_1, ttid, 1, oid_list, uuid_list, tid = self.app.tm.prepare(ttid, 1, oid_list, uuid_list,
msg_id) msg_id)
self.assertTrue(ttid in self.app.tm) self.assertTrue(ttid in self.app.tm)
# the first storage acknowledge the lock # the first storage acknowledge the lock
...@@ -141,17 +141,17 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -141,17 +141,17 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# give a uuid # give a uuid
service.askUnfinishedTransactions(conn) service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn) packet = self.checkAnswerUnfinishedTransactions(conn)
tid_list, = packet.decode() max_tid, tid_list = packet.decode()
self.assertEqual(tid_list, []) self.assertEqual(tid_list, [])
# create some transaction # create some transaction
node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port=self.client_port) port=self.client_port)
self.app.tm.prepare(node, self.getNextTID(), 1, ttid = self.app.tm.begin(node)
self.app.tm.prepare(ttid, 1,
[self.getOID(1)], [node.getUUID()], 1) [self.getOID(1)], [node.getUUID()], 1)
conn = self.getFakeConnection(node.getUUID(), self.storage_address) conn = self.getFakeConnection(node.getUUID(), self.storage_address)
service.askUnfinishedTransactions(conn) service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn) max_tid, tid_list = self.checkAnswerUnfinishedTransactions(conn, decode=True)
(tid_list, ) = packet.decode()
self.assertEqual(len(tid_list), 1) self.assertEqual(len(tid_list), 1)
def _testWithMethod(self, method, state): def _testWithMethod(self, method, state):
...@@ -208,26 +208,28 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -208,26 +208,28 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
ttid1 = tm.begin(node1.getUUID()) ttid1 = tm.begin(client1)
tid1 = tm.prepare(client1, ttid1, 1, oid_list, tid1 = tm.prepare(ttid1, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_1) [node1.getUUID(), node2.getUUID()], msg_id_1)
tm.lock(ttid1, node2.getUUID()) tm.lock(ttid1, node2.getUUID())
# storage 1 request a notification at commit
tm. registerForNotification(node1.getUUID())
self.checkNoPacketSent(cconn1) self.checkNoPacketSent(cconn1)
# Storage 1 dies # Storage 1 dies
node1.setTemporarilyDown() node1.setTemporarilyDown()
self.service.nodeLost(conn1, node1) self.service.nodeLost(conn1, node1)
# T1: last locking node lost, client receives AnswerTransactionFinished # T1: last locking node lost, client receives AnswerTransactionFinished
self.checkAnswerTransactionFinished(cconn1) self.checkAnswerTransactionFinished(cconn1)
self.checkNotifyTransactionFinished(conn1)
self.checkNotifyUnlockInformation(conn2) self.checkNotifyUnlockInformation(conn2)
self.checkNoPacketSent(conn1)
# ...and notifications are sent to other clients # ...and notifications are sent to other clients
self.checkInvalidateObjects(cconn2) self.checkInvalidateObjects(cconn2)
self.checkInvalidateObjects(cconn3) self.checkInvalidateObjects(cconn3)
# Transaction 2: 2 storage nodes involved, one will die # Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2 msg_id_2 = 2
ttid2 = tm.begin(node1.getUUID()) ttid2 = tm.begin(node1)
tid2 = tm.prepare(client2, ttid2, 1, oid_list, tid2 = tm.prepare(ttid2, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_2) [node1.getUUID(), node2.getUUID()], msg_id_2)
# T2: pending locking answer, client keeps waiting # T2: pending locking answer, client keeps waiting
self.checkNoPacketSent(cconn2, check_notify=False) self.checkNoPacketSent(cconn2, check_notify=False)
...@@ -235,8 +237,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -235,8 +237,8 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# Transaction 3: 1 storage node involved, which won't die # Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3 msg_id_3 = 3
ttid3 = tm.begin(node1.getUUID()) ttid3 = tm.begin(node1)
tid3 = tm.prepare(client3, ttid3, 1, oid_list, tid3 = tm.prepare(ttid3, 1, oid_list,
[node2.getUUID(), ], msg_id_3) [node2.getUUID(), ], msg_id_3)
# T3: action not significant to this transacion, so no response # T3: action not significant to this transacion, so no response
self.checkNoPacketSent(cconn3, check_notify=False) self.checkNoPacketSent(cconn3, check_notify=False)
......
...@@ -37,7 +37,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -37,7 +37,7 @@ class testTransactionManager(NeoUnitTestBase):
def makeNode(self, i): def makeNode(self, i):
uuid = self.makeUUID(i) uuid = self.makeUUID(i)
node = Mock({'getUUID': uuid, '__hash__': 0}) node = Mock({'getUUID': uuid, '__hash__': i, '__repr__': 'FakeNode'})
return uuid, node return uuid, node
def testTransaction(self): def testTransaction(self):
...@@ -49,7 +49,8 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -49,7 +49,8 @@ class testTransactionManager(NeoUnitTestBase):
uuid_list = (uuid1, uuid2) = [self.makeUUID(1), self.makeUUID(2)] uuid_list = (uuid1, uuid2) = [self.makeUUID(1), self.makeUUID(2)]
msg_id = 1 msg_id = 1
# create transaction object # create transaction object
txn = Transaction(node, ttid, tid, oid_list, uuid_list, msg_id) txn = Transaction(node, ttid)
txn.prepare(tid, oid_list, uuid_list, msg_id)
self.assertEqual(txn.getUUIDList(), uuid_list) self.assertEqual(txn.getUUIDList(), uuid_list)
self.assertEqual(txn.getOIDList(), oid_list) self.assertEqual(txn.getOIDList(), oid_list)
# lock nodes one by one # lock nodes one by one
...@@ -69,16 +70,16 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -69,16 +70,16 @@ class testTransactionManager(NeoUnitTestBase):
callback = Mock() callback = Mock()
txnman = TransactionManager(on_commit=callback) txnman = TransactionManager(on_commit=callback)
self.assertFalse(txnman.hasPending()) self.assertFalse(txnman.hasPending())
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.registerForNotification(uuid1), set())
# begin the transaction # begin the transaction
ttid = txnman.begin(client_uuid) ttid = txnman.begin(node)
self.assertTrue(ttid is not None) self.assertTrue(ttid is not None)
self.assertFalse(txnman.hasPending()) self.assertEqual(len(txnman.registerForNotification(uuid1)), 1)
self.assertEqual(len(txnman.getPendingList()), 0) self.assertTrue(txnman.hasPending())
# prepare the transaction # prepare the transaction
tid = txnman.prepare(node, ttid, 1, oid_list, uuid_list, msg_id) tid = txnman.prepare(ttid, 1, oid_list, uuid_list, msg_id)
self.assertTrue(txnman.hasPending()) self.assertTrue(txnman.hasPending())
self.assertEqual(txnman.getPendingList()[0], tid) self.assertEqual(txnman.registerForNotification(uuid1), set([ttid]))
txn = txnman[ttid] txn = txnman[ttid]
self.assertEqual(txn.getTID(), tid) self.assertEqual(txn.getTID(), tid)
self.assertEqual(txn.getUUIDList(), list(uuid_list)) self.assertEqual(txn.getUUIDList(), list(uuid_list))
...@@ -90,30 +91,30 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -90,30 +91,30 @@ class testTransactionManager(NeoUnitTestBase):
self.assertEqual(len(callback.getNamedCalls('__call__')), 1) self.assertEqual(len(callback.getNamedCalls('__call__')), 1)
# transaction finished # transaction finished
txnman.remove(client_uuid, ttid) txnman.remove(client_uuid, ttid)
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.registerForNotification(uuid1), set())
def testAbortFor(self): def testAbortFor(self):
node1 = Mock({'__hash__': 1})
node2 = Mock({'__hash__': 2})
oid_list = [self.makeOID(1), ] oid_list = [self.makeOID(1), ]
storage_1_uuid = self.makeUUID(1) storage_1_uuid, node1 = self.makeNode(1)
storage_2_uuid = self.makeUUID(2) storage_2_uuid, node2 = self.makeNode(2)
client_uuid = self.makeUUID(3) client_uuid, client = self.makeNode(3)
txnman = TransactionManager(lambda tid, txn: None) txnman = TransactionManager(lambda tid, txn: None)
# register 4 transactions made by two nodes # register 4 transactions made by two nodes
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.registerForNotification(storage_1_uuid), set())
ttid1 = txnman.begin(client_uuid) ttid1 = txnman.begin(client)
tid1 = txnman.prepare(node1, ttid1, 1, oid_list, [storage_1_uuid], 1) tid1 = txnman.prepare(ttid1, 1, oid_list, [storage_1_uuid], 1)
self.assertEqual(txnman.getPendingList(), [tid1]) self.assertEqual(txnman.registerForNotification(storage_1_uuid), set([ttid1]))
# abort transactions of another node, transaction stays # abort transactions of another node, transaction stays
txnman.abortFor(node2) txnman.abortFor(node2)
self.assertEqual(txnman.getPendingList(), [tid1]) self.assertEqual(txnman.registerForNotification(storage_1_uuid), set([ttid1]))
# abort transactions of requesting node, transaction is removed # abort transactions of requesting node, transaction is not removed
# because the transaction is prepared and must remains until the end of
# the 2PC
txnman.abortFor(node1) txnman.abortFor(node1)
self.assertEqual(txnman.getPendingList(), []) self.assertEqual(txnman.registerForNotification(storage_1_uuid), set([ttid1]))
self.assertFalse(txnman.hasPending()) self.assertTrue(txnman.hasPending())
# ...and the lock is available # ...and the lock is available
txnman.begin(client_uuid, self.getNextTID()) txnman.begin(client, self.getNextTID())
def test_getNextOIDList(self): def test_getNextOIDList(self):
txnman = TransactionManager(lambda tid, txn: None) txnman = TransactionManager(lambda tid, txn: None)
...@@ -141,8 +142,8 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -141,8 +142,8 @@ class testTransactionManager(NeoUnitTestBase):
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
ttid1 = tm.begin(client_uuid) ttid1 = tm.begin(client1)
tid1 = tm.prepare(client1, ttid1, 1, oid_list, tid1 = tm.prepare(ttid1, 1, oid_list,
[storage_1_uuid, storage_2_uuid], msg_id_1) [storage_1_uuid, storage_2_uuid], msg_id_1)
tm.lock(ttid1, storage_2_uuid) tm.lock(ttid1, storage_2_uuid)
t1 = tm[ttid1] t1 = tm[ttid1]
...@@ -155,8 +156,8 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -155,8 +156,8 @@ class testTransactionManager(NeoUnitTestBase):
# Transaction 2: 2 storage nodes involved, one will die # Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2 msg_id_2 = 2
ttid2 = tm.begin(client_uuid) ttid2 = tm.begin(client2)
tid2 = tm.prepare(client2, ttid2, 1, oid_list, tid2 = tm.prepare(ttid2, 1, oid_list,
[storage_1_uuid, storage_2_uuid], msg_id_2) [storage_1_uuid, storage_2_uuid], msg_id_2)
t2 = tm[ttid2] t2 = tm[ttid2]
self.assertFalse(t2.locked()) self.assertFalse(t2.locked())
...@@ -169,8 +170,8 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -169,8 +170,8 @@ class testTransactionManager(NeoUnitTestBase):
# Transaction 3: 1 storage node involved, which won't die # Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3 msg_id_3 = 3
ttid3 = tm.begin(client_uuid) ttid3 = tm.begin(client3)
tid3 = tm.prepare(client3, ttid3, 1, oid_list, [storage_2_uuid, ], tid3 = tm.prepare(ttid3, 1, oid_list, [storage_2_uuid, ],
msg_id_3) msg_id_3)
t3 = tm[ttid3] t3 = tm[ttid3]
self.assertFalse(t3.locked()) self.assertFalse(t3.locked())
...@@ -213,29 +214,28 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -213,29 +214,28 @@ class testTransactionManager(NeoUnitTestBase):
strictly increasing order. strictly increasing order.
Note: this implementation might change later, to allow more paralelism. Note: this implementation might change later, to allow more paralelism.
""" """
client_uuid = self.makeUUID(3) client_uuid, client = self.makeNode(1)
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(lambda tid, txn: None)
# With a requested TID, lock spans from begin to remove # With a requested TID, lock spans from begin to remove
ttid1 = self.getNextTID() ttid1 = self.getNextTID()
ttid2 = self.getNextTID() ttid2 = self.getNextTID()
tid1 = tm.begin(client_uuid, ttid1) tid1 = tm.begin(client, ttid1)
self.assertEqual(tid1, ttid1) self.assertEqual(tid1, ttid1)
tm.remove(client_uuid, tid1) tm.remove(client_uuid, tid1)
# Without a requested TID, lock spans from prepare to remove only # Without a requested TID, lock spans from prepare to remove only
ttid3 = tm.begin(client_uuid) ttid3 = tm.begin(client)
ttid4 = tm.begin(client_uuid) # Doesn't raise ttid4 = tm.begin(client) # Doesn't raise
node = Mock({'getUUID': client_uuid, '__hash__': 0}) node = Mock({'getUUID': client_uuid, '__hash__': 0})
tid4 = tm.prepare(node, ttid4, 1, [], [], 0) tid4 = tm.prepare(ttid4, 1, [], [], 0)
tm.remove(client_uuid, tid4) tm.remove(client_uuid, tid4)
tm.prepare(node, ttid3, 1, [], [], 0) tm.prepare(ttid3, 1, [], [], 0)
def testClientDisconectsAfterBegin(self): def testClientDisconectsAfterBegin(self):
client1_uuid = self.makeUUID(1) client_uuid1, node1 = self.makeNode(1)
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(lambda tid, txn: None)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tm.begin(client1_uuid, tid1) tm.begin(node1, tid1)
node1 = Mock({'getUUID': client1_uuid, '__hash__': 0})
tm.abortFor(node1) tm.abortFor(node1)
self.assertTrue(tid1 not in tm) self.assertTrue(tid1 not in tm)
...@@ -245,10 +245,10 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -245,10 +245,10 @@ class testTransactionManager(NeoUnitTestBase):
uuid2, node2 = self.makeNode(2) uuid2, node2 = self.makeNode(2)
storage_uuid = self.makeUUID(3) storage_uuid = self.makeUUID(3)
tm = TransactionManager(callback) tm = TransactionManager(callback)
ttid1 = tm.begin(uuid1) ttid1 = tm.begin(node1)
ttid2 = tm.begin(uuid2) ttid2 = tm.begin(node2)
tid1 = tm.prepare(node1, ttid1, 1, [], [storage_uuid], 0) tid1 = tm.prepare(ttid1, 1, [], [storage_uuid], 0)
tid2 = tm.prepare(node2, ttid2, 1, [], [storage_uuid], 0) tid2 = tm.prepare(ttid2, 1, [], [storage_uuid], 0)
tm.lock(ttid2, storage_uuid) tm.lock(ttid2, storage_uuid)
# txn 2 is still blocked by txn 1 # txn 2 is still blocked by txn 1
self.assertEqual(len(callback.getNamedCalls('__call__')), 0) self.assertEqual(len(callback.getNamedCalls('__call__')), 0)
......
...@@ -124,14 +124,14 @@ class MasterVerificationTests(NeoUnitTestBase): ...@@ -124,14 +124,14 @@ class MasterVerificationTests(NeoUnitTestBase):
self.assertEquals(len(self.verification._uuid_set), 0) self.assertEquals(len(self.verification._uuid_set), 0)
self.assertEquals(len(self.verification._tid_set), 0) self.assertEquals(len(self.verification._tid_set), 0)
new_tid = self.getNextTID() new_tid = self.getNextTID()
verification.answerUnfinishedTransactions(conn, [new_tid]) verification.answerUnfinishedTransactions(conn, new_tid, [new_tid])
self.assertEquals(len(self.verification._tid_set), 0) self.assertEquals(len(self.verification._tid_set), 0)
# update dict # update dict
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.verification._uuid_set.add(uuid) self.verification._uuid_set.add(uuid)
self.assertEquals(len(self.verification._tid_set), 0) self.assertEquals(len(self.verification._tid_set), 0)
new_tid = self.getNextTID(new_tid) new_tid = self.getNextTID(new_tid)
verification.answerUnfinishedTransactions(conn, [new_tid,]) verification.answerUnfinishedTransactions(conn, new_tid, [new_tid])
self.assertTrue(uuid not in self.verification._uuid_set) self.assertTrue(uuid not in self.verification._uuid_set)
self.assertEquals(len(self.verification._tid_set), 1) self.assertEquals(len(self.verification._tid_set), 1)
self.assertTrue(new_tid in self.verification._tid_set) self.assertTrue(new_tid in self.verification._tid_set)
......
...@@ -190,11 +190,12 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -190,11 +190,12 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
self.app.replicator = Mock() self.app.replicator = Mock()
self.operation.answerUnfinishedTransactions( self.operation.answerUnfinishedTransactions(
conn=conn, conn=conn,
tid_list=(INVALID_TID, ), max_tid=INVALID_TID,
ttid_list=(INVALID_TID, ),
) )
calls = self.app.replicator.mockGetNamedCalls('setUnfinishedTIDList') calls = self.app.replicator.mockGetNamedCalls('setUnfinishedTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs((INVALID_TID, )) calls[0].checkArgs(INVALID_TID, (INVALID_TID, ))
def test_askPack(self): def test_askPack(self):
self.app.dm = Mock({'pack': None}) self.app.dm = Mock({'pack': None})
......
...@@ -40,7 +40,6 @@ class StorageReplicatorTests(NeoUnitTestBase): ...@@ -40,7 +40,6 @@ class StorageReplicatorTests(NeoUnitTestBase):
}) })
replicator = Replicator(app) replicator = Replicator(app)
self.assertEqual(replicator.new_partition_set, set()) self.assertEqual(replicator.new_partition_set, set())
replicator.replication_done = False
replicator.populate() replicator.populate()
self.assertEqual(replicator.new_partition_set, set([0])) self.assertEqual(replicator.new_partition_set, set([0]))
...@@ -50,40 +49,32 @@ class StorageReplicatorTests(NeoUnitTestBase): ...@@ -50,40 +49,32 @@ class StorageReplicatorTests(NeoUnitTestBase):
replicator.task_dict = {'foo': 'bar'} replicator.task_dict = {'foo': 'bar'}
replicator.current_partition = 'foo' replicator.current_partition = 'foo'
replicator.current_connection = 'foo' replicator.current_connection = 'foo'
replicator.unfinished_tid_list = ['foo']
replicator.replication_done = 'foo' replicator.replication_done = 'foo'
replicator.reset() replicator.reset()
self.assertEqual(replicator.task_list, []) self.assertEqual(replicator.task_list, [])
self.assertEqual(replicator.task_dict, {}) self.assertEqual(replicator.task_dict, {})
self.assertEqual(replicator.current_partition, None) self.assertEqual(replicator.current_partition, None)
self.assertEqual(replicator.current_connection, None) self.assertEqual(replicator.current_connection, None)
self.assertEqual(replicator.unfinished_tid_list, None)
self.assertTrue(replicator.replication_done) self.assertTrue(replicator.replication_done)
def test_setCriticalTID(self): def test_setCriticalTID(self):
replicator = Replicator(None)
critical_tid = self.getNextTID() critical_tid = self.getNextTID()
partition = Partition(0, critical_tid) partition = Partition(0, critical_tid, [])
self.assertEqual(partition.getCriticalTID(), critical_tid) self.assertEqual(partition.getCriticalTID(), critical_tid)
self.assertEqual(partition.getOffset(), 0)
def test_setUnfinishedTIDList(self):
replicator = Replicator(None)
replicator.waiting_for_unfinished_tids = True
assert replicator.unfinished_tid_list is None, \
replicator.unfinished_tid_list
tid_list = [self.getNextTID(), ]
replicator.setUnfinishedTIDList(tid_list)
self.assertEqual(replicator.unfinished_tid_list, tid_list)
self.assertFalse(replicator.waiting_for_unfinished_tids)
def test_act(self): def test_act(self):
# Also tests "pending" # Also tests "pending"
uuid = self.getNewUUID() uuid = self.getNewUUID()
master_uuid = self.getNewUUID() master_uuid = self.getNewUUID()
bad_unfinished_tid = self.getNextTID() critical_tid_0 = self.getNextTID()
critical_tid = self.getNextTID() critical_tid_1 = self.getNextTID()
unfinished_tid = self.getNextTID() critical_tid_2 = self.getNextTID()
unfinished_ttid_1 = self.getOID(1)
unfinished_ttid_2 = self.getOID(2)
app = Mock() app = Mock()
app.server = ('127.0.0.1', 10000)
app.name = 'fake cluster'
app.em = Mock({ app.em = Mock({
'register': None, 'register': None,
}) })
...@@ -105,6 +96,7 @@ class StorageReplicatorTests(NeoUnitTestBase): ...@@ -105,6 +96,7 @@ class StorageReplicatorTests(NeoUnitTestBase):
app.pt = Mock({ app.pt = Mock({
'getCellList': [running_cell, unknown_cell], 'getCellList': [running_cell, unknown_cell],
'getOutdatedOffsetListFor': [0], 'getOutdatedOffsetListFor': [0],
'getPartition': 0,
}) })
node_conn_handler = Mock({ node_conn_handler = Mock({
'startReplication': None, 'startReplication': None,
...@@ -119,37 +111,28 @@ class StorageReplicatorTests(NeoUnitTestBase): ...@@ -119,37 +111,28 @@ class StorageReplicatorTests(NeoUnitTestBase):
app.master_conn = self.getFakeConnection(uuid=master_uuid) app.master_conn = self.getFakeConnection(uuid=master_uuid)
self.assertTrue(replicator.pending()) self.assertTrue(replicator.pending())
replicator.act() replicator.act()
# ask last IDs to infer critical_tid and unfinished tids # ask unfinished tids
act() act()
last_ids, unfinished_tids = [x.getParam(0) for x in \ unfinished_tids = app.master_conn.mockGetNamedCalls('ask')[0].getParam(0)
app.master_conn.mockGetNamedCalls('ask')] self.assertTrue(replicator.new_partition_set)
self.assertEqual(last_ids.getType(), Packets.AskLastIDs) self.assertEqual(unfinished_tids.getType(), Packets.AskUnfinishedTransactions)
self.assertFalse(replicator.new_partition_set)
self.assertEqual(unfinished_tids.getType(),
Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids) self.assertTrue(replicator.waiting_for_unfinished_tids)
# nothing happens until waiting_for_unfinished_tids becomes False # nothing happens until waiting_for_unfinished_tids becomes False
act() act()
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids) self.assertTrue(replicator.waiting_for_unfinished_tids)
# Send answers (garanteed to happen in this order)
replicator.setCriticalTID(critical_tid)
act()
self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# first time, there is an unfinished tid before critical tid, # first time, there is an unfinished tid before critical tid,
# replication cannot start, and unfinished TIDs are asked again # replication cannot start, and unfinished TIDs are asked again
replicator.setUnfinishedTIDList([unfinished_tid, bad_unfinished_tid]) replicator.setUnfinishedTIDList(critical_tid_0,
[unfinished_ttid_1, unfinished_ttid_2])
self.assertFalse(replicator.waiting_for_unfinished_tids) self.assertFalse(replicator.waiting_for_unfinished_tids)
# Note: detection that nothing can be replicated happens on first call # Note: detection that nothing can be replicated happens on first call
# and unfinished tids are asked again on second call. This is ok, but # and unfinished tids are asked again on second call. This is ok, but
# might change, so just call twice. # might change, so just call twice.
act() act()
replicator.transactionFinished(unfinished_ttid_1, critical_tid_1)
act() act()
self.checkAskPacket(app.master_conn, Packets.AskUnfinishedTransactions) replicator.transactionFinished(unfinished_ttid_2, critical_tid_2)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# this time, critical tid check should be satisfied
replicator.setUnfinishedTIDList([unfinished_tid, ])
replicator.current_connection = node_conn replicator.current_connection = node_conn
act() act()
self.assertEqual(replicator.current_partition, self.assertEqual(replicator.current_partition,
...@@ -174,8 +157,6 @@ class StorageReplicatorTests(NeoUnitTestBase): ...@@ -174,8 +157,6 @@ class StorageReplicatorTests(NeoUnitTestBase):
'isPending': False, 'isPending': False,
}) })
act() act()
# unfinished tid list will not be asked again
self.assertTrue(replicator.unfinished_tid_list)
# also, replication is over # also, replication is over
self.assertFalse(replicator.pending()) self.assertFalse(replicator.pending())
......
...@@ -161,7 +161,7 @@ class StorageVerificationHandlerTests(NeoUnitTestBase): ...@@ -161,7 +161,7 @@ class StorageVerificationHandlerTests(NeoUnitTestBase):
}) })
conn = self.getMasterConnection() conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn) self.verification.askUnfinishedTransactions(conn)
(tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True) (max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 0) self.assertEqual(len(tid_list), 0)
call_list = self.app.dm.mockGetNamedCalls('getUnfinishedTIDList') call_list = self.app.dm.mockGetNamedCalls('getUnfinishedTIDList')
self.assertEqual(len(call_list), 1) self.assertEqual(len(call_list), 1)
...@@ -173,7 +173,7 @@ class StorageVerificationHandlerTests(NeoUnitTestBase): ...@@ -173,7 +173,7 @@ class StorageVerificationHandlerTests(NeoUnitTestBase):
}) })
conn = self.getMasterConnection() conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn) self.verification.askUnfinishedTransactions(conn)
(tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True) (max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1) self.assertEqual(len(tid_list), 1)
self.assertEqual(u64(tid_list[0]), 4) self.assertEqual(u64(tid_list[0]), 4)
......
...@@ -195,13 +195,15 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -195,13 +195,15 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(p.decode(), ()) self.assertEqual(p.decode(), ())
def test_27_answerUnfinishedTransaction(self): def test_27_answerUnfinishedTransaction(self):
tid = self.getNextTID()
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid_list = [tid1, tid2, tid3, tid4] tid_list = [tid1, tid2, tid3, tid4]
p = Packets.AnswerUnfinishedTransactions(tid_list) p = Packets.AnswerUnfinishedTransactions(tid, tid_list)
p_tid_list = p.decode()[0] p_tid, p_tid_list = p.decode()
self.assertEqual(p_tid, tid)
self.assertEqual(p_tid_list, tid_list) self.assertEqual(p_tid_list, tid_list)
def test_28_askObjectPresent(self): def test_28_askObjectPresent(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment