Commit 6959fbe6 authored by Julien Muchembled's avatar Julien Muchembled

storage: remove useless 'num_partitions' parameter from backend methods

parent 3a363b85
...@@ -125,6 +125,10 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -125,6 +125,10 @@ class BTreeDatabaseManager(DatabaseManager):
super(BTreeDatabaseManager, self).__init__() super(BTreeDatabaseManager, self).__init__()
self.setup(reset=1) self.setup(reset=1)
@property
def _num_partitions(self):
return self._config['partitions']
def setup(self, reset=0): def setup(self, reset=0):
if reset: if reset:
self._data = OOBTree() self._data = OOBTree()
...@@ -289,8 +293,9 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -289,8 +293,9 @@ class BTreeDatabaseManager(DatabaseManager):
self.unlockData(checksum_list) self.unlockData(checksum_list)
self._pruneData(checksum_set) self._pruneData(checksum_set)
def dropPartitions(self, num_partitions, offset_list): def dropPartitions(self, offset_list):
offset_list = frozenset(offset_list) offset_list = frozenset(offset_list)
num_partitions = self._num_partitions
def same_partition(key, _): def same_partition(key, _):
return key % num_partitions in offset_list return key % num_partitions in offset_list
batchDelete(self._obj, same_partition, self._objDeleterCallback) batchDelete(self._obj, same_partition, self._objDeleterCallback)
...@@ -400,7 +405,8 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -400,7 +405,8 @@ class BTreeDatabaseManager(DatabaseManager):
except KeyError: except KeyError:
pass pass
def deleteTransactionsAbove(self, num_partitions, partition, tid, max_tid): def deleteTransactionsAbove(self, partition, tid, max_tid):
num_partitions = self._num_partitions
def same_partition(key, _): def same_partition(key, _):
return key % num_partitions == partition return key % num_partitions == partition
batchDelete(self._trans, same_partition, batchDelete(self._trans, same_partition,
...@@ -421,13 +427,13 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -421,13 +427,13 @@ class BTreeDatabaseManager(DatabaseManager):
if not tserial: if not tserial:
del obj[oid] del obj[oid]
def deleteObjectsAbove(self, num_partitions, partition, oid, serial, def deleteObjectsAbove(self, partition, oid, serial, max_tid):
max_tid):
obj = self._obj obj = self._obj
u64 = util.u64 u64 = util.u64
oid = u64(oid) oid = u64(oid)
serial = u64(serial) serial = u64(serial)
max_tid = u64(max_tid) max_tid = u64(max_tid)
num_partitions = self._num_partitions
if oid % num_partitions == partition: if oid % num_partitions == partition:
try: try:
tserial = obj[oid] tserial = obj[oid]
...@@ -460,20 +466,6 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -460,20 +466,6 @@ class BTreeDatabaseManager(DatabaseManager):
result = (list(oid_list), user, desc, ext, packed) result = (list(oid_list), user, desc, ext, packed)
return result return result
def getOIDList(self, min_oid, length, num_partitions,
partition_list):
p64 = util.p64
partition_list = frozenset(partition_list)
result = []
append = result.append
for oid in safeIter(self._obj.keys, min=min_oid):
if oid % num_partitions in partition_list:
if length == 0:
break
length -= 1
append(p64(oid))
return result
def _getObjectLength(self, oid, value_serial): def _getObjectLength(self, oid, value_serial):
if value_serial is None: if value_serial is None:
raise CreationUndone raise CreationUndone
...@@ -521,13 +513,14 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -521,13 +513,14 @@ class BTreeDatabaseManager(DatabaseManager):
return result return result
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length, def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
num_partitions, partition): partition):
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
min_oid = u64(min_oid) min_oid = u64(min_oid)
min_serial = u64(min_serial) min_serial = u64(min_serial)
max_serial = u64(max_serial) max_serial = u64(max_serial)
result = {} result = {}
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid): for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition: if oid % num_partitions == partition:
if length == 0: if length == 0:
...@@ -553,12 +546,13 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -553,12 +546,13 @@ class BTreeDatabaseManager(DatabaseManager):
break break
return result return result
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, partition_list):
p64 = util.p64 p64 = util.p64
partition_list = frozenset(partition_list) partition_list = frozenset(partition_list)
result = [] result = []
append = result.append append = result.append
trans_iter = descKeys(self._trans) trans_iter = descKeys(self._trans)
num_partitions = self._num_partitions
while offset > 0: while offset > 0:
tid = trans_iter.next() tid = trans_iter.next()
if tid % num_partitions in partition_list: if tid % num_partitions in partition_list:
...@@ -571,12 +565,12 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -571,12 +565,12 @@ class BTreeDatabaseManager(DatabaseManager):
append(p64(tid)) append(p64(tid))
return result return result
def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, max_tid, length, partition):
partition):
p64 = util.p64 p64 = util.p64
u64 = util.u64 u64 = util.u64
result = [] result = []
append = result.append append = result.append
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=u64(min_tid), max=u64(max_tid)): for tid in safeIter(self._trans.keys, min=u64(min_tid), max=u64(max_tid)):
if tid % num_partitions == partition: if tid % num_partitions == partition:
if length == 0: if length == 0:
...@@ -633,9 +627,10 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -633,9 +627,10 @@ class BTreeDatabaseManager(DatabaseManager):
return not tserial return not tserial
batchDelete(self._obj, obj_callback, self._objDeleterCallback) batchDelete(self._obj, obj_callback, self._objDeleterCallback)
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, partition):
if length: if length:
tid_list = [] tid_list = []
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=util.u64(min_tid), for tid in safeIter(self._trans.keys, min=util.u64(min_tid),
max=util.u64(max_tid)): max=util.u64(max_tid)):
if tid % num_partitions == partition: if tid % num_partitions == partition:
...@@ -648,14 +643,14 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -648,14 +643,14 @@ class BTreeDatabaseManager(DatabaseManager):
util.p64(tid_list[-1])) util.p64(tid_list[-1]))
return 0, ZERO_HASH, ZERO_TID return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
num_partitions, partition):
if length: if length:
u64 = util.u64 u64 = util.u64
min_oid = u64(min_oid) min_oid = u64(min_oid)
max_tid = u64(max_tid) max_tid = u64(max_tid)
oid_list = [] oid_list = []
serial_list = [] serial_list = []
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid): for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition: if oid % num_partitions == partition:
try: try:
......
...@@ -282,7 +282,7 @@ class DatabaseManager(object): ...@@ -282,7 +282,7 @@ class DatabaseManager(object):
thrown away.""" thrown away."""
raise NotImplementedError raise NotImplementedError
def dropPartitions(self, num_partitions, offset_list): def dropPartitions(self, offset_list):
""" Drop any data of non-assigned partitions for a given UUID """ """ Drop any data of non-assigned partitions for a given UUID """
raise NotImplementedError('this method must be overriden') raise NotImplementedError('this method must be overriden')
...@@ -461,7 +461,7 @@ class DatabaseManager(object): ...@@ -461,7 +461,7 @@ class DatabaseManager(object):
an oid list""" an oid list"""
raise NotImplementedError raise NotImplementedError
def deleteTransactionsAbove(self, num_partitions, partition, tid, max_tid): def deleteTransactionsAbove(self, partition, tid, max_tid):
"""Delete all transactions above given TID (inclued) in given """Delete all transactions above given TID (inclued) in given
partition, but never above max_tid (in case transactions are committed partition, but never above max_tid (in case transactions are committed
during replication).""" during replication)."""
...@@ -472,8 +472,7 @@ class DatabaseManager(object): ...@@ -472,8 +472,7 @@ class DatabaseManager(object):
given oid.""" given oid."""
raise NotImplementedError raise NotImplementedError
def deleteObjectsAbove(self, num_partitions, partition, oid, serial, def deleteObjectsAbove(self, partition, oid, serial, max_tid):
max_tid):
"""Delete all objects above given OID and serial (inclued) in given """Delete all objects above given OID and serial (inclued) in given
partition, but never above max_tid (in case objects are stored during partition, but never above max_tid (in case objects are stored during
replication)""" replication)"""
...@@ -495,20 +494,19 @@ class DatabaseManager(object): ...@@ -495,20 +494,19 @@ class DatabaseManager(object):
raise NotImplementedError raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, max_serial, length, def getObjectHistoryFrom(self, oid, min_serial, max_serial, length,
num_partitions, partition): partition):
"""Return a dict of length serials grouped by oid at (or above) """Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial and below max_serial, for given partition, min_oid and min_serial and below max_serial, for given partition,
sorted in ascending order.""" sorted in ascending order."""
raise NotImplementedError raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, partition_list):
"""Return a list of TIDs in ascending order from an offset, """Return a list of TIDs in ascending order from an offset,
at most the specified length. The list of partitions are passed at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs.""" to filter out non-applicable TIDs."""
raise NotImplementedError raise NotImplementedError
def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, max_tid, length, partition):
partition):
"""Return a list of TIDs in ascending order from an initial tid value, """Return a list of TIDs in ascending order from an initial tid value,
at most the specified length up to max_tid. The partition number is at most the specified length up to max_tid. The partition number is
passed to filter out non-applicable TIDs.""" passed to filter out non-applicable TIDs."""
...@@ -531,15 +529,13 @@ class DatabaseManager(object): ...@@ -531,15 +529,13 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, partition):
""" """
Generate a diggest from transaction list. Generate a diggest from transaction list.
min_tid (packed) min_tid (packed)
TID at which verification starts. TID at which verification starts.
length (int) length (int)
Maximum number of records to include in result. Maximum number of records to include in result.
num_partitions, partition (int, int)
Specifies concerned partition.
Returns a 3-tuple: Returns a 3-tuple:
- number of records actually found - number of records actually found
...@@ -550,8 +546,7 @@ class DatabaseManager(object): ...@@ -550,8 +546,7 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def checkSerialRange(self, min_oid, min_serial, max_tid, length, def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
num_partitions, partition):
""" """
Generate a diggest from object list. Generate a diggest from object list.
min_oid (packed) min_oid (packed)
...@@ -560,8 +555,6 @@ class DatabaseManager(object): ...@@ -560,8 +555,6 @@ class DatabaseManager(object):
Serial of min_oid object at which search should start. Serial of min_oid object at which search should start.
length length
Maximum number of records to include in result. Maximum number of records to include in result.
num_partitions, partition (int, int)
Specifies concerned partition.
Returns a 5-tuple: Returns a 5-tuple:
- number of records actually found - number of records actually found
......
...@@ -380,7 +380,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -380,7 +380,7 @@ class MySQLDatabaseManager(DatabaseManager):
def setPartitionTable(self, ptid, cell_list): def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True) self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, num_partitions, offset_list): def dropPartitions(self, offset_list):
q = self.query q = self.query
self.begin() self.begin()
try: try:
...@@ -566,7 +566,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -566,7 +566,7 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def deleteTransactionsAbove(self, num_partitions, partition, tid, max_tid): def deleteTransactionsAbove(self, partition, tid, max_tid):
self.begin() self.begin()
try: try:
self.query('DELETE FROM trans WHERE partition=%(partition)d AND ' self.query('DELETE FROM trans WHERE partition=%(partition)d AND '
...@@ -598,8 +598,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -598,8 +598,7 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def deleteObjectsAbove(self, num_partitions, partition, oid, serial, def deleteObjectsAbove(self, partition, oid, serial, max_tid):
max_tid):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
oid = u64(oid) oid = u64(oid)
...@@ -675,7 +674,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -675,7 +674,7 @@ class MySQLDatabaseManager(DatabaseManager):
return None return None
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length, def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
num_partitions, partition): partition):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
...@@ -703,15 +702,14 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -703,15 +702,14 @@ class MySQLDatabaseManager(DatabaseManager):
serial_list.append(p64(serial)) serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems()) return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, partition_list):
q = self.query q = self.query
r = q("""SELECT tid FROM trans WHERE partition in (%s) r = q("""SELECT tid FROM trans WHERE partition in (%s)
ORDER BY tid DESC LIMIT %d,%d""" \ ORDER BY tid DESC LIMIT %d,%d""" \
% (','.join(map(str, partition_list)), offset, length)) % (','.join(map(str, partition_list)), offset, length))
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, max_tid, length, partition):
partition):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
...@@ -794,7 +792,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -794,7 +792,7 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def checkTIDRange(self, min_tid, max_tid, length, num_partitions, partition): def checkTIDRange(self, min_tid, max_tid, length, partition):
count, tid_checksum, max_tid = self.query( count, tid_checksum, max_tid = self.query(
"""SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid) """SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
FROM (SELECT tid FROM trans FROM (SELECT tid FROM trans
...@@ -811,8 +809,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -811,8 +809,7 @@ class MySQLDatabaseManager(DatabaseManager):
return count, a2b_hex(tid_checksum), util.p64(max_tid) return count, a2b_hex(tid_checksum), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
num_partitions, partition):
u64 = util.u64 u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange) # We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_. # because it's difficult to get the last serial _for the last oid_.
......
...@@ -98,14 +98,11 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -98,14 +98,11 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
data_serial, ttid, unlock, time.time()) data_serial, ttid, unlock, time.time())
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list): def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
app = self.app getReplicationTIDList = self.app.dm.getReplicationTIDList
getReplicationTIDList = app.dm.getReplicationTIDList
partitions = app.pt.getPartitions()
tid_list = [] tid_list = []
extend = tid_list.extend extend = tid_list.extend
for partition in partition_list: for partition in partition_list:
extend(getReplicationTIDList(min_tid, max_tid, length, extend(getReplicationTIDList(min_tid, max_tid, length, partition))
partitions, partition))
conn.answer(Packets.AnswerTIDsFrom(tid_list)) conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askTIDs(self, conn, first, last, partition): def askTIDs(self, conn, first, last, partition):
...@@ -120,8 +117,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -120,8 +117,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
else: else:
partition_list = [partition] partition_list = [partition]
tid_list = app.dm.getTIDList(first, last - first, tid_list = app.dm.getTIDList(first, last - first, partition_list)
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerTIDs(tid_list)) conn.answer(Packets.AnswerTIDs(tid_list))
def askObjectUndoSerial(self, conn, ttid, ltid, undone_tid, oid_list): def askObjectUndoSerial(self, conn, ttid, ltid, undone_tid, oid_list):
......
...@@ -48,9 +48,8 @@ class InitializationHandler(BaseMasterHandler): ...@@ -48,9 +48,8 @@ class InitializationHandler(BaseMasterHandler):
unassigned_set.remove(offset) unassigned_set.remove(offset)
# delete objects database # delete objects database
if unassigned_set: if unassigned_set:
neo.lib.logging.debug( neo.lib.logging.debug('drop data for partitions %r', unassigned_set)
'drop data for partitions %r' % unassigned_set) app.dm.dropPartitions(unassigned_set)
app.dm.dropPartitions(num_partitions, unassigned_set)
app.dm.setPartitionTable(ptid, cell_list) app.dm.setPartitionTable(ptid, cell_list)
......
...@@ -300,8 +300,7 @@ class ReplicationHandler(EventHandler): ...@@ -300,8 +300,7 @@ class ReplicationHandler(EventHandler):
" length=%s, count=%s, max_tid=%x, last_tid=%x," " length=%s, count=%s, max_tid=%x, last_tid=%x,"
" critical_tid=%x)", offset, u64(pkt_min_tid), length, count, " critical_tid=%x)", offset, u64(pkt_min_tid), length, count,
u64(max_tid), u64(last_tid), u64(critical_tid)) u64(max_tid), u64(last_tid), u64(critical_tid))
app.dm.deleteTransactionsAbove(app.pt.getPartitions(), app.dm.deleteTransactionsAbove(offset, last_tid, critical_tid)
offset, last_tid, critical_tid)
# If no more TID, a replication of transactions is finished. # If no more TID, a replication of transactions is finished.
# So start to replicate objects now. # So start to replicate objects now.
ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID, critical_tid)) ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID, critical_tid))
...@@ -339,8 +338,7 @@ class ReplicationHandler(EventHandler): ...@@ -339,8 +338,7 @@ class ReplicationHandler(EventHandler):
offset, u64(min_oid), u64(min_serial), length, count, offset, u64(min_oid), u64(min_serial), length, count,
u64(max_oid), u64(max_serial), u64(last_oid), u64(last_serial), u64(max_oid), u64(max_serial), u64(last_oid), u64(last_serial),
u64(max_tid)) u64(max_tid))
app.dm.deleteObjectsAbove(app.pt.getPartitions(), app.dm.deleteObjectsAbove(offset, last_oid, last_serial, max_tid)
offset, last_oid, last_serial, max_tid)
# Nothing remains, so the replication for this partition is # Nothing remains, so the replication for this partition is
# finished. # finished.
replicator.setReplicationDone() replicator.setReplicationDone()
......
...@@ -34,32 +34,27 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -34,32 +34,27 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list): def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
assert len(partition_list) == 1, partition_list assert len(partition_list) == 1, partition_list
partition = partition_list[0] tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid, length,
app = self.app partition_list[0])
tid_list = app.dm.getReplicationTIDList(min_tid, max_tid, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerTIDsFrom(tid_list)) conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, min_oid, min_serial, max_serial, def askObjectHistoryFrom(self, conn, min_oid, min_serial, max_serial,
length, partition): length, partition):
app = self.app object_dict = self.app.dm.getObjectHistoryFrom(min_oid, min_serial,
object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, max_serial, max_serial, length, partition)
length, app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict)) conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition): def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
app = self.app count, tid_checksum, max_tid = self.app.dm.checkTIDRange(min_tid,
count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, max_tid, max_tid, length, partition)
length, app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, conn.answer(Packets.AnswerCheckTIDRange(min_tid, length,
count, tid_checksum, max_tid)) count, tid_checksum, max_tid))
def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length, def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition): partition):
app = self.app
count, oid_checksum, max_oid, serial_checksum, max_serial = \ count, oid_checksum, max_oid, serial_checksum, max_serial = \
app.dm.checkSerialRange(min_oid, min_serial, max_tid, length, self.app.dm.checkSerialRange(min_oid, min_serial, max_tid, length,
app.pt.getPartitions(), partition) partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length, conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial)) count, oid_checksum, max_oid, serial_checksum, max_serial))
...@@ -342,29 +342,23 @@ class Replicator(object): ...@@ -342,29 +342,23 @@ class Replicator(object):
self.task_list = [] self.task_list = []
def checkTIDRange(self, min_tid, max_tid, length, partition): def checkTIDRange(self, min_tid, max_tid, length, partition):
app = self.app self._addTask(('TID', min_tid, length),
self._addTask(('TID', min_tid, length), app.dm.checkTIDRange, self.app.dm.checkTIDRange, (min_tid, max_tid, length, partition))
(min_tid, max_tid, length, app.pt.getPartitions(), partition))
def checkSerialRange(self, min_oid, min_serial, max_tid, length, def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition): partition):
app = self.app
self._addTask(('Serial', min_oid, min_serial, length), self._addTask(('Serial', min_oid, min_serial, length),
app.dm.checkSerialRange, (min_oid, min_serial, max_tid, length, self.app.dm.checkSerialRange, (min_oid, min_serial, max_tid, length,
app.pt.getPartitions(), partition)) partition))
def getTIDsFrom(self, min_tid, max_tid, length, partition): def getTIDsFrom(self, min_tid, max_tid, length, partition):
app = self.app self._addTask('TIDsFrom', self.app.dm.getReplicationTIDList,
self._addTask('TIDsFrom', (min_tid, max_tid, length, partition))
app.dm.getReplicationTIDList, (min_tid, max_tid, length,
app.pt.getPartitions(), partition))
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length, def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition): partition):
app = self.app self._addTask('ObjectHistoryFrom', self.app.dm.getObjectHistoryFrom,
self._addTask('ObjectHistoryFrom', (min_oid, min_serial, max_serial, length, partition))
app.dm.getObjectHistoryFrom, (min_oid, min_serial, max_serial,
length, app.pt.getPartitions(), partition))
def _getCheckResult(self, key): def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult() return self.task_dict.pop(key).getResult()
......
...@@ -149,7 +149,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -149,7 +149,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askTIDs(conn, 1, 2, 1) self.operation.askTIDs(conn, 1, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ]) calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn) self.checkAnswerTids(conn)
def test_25_askTIDs3(self): def test_25_askTIDs3(self):
...@@ -166,7 +166,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -166,7 +166,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.assertEqual(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEqual(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [0]) calls[0].checkArgs(1, 1, [0])
self.checkAnswerTids(conn) self.checkAnswerTids(conn)
def test_26_askObjectHistory1(self): def test_26_askObjectHistory1(self):
......
...@@ -92,35 +92,35 @@ class ReplicationTests(NeoUnitTestBase): ...@@ -92,35 +92,35 @@ class ReplicationTests(NeoUnitTestBase):
oapp.replicator.processDelayedTasks() oapp.replicator.processDelayedTasks()
process |= rconn.process(ohandler, oconn) process |= rconn.process(ohandler, oconn)
# check transactions # check transactions
for tid in reference.getTIDList(0, MAX_TRANSACTIONS, 1, [0]): for tid in reference.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual( self.assertEqual(
reference.getTransaction(tid), reference.getTransaction(tid),
outdated.getTransaction(tid), outdated.getTransaction(tid),
) )
for tid in outdated.getTIDList(0, MAX_TRANSACTIONS, 1, [0]): for tid in outdated.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual( self.assertEqual(
outdated.getTransaction(tid), outdated.getTransaction(tid),
reference.getTransaction(tid), reference.getTransaction(tid),
) )
# check transactions # check transactions
params = (ZERO_TID, '\xFF' * 8, MAX_TRANSACTIONS, 1, 0) params = ZERO_TID, '\xFF' * 8, MAX_TRANSACTIONS, 0
self.assertEqual( self.assertEqual(
reference.getReplicationTIDList(*params), reference.getReplicationTIDList(*params),
outdated.getReplicationTIDList(*params), outdated.getReplicationTIDList(*params),
) )
# check objects # check objects
params = (ZERO_OID, ZERO_TID, '\xFF' * 8, MAX_OBJECTS, 1, 0) params = ZERO_OID, ZERO_TID, '\xFF' * 8, MAX_OBJECTS, 0
self.assertEqual( self.assertEqual(
reference.getObjectHistoryFrom(*params), reference.getObjectHistoryFrom(*params),
outdated.getObjectHistoryFrom(*params), outdated.getObjectHistoryFrom(*params),
) )
def buildStorage(self, transactions, objects, name='BTree', config=None): def buildStorage(self, transactions, objects, name='BTree', database=None):
def makeid(oid_or_tid): def makeid(oid_or_tid):
return pack('!Q', oid_or_tid) return pack('!Q', oid_or_tid)
storage = buildDatabaseManager(name, config) storage = buildDatabaseManager(name, database)
storage.getNumPartitions = lambda: 1
storage.setup(reset=True) storage.setup(reset=True)
storage.setNumPartitions(1)
storage._transactions = transactions storage._transactions = transactions
storage._objects = objects storage._objects = objects
# store transactions # store transactions
......
...@@ -380,7 +380,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -380,7 +380,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# ...and delete partition tail # ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove') calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, add64(max_tid, 1), ZERO_TID) calls[0].checkArgs(rid, add64(max_tid, 1), ZERO_TID)
def test_answerCheckTIDRangeDifferentBigChunk(self): def test_answerCheckTIDRangeDifferentBigChunk(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
...@@ -531,8 +531,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -531,8 +531,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# ...and delete partition tail # ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove') calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, max_oid, add64(max_serial, 1), calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), ZERO_TID)
ZERO_TID)
def test_answerCheckSerialRangeDifferentBigChunk(self): def test_answerCheckSerialRangeDifferentBigChunk(self):
min_oid = self.getOID(1) min_oid = self.getOID(1)
...@@ -626,8 +625,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -626,8 +625,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# ...and delete partition tail # ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove') calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, max_oid, add64(max_serial, 1), calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), critical_tid)
critical_tid)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -64,7 +64,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -64,7 +64,7 @@ class StorageDBTests(NeoUnitTestBase):
if num_partitions == n: if num_partitions == n:
return return
if num_partitions < n: if num_partitions < n:
db.dropPartitions(n, range(num_partitions, n)) db.dropPartitions(n)
db.setNumPartitions(num_partitions) db.setNumPartitions(num_partitions)
self.assertEqual(num_partitions, db.getNumPartitions()) self.assertEqual(num_partitions, db.getNumPartitions())
uuid = self.getNewUUID() uuid = self.getNewUUID()
...@@ -374,7 +374,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -374,7 +374,7 @@ class StorageDBTests(NeoUnitTestBase):
txn, objs = self.getTransaction([oid1]) txn, objs = self.getTransaction([oid1])
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid) self.db.finishTransaction(tid)
self.db.deleteTransactionsAbove(2, 0, tid2, tid3) self.db.deleteTransactionsAbove(0, tid2, tid3)
# Right partition, below cutoff # Right partition, below cutoff
self.assertNotEqual(self.db.getTransaction(tid1, True), None) self.assertNotEqual(self.db.getTransaction(tid1, True), None)
# Wrong partition, above cutoff # Wrong partition, above cutoff
...@@ -411,11 +411,11 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -411,11 +411,11 @@ class StorageDBTests(NeoUnitTestBase):
txn, objs = self.getTransaction([oid1, oid2, oid3]) txn, objs = self.getTransaction([oid1, oid2, oid3])
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid) self.db.finishTransaction(tid)
self.db.deleteObjectsAbove(2, 0, oid1, tid2, tid3) self.db.deleteObjectsAbove(0, oid1, tid2, tid3)
# Check getObjectHistoryFrom because MySQL adapter use two tables # Check getObjectHistoryFrom because MySQL adapter use two tables
# that must be synchronized # that must be synchronized
self.assertEqual(self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, self.assertEqual(self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID,
MAX_TID, 10, 2, 0), {oid1: [tid1]}) MAX_TID, 10, 0), {oid1: [tid1]})
# Right partition, below cutoff # Right partition, below cutoff
self.assertNotEqual(self.db.getObject(oid1, tid=tid1), None) self.assertNotEqual(self.db.getObject(oid1, tid=tid1), None)
# Right partition, above tid cutoff # Right partition, above tid cutoff
...@@ -492,34 +492,32 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -492,34 +492,32 @@ class StorageDBTests(NeoUnitTestBase):
self.db.finishTransaction(tid5) self.db.finishTransaction(tid5)
# Check full result # Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10, result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
2, 0) 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
oid2: [tid2, tid4], oid2: [tid2, tid4],
}) })
# Lower bound is inclusive # Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, MAX_TID, 10, 2, 0) result = self.db.getObjectHistoryFrom(oid1, tid1, MAX_TID, 10, 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
oid2: [tid2, tid4], oid2: [tid2, tid4],
}) })
# Upper bound is inclusive # Upper bound is inclusive
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, tid3, 10, result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, tid3, 10, 0)
2, 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
oid2: [tid2], oid2: [tid2],
}) })
# Length is total number of serials # Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 3, result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 3, 0)
2, 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
oid2: [tid2], oid2: [tid2],
}) })
# Partition constraints are honored # Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10, result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
2, 1) 1)
self.assertEqual(result, { self.assertEqual(result, {
oid3: [tid5], oid3: [tid5],
}) })
...@@ -539,19 +537,19 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -539,19 +537,19 @@ class StorageDBTests(NeoUnitTestBase):
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids # get tids
# - all partitions # - all partitions
result = self.db.getTIDList(0, 4, 2, [0, 1]) result = self.db.getTIDList(0, 4, [0, 1])
self.checkSet(result, [tid1, tid2, tid3, tid4]) self.checkSet(result, [tid1, tid2, tid3, tid4])
# - one partition # - one partition
result = self.db.getTIDList(0, 4, 2, [0]) result = self.db.getTIDList(0, 4, [0])
self.checkSet(result, [tid1, tid3]) self.checkSet(result, [tid1, tid3])
result = self.db.getTIDList(0, 4, 2, [1]) result = self.db.getTIDList(0, 4, [1])
self.checkSet(result, [tid2, tid4]) self.checkSet(result, [tid2, tid4])
# get a subset of tids # get a subset of tids
result = self.db.getTIDList(0, 1, 2, [0]) result = self.db.getTIDList(0, 1, [0])
self.checkSet(result, [tid3]) # desc order self.checkSet(result, [tid3]) # desc order
result = self.db.getTIDList(1, 1, 2, [1]) result = self.db.getTIDList(1, 1, [1])
self.checkSet(result, [tid2]) self.checkSet(result, [tid2])
result = self.db.getTIDList(2, 2, 2, [0]) result = self.db.getTIDList(2, 2, [0])
self.checkSet(result, []) self.checkSet(result, [])
def test_getReplicationTIDList(self): def test_getReplicationTIDList(self):
...@@ -559,22 +557,22 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -559,22 +557,22 @@ class StorageDBTests(NeoUnitTestBase):
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids # get tids
# - all # - all
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 2, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3]) self.checkSet(result, [tid1, tid3])
# - one partition # - one partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 2, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3]) self.checkSet(result, [tid1, tid3])
# - another partition # - another partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 2, 1) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 1)
self.checkSet(result, [tid2, tid4]) self.checkSet(result, [tid2, tid4])
# - min_tid is inclusive # - min_tid is inclusive
result = self.db.getReplicationTIDList(tid3, MAX_TID, 10, 2, 0) result = self.db.getReplicationTIDList(tid3, MAX_TID, 10, 0)
self.checkSet(result, [tid3]) self.checkSet(result, [tid3])
# - max tid is inclusive # - max tid is inclusive
result = self.db.getReplicationTIDList(ZERO_TID, tid2, 10, 2, 0) result = self.db.getReplicationTIDList(ZERO_TID, tid2, 10, 0)
self.checkSet(result, [tid1]) self.checkSet(result, [tid1])
# - limit # - limit
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 2, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 0)
self.checkSet(result, [tid1]) self.checkSet(result, [tid1])
def test_findUndoTID(self): def test_findUndoTID(self):
......
...@@ -124,7 +124,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -124,7 +124,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.operation.askTIDsFrom(conn, tid, tid2, 2, [1]) self.operation.askTIDsFrom(conn, tid, tid2, 2, [1])
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, tid2, 2, 1, 1) calls[0].checkArgs(tid, tid2, 2, 1)
self.checkAnswerTidsFrom(conn) self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self): def test_26_askObjectHistoryFrom(self):
...@@ -145,8 +145,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -145,8 +145,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.checkAnswerObjectHistoryFrom(conn) self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom') calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
num_partitions, partition)
def test_askCheckTIDRange(self): def test_askCheckTIDRange(self):
count = 1 count = 1
...@@ -162,7 +161,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -162,7 +161,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
self.operation.askCheckTIDRange(conn, min_tid, max_tid, length, partition) self.operation.askCheckTIDRange(conn, min_tid, max_tid, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkTIDRange') calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, max_tid, length, num_partitions, partition) calls[0].checkArgs(min_tid, max_tid, length, partition)
pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \ pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange, self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
decode=True) decode=True)
...@@ -191,8 +190,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -191,8 +190,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
max_serial, length, partition) max_serial, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkSerialRange') calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
num_partitions, partition)
pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \ pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
pserial_checksum, pmax_serial = self.checkAnswerPacket(conn, pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
Packets.AnswerCheckSerialRange, decode=True) Packets.AnswerCheckSerialRange, decode=True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment