Commit 10a081bc authored by Grégory Wisniewski's avatar Grégory Wisniewski

Make checkSerialRange and checkTIDRange take critical tid into account.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2587 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 74bdb693
...@@ -348,14 +348,14 @@ class EventHandler(object): ...@@ -348,14 +348,14 @@ class EventHandler(object):
def answerPack(self, conn, status): def answerPack(self, conn, status):
raise UnexpectedPacketError raise UnexpectedPacketError
def askCheckTIDRange(self, conn, min_tid, length, partition): def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum, def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid): max_tid):
raise UnexpectedPacketError raise UnexpectedPacketError
def askCheckSerialRange(self, conn, min_oid, min_serial, length, def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition): partition):
raise UnexpectedPacketError raise UnexpectedPacketError
......
...@@ -1679,13 +1679,14 @@ class AskCheckTIDRange(Packet): ...@@ -1679,13 +1679,14 @@ class AskCheckTIDRange(Packet):
reference node. reference node.
S -> S S -> S
""" """
_header_format = '!8sLL' _header_format = '!8s8sLL'
def _encode(self, min_tid, length, partition): def _encode(self, min_tid, max_tid, length, partition):
return pack(self._header_format, min_tid, length, partition) return pack(self._header_format, min_tid, max_tid, length, partition)
def _decode(self, body): def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition # min_tid, max_tid, length, partition
return unpack(self._header_format, body)
class AnswerCheckTIDRange(Packet): class AnswerCheckTIDRange(Packet):
""" """
...@@ -1710,14 +1711,14 @@ class AskCheckSerialRange(Packet): ...@@ -1710,14 +1711,14 @@ class AskCheckSerialRange(Packet):
reference node. reference node.
S -> S S -> S
""" """
_header_format = '!8s8sLL' _header_format = '!8s8s8sLL'
def _encode(self, min_oid, min_serial, length, partition): def _encode(self, min_oid, min_serial, max_tid, length, partition):
return pack(self._header_format, min_oid, min_serial, length, return pack(self._header_format, min_oid, min_serial, max_tid, length,
partition) partition)
def _decode(self, body): def _decode(self, body):
# min_oid, min_serial, length, partition # min_oid, min_serial, max_tid, length, partition
return unpack(self._header_format, body) return unpack(self._header_format, body)
class AnswerCheckSerialRange(Packet): class AnswerCheckSerialRange(Packet):
......
...@@ -684,13 +684,15 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -684,13 +684,15 @@ class BTreeDatabaseManager(DatabaseManager):
return not tserial return not tserial
batchDelete(self._obj, obj_callback, recycle_subtrees=True) batchDelete(self._obj, obj_callback, recycle_subtrees=True)
def checkTIDRange(self, min_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
# XXX: XOR is a lame checksum # XXX: XOR is a lame checksum
count = 0 count = 0
tid_checksum = 0 tid_checksum = 0
tid = 0 tid = 0
upper_bound = util.u64(max_tid)
max_tid = 0 max_tid = 0
for tid in safeIter(self._trans.keys, min=util.u64(min_tid)): for tid in safeIter(self._trans.keys, min=util.u64(min_tid),
max=upper_bound):
if tid % num_partitions == partition: if tid % num_partitions == partition:
if count >= length: if count >= length:
break break
...@@ -699,8 +701,8 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -699,8 +701,8 @@ class BTreeDatabaseManager(DatabaseManager):
count += 1 count += 1
return count, tid_checksum, util.p64(max_tid) return count, tid_checksum, util.p64(max_tid)
def checkSerialRange(self, min_oid, min_serial, length, num_partitions, def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition): num_partitions, partition):
# XXX: XOR is a lame checksum # XXX: XOR is a lame checksum
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
...@@ -712,7 +714,8 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -712,7 +714,8 @@ class BTreeDatabaseManager(DatabaseManager):
if oid % num_partitions == partition: if oid % num_partitions == partition:
if oid == min_oid: if oid == min_oid:
try: try:
serial_iter = tserial.keys(min=u64(min_serial)) serial_iter = tserial.keys(min=u64(min_serial),
max=u64(max_tid))
except ValueError: except ValueError:
continue continue
else: else:
......
...@@ -459,7 +459,7 @@ class DatabaseManager(object): ...@@ -459,7 +459,7 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def checkTIDRange(self, min_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
""" """
Generate a diggest from transaction list. Generate a diggest from transaction list.
min_tid (packed) min_tid (packed)
...@@ -478,8 +478,8 @@ class DatabaseManager(object): ...@@ -478,8 +478,8 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def checkSerialRange(self, min_oid, min_serial, length, num_partitions, def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition): num_partitions, partition):
""" """
Generate a diggest from object list. Generate a diggest from object list.
min_oid (packed) min_oid (packed)
......
...@@ -825,17 +825,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -825,17 +825,19 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def checkTIDRange(self, min_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
# XXX: XOR is a lame checksum # XXX: XOR is a lame checksum
count, tid_checksum, max_tid = self.query('SELECT COUNT(*), ' count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
'BIT_XOR(tid), MAX(tid) FROM (' 'BIT_XOR(tid), MAX(tid) FROM ('
'SELECT tid FROM trans ' 'SELECT tid FROM trans '
'WHERE partition = %(partition)s ' 'WHERE partition = %(partition)s '
'AND tid >= %(min_tid)d ' 'AND tid >= %(min_tid)d '
'AND tid <= %(max_tid)d '
'ORDER BY tid ASC LIMIT %(length)d' 'ORDER BY tid ASC LIMIT %(length)d'
') AS foo' % { ') AS foo' % {
'partition': partition, 'partition': partition,
'min_tid': util.u64(min_tid), 'min_tid': util.u64(min_tid),
'max_tid': util.u64(max_tid),
'length': length, 'length': length,
})[0] })[0]
if count == 0: if count == 0:
...@@ -845,18 +847,20 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -845,18 +847,20 @@ class MySQLDatabaseManager(DatabaseManager):
max_tid = util.p64(max_tid) max_tid = util.p64(max_tid)
return count, tid_checksum, max_tid return count, tid_checksum, max_tid
def checkSerialRange(self, min_oid, min_serial, length, num_partitions, def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition): num_partitions, partition):
# XXX: XOR is a lame checksum # XXX: XOR is a lame checksum
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
r = self.query('SELECT oid, serial FROM obj_short WHERE ' r = self.query('SELECT oid, serial FROM obj_short WHERE '
'partition = %(partition)s AND ' 'partition = %(partition)s AND '
'serial <= %(max_tid)d AND '
'(oid > %(min_oid)d OR ' '(oid > %(min_oid)d OR '
'(oid = %(min_oid)d AND serial >= %(min_serial)d)) ' '(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % { 'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': u64(min_oid), 'min_oid': u64(min_oid),
'min_serial': u64(min_serial), 'min_serial': u64(min_serial),
'max_tid': u64(max_tid),
'length': length, 'length': length,
'partition': partition, 'partition': partition,
}) })
......
...@@ -96,7 +96,8 @@ class ReplicationHandler(EventHandler): ...@@ -96,7 +96,8 @@ class ReplicationHandler(EventHandler):
self.startReplication(conn) self.startReplication(conn)
def startReplication(self, conn): def startReplication(self, conn):
conn.ask(self._doAskCheckTIDRange(ZERO_TID), timeout=300) max_tid = self.app.replicator.getCurrentCriticalTID()
conn.ask(self._doAskCheckTIDRange(ZERO_TID, max_tid), timeout=300)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
...@@ -118,7 +119,9 @@ class ReplicationHandler(EventHandler): ...@@ -118,7 +119,9 @@ class ReplicationHandler(EventHandler):
if len(tid_list) == MIN_RANGE_LENGTH: if len(tid_list) == MIN_RANGE_LENGTH:
# If we received fewer, we knew it before sending AskTIDsFrom, and # If we received fewer, we knew it before sending AskTIDsFrom, and
# we should have finished TID replication at that time. # we should have finished TID replication at that time.
ask(self._doAskCheckTIDRange(add64(tid_list[-1], 1), RANGE_LENGTH)) max_tid = self.app.replicator.getCurrentCriticalTID()
ask(self._doAskCheckTIDRange(add64(tid_list[-1], 1), max_tid,
RANGE_LENGTH))
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid, def answerTransactionInformation(self, conn, tid,
...@@ -160,8 +163,9 @@ class ReplicationHandler(EventHandler): ...@@ -160,8 +163,9 @@ class ReplicationHandler(EventHandler):
if not app.dm.objectPresent(oid, serial): if not app.dm.objectPresent(oid, serial):
ask(Packets.AskObject(oid, serial, None), timeout=300) ask(Packets.AskObject(oid, serial, None), timeout=300)
if sum((len(x) for x in object_dict.itervalues())) == MIN_RANGE_LENGTH: if sum((len(x) for x in object_dict.itervalues())) == MIN_RANGE_LENGTH:
max_tid = self.app.replicator.getCurrentCriticalTID()
ask(self._doAskCheckSerialRange(max_oid, add64(max_serial, 1), ask(self._doAskCheckSerialRange(max_oid, add64(max_serial, 1),
RANGE_LENGTH)) max_tid, RANGE_LENGTH))
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start, def answerObject(self, conn, oid, serial_start,
...@@ -173,17 +177,19 @@ class ReplicationHandler(EventHandler): ...@@ -173,17 +177,19 @@ class ReplicationHandler(EventHandler):
del obj del obj
del data del data
def _doAskCheckSerialRange(self, min_oid, min_tid, length=RANGE_LENGTH): def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid,
length=RANGE_LENGTH):
replicator = self.app.replicator replicator = self.app.replicator
partition = replicator.getCurrentRID() partition = replicator.getCurrentRID()
replicator.checkSerialRange(min_oid, min_tid, length, partition) check_args = (min_oid, min_tid, max_tid, length, partition)
return Packets.AskCheckSerialRange(min_oid, min_tid, length, partition) replicator.checkSerialRange(*check_args)
return Packets.AskCheckSerialRange(*check_args)
def _doAskCheckTIDRange(self, min_tid, length=RANGE_LENGTH): def _doAskCheckTIDRange(self, min_tid, max_tid, length=RANGE_LENGTH):
replicator = self.app.replicator replicator = self.app.replicator
partition = replicator.getCurrentRID() partition = replicator.getCurrentRID()
replicator.checkTIDRange(min_tid, length, partition) replicator.checkTIDRange(min_tid, max_tid, length, partition)
return Packets.AskCheckTIDRange(min_tid, length, partition) return Packets.AskCheckTIDRange(min_tid, max_tid, length, partition)
def _doAskTIDsFrom(self, min_tid, length): def _doAskTIDsFrom(self, min_tid, length):
replicator = self.app.replicator replicator = self.app.replicator
...@@ -269,7 +275,8 @@ class ReplicationHandler(EventHandler): ...@@ -269,7 +275,8 @@ class ReplicationHandler(EventHandler):
action = CHECK_DONE action = CHECK_DONE
params = (next_tid, ) params = (next_tid, )
else: else:
ask(self._doAskCheckTIDRange(min_tid, count)) max_tid = replicator.getCurrentCriticalTID()
ask(self._doAskCheckTIDRange(min_tid, max_tid, count))
if action == CHECK_DONE: if action == CHECK_DONE:
# Delete all transactions we might have which are beyond what peer # Delete all transactions we might have which are beyond what peer
# knows. # knows.
...@@ -278,7 +285,8 @@ class ReplicationHandler(EventHandler): ...@@ -278,7 +285,8 @@ class ReplicationHandler(EventHandler):
replicator.getCurrentRID(), last_tid) replicator.getCurrentRID(), last_tid)
# If no more TID, a replication of transactions is finished. # If no more TID, a replication of transactions is finished.
# So start to replicate objects now. # So start to replicate objects now.
ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID)) max_tid = replicator.getCurrentCriticalTID()
ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID, max_tid))
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count, def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
...@@ -299,7 +307,8 @@ class ReplicationHandler(EventHandler): ...@@ -299,7 +307,8 @@ class ReplicationHandler(EventHandler):
params = (next_params, ) params = (next_params, )
if action == CHECK_CHUNK: if action == CHECK_CHUNK:
((min_oid, min_serial), count) = params ((min_oid, min_serial), count) = params
ask(self._doAskCheckSerialRange(min_oid, min_serial, count)) max_tid = replicator.getCurrentCriticalTID()
ask(self._doAskCheckSerialRange(min_oid, min_serial, max_tid, count))
if action == CHECK_DONE: if action == CHECK_DONE:
# Delete all objects we might have which are beyond what peer # Delete all objects we might have which are beyond what peer
# knows. # knows.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from neo.storage.handlers import BaseClientAndStorageOperationHandler from neo.storage.handlers import BaseClientAndStorageOperationHandler
from neo.protocol import Packets from neo.protocol import Packets, MAX_TID
class StorageOperationHandler(BaseClientAndStorageOperationHandler): class StorageOperationHandler(BaseClientAndStorageOperationHandler):
...@@ -44,18 +44,18 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -44,18 +44,18 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
length, app.pt.getPartitions(), partition) length, app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict)) conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, length, partition): def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
app = self.app app = self.app
count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, length, count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, max_tid,
app.pt.getPartitions(), partition) length, app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, count, conn.answer(Packets.AnswerCheckTIDRange(min_tid, length,
tid_checksum, max_tid)) count, tid_checksum, max_tid))
def askCheckSerialRange(self, conn, min_oid, min_serial, length, def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition): partition):
app = self.app app = self.app
count, oid_checksum, max_oid, serial_checksum, max_serial = \ count, oid_checksum, max_oid, serial_checksum, max_serial = \
app.dm.checkSerialRange(min_oid, min_serial, length, app.dm.checkSerialRange(min_oid, min_serial, max_tid, length,
app.pt.getPartitions(), partition) app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length, conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial)) count, oid_checksum, max_oid, serial_checksum, max_serial))
......
...@@ -366,15 +366,16 @@ class Replicator(object): ...@@ -366,15 +366,16 @@ class Replicator(object):
task.process() task.process()
self.task_list = [] self.task_list = []
def checkTIDRange(self, min_tid, length, partition): def checkTIDRange(self, min_tid, max_tid, length, partition):
app = self.app app = self.app
self._addTask(('TID', min_tid, length), app.dm.checkTIDRange, self._addTask(('TID', min_tid, length), app.dm.checkTIDRange,
(min_tid, length, app.pt.getPartitions(), partition)) (min_tid, max_tid, length, app.pt.getPartitions(), partition))
def checkSerialRange(self, min_oid, min_serial, length, partition): def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition):
app = self.app app = self.app
self._addTask(('Serial', min_oid, min_serial, length), self._addTask(('Serial', min_oid, min_serial, length),
app.dm.checkSerialRange, (min_oid, min_serial, length, app.dm.checkSerialRange, (min_oid, min_serial, max_tid, length,
app.pt.getPartitions(), partition)) app.pt.getPartitions(), partition))
def getTIDsFrom(self, min_tid, max_tid, length, partition): def getTIDsFrom(self, min_tid, max_tid, length, partition):
......
...@@ -108,14 +108,14 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -108,14 +108,14 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
return FakeApp return FakeApp
def _checkReplicationStarted(self, conn, rid, replicator): def _checkReplicationStarted(self, conn, rid, replicator):
min_tid, length, partition = self.checkAskPacket(conn, min_tid, max_tid, length, partition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True) Packets.AskCheckTIDRange, decode=True)
self.assertEqual(min_tid, ZERO_TID) self.assertEqual(min_tid, ZERO_TID)
self.assertEqual(length, RANGE_LENGTH) self.assertEqual(length, RANGE_LENGTH)
self.assertEqual(partition, rid) self.assertEqual(partition, rid)
calls = replicator.mockGetNamedCalls('checkTIDRange') calls = replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, length, partition) calls[0].checkArgs(min_tid, max_tid, length, partition)
def _checkPacketTIDList(self, conn, tid_list, next_tid, app): def _checkPacketTIDList(self, conn, tid_list, next_tid, app):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')] packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
...@@ -294,14 +294,14 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -294,14 +294,14 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# match. # match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid) handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk # Result: go on with next chunk
pmin_tid, plength, ppartition = self.checkAskPacket(conn, pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True) Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, add64(max_tid, 1)) self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, RANGE_LENGTH) self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange') calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDSmallRangeIdenticalChunkWithNext(self): def test_answerCheckTIDSmallRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
...@@ -318,14 +318,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -318,14 +318,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# match. # match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid) handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk # Result: go on with next chunk
pmin_tid, plength, ppartition = self.checkAskPacket(conn, pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True) Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(pmin_tid, add64(max_tid, 1)) self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, length / 2) self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange') calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkAboveCriticalTID(self): def test_answerCheckTIDRangeIdenticalChunkAboveCriticalTID(self):
critical_tid = self.getNextTID() critical_tid = self.getNextTID()
...@@ -342,15 +343,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -342,15 +343,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# match. # match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid) handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with object range checks # Result: go on with object range checks
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn, pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
Packets.AskCheckSerialRange, decode=True) self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID) self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID) self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH) self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange') calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self): def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
...@@ -367,15 +368,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -367,15 +368,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0, handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid) max_tid)
# Result: go on with object range checks # Result: go on with object range checks
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn, pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
Packets.AskCheckSerialRange, decode=True) self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID) self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID) self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH) self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange') calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
# ...and delete partition tail # ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove') calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
...@@ -396,14 +397,14 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -396,14 +397,14 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# Peer has different data # Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid) handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask again, length halved # Result: ask again, length halved
pmin_tid, plength, ppartition = self.checkAskPacket(conn, pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True) Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, min_tid) self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length / 2) self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange') calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self): def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
...@@ -474,15 +475,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -474,15 +475,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
handler.answerCheckSerialRange(conn, min_oid, min_serial, length, handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial) length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk # Result: go on with next chunk
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn, pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
Packets.AskCheckSerialRange, decode=True) self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid) self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1)) self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, RANGE_LENGTH) self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange') calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialSmallRangeIdenticalChunkWithNext(self): def test_answerCheckSerialSmallRangeIdenticalChunkWithNext(self):
min_oid = self.getOID(1) min_oid = self.getOID(1)
...@@ -499,15 +500,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -499,15 +500,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
handler.answerCheckSerialRange(conn, min_oid, min_serial, length, handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial) length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk # Result: go on with next chunk
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn, pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
Packets.AskCheckSerialRange, decode=True) self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid) self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1)) self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, length / 2) self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange') calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialRangeIdenticalChunkWithoutNext(self): def test_answerCheckSerialRangeIdenticalChunkWithoutNext(self):
min_oid = self.getOID(1) min_oid = self.getOID(1)
...@@ -547,15 +548,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -547,15 +548,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
handler.answerCheckSerialRange(conn, min_oid, min_serial, length, handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial) length, 0, max_oid, 1, max_serial)
# Result: ask again, length halved # Result: ask again, length halved
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn, pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
Packets.AskCheckSerialRange, decode=True) self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, min_oid) self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial) self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length / 2) self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange') calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithNext(self): def test_answerCheckSerialRangeDifferentSmallChunkWithNext(self):
min_oid = self.getOID(1) min_oid = self.getOID(1)
......
...@@ -157,10 +157,10 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -157,10 +157,10 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)}) self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
self.app.pt = Mock({'getPartitions': num_partitions}) self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.operation.askCheckTIDRange(conn, min_tid, length, partition) self.operation.askCheckTIDRange(conn, min_tid, max_tid, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkTIDRange') calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, length, num_partitions, partition) calls[0].checkArgs(min_tid, max_tid, length, num_partitions, partition)
pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \ pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange, self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
decode=True) decode=True)
...@@ -185,12 +185,12 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -185,12 +185,12 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
serial_checksum, max_serial)}) serial_checksum, max_serial)})
self.app.pt = Mock({'getPartitions': num_partitions}) self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.operation.askCheckSerialRange(conn, min_oid, min_serial, length, self.operation.askCheckSerialRange(conn, min_oid, min_serial,
partition) max_serial, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkSerialRange') calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, length, num_partitions, calls[0].checkArgs(min_oid, min_serial, max_serial, length,
partition) num_partitions, partition)
pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \ pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
pserial_checksum, pmax_serial = self.checkAnswerPacket(conn, pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
Packets.AnswerCheckSerialRange, decode=True) Packets.AnswerCheckSerialRange, decode=True)
......
...@@ -627,11 +627,13 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -627,11 +627,13 @@ class ProtocolTests(NeoUnitTestBase):
def test_AskCheckTIDRange(self): def test_AskCheckTIDRange(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = 2 length = 2
partition = 4 partition = 4
p = Packets.AskCheckTIDRange(min_tid, length, partition) p = Packets.AskCheckTIDRange(min_tid, max_tid, length, partition)
p_min_tid, p_length, p_partition = p.decode() p_min_tid, p_max_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_tid, min_tid) self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_max_tid, max_tid)
self.assertEqual(p_length, length) self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition) self.assertEqual(p_partition, partition)
...@@ -653,12 +655,15 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -653,12 +655,15 @@ class ProtocolTests(NeoUnitTestBase):
def test_AskCheckSerialRange(self): def test_AskCheckSerialRange(self):
min_oid = self.getOID(1) min_oid = self.getOID(1)
min_serial = self.getNextTID() min_serial = self.getNextTID()
max_tid = self.getNextTID()
length = 2 length = 2
partition = 4 partition = 4
p = Packets.AskCheckSerialRange(min_oid, min_serial, length, partition) p = Packets.AskCheckSerialRange(min_oid, min_serial, max_tid, length,
p_min_oid, p_min_serial, p_length, p_partition = p.decode() partition)
p_min_oid, p_min_serial, p_max_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_oid, min_oid) self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial) self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_max_tid, max_tid)
self.assertEqual(p_length, length) self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition) self.assertEqual(p_partition, partition)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment