Commit d8a7a177 authored by Vincent Pelletier's avatar Vincent Pelletier

Improve replication SQL queries.

It is more efficient to provide a boundary value than a row count range.
This fixes replication on partitions with a large number of objects, revisions
or transactions: query time is now constant where it used to increase, causing
timeout problems when query duration exceeded ping time + ping timeout (11s
currently).

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2221 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 44b434d5
......@@ -256,6 +256,12 @@ class EventHandler(object):
def answerTIDs(self, conn, tid_list):
raise UnexpectedPacketError
def askTIDsFrom(self, conn, min_tid, length, partition):
raise UnexpectedPacketError
def answerTIDsFrom(self, conn, tid_list):
raise UnexpectedPacketError
def askTransactionInformation(self, conn, tid):
raise UnexpectedPacketError
......@@ -269,7 +275,13 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError
def askOIDs(self, conn, first, last, partition):
def askObjectHistoryFrom(self, conn, oid, min_serial, length):
raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, oid, history_list):
raise UnexpectedPacketError
def askOIDs(self, conn, min_oid, length, partition):
raise UnexpectedPacketError
def answerOIDs(self, conn, oid_list):
......@@ -414,11 +426,15 @@ class EventHandler(object):
d[Packets.AnswerObject] = self.answerObject
d[Packets.AskTIDs] = self.askTIDs
d[Packets.AnswerTIDs] = self.answerTIDs
d[Packets.AskTIDsFrom] = self.askTIDsFrom
d[Packets.AnswerTIDsFrom] = self.answerTIDsFrom
d[Packets.AskTransactionInformation] = self.askTransactionInformation
d[Packets.AnswerTransactionInformation] = \
self.answerTransactionInformation
d[Packets.AskObjectHistory] = self.askObjectHistory
d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
d[Packets.AskOIDs] = self.askOIDs
d[Packets.AnswerOIDs] = self.answerOIDs
d[Packets.AskPartitionList] = self.askPartitionList
......
......@@ -109,6 +109,8 @@ INVALID_OID = '\xff' * 8
INVALID_PTID = '\0' * 8
INVALID_SERIAL = INVALID_TID
INVALID_PARTITION = 0xffffffff
ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID)
UUID_NAMESPACES = {
......@@ -1024,7 +1026,7 @@ class AnswerObject(Packet):
class AskTIDs(Packet):
"""
Ask for TIDs between a range of offsets. The order of TIDs is descending,
and the range is [first, last). C, S -> S.
and the range is [first, last). C -> S.
"""
_header_format = '!QQL'
......@@ -1036,7 +1038,7 @@ class AskTIDs(Packet):
class AnswerTIDs(Packet):
"""
Answer the requested TIDs. S -> C, S.
Answer the requested TIDs. S -> C.
"""
_header_format = '!L'
_list_entry_format = '8s'
......@@ -1060,6 +1062,25 @@ class AnswerTIDs(Packet):
tid_list.append(tid)
return (tid_list,)
class AskTIDsFrom(Packet):
"""
Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
S -> S.
"""
_header_format = '!8sLL'
def _encode(self, min_tid, length, partition):
return pack(self._header_format, min_tid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition
class AnswerTIDsFrom(AnswerTIDs):
"""
Answer the requested TIDs. S -> S
"""
pass
class AskTransactionInformation(Packet):
"""
Ask information about a transaction. Any -> S.
......@@ -1105,7 +1126,7 @@ class AnswerTransactionInformation(Packet):
class AskObjectHistory(Packet):
"""
Ask history information for a given object. The order of serials is
descending, and the range is [first, last]. C, S -> S.
descending, and the range is [first, last]. C -> S.
"""
_header_format = '!8sQQ'
......@@ -1118,7 +1139,7 @@ class AskObjectHistory(Packet):
class AnswerObjectHistory(Packet):
"""
Answer history information (serial, size) for an object. S -> C, S.
Answer history information (serial, size) for an object. S -> C.
"""
_header_format = '!8sL'
_list_entry_format = '!8sL'
......@@ -1144,18 +1165,40 @@ class AnswerObjectHistory(Packet):
history_list.append((serial, size))
return (oid, history_list)
class AskObjectHistoryFrom(Packet):
"""
Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial. S -> S.
"""
_header_format = '!8s8sL'
def _encode(self, oid, min_serial, length):
return pack(self._header_format, oid, min_serial, length)
def _decode(self, body):
return unpack(self._header_format, body) # oid, min_serial, length
class AnswerObjectHistoryFrom(AskFinishTransaction):
"""
Answer the requested serials. S -> S.
"""
# This is similar to AskFinishTransaction as TID size is identical to OID
# size:
# - we have a single OID (TID in AskFinishTransaction)
# - we have a list of TIDs (OIDs in AskFinishTransaction)
pass
class AskOIDs(Packet):
"""
Ask for OIDs between a range of offsets. The order of OIDs is descending,
and the range is [first, last). S -> S.
Ask for length OIDs starting at min_oid. S -> S.
"""
_header_format = '!QQL'
_header_format = '!8sLL'
def _encode(self, first, last, partition):
return pack(self._header_format, first, last, partition)
def _encode(self, min_oid, length, partition):
return pack(self._header_format, min_oid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # first, last, partition
return unpack(self._header_format, body) # min_oid, length, partition
class AnswerOIDs(Packet):
"""
......@@ -1787,6 +1830,14 @@ class PacketRegistry(dict):
0x0034,
AskHasLock,
AnswerHasLock)
AskTIDsFrom, AnswerTIDsFrom = register(
0x0035,
AskTIDsFrom,
AnswerTIDsFrom)
AskObjectHistoryFrom, AnswerObjectHistoryFrom = register(
0x0036,
AskObjectHistoryFrom,
AnswerObjectHistoryFrom)
# build a "singleton"
Packets = PacketRegistry()
......
......@@ -263,8 +263,8 @@ class DatabaseManager(object):
area as well."""
raise NotImplementedError
def getOIDList(self, offset, length, num_partitions, partition_list):
"""Return a list of OIDs in descending order from an offset,
def getOIDList(self, min_oid, length, num_partitions, partition_list):
"""Return a list of OIDs in ascending order from a minimal oid,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
......@@ -276,15 +276,20 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None."""
raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, length):
"""Return a list of length serials for a given object ID at (or above)
min_serial, sorted in ascending order."""
raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list):
"""Return a list of TIDs in ascending order from an offset,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getReplicationTIDList(self, offset, length, num_partitions,
def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list):
"""Return a list of TIDs in descending order from an offset,
"""Return a list of TIDs in ascending order from an initial tid value,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
......
......@@ -618,12 +618,18 @@ class MySQLDatabaseManager(DatabaseManager):
return oid_list, user, desc, ext, bool(packed)
return None
def getOIDList(self, offset, length, num_partitions, partition_list):
def getOIDList(self, min_oid, length, num_partitions,
partition_list):
q = self.query
r = q("""SELECT DISTINCT oid FROM obj WHERE MOD(oid, %d) in (%s)
ORDER BY oid DESC LIMIT %d,%d""" \
% (num_partitions, ','.join([str(p) for p in partition_list]),
offset, length))
r = q("""SELECT DISTINCT oid FROM obj WHERE
MOD(oid, %(num_partitions)d) in (%(partitions)s)
AND oid >= %(min_oid)d
ORDER BY oid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_oid': util.u64(min_oid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def _getObjectLength(self, oid, value_serial):
......@@ -662,6 +668,19 @@ class MySQLDatabaseManager(DatabaseManager):
return result
return None
def getObjectHistoryFrom(self, oid, min_serial, length):
q = self.query
oid = util.u64(oid)
p64 = util.p64
r = q("""SELECT serial FROM obj
WHERE oid = %(oid)d AND serial >= %(min_serial)d
ORDER BY serial ASC LIMIT %(length)d""" % {
'oid': oid,
'min_serial': util.u64(min_serial),
'length': length,
})
return [p64(t[0]) for t in r]
def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query
r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s)
......@@ -671,13 +690,18 @@ class MySQLDatabaseManager(DatabaseManager):
offset, length))
return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, offset, length, num_partitions, partition_list):
def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list):
q = self.query
r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s)
ORDER BY tid ASC LIMIT %d,%d""" \
% (num_partitions,
','.join([str(p) for p in partition_list]),
offset, length))
r = q("""SELECT tid FROM trans WHERE
MOD(tid, %(num_partitions)d) in (%(partitions)s)
AND tid >= %(min_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_tid': util.u64(min_tid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def getTIDListPresent(self, tid_list):
......
......@@ -65,16 +65,6 @@ class BaseMasterHandler(EventHandler):
class BaseClientAndStorageOperationHandler(EventHandler):
""" Accept requests common to client and storage nodes """
def askObjectHistory(self, conn, oid, first, last):
if first >= last:
raise protocol.ProtocolError( 'invalid offsets')
app = self.app
history_list = app.dm.getObjectHistory(oid, first, last - first)
if history_list is None:
history_list = []
conn.answer(Packets.AnswerObjectHistory(oid, history_list))
def askTransactionInformation(self, conn, tid):
app = self.app
t = app.dm.getTransaction(tid)
......
......@@ -144,3 +144,13 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
state = LockState.GRANTED_TO_OTHER
conn.answer(Packets.AnswerHasLock(oid, state))
def askObjectHistory(self, conn, oid, first, last):
if first >= last:
raise protocol.ProtocolError( 'invalid offsets')
app = self.app
history_list = app.dm.getObjectHistory(oid, first, last - first)
if history_list is None:
history_list = []
conn.answer(Packets.AnswerObjectHistory(oid, history_list))
......@@ -19,7 +19,8 @@
from neo import logging
from neo.handler import EventHandler
from neo.protocol import Packets
from neo.protocol import Packets, ZERO_TID, ZERO_OID
from neo import util
def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
......@@ -31,6 +32,10 @@ def checkConnectionIsReplicatorConnection(func):
return result
return decorator
def add64(packed, offset):
"""Add a python number to a 64-bits packed value"""
return util.p64(util.u64(packed) + offset)
class ReplicationHandler(EventHandler):
"""This class handles events for replications."""
......@@ -48,7 +53,7 @@ class ReplicationHandler(EventHandler):
conn.setUUID(uuid)
@checkConnectionIsReplicatorConnection
def answerTIDs(self, conn, tid_list):
def answerTIDsFrom(self, conn, tid_list):
app = self.app
if tid_list:
# If I have pending TIDs, check which TIDs I don't have, and
......@@ -59,18 +64,15 @@ class ReplicationHandler(EventHandler):
conn.ask(Packets.AskTransactionInformation(tid), timeout=300)
# And, ask more TIDs.
app.replicator.tid_offset += 1000
offset = app.replicator.tid_offset
p = Packets.AskTIDs(offset, offset + 1000,
p = Packets.AskTIDsFrom(add64(tid_list[-1], 1), 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
else:
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
p = Packets.AskOIDs(0, 1000,
p = Packets.AskOIDs(ZERO_OID, 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
app.replicator.oid_offset = 0
@checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid,
......@@ -84,10 +86,11 @@ class ReplicationHandler(EventHandler):
def answerOIDs(self, conn, oid_list):
app = self.app
if oid_list:
app.replicator.next_oid = add64(oid_list[-1], 1)
# Pick one up, and ask the history.
oid = oid_list.pop()
conn.ask(Packets.AskObjectHistory(oid, 0, 1000), timeout=300)
app.replicator.serial_offset = 0
conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
timeout=300)
app.replicator.oid_list = oid_list
else:
# Nothing remains, so the replication for this partition is
......@@ -95,34 +98,29 @@ class ReplicationHandler(EventHandler):
app.replicator.replication_done = True
@checkConnectionIsReplicatorConnection
def answerObjectHistory(self, conn, oid, history_list):
def answerObjectHistoryFrom(self, conn, oid, serial_list):
app = self.app
if history_list:
if serial_list:
# Check if I have objects, request those which I don't have.
serial_list = [t[0] for t in history_list]
present_serial_list = app.dm.getSerialListPresent(oid, serial_list)
serial_set = set(serial_list) - set(present_serial_list)
for serial in serial_set:
conn.ask(Packets.AskObject(oid, serial, None), timeout=300)
# And, ask more serials.
app.replicator.serial_offset += 1000
offset = app.replicator.serial_offset
p = Packets.AskObjectHistory(oid, offset, offset + 1000)
conn.ask(p, timeout=300)
conn.ask(Packets.AskObjectHistoryFrom(oid,
add64(serial_list[-1], 1), 1000), timeout=300)
else:
# This OID is finished. So advance to next.
oid_list = app.replicator.oid_list
if oid_list:
# If I have more pending OIDs, pick one up.
oid = oid_list.pop()
conn.ask(Packets.AskObjectHistory(oid, 0, 1000), timeout=300)
app.replicator.serial_offset = 0
conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
timeout=300)
else:
# Otherwise, acquire more OIDs.
app.replicator.oid_offset += 1000
offset = app.replicator.oid_offset
p = Packets.AskOIDs(offset, offset + 1000,
p = Packets.AskOIDs(app.replicator.next_oid, 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
......
......@@ -30,36 +30,34 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askOIDs(self, conn, first, last, partition):
def askOIDs(self, conn, min_oid, length, partition):
# This method is complicated, because I must return OIDs only
# about usable partitions assigned to me.
if first >= last:
raise protocol.ProtocolError('invalid offsets')
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
oid_list = app.dm.getOIDList(first, last - first,
oid_list = app.dm.getOIDList(min_oid, length,
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerOIDs(oid_list))
def askTIDs(self, conn, first, last, partition):
def askTIDsFrom(self, conn, min_tid, length, partition):
# This method is complicated, because I must return TIDs only
# about usable partitions assigned to me.
if first >= last:
raise protocol.ProtocolError('invalid offsets')
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
tid_list = app.dm.getReplicationTIDList(first, last - first,
tid_list = app.dm.getReplicationTIDList(min_tid, length,
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerTIDs(tid_list))
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, oid, min_serial, length):
app = self.app
history_list = app.dm.getObjectHistoryFrom(oid, min_serial, length)
conn.answer(Packets.AnswerObjectHistoryFrom(oid, history_list))
......@@ -19,7 +19,7 @@ from neo import logging
from random import choice
from neo.storage.handlers import replication
from neo.protocol import NodeTypes, NodeStates, CellStates, Packets
from neo.protocol import NodeTypes, NodeStates, CellStates, Packets, ZERO_TID
from neo.connection import ClientConnection
from neo.util import dump
......@@ -38,7 +38,7 @@ class Partition(object):
def setCriticalTID(self, tid):
if tid is None:
tid = '\x00' * 8
tid = ZERO_TID
self.tid = tid
def safe(self, min_pending_tid):
......@@ -81,7 +81,6 @@ class Replicator(object):
self.app = app
self.new_partition_dict = self._getOutdatedPartitionList()
self.critical_tid_dict = {}
self.tid_offset = 0
self.reset()
def reset(self):
......@@ -172,8 +171,8 @@ class Replicator(object):
app.uuid, app.server, app.name)
self.current_connection.ask(p)
self.tid_offset = 0
p = Packets.AskTIDs(0, 1000, self.current_partition.getRID())
p = Packets.AskTIDsFrom(ZERO_TID, 1000,
self.current_partition.getRID())
self.current_connection.ask(p, timeout=300)
self.replication_done = False
......
......@@ -364,9 +364,15 @@ class NeoTestBase(unittest.TestCase):
def checkAnswerTids(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
def checkAnswerTidsFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw)
def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerObjectHistoryFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistoryFrom, **kw)
def checkAnswerStoreTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
......
......@@ -113,28 +113,19 @@ class StorageStorageHandlerTests(NeoTestBase):
self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
def test_25_askTIDs1(self):
# invalid offsets => error
app = self.app
app.pt = Mock()
app.dm = Mock()
conn = self.getFakeConnection()
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getReplicationTIDList')), 0)
def test_25_askTIDs2(self):
def test_25_askTIDsFrom1(self):
# well case => answer
conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getPartitions': 1})
self.operation.askTIDs(conn, 1, 2, 1)
tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ])
self.checkAnswerTids(conn)
calls[0].checkArgs(tid, 2, 1, [1, ])
self.checkAnswerTidsFrom(conn)
def test_25_askTIDs3(self):
def test_25_askTIDsFrom2(self):
# invalid partition => answer usable partitions
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
......@@ -144,59 +135,39 @@ class StorageStorageHandlerTests(NeoTestBase):
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
self.operation.askTIDs(conn, 1, 2, INVALID_PARTITION)
tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [0, ])
self.checkAnswerTids(conn)
calls[0].checkArgs(tid, 2, 1, [0, ])
self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistory1(self):
# invalid offsets => error
app = self.app
app.dm = Mock()
conn = self.getFakeConnection()
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn,
1, 1, None)
self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self):
oid1 = self.getOID(1)
oid2 = self.getOID(2)
def test_26_askObjectHistoryFrom(self):
oid = self.getOID(2)
min_tid = self.getNextTID()
tid = self.getNextTID()
# first case: empty history
conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, oid1, 1, 2)
self.checkAnswerObjectHistory(conn)
# second case: not empty history
conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistory': [(tid, 0, ), ]})
self.operation.askObjectHistory(conn, oid2, 1, 2)
self.checkAnswerObjectHistory(conn)
self.app.dm = Mock({'getObjectHistoryFrom': [tid]})
self.operation.askObjectHistoryFrom(conn, oid, min_tid, 2)
self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, min_tid, 2)
def test_25_askOIDs1(self):
# invalid offsets => error
app = self.app
app.pt = Mock()
app.dm = Mock()
conn = self.getFakeConnection()
self.checkProtocolErrorRaised(self.operation.askOIDs, conn, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getOIDList')), 0)
def test_25_askOIDs2(self):
# well case > answer OIDs
conn = self.getFakeConnection()
self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.operation.askOIDs(conn, 1, 2, 1)
oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ])
calls[0].checkArgs(oid, 2, 1, [1, ])
self.checkAnswerOids(conn)
def test_25_askOIDs3(self):
def test_25_askOIDs2(self):
# invalid partition => answer usable partitions
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
......@@ -206,11 +177,12 @@ class StorageStorageHandlerTests(NeoTestBase):
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
self.operation.askOIDs(conn, 1, 2, INVALID_PARTITION)
oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [0])
calls[0].checkArgs(oid, 2, 1, [0])
self.checkAnswerOids(conn)
......
......@@ -457,20 +457,20 @@ class StorageMySQSLdbTests(NeoTestBase):
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
# get oids
result = self.db.getOIDList(0, 4, 1, [0])
result = self.db.getOIDList(oid1, 4, 1, [0])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(0, 4, 2, [0])
result = self.db.getOIDList(oid1, 4, 2, [0])
self.checkSet(result, [oid1, oid3])
result = self.db.getOIDList(0, 4, 2, [0, 1])
result = self.db.getOIDList(oid1, 4, 2, [0, 1])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(0, 4, 3, [0])
result = self.db.getOIDList(oid1, 4, 3, [0])
self.checkSet(result, [oid1, oid4])
# get a subset of oids
result = self.db.getOIDList(2, 4, 1, [0])
result = self.db.getOIDList(oid1, 2, 1, [0])
self.checkSet(result, [oid1, oid2])
result = self.db.getOIDList(0, 2, 1, [0])
result = self.db.getOIDList(oid3, 2, 1, [0])
self.checkSet(result, [oid3, oid4])
result = self.db.getOIDList(0, 1, 3, [0])
result = self.db.getOIDList(oid2, 1, 3, [0])
self.checkSet(result, [oid4])
def test_getObjectHistory(self):
......@@ -496,23 +496,18 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None)
def test_getTIDList(self):
def _storeTransactions(self, count):
# use OID generator to know result of tid % N
tid1, tid2, tid3, tid4 = self.getOIDs(4)
tid_list = self.getOIDs(count)
oid = self.getOID(1)
txn1, objs1 = self.getTransaction([oid])
txn2, objs2 = self.getTransaction([oid])
txn3, objs3 = self.getTransaction([oid])
txn4, objs4 = self.getTransaction([oid])
# store four transaction
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.storeTransaction(tid3, objs3, txn3)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
for tid in tid_list:
txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
return tid_list
def test_getTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
result = self.db.getTIDList(0, 4, 1, [0])
self.checkSet(result, [tid1, tid2, tid3, tid4])
......@@ -530,6 +525,25 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getTIDList(0, 1, 3, [0])
self.checkSet(result, [tid4])
def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
result = self.db.getReplicationTIDList(tid1, 4, 1, [0])
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0])
self.checkSet(result, [tid1, tid3])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0, 1])
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 3, [0])
self.checkSet(result, [tid1, tid4])
# get a subset of tids
result = self.db.getReplicationTIDList(tid3, 4, 1, [0])
self.checkSet(result, [tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 2, 1, [0])
self.checkSet(result, [tid1, tid2])
result = self.db.getReplicationTIDList(tid1, 1, 3, [1])
self.checkSet(result, [tid2])
def test_getTIDListPresent(self):
oid = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
......
......@@ -269,13 +269,16 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(p_oid_list, oid_list)
def test_36_askFinishTransaction(self):
self._testXIDAndYIDList(Packets.AskFinishTransaction)
def _testXIDAndYIDList(self, packet):
oid1 = self.getNextTID()
oid2 = self.getNextTID()
oid3 = self.getNextTID()
oid4 = self.getNextTID()
tid = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4]
p = Packets.AskFinishTransaction(tid, oid_list)
p = packet(tid, oid_list)
p_tid, p_oid_list = p.decode()
self.assertEqual(p_tid, tid)
self.assertEqual(p_oid_list, oid_list)
......@@ -404,12 +407,15 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(partition, 5)
def test_50_answerTIDs(self):
self._test_AnswerTIDs(Packets.AnswerTIDs)
def _test_AnswerTIDs(self, packet):
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
tid4 = self.getNextTID()
tid_list = [tid1, tid2, tid3, tid4]
p = Packets.AnswerTIDs(tid_list)
p = packet(tid_list)
p_tid_list = p.decode()[0]
self.assertEqual(p_tid_list, tid_list)
......@@ -457,10 +463,11 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(oid, poid)
def test_55_askOIDs(self):
p = Packets.AskOIDs(1, 10, 5)
first, last, partition = p.decode()
self.assertEqual(first, 1)
self.assertEqual(last, 10)
oid = self.getOID(1)
p = Packets.AskOIDs(oid, 1000, 5)
min_oid, length, partition = p.decode()
self.assertEqual(min_oid, oid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
def test_56_answerOIDs(self):
......@@ -602,6 +609,30 @@ class ProtocolTests(NeoTestBase):
msg = 'test'
self.assertEqual(Packets.Notify(msg).decode(), (msg, ))
def test_AskTIDsFrom(self):
tid = self.getNextTID()
p = Packets.AskTIDsFrom(tid, 1000, 5)
min_tid, length, partition = p.decode()
self.assertEqual(min_tid, tid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
def test_AskObjectHistoryFrom(self):
oid = self.getOID(1)
min_serial = self.getNextTID()
length = 5
p = Packets.AskObjectHistoryFrom(oid, min_serial, length)
p_oid, p_min_serial, p_length = p.decode()
self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
def test_AnswerObjectHistoryFrom(self):
self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment