Commit 6959fbe6 authored by Julien Muchembled's avatar Julien Muchembled

storage: remove useless 'num_partitions' parameter from backend methods

parent 3a363b85
......@@ -125,6 +125,10 @@ class BTreeDatabaseManager(DatabaseManager):
super(BTreeDatabaseManager, self).__init__()
self.setup(reset=1)
@property
def _num_partitions(self):
return self._config['partitions']
def setup(self, reset=0):
if reset:
self._data = OOBTree()
......@@ -289,8 +293,9 @@ class BTreeDatabaseManager(DatabaseManager):
self.unlockData(checksum_list)
self._pruneData(checksum_set)
def dropPartitions(self, num_partitions, offset_list):
def dropPartitions(self, offset_list):
offset_list = frozenset(offset_list)
num_partitions = self._num_partitions
def same_partition(key, _):
return key % num_partitions in offset_list
batchDelete(self._obj, same_partition, self._objDeleterCallback)
......@@ -400,7 +405,8 @@ class BTreeDatabaseManager(DatabaseManager):
except KeyError:
pass
def deleteTransactionsAbove(self, num_partitions, partition, tid, max_tid):
def deleteTransactionsAbove(self, partition, tid, max_tid):
num_partitions = self._num_partitions
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(self._trans, same_partition,
......@@ -421,13 +427,13 @@ class BTreeDatabaseManager(DatabaseManager):
if not tserial:
del obj[oid]
def deleteObjectsAbove(self, num_partitions, partition, oid, serial,
max_tid):
def deleteObjectsAbove(self, partition, oid, serial, max_tid):
obj = self._obj
u64 = util.u64
oid = u64(oid)
serial = u64(serial)
max_tid = u64(max_tid)
num_partitions = self._num_partitions
if oid % num_partitions == partition:
try:
tserial = obj[oid]
......@@ -460,20 +466,6 @@ class BTreeDatabaseManager(DatabaseManager):
result = (list(oid_list), user, desc, ext, packed)
return result
def getOIDList(self, min_oid, length, num_partitions,
partition_list):
p64 = util.p64
partition_list = frozenset(partition_list)
result = []
append = result.append
for oid in safeIter(self._obj.keys, min=min_oid):
if oid % num_partitions in partition_list:
if length == 0:
break
length -= 1
append(p64(oid))
return result
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
......@@ -521,13 +513,14 @@ class BTreeDatabaseManager(DatabaseManager):
return result
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
num_partitions, partition):
partition):
u64 = util.u64
p64 = util.p64
min_oid = u64(min_oid)
min_serial = u64(min_serial)
max_serial = u64(max_serial)
result = {}
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition:
if length == 0:
......@@ -553,12 +546,13 @@ class BTreeDatabaseManager(DatabaseManager):
break
return result
def getTIDList(self, offset, length, num_partitions, partition_list):
def getTIDList(self, offset, length, partition_list):
p64 = util.p64
partition_list = frozenset(partition_list)
result = []
append = result.append
trans_iter = descKeys(self._trans)
num_partitions = self._num_partitions
while offset > 0:
tid = trans_iter.next()
if tid % num_partitions in partition_list:
......@@ -571,12 +565,12 @@ class BTreeDatabaseManager(DatabaseManager):
append(p64(tid))
return result
def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions,
partition):
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
p64 = util.p64
u64 = util.u64
result = []
append = result.append
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=u64(min_tid), max=u64(max_tid)):
if tid % num_partitions == partition:
if length == 0:
......@@ -633,9 +627,10 @@ class BTreeDatabaseManager(DatabaseManager):
return not tserial
batchDelete(self._obj, obj_callback, self._objDeleterCallback)
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
def checkTIDRange(self, min_tid, max_tid, length, partition):
if length:
tid_list = []
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=util.u64(min_tid),
max=util.u64(max_tid)):
if tid % num_partitions == partition:
......@@ -648,14 +643,14 @@ class BTreeDatabaseManager(DatabaseManager):
util.p64(tid_list[-1]))
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
num_partitions, partition):
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
if length:
u64 = util.u64
min_oid = u64(min_oid)
max_tid = u64(max_tid)
oid_list = []
serial_list = []
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition:
try:
......
......@@ -282,7 +282,7 @@ class DatabaseManager(object):
thrown away."""
raise NotImplementedError
def dropPartitions(self, num_partitions, offset_list):
def dropPartitions(self, offset_list):
""" Drop any data of non-assigned partitions for a given UUID """
raise NotImplementedError('this method must be overriden')
......@@ -461,7 +461,7 @@ class DatabaseManager(object):
an oid list"""
raise NotImplementedError
def deleteTransactionsAbove(self, num_partitions, partition, tid, max_tid):
def deleteTransactionsAbove(self, partition, tid, max_tid):
"""Delete all transactions above given TID (inclued) in given
partition, but never above max_tid (in case transactions are committed
during replication)."""
......@@ -472,8 +472,7 @@ class DatabaseManager(object):
given oid."""
raise NotImplementedError
def deleteObjectsAbove(self, num_partitions, partition, oid, serial,
max_tid):
def deleteObjectsAbove(self, partition, oid, serial, max_tid):
"""Delete all objects above given OID and serial (inclued) in given
partition, but never above max_tid (in case objects are stored during
replication)"""
......@@ -495,20 +494,19 @@ class DatabaseManager(object):
raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, max_serial, length,
num_partitions, partition):
partition):
"""Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial and below max_serial, for given partition,
sorted in ascending order."""
raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list):
def getTIDList(self, offset, length, partition_list):
"""Return a list of TIDs in ascending order from an offset,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions,
partition):
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
"""Return a list of TIDs in ascending order from an initial tid value,
at most the specified length up to max_tid. The partition number is
passed to filter out non-applicable TIDs."""
......@@ -531,15 +529,13 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
def checkTIDRange(self, min_tid, max_tid, length, partition):
"""
Generate a diggest from transaction list.
min_tid (packed)
TID at which verification starts.
length (int)
Maximum number of records to include in result.
num_partitions, partition (int, int)
Specifies concerned partition.
Returns a 3-tuple:
- number of records actually found
......@@ -550,8 +546,7 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
num_partitions, partition):
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
"""
Generate a diggest from object list.
min_oid (packed)
......@@ -560,8 +555,6 @@ class DatabaseManager(object):
Serial of min_oid object at which search should start.
length
Maximum number of records to include in result.
num_partitions, partition (int, int)
Specifies concerned partition.
Returns a 5-tuple:
- number of records actually found
......
......@@ -380,7 +380,7 @@ class MySQLDatabaseManager(DatabaseManager):
def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, num_partitions, offset_list):
def dropPartitions(self, offset_list):
q = self.query
self.begin()
try:
......@@ -566,7 +566,7 @@ class MySQLDatabaseManager(DatabaseManager):
raise
self.commit()
def deleteTransactionsAbove(self, num_partitions, partition, tid, max_tid):
def deleteTransactionsAbove(self, partition, tid, max_tid):
self.begin()
try:
self.query('DELETE FROM trans WHERE partition=%(partition)d AND '
......@@ -598,8 +598,7 @@ class MySQLDatabaseManager(DatabaseManager):
raise
self.commit()
def deleteObjectsAbove(self, num_partitions, partition, oid, serial,
max_tid):
def deleteObjectsAbove(self, partition, oid, serial, max_tid):
q = self.query
u64 = util.u64
oid = u64(oid)
......@@ -675,7 +674,7 @@ class MySQLDatabaseManager(DatabaseManager):
return None
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
num_partitions, partition):
partition):
q = self.query
u64 = util.u64
p64 = util.p64
......@@ -703,15 +702,14 @@ class MySQLDatabaseManager(DatabaseManager):
serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, num_partitions, partition_list):
def getTIDList(self, offset, length, partition_list):
q = self.query
r = q("""SELECT tid FROM trans WHERE partition in (%s)
ORDER BY tid DESC LIMIT %d,%d""" \
% (','.join(map(str, partition_list)), offset, length))
return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions,
partition):
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
q = self.query
u64 = util.u64
p64 = util.p64
......@@ -794,7 +792,7 @@ class MySQLDatabaseManager(DatabaseManager):
raise
self.commit()
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition):
def checkTIDRange(self, min_tid, max_tid, length, partition):
count, tid_checksum, max_tid = self.query(
"""SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
FROM (SELECT tid FROM trans
......@@ -811,8 +809,7 @@ class MySQLDatabaseManager(DatabaseManager):
return count, a2b_hex(tid_checksum), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
num_partitions, partition):
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_.
......
......@@ -98,14 +98,11 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
data_serial, ttid, unlock, time.time())
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
app = self.app
getReplicationTIDList = app.dm.getReplicationTIDList
partitions = app.pt.getPartitions()
getReplicationTIDList = self.app.dm.getReplicationTIDList
tid_list = []
extend = tid_list.extend
for partition in partition_list:
extend(getReplicationTIDList(min_tid, max_tid, length,
partitions, partition))
extend(getReplicationTIDList(min_tid, max_tid, length, partition))
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askTIDs(self, conn, first, last, partition):
......@@ -120,8 +117,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
else:
partition_list = [partition]
tid_list = app.dm.getTIDList(first, last - first,
app.pt.getPartitions(), partition_list)
tid_list = app.dm.getTIDList(first, last - first, partition_list)
conn.answer(Packets.AnswerTIDs(tid_list))
def askObjectUndoSerial(self, conn, ttid, ltid, undone_tid, oid_list):
......
......@@ -48,9 +48,8 @@ class InitializationHandler(BaseMasterHandler):
unassigned_set.remove(offset)
# delete objects database
if unassigned_set:
neo.lib.logging.debug(
'drop data for partitions %r' % unassigned_set)
app.dm.dropPartitions(num_partitions, unassigned_set)
neo.lib.logging.debug('drop data for partitions %r', unassigned_set)
app.dm.dropPartitions(unassigned_set)
app.dm.setPartitionTable(ptid, cell_list)
......
......@@ -300,8 +300,7 @@ class ReplicationHandler(EventHandler):
" length=%s, count=%s, max_tid=%x, last_tid=%x,"
" critical_tid=%x)", offset, u64(pkt_min_tid), length, count,
u64(max_tid), u64(last_tid), u64(critical_tid))
app.dm.deleteTransactionsAbove(app.pt.getPartitions(),
offset, last_tid, critical_tid)
app.dm.deleteTransactionsAbove(offset, last_tid, critical_tid)
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID, critical_tid))
......@@ -339,8 +338,7 @@ class ReplicationHandler(EventHandler):
offset, u64(min_oid), u64(min_serial), length, count,
u64(max_oid), u64(max_serial), u64(last_oid), u64(last_serial),
u64(max_tid))
app.dm.deleteObjectsAbove(app.pt.getPartitions(),
offset, last_oid, last_serial, max_tid)
app.dm.deleteObjectsAbove(offset, last_oid, last_serial, max_tid)
# Nothing remains, so the replication for this partition is
# finished.
replicator.setReplicationDone()
......
......@@ -34,32 +34,27 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
assert len(partition_list) == 1, partition_list
partition = partition_list[0]
app = self.app
tid_list = app.dm.getReplicationTIDList(min_tid, max_tid, length,
app.pt.getPartitions(), partition)
tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid, length,
partition_list[0])
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, min_oid, min_serial, max_serial,
length, partition):
app = self.app
object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, max_serial,
length, app.pt.getPartitions(), partition)
object_dict = self.app.dm.getObjectHistoryFrom(min_oid, min_serial,
max_serial, length, partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
app = self.app
count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, max_tid,
length, app.pt.getPartitions(), partition)
count, tid_checksum, max_tid = self.app.dm.checkTIDRange(min_tid,
max_tid, length, partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length,
count, tid_checksum, max_tid))
def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition):
app = self.app
count, oid_checksum, max_oid, serial_checksum, max_serial = \
app.dm.checkSerialRange(min_oid, min_serial, max_tid, length,
app.pt.getPartitions(), partition)
self.app.dm.checkSerialRange(min_oid, min_serial, max_tid, length,
partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial))
......@@ -342,29 +342,23 @@ class Replicator(object):
self.task_list = []
def checkTIDRange(self, min_tid, max_tid, length, partition):
app = self.app
self._addTask(('TID', min_tid, length), app.dm.checkTIDRange,
(min_tid, max_tid, length, app.pt.getPartitions(), partition))
self._addTask(('TID', min_tid, length),
self.app.dm.checkTIDRange, (min_tid, max_tid, length, partition))
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition):
app = self.app
self._addTask(('Serial', min_oid, min_serial, length),
app.dm.checkSerialRange, (min_oid, min_serial, max_tid, length,
app.pt.getPartitions(), partition))
self.app.dm.checkSerialRange, (min_oid, min_serial, max_tid, length,
partition))
def getTIDsFrom(self, min_tid, max_tid, length, partition):
app = self.app
self._addTask('TIDsFrom',
app.dm.getReplicationTIDList, (min_tid, max_tid, length,
app.pt.getPartitions(), partition))
self._addTask('TIDsFrom', self.app.dm.getReplicationTIDList,
(min_tid, max_tid, length, partition))
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition):
app = self.app
self._addTask('ObjectHistoryFrom',
app.dm.getObjectHistoryFrom, (min_oid, min_serial, max_serial,
length, app.pt.getPartitions(), partition))
self._addTask('ObjectHistoryFrom', self.app.dm.getObjectHistoryFrom,
(min_oid, min_serial, max_serial, length, partition))
def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult()
......
......@@ -149,7 +149,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askTIDs(conn, 1, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ])
calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn)
def test_25_askTIDs3(self):
......@@ -166,7 +166,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.assertEqual(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [0])
calls[0].checkArgs(1, 1, [0])
self.checkAnswerTids(conn)
def test_26_askObjectHistory1(self):
......
......@@ -92,35 +92,35 @@ class ReplicationTests(NeoUnitTestBase):
oapp.replicator.processDelayedTasks()
process |= rconn.process(ohandler, oconn)
# check transactions
for tid in reference.getTIDList(0, MAX_TRANSACTIONS, 1, [0]):
for tid in reference.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual(
reference.getTransaction(tid),
outdated.getTransaction(tid),
)
for tid in outdated.getTIDList(0, MAX_TRANSACTIONS, 1, [0]):
for tid in outdated.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual(
outdated.getTransaction(tid),
reference.getTransaction(tid),
)
# check transactions
params = (ZERO_TID, '\xFF' * 8, MAX_TRANSACTIONS, 1, 0)
params = ZERO_TID, '\xFF' * 8, MAX_TRANSACTIONS, 0
self.assertEqual(
reference.getReplicationTIDList(*params),
outdated.getReplicationTIDList(*params),
)
# check objects
params = (ZERO_OID, ZERO_TID, '\xFF' * 8, MAX_OBJECTS, 1, 0)
params = ZERO_OID, ZERO_TID, '\xFF' * 8, MAX_OBJECTS, 0
self.assertEqual(
reference.getObjectHistoryFrom(*params),
outdated.getObjectHistoryFrom(*params),
)
def buildStorage(self, transactions, objects, name='BTree', config=None):
def buildStorage(self, transactions, objects, name='BTree', database=None):
def makeid(oid_or_tid):
return pack('!Q', oid_or_tid)
storage = buildDatabaseManager(name, config)
storage.getNumPartitions = lambda: 1
storage = buildDatabaseManager(name, database)
storage.setup(reset=True)
storage.setNumPartitions(1)
storage._transactions = transactions
storage._objects = objects
# store transactions
......
......@@ -380,7 +380,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, add64(max_tid, 1), ZERO_TID)
calls[0].checkArgs(rid, add64(max_tid, 1), ZERO_TID)
def test_answerCheckTIDRangeDifferentBigChunk(self):
min_tid = self.getNextTID()
......@@ -531,8 +531,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, max_oid, add64(max_serial, 1),
ZERO_TID)
calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), ZERO_TID)
def test_answerCheckSerialRangeDifferentBigChunk(self):
min_oid = self.getOID(1)
......@@ -626,8 +625,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, max_oid, add64(max_serial, 1),
critical_tid)
calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), critical_tid)
if __name__ == "__main__":
unittest.main()
......@@ -64,7 +64,7 @@ class StorageDBTests(NeoUnitTestBase):
if num_partitions == n:
return
if num_partitions < n:
db.dropPartitions(n, range(num_partitions, n))
db.dropPartitions(n)
db.setNumPartitions(num_partitions)
self.assertEqual(num_partitions, db.getNumPartitions())
uuid = self.getNewUUID()
......@@ -374,7 +374,7 @@ class StorageDBTests(NeoUnitTestBase):
txn, objs = self.getTransaction([oid1])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.deleteTransactionsAbove(2, 0, tid2, tid3)
self.db.deleteTransactionsAbove(0, tid2, tid3)
# Right partition, below cutoff
self.assertNotEqual(self.db.getTransaction(tid1, True), None)
# Wrong partition, above cutoff
......@@ -411,11 +411,11 @@ class StorageDBTests(NeoUnitTestBase):
txn, objs = self.getTransaction([oid1, oid2, oid3])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.deleteObjectsAbove(2, 0, oid1, tid2, tid3)
self.db.deleteObjectsAbove(0, oid1, tid2, tid3)
# Check getObjectHistoryFrom because MySQL adapter use two tables
# that must be synchronized
self.assertEqual(self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID,
MAX_TID, 10, 2, 0), {oid1: [tid1]})
MAX_TID, 10, 0), {oid1: [tid1]})
# Right partition, below cutoff
self.assertNotEqual(self.db.getObject(oid1, tid=tid1), None)
# Right partition, above tid cutoff
......@@ -492,34 +492,32 @@ class StorageDBTests(NeoUnitTestBase):
self.db.finishTransaction(tid5)
# Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
2, 0)
0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, MAX_TID, 10, 2, 0)
result = self.db.getObjectHistoryFrom(oid1, tid1, MAX_TID, 10, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Upper bound is inclusive
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, tid3, 10,
2, 0)
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, tid3, 10, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 3,
2, 0)
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 3, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
2, 1)
1)
self.assertEqual(result, {
oid3: [tid5],
})
......@@ -539,19 +537,19 @@ class StorageDBTests(NeoUnitTestBase):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
# - all partitions
result = self.db.getTIDList(0, 4, 2, [0, 1])
result = self.db.getTIDList(0, 4, [0, 1])
self.checkSet(result, [tid1, tid2, tid3, tid4])
# - one partition
result = self.db.getTIDList(0, 4, 2, [0])
result = self.db.getTIDList(0, 4, [0])
self.checkSet(result, [tid1, tid3])
result = self.db.getTIDList(0, 4, 2, [1])
result = self.db.getTIDList(0, 4, [1])
self.checkSet(result, [tid2, tid4])
# get a subset of tids
result = self.db.getTIDList(0, 1, 2, [0])
result = self.db.getTIDList(0, 1, [0])
self.checkSet(result, [tid3]) # desc order
result = self.db.getTIDList(1, 1, 2, [1])
result = self.db.getTIDList(1, 1, [1])
self.checkSet(result, [tid2])
result = self.db.getTIDList(2, 2, 2, [0])
result = self.db.getTIDList(2, 2, [0])
self.checkSet(result, [])
def test_getReplicationTIDList(self):
......@@ -559,22 +557,22 @@ class StorageDBTests(NeoUnitTestBase):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
# - all
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 2, 0)
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3])
# - one partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 2, 0)
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3])
# - another partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 2, 1)
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 1)
self.checkSet(result, [tid2, tid4])
# - min_tid is inclusive
result = self.db.getReplicationTIDList(tid3, MAX_TID, 10, 2, 0)
result = self.db.getReplicationTIDList(tid3, MAX_TID, 10, 0)
self.checkSet(result, [tid3])
# - max tid is inclusive
result = self.db.getReplicationTIDList(ZERO_TID, tid2, 10, 2, 0)
result = self.db.getReplicationTIDList(ZERO_TID, tid2, 10, 0)
self.checkSet(result, [tid1])
# - limit
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 2, 0)
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 0)
self.checkSet(result, [tid1])
def test_findUndoTID(self):
......
......@@ -124,7 +124,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.operation.askTIDsFrom(conn, tid, tid2, 2, [1])
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, tid2, 2, 1, 1)
calls[0].checkArgs(tid, tid2, 2, 1)
self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self):
......@@ -145,8 +145,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length,
num_partitions, partition)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
def test_askCheckTIDRange(self):
count = 1
......@@ -162,7 +161,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.operation.askCheckTIDRange(conn, min_tid, max_tid, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, max_tid, length, num_partitions, partition)
calls[0].checkArgs(min_tid, max_tid, length, partition)
pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
decode=True)
......@@ -191,8 +190,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
max_serial, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length,
num_partitions, partition)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
Packets.AnswerCheckSerialRange, decode=True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment