Commit b0753366 authored by Vincent Pelletier's avatar Vincent Pelletier

Implement rsync-ish replication.

For further description, see storage/handlers/replication.py .

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2295 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 01735352
...@@ -275,16 +275,10 @@ class EventHandler(object): ...@@ -275,16 +275,10 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list): def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askObjectHistoryFrom(self, conn, oid, min_serial, length): def askObjectHistoryFrom(self, conn, oid, min_serial, length, partition):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, oid, history_list): def answerObjectHistoryFrom(self, conn, object_dict):
raise UnexpectedPacketError
def askOIDs(self, conn, min_oid, length, partition):
raise UnexpectedPacketError
def answerOIDs(self, conn, oid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askPartitionList(self, conn, min_offset, max_offset, uuid): def askPartitionList(self, conn, min_offset, max_offset, uuid):
...@@ -358,6 +352,21 @@ class EventHandler(object): ...@@ -358,6 +352,21 @@ class EventHandler(object):
def answerPack(self, conn, status): def answerPack(self, conn, status):
raise UnexpectedPacketError raise UnexpectedPacketError
def askCheckTIDRange(self, conn, min_tid, length, partition):
raise UnexpectedPacketError
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
raise UnexpectedPacketError
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
raise UnexpectedPacketError
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
raise UnexpectedPacketError
# Error packet handlers. # Error packet handlers.
...@@ -450,8 +459,6 @@ class EventHandler(object): ...@@ -450,8 +459,6 @@ class EventHandler(object):
d[Packets.AnswerObjectHistory] = self.answerObjectHistory d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
d[Packets.AskOIDs] = self.askOIDs
d[Packets.AnswerOIDs] = self.answerOIDs
d[Packets.AskPartitionList] = self.askPartitionList d[Packets.AskPartitionList] = self.askPartitionList
d[Packets.AnswerPartitionList] = self.answerPartitionList d[Packets.AnswerPartitionList] = self.answerPartitionList
d[Packets.AskNodeList] = self.askNodeList d[Packets.AskNodeList] = self.askNodeList
...@@ -476,6 +483,10 @@ class EventHandler(object): ...@@ -476,6 +483,10 @@ class EventHandler(object):
d[Packets.AnswerBarrier] = self.answerBarrier d[Packets.AnswerBarrier] = self.answerBarrier
d[Packets.AskPack] = self.askPack d[Packets.AskPack] = self.askPack
d[Packets.AnswerPack] = self.answerPack d[Packets.AnswerPack] = self.answerPack
d[Packets.AskCheckTIDRange] = self.askCheckTIDRange
d[Packets.AnswerCheckTIDRange] = self.answerCheckTIDRange
d[Packets.AskCheckSerialRange] = self.askCheckSerialRange
d[Packets.AnswerCheckSerialRange] = self.answerCheckSerialRange
return d return d
......
...@@ -113,6 +113,7 @@ INVALID_PARTITION = 0xffffffff ...@@ -113,6 +113,7 @@ INVALID_PARTITION = 0xffffffff
ZERO_TID = '\0' * 8 ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8 ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID) OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID)
UUID_NAMESPACES = { UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S', NodeTypes.STORAGE: 'S',
...@@ -1167,63 +1168,47 @@ class AnswerObjectHistory(Packet): ...@@ -1167,63 +1168,47 @@ class AnswerObjectHistory(Packet):
class AskObjectHistoryFrom(Packet): class AskObjectHistoryFrom(Packet):
""" """
Ask history information for a given object. The order of serials is Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial. S -> S. ascending, and starts at (or above) min_serial for min_oid. S -> S.
""" """
_header_format = '!8s8sL' _header_format = '!8s8sLL'
def _encode(self, oid, min_serial, length): def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, oid, min_serial, length) return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body): def _decode(self, body):
return unpack(self._header_format, body) # oid, min_serial, length # min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerObjectHistoryFrom(AskFinishTransaction): class AnswerObjectHistoryFrom(Packet):
""" """
Answer the requested serials. S -> S. Answer the requested serials. S -> S.
""" """
# This is similar to AskFinishTransaction as TID size is identical to OID
# size:
# - we have a single OID (TID in AskFinishTransaction)
# - we have a list of TIDs (OIDs in AskFinishTransaction)
pass
class AskOIDs(Packet):
"""
Ask for length OIDs starting at min_oid. S -> S.
"""
_header_format = '!8sLL'
def _encode(self, min_oid, length, partition):
return pack(self._header_format, min_oid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_oid, length, partition
class AnswerOIDs(Packet):
"""
Answer the requested OIDs. S -> S.
"""
_header_format = '!L' _header_format = '!L'
_list_entry_format = '8s' _list_entry_format = '!8sL'
_list_entry_len = calcsize(_list_entry_format) _list_entry_len = calcsize(_list_entry_format)
def _encode(self, oid_list): def _encode(self, object_dict):
body = [pack(self._header_format, len(oid_list))] body = [pack(self._header_format, len(object_dict))]
body.extend(oid_list) append = body.append
extend = body.extend
list_entry_format = self._list_entry_format
for oid, serial_list in object_dict.iteritems():
append(pack(list_entry_format, oid, len(serial_list)))
extend(serial_list)
return ''.join(body) return ''.join(body)
def _decode(self, body): def _decode(self, body):
offset = self._header_len body = StringIO(body)
(n,) = unpack(self._header_format, body[:offset]) read = body.read
oid_list = []
list_entry_format = self._list_entry_format list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len list_entry_len = self._list_entry_len
for _ in xrange(n): object_dict = {}
next_offset = offset + list_entry_len dict_len = unpack(self._header_format, read(self._header_len))[0]
oid = unpack(list_entry_format, body[offset:next_offset])[0] for _ in xrange(dict_len):
offset = next_offset oid, serial_len = unpack(list_entry_format, read(list_entry_len))
oid_list.append(oid) object_dict[oid] = [read(TID_LEN) for _ in xrange(serial_len)]
return (oid_list,) return (object_dict, )
class AskPartitionList(Packet): class AskPartitionList(Packet):
""" """
...@@ -1660,6 +1645,73 @@ class AnswerPack(Packet): ...@@ -1660,6 +1645,73 @@ class AnswerPack(Packet):
def _decode(self, body): def _decode(self, body):
return (bool(unpack(self._header_format, body)[0]), ) return (bool(unpack(self._header_format, body)[0]), )
class AskCheckTIDRange(Packet):
"""
Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLL'
def _encode(self, min_tid, length, partition):
return pack(self._header_format, min_tid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition
class AnswerCheckTIDRange(Packet):
"""
Stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLLQ8s'
def _encode(self, min_tid, length, count, tid_checksum, max_tid):
return pack(self._header_format, min_tid, length, count, tid_checksum,
max_tid)
def _decode(self, body):
# min_tid, length, partition, count, tid_checksum, max_tid
return unpack(self._header_format, body)
class AskCheckSerialRange(Packet):
"""
Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLL'
def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body):
# min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerCheckSerialRange(Packet):
"""
Stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLLQ8sQ8s'
def _encode(self, min_oid, min_serial, length, count, oid_checksum,
max_oid, serial_checksum, max_serial):
return pack(self._header_format, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial)
def _decode(self, body):
# min_oid, min_serial, length, count, oid_checksum, max_oid,
# serial_checksum, max_serial
return unpack(self._header_format, body)
class Error(Packet): class Error(Packet):
""" """
Error is a special type of message, because this can be sent against Error is a special type of message, because this can be sent against
...@@ -1844,10 +1896,6 @@ class PacketRegistry(dict): ...@@ -1844,10 +1896,6 @@ class PacketRegistry(dict):
0x001F, 0x001F,
AskObjectHistory, AskObjectHistory,
AnswerObjectHistory) AnswerObjectHistory)
AskOIDs, AnswerOIDs = register(
0x0020,
AskOIDs,
AnswerOIDs)
AskPartitionList, AnswerPartitionList = register( AskPartitionList, AnswerPartitionList = register(
0x0021, 0x0021,
AskPartitionList, AskPartitionList,
...@@ -1903,6 +1951,16 @@ class PacketRegistry(dict): ...@@ -1903,6 +1951,16 @@ class PacketRegistry(dict):
0x0038, 0x0038,
AskPack, AskPack,
AnswerPack) AnswerPack)
AskCheckTIDRange, AnswerCheckTIDRange = register(
0x0039,
AskCheckTIDRange,
AnswerCheckTIDRange,
)
AskCheckSerialRange, AnswerCheckSerialRange = register(
0x003A,
AskCheckSerialRange,
AnswerCheckSerialRange,
)
# build a "singleton" # build a "singleton"
Packets = PacketRegistry() Packets = PacketRegistry()
......
...@@ -288,6 +288,12 @@ class Application(object): ...@@ -288,6 +288,12 @@ class Application(object):
while True: while True:
em.poll(1) em.poll(1)
if self.replicator.pending(): if self.replicator.pending():
# Call processDelayedTasks before act, so tasks added in the
# act call are executed after one poll call, so that sent
# packets are already on the network and delayed task
# processing happens in parallel with the same task on the
# other storage node.
self.replicator.processDelayedTasks()
self.replicator.act() self.replicator.act()
def wait(self): def wait(self):
......
...@@ -274,6 +274,11 @@ class DatabaseManager(object): ...@@ -274,6 +274,11 @@ class DatabaseManager(object):
area.""" area."""
raise NotImplementedError raise NotImplementedError
def deleteObject(self, oid, serial=None):
"""Delete given object. If serial is given, only delete that serial for
given oid."""
raise NotImplementedError
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction a description, and extension information, for a given transaction
...@@ -282,12 +287,6 @@ class DatabaseManager(object): ...@@ -282,12 +287,6 @@ class DatabaseManager(object):
area as well.""" area as well."""
raise NotImplementedError raise NotImplementedError
def getOIDList(self, min_oid, length, num_partitions, partition_list):
"""Return a list of OIDs in ascending order from a minimal oid,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getObjectHistory(self, oid, offset = 0, length = 1): def getObjectHistory(self, oid, offset = 0, length = 1):
"""Return a list of serials and sizes for a given object ID. """Return a list of serials and sizes for a given object ID.
The length specifies the maximum size of such a list. Result starts The length specifies the maximum size of such a list. Result starts
...@@ -295,9 +294,11 @@ class DatabaseManager(object): ...@@ -295,9 +294,11 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None.""" If there is no such object ID in a database, return None."""
raise NotImplementedError raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, length): def getObjectHistoryFrom(self, oid, min_serial, length, num_partitions,
"""Return a list of length serials for a given object ID at (or above) partition):
min_serial, sorted in ascending order.""" """Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial, for given partition, sorted in ascending
order."""
raise NotImplementedError raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
...@@ -307,20 +308,10 @@ class DatabaseManager(object): ...@@ -307,20 +308,10 @@ class DatabaseManager(object):
raise NotImplementedError raise NotImplementedError
def getReplicationTIDList(self, min_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list): partition):
"""Return a list of TIDs in ascending order from an initial tid value, """Return a list of TIDs in ascending order from an initial tid value,
at most the specified length. The list of partitions are passed at most the specified length. The partition number is passed to filter
to filter out non-applicable TIDs.""" out non-applicable TIDs."""
raise NotImplementedError
def getTIDListPresent(self, tid_list):
"""Return a list of TIDs which are present in a database among
the given list."""
raise NotImplementedError
def getSerialListPresent(self, oid, serial_list):
"""Return a list of serials which are present in a database among
the given list."""
raise NotImplementedError raise NotImplementedError
def pack(self, tid, updateObjectDataForPack): def pack(self, tid, updateObjectDataForPack):
......
...@@ -24,7 +24,7 @@ import string ...@@ -24,7 +24,7 @@ import string
from neo.storage.database import DatabaseManager from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure from neo.exception import DatabaseFailure
from neo.protocol import CellStates from neo.protocol import CellStates, ZERO_OID, ZERO_TID
from neo import util from neo import util
LOG_QUERIES = False LOG_QUERIES = False
...@@ -576,6 +576,23 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -576,6 +576,23 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def deleteObject(self, oid, serial=None):
u64 = util.u64
query_param_dict = {
'oid': u64(oid),
}
query_fmt = 'DELETE FROM obj WHERE oid = %(oid)d'
if serial is not None:
query_param_dict['serial'] = u64(serial)
query_fmt = query_fmt + ' AND serial = %(serial)d'
self.begin()
try:
self.query(query_fmt % query_param_dict)
except:
self.rollback()
raise
self.commit()
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
...@@ -594,20 +611,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -594,20 +611,6 @@ class MySQLDatabaseManager(DatabaseManager):
return oid_list, user, desc, ext, bool(packed) return oid_list, user, desc, ext, bool(packed)
return None return None
def getOIDList(self, min_oid, length, num_partitions,
partition_list):
q = self.query
r = q("""SELECT DISTINCT oid FROM obj WHERE
MOD(oid, %(num_partitions)d) in (%(partitions)s)
AND oid >= %(min_oid)d
ORDER BY oid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_oid': util.u64(min_oid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def _getObjectLength(self, oid, value_serial): def _getObjectLength(self, oid, value_serial):
if value_serial is None: if value_serial is None:
raise CreationUndone raise CreationUndone
...@@ -646,18 +649,32 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -646,18 +649,32 @@ class MySQLDatabaseManager(DatabaseManager):
return result return result
return None return None
def getObjectHistoryFrom(self, oid, min_serial, length): def getObjectHistoryFrom(self, min_oid, min_serial, length, num_partitions,
partition):
q = self.query q = self.query
oid = util.u64(oid) u64 = util.u64
p64 = util.p64 p64 = util.p64
r = q("""SELECT serial FROM obj min_oid = u64(min_oid)
WHERE oid = %(oid)d AND serial >= %(min_serial)d min_serial = u64(min_serial)
ORDER BY serial ASC LIMIT %(length)d""" % { r = q('SELECT oid, serial FROM obj '
'oid': oid, 'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR '
'min_serial': util.u64(min_serial), 'oid > %(min_oid)d) AND '
'MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid,
'min_serial': min_serial,
'length': length, 'length': length,
'num_partitions': num_partitions,
'partition': partition,
}) })
return [p64(t[0]) for t in r] result = {}
for oid, serial in r:
try:
serial_list = result[oid]
except KeyError:
serial_list = result[oid] = []
serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query q = self.query
...@@ -669,32 +686,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -669,32 +686,19 @@ class MySQLDatabaseManager(DatabaseManager):
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list): partition):
q = self.query q = self.query
r = q("""SELECT tid FROM trans WHERE r = q("""SELECT tid FROM trans WHERE
MOD(tid, %(num_partitions)d) in (%(partitions)s) MOD(tid, %(num_partitions)d) = %(partition)d
AND tid >= %(min_tid)d AND tid >= %(min_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % { ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions, 'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]), 'partition': partition,
'min_tid': util.u64(min_tid), 'min_tid': util.u64(min_tid),
'length': length, 'length': length,
}) })
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getTIDListPresent(self, tid_list):
q = self.query
r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \
% ','.join([str(util.u64(tid)) for tid in tid_list]))
return [util.p64(t[0]) for t in r]
def getSerialListPresent(self, oid, serial_list):
q = self.query
oid = util.u64(oid)
r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \
% (oid, ','.join([str(util.u64(serial)) for serial in serial_list])))
return [util.p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial, def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack): updateObjectDataForPack):
q = self.query q = self.query
...@@ -783,4 +787,54 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -783,4 +787,54 @@ class MySQLDatabaseManager(DatabaseManager):
self.rollback() self.rollback()
raise raise
self.commit() self.commit()
def checkTIDRange(self, min_tid, length, num_partitions, partition):
# XXX: XOR is a lame checksum
count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
'BIT_XOR(tid), MAX(tid) FROM ('
'SELECT tid FROM trans '
'WHERE MOD(tid, %(num_partitions)d) = %(partition)s '
'AND tid >= %(min_tid)d '
'ORDER BY tid ASC LIMIT %(length)d'
') AS foo' % {
'num_partitions': num_partitions,
'partition': partition,
'min_tid': util.u64(min_tid),
'length': length,
})[0]
if count == 0:
tid_checksum = 0
max_tid = ZERO_TID
else:
max_tid = util.p64(max_tid)
return count, tid_checksum, max_tid
def checkSerialRange(self, min_oid, min_serial, length, num_partitions,
partition):
# XXX: XOR is a lame checksum
u64 = util.u64
p64 = util.p64
r = self.query('SELECT oid, serial FROM obj WHERE '
'(oid > %(min_oid)d OR '
'(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
'AND MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': u64(min_oid),
'min_serial': u64(min_serial),
'length': length,
'num_partitions': num_partitions,
'partition': partition,
})
count = len(r)
oid_checksum = serial_checksum = 0
if count == 0:
max_oid = ZERO_OID
max_serial = ZERO_TID
else:
for max_oid, max_serial in r:
oid_checksum ^= max_oid
serial_checksum ^= max_serial
max_oid = p64(max_oid)
max_serial = p64(max_serial)
return count, oid_checksum, max_oid, serial_checksum, max_serial
This diff is collapsed.
...@@ -30,34 +30,32 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -30,34 +30,32 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID() tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID())) conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askOIDs(self, conn, min_oid, length, partition):
# This method is complicated, because I must return OIDs only
# about usable partitions assigned to me.
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
oid_list = app.dm.getOIDList(min_oid, length,
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerOIDs(oid_list))
def askTIDsFrom(self, conn, min_tid, length, partition): def askTIDsFrom(self, conn, min_tid, length, partition):
# This method is complicated, because I must return TIDs only
# about usable partitions assigned to me.
app = self.app app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
tid_list = app.dm.getReplicationTIDList(min_tid, length, tid_list = app.dm.getReplicationTIDList(min_tid, length,
app.pt.getPartitions(), partition_list) app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerTIDsFrom(tid_list)) conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, oid, min_serial, length): def askObjectHistoryFrom(self, conn, min_oid, min_serial, length,
partition):
app = self.app
object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, length, partition):
app = self.app
count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, count,
tid_checksum, max_tid))
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
app = self.app app = self.app
history_list = app.dm.getObjectHistoryFrom(oid, min_serial, length) count, oid_checksum, max_oid, serial_checksum, max_serial = \
conn.answer(Packets.AnswerObjectHistoryFrom(oid, history_list)) app.dm.checkSerialRange(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial))
...@@ -46,6 +46,46 @@ class Partition(object): ...@@ -46,6 +46,46 @@ class Partition(object):
return tid is not None and ( return tid is not None and (
min_pending_tid is None or tid < min_pending_tid) min_pending_tid is None or tid < min_pending_tid)
class Task(object):
"""
A Task is a callable to execute at another time, with given parameters.
Execution result is kept and can be retrieved later.
"""
_func = None
_args = None
_kw = None
_result = None
_processed = False
def __init__(self, func, args=(), kw=None):
self._func = func
self._args = args
if kw is None:
kw = {}
self._kw = kw
def process(self):
if self._processed:
raise ValueError, 'You cannot process a single Task twice'
self._processed = True
self._result = self._func(*self._args, **self._kw)
def getResult(self):
# Should we instead execute immediately rather than raising ?
if not self._processed:
raise ValueError, 'You cannot get a result until task is executed'
return self._result
def __repr__(self):
fmt = '<%s at %x %r(*%r, **%r)%%s>' % (self.__class__.__name__,
id(self), self._func, self._args, self._kw)
if self._processed:
extra = ' => %r' % (self._result, )
else:
extra = ''
return fmt % (extra, )
class Replicator(object): class Replicator(object):
"""This class handles replications of objects and transactions. """This class handles replications of objects and transactions.
...@@ -98,21 +138,23 @@ class Replicator(object): ...@@ -98,21 +138,23 @@ class Replicator(object):
# didn't answer yet. # didn't answer yet.
# unfinished_tid_list # unfinished_tid_list
# The list of unfinished TIDs known by master node. # The list of unfinished TIDs known by master node.
# oid_list
# List of OIDs to replicate. Doesn't contains currently-replicated
# object.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# next_oid
# Next OID to ask when oid_list is empty.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# replication_done # replication_done
# False if we know there is something to replicate. # False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if # True when current_partition is replicated, or we don't know yet if
# there is something to replicate # there is something to replicate
# XXX: accessed (w) directly by ReplicationHandler # XXX: accessed (w) directly by ReplicationHandler
new_partition_dict = None
critical_tid_dict = None
partition_dict = None
task_list = None
task_dict = None
current_partition = None
current_connection = None
waiting_for_unfinished_tids = None
unfinished_tid_list = None
replication_done = None
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
...@@ -129,6 +171,8 @@ class Replicator(object): ...@@ -129,6 +171,8 @@ class Replicator(object):
def reset(self): def reset(self):
"""Reset attributes to restart replicating.""" """Reset attributes to restart replicating."""
self.task_list = []
self.task_dict = {}
self.current_partition = None self.current_partition = None
self.current_connection = None self.current_connection = None
self.waiting_for_unfinished_tids = False self.waiting_for_unfinished_tids = False
...@@ -213,15 +257,12 @@ class Replicator(object): ...@@ -213,15 +257,12 @@ class Replicator(object):
p = Packets.RequestIdentification(NodeTypes.STORAGE, p = Packets.RequestIdentification(NodeTypes.STORAGE,
app.uuid, app.server, app.name) app.uuid, app.server, app.name)
self.current_connection.ask(p) self.current_connection.ask(p)
else:
p = Packets.AskTIDsFrom(ZERO_TID, 1000, self.current_connection.getHandler().startReplication(
self.current_partition.getRID()) self.current_connection)
self.current_connection.ask(p, timeout=300)
self.replication_done = False self.replication_done = False
def _finishReplication(self): def _finishReplication(self):
app = self.app
# TODO: remove try..except: pass # TODO: remove try..except: pass
try: try:
self.partition_dict.pop(self.current_partition.getRID()) self.partition_dict.pop(self.current_partition.getRID())
...@@ -243,7 +284,11 @@ class Replicator(object): ...@@ -243,7 +284,11 @@ class Replicator(object):
self._askCriticalTID() self._askCriticalTID()
if self.current_partition is not None: if self.current_partition is not None:
if self.replication_done: # Don't end replication until we have received all expected
# answers, as we might have asked object data just before the last
# AnswerCheckSerialRange.
if self.replication_done and \
not self.current_connection.isPending():
# finish a replication # finish a replication
logging.info('replication is done for %s' % logging.info('replication is done for %s' %
(self.current_partition.getRID(), )) (self.current_partition.getRID(), ))
...@@ -289,3 +334,57 @@ class Replicator(object): ...@@ -289,3 +334,57 @@ class Replicator(object):
and not self.new_partition_dict.has_key(rid): and not self.new_partition_dict.has_key(rid):
self.new_partition_dict[rid] = Partition(rid) self.new_partition_dict[rid] = Partition(rid)
def _addTask(self, key, func, args=(), kw=None):
task = Task(func, args, kw)
task_dict = self.task_dict
if key in task_dict:
raise ValueError, 'Task with key %r already exists (%r), cannot ' \
'add %r' % (key, task_dict[key], task)
task_dict[key] = task
self.task_list.append(task)
def processDelayedTasks(self):
task_list = self.task_list
if task_list:
for task in task_list:
task.process()
self.task_list = []
def checkTIDRange(self, min_tid, length, partition):
app = self.app
self._addTask(('TID', min_tid, length), app.dm.checkTIDRange,
(min_tid, length, app.pt.getPartitions(), partition))
def checkSerialRange(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask(('Serial', min_oid, min_serial, length),
app.dm.checkSerialRange, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def getTIDsFrom(self, min_tid, length, partition):
app = self.app
self._addTask('TIDsFrom',
app.dm.getReplicationTIDList, (min_tid, length,
app.pt.getPartitions(), partition))
def getObjectHistoryFrom(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask('ObjectHistoryFrom',
app.dm.getObjectHistoryFrom, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult()
def getTIDCheckResult(self, min_tid, length):
return self._getCheckResult(('TID', min_tid, length))
def getSerialCheckResult(self, min_oid, min_serial, length):
return self._getCheckResult(('Serial', min_oid, min_serial, length))
def getTIDsFromResult(self):
return self._getCheckResult('TIDsFrom')
def getObjectHistoryFromResult(self):
return self._getCheckResult('ObjectHistoryFrom')
This diff is collapsed.
This diff is collapsed.
...@@ -21,7 +21,7 @@ from collections import deque ...@@ -21,7 +21,7 @@ from collections import deque
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler from neo.storage.handlers.storage import StorageOperationHandler
from neo.protocol import INVALID_PARTITION from neo.protocol import INVALID_PARTITION, Packets
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
class StorageStorageHandlerTests(NeoTestBase): class StorageStorageHandlerTests(NeoTestBase):
...@@ -113,7 +113,7 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -113,7 +113,7 @@ class StorageStorageHandlerTests(NeoTestBase):
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn) self.checkAnswerObject(conn)
def test_25_askTIDsFrom1(self): def test_25_askTIDsFrom(self):
# well case => answer # well case => answer
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
...@@ -122,69 +122,85 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -122,69 +122,85 @@ class StorageStorageHandlerTests(NeoTestBase):
self.operation.askTIDsFrom(conn, tid, 2, 1) self.operation.askTIDsFrom(conn, tid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [1, ]) calls[0].checkArgs(tid, 2, 1, 1)
self.checkAnswerTidsFrom(conn)
def test_25_askTIDsFrom2(self):
# invalid partition => answer usable partitions
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [0, ])
self.checkAnswerTidsFrom(conn) self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self): def test_26_askObjectHistoryFrom(self):
oid = self.getOID(2) min_oid = self.getOID(2)
min_tid = self.getNextTID() min_serial = self.getNextTID()
length = 4
partition = 8
num_partitions = 16
tid = self.getNextTID() tid = self.getNextTID()
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistoryFrom': [tid]}) self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},})
self.operation.askObjectHistoryFrom(conn, oid, min_tid, 2) self.app.pt = Mock({
'getPartitions': num_partitions,
})
self.operation.askObjectHistoryFrom(conn, min_oid, min_serial, length,
partition)
self.checkAnswerObjectHistoryFrom(conn) self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom') calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, min_tid, 2) calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
partition)
def test_25_askOIDs1(self): def test_askCheckTIDRange(self):
# well case > answer OIDs count = 1
tid_checksum = 2
min_tid = self.getNextTID()
num_partitions = 4
length = 5
partition = 6
max_tid = self.getNextTID()
self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.pt = Mock({'getPartitions': 1}) self.operation.askCheckTIDRange(conn, min_tid, length, partition)
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
oid = self.getOID(1) self.assertEqual(len(calls), 1)
self.operation.askOIDs(conn, oid, 2, 1) calls[0].checkArgs(min_tid, length, num_partitions, partition)
calls = self.app.dm.mockGetNamedCalls('getOIDList') pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.assertEquals(len(calls), 1) self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
calls[0].checkArgs(oid, 2, 1, [1, ]) decode=True)
self.checkAnswerOids(conn) self.assertEqual(min_tid, pmin_tid)
self.assertEqual(length, plength)
def test_25_askOIDs2(self): self.assertEqual(count, pcount)
# invalid partition => answer usable partitions self.assertEqual(tid_checksum, ptid_checksum)
self.assertEqual(max_tid, pmax_tid)
def test_askCheckSerialRange(self):
count = 1
oid_checksum = 2
min_oid = self.getOID(1)
num_partitions = 4
length = 5
partition = 6
serial_checksum = 7
min_serial = self.getNextTID()
max_serial = self.getNextTID()
max_oid = self.getOID(2)
self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid,
serial_checksum, max_serial)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection() conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid}) self.operation.askCheckSerialRange(conn, min_oid, min_serial, length,
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) partition)
self.app.pt = Mock({ calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
'getCellList': (cell, ), self.assertEqual(len(calls), 1)
'getPartitions': 1, calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
'getAssignedPartitionList': [0], partition)
}) pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
oid = self.getOID(1) pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
self.operation.askOIDs(conn, oid, 2, INVALID_PARTITION) Packets.AnswerCheckSerialRange, decode=True)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEqual(min_oid, pmin_oid)
calls = self.app.dm.mockGetNamedCalls('getOIDList') self.assertEqual(min_serial, pmin_serial)
self.assertEquals(len(calls), 1) self.assertEqual(length, plength)
calls[0].checkArgs(oid, 2, 1, [0]) self.assertEqual(count, pcount)
self.checkAnswerOids(conn) self.assertEqual(oid_checksum, poid_checksum)
self.assertEqual(max_oid, pmax_oid)
self.assertEqual(serial_checksum, pserial_checksum)
self.assertEqual(max_serial, pmax_serial)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import MySQLdb import MySQLdb
from mock import Mock from mock import Mock
from neo.util import dump, p64, u64 from neo.util import dump, p64, u64
from neo.protocol import CellStates, INVALID_PTID from neo.protocol import CellStates, INVALID_PTID, ZERO_OID, ZERO_TID
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.exception import DatabaseFailure from neo.exception import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager from neo.storage.database.mysqldb import MySQLDatabaseManager
...@@ -441,6 +441,23 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -441,6 +441,23 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None) self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None) self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1)
self.assertEqual(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \
objs2[1][1:])
def test_getTransaction(self): def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
...@@ -459,30 +476,6 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -459,30 +476,6 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False)) self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False))
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getOIDList(self):
# store four objects
oid1, oid2, oid3, oid4 = self.getOIDs(4)
tid = self.getNextTID()
txn, objs = self.getTransaction([oid1, oid2, oid3, oid4])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
# get oids
result = self.db.getOIDList(oid1, 4, 1, [0])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 2, [0])
self.checkSet(result, [oid1, oid3])
result = self.db.getOIDList(oid1, 4, 2, [0, 1])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 3, [0])
self.checkSet(result, [oid1, oid4])
# get a subset of oids
result = self.db.getOIDList(oid1, 2, 1, [0])
self.checkSet(result, [oid1, oid2])
result = self.db.getOIDList(oid3, 2, 1, [0])
self.checkSet(result, [oid3, oid4])
result = self.db.getOIDList(oid2, 1, 3, [0])
self.checkSet(result, [oid4])
def test_getObjectHistory(self): def test_getObjectHistory(self):
oid = self.getOID(1) oid = self.getOID(1)
tid1, tid2, tid3 = self.getTIDs(3) tid1, tid2, tid3 = self.getTIDs(3)
...@@ -506,6 +499,50 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -506,6 +499,50 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getObjectHistory(oid, 2, 3) result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None) self.assertEqual(result, None)
def test_getObjectHistoryFrom(self):
oid1 = self.getOID(0)
oid2 = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
txn3, objs3 = self.getTransaction([oid1])
txn4, objs4 = self.getTransaction([oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.storeTransaction(tid3, objs3, txn3)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
# Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 3, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
})
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 1)
self.assertEqual(result, {
oid2: [tid2, tid4],
})
def _storeTransactions(self, count): def _storeTransactions(self, count):
# use OID generator to know result of tid % N # use OID generator to know result of tid % N
tid_list = self.getOIDs(count) tid_list = self.getOIDs(count)
...@@ -538,59 +575,20 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -538,59 +575,20 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_getReplicationTIDList(self): def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids # get tids
result = self.db.getReplicationTIDList(tid1, 4, 1, [0]) result = self.db.getReplicationTIDList(tid1, 4, 1, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4]) self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0]) result = self.db.getReplicationTIDList(tid1, 4, 2, 0)
self.checkSet(result, [tid1, tid3]) self.checkSet(result, [tid1, tid3])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0, 1]) result = self.db.getReplicationTIDList(tid1, 4, 3, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 3, [0])
self.checkSet(result, [tid1, tid4]) self.checkSet(result, [tid1, tid4])
# get a subset of tids # get a subset of tids
result = self.db.getReplicationTIDList(tid3, 4, 1, [0]) result = self.db.getReplicationTIDList(tid3, 4, 1, 0)
self.checkSet(result, [tid3, tid4]) self.checkSet(result, [tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 2, 1, [0]) result = self.db.getReplicationTIDList(tid1, 2, 1, 0)
self.checkSet(result, [tid1, tid2]) self.checkSet(result, [tid1, tid2])
result = self.db.getReplicationTIDList(tid1, 1, 3, [1]) result = self.db.getReplicationTIDList(tid1, 1, 3, 1)
self.checkSet(result, [tid2]) self.checkSet(result, [tid2])
def test_getTIDListPresent(self):
oid = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid])
txn4, objs4 = self.getTransaction([oid])
# four tids, two missing
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getTIDListPresent([tid1, tid2, tid3, tid4])
self.checkSet(result, [tid1, tid4])
result = self.db.getTIDListPresent([tid1, tid2])
self.checkSet(result, [tid1])
self.assertEqual(self.db.getTIDListPresent([tid2, tid3]), [])
def test_getSerialListPresent(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
txn3, objs3 = self.getTransaction([oid2])
txn4, objs4 = self.getTransaction([oid2])
# four object, one revision each
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getSerialListPresent(oid1, [tid1, tid2])
self.checkSet(result, [tid1])
result = self.db.getSerialListPresent(oid2, [tid3, tid4])
self.checkSet(result, [tid4])
result = self.db.getSerialListPresent(oid1, [tid2])
self.assertEqual(result, [])
result = self.db.getSerialListPresent(oid2, [tid3])
self.assertEqual(result, [])
def test__getObjectData(self): def test__getObjectData(self):
db = self.db db = self.db
db.setup(reset=True) db.setup(reset=True)
......
...@@ -458,24 +458,6 @@ class ProtocolTests(NeoTestBase): ...@@ -458,24 +458,6 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(p_hist_list, hist_list) self.assertEqual(p_hist_list, hist_list)
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
def test_55_askOIDs(self):
oid = self.getOID(1)
p = Packets.AskOIDs(oid, 1000, 5)
min_oid, length, partition = p.decode()
self.assertEqual(min_oid, oid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
def test_56_answerOIDs(self):
oid1 = self.getNextTID()
oid2 = self.getNextTID()
oid3 = self.getNextTID()
oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4]
p = Packets.AnswerOIDs(oid_list)
p_oid_list = p.decode()[0]
self.assertEqual(p_oid_list, oid_list)
def test_57_notifyReplicationDone(self): def test_57_notifyReplicationDone(self):
offset = 10 offset = 10
p = Packets.NotifyReplicationDone(offset) p = Packets.NotifyReplicationDone(offset)
...@@ -626,14 +608,82 @@ class ProtocolTests(NeoTestBase): ...@@ -626,14 +608,82 @@ class ProtocolTests(NeoTestBase):
oid = self.getOID(1) oid = self.getOID(1)
min_serial = self.getNextTID() min_serial = self.getNextTID()
length = 5 length = 5
p = Packets.AskObjectHistoryFrom(oid, min_serial, length) partition = 4
p_oid, p_min_serial, p_length = p.decode() p = Packets.AskObjectHistoryFrom(oid, min_serial, length, partition)
p_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid) self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial) self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length) self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerObjectHistoryFrom(self): def test_AnswerObjectHistoryFrom(self):
self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom) object_dict = {}
for int_oid in xrange(4):
object_dict[self.getOID(int_oid)] = [self.getNextTID() \
for _ in xrange(5)]
p = Packets.AnswerObjectHistoryFrom(object_dict)
p_object_dict = p.decode()[0]
self.assertEqual(object_dict, p_object_dict)
def test_AskCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckTIDRange(min_tid, length, partition)
p_min_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
count = 1
tid_checksum = 42
max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid)
p_min_tid, p_length, p_count, p_tid_checksum, p_max_tid = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_tid, max_tid)
def test_AskCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckSerialRange(min_oid, min_serial, length, partition)
p_min_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
count = 1
oid_checksum = 24
max_oid = self.getOID(5)
tid_checksum = 42
max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial)
p_min_oid, p_min_serial, p_length, p_count, p_oid_checksum, \
p_max_oid, p_tid_checksum, p_max_serial = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_oid_checksum, oid_checksum)
self.assertEqual(p_max_oid, max_oid)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_serial, max_serial)
def test_AskPack(self): def test_AskPack(self):
tid = self.getNextTID() tid = self.getNextTID()
......
...@@ -56,6 +56,8 @@ UNIT_TEST_MODULES = [ ...@@ -56,6 +56,8 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testVerificationHandler', 'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler', 'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions', 'neo.tests.storage.testTransactions',
'neo.tests.storage.testReplicationHandler',
'neo.tests.storage.testReplicator',
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler', 'neo.tests.client.testMasterHandler',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment