Commit b0753366 authored by Vincent Pelletier's avatar Vincent Pelletier

Implement rsync-ish replication.

For further description, see storage/handlers/replication.py .

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2295 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 01735352
......@@ -275,16 +275,10 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError
def askObjectHistoryFrom(self, conn, oid, min_serial, length):
def askObjectHistoryFrom(self, conn, oid, min_serial, length, partition):
raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, oid, history_list):
raise UnexpectedPacketError
def askOIDs(self, conn, min_oid, length, partition):
raise UnexpectedPacketError
def answerOIDs(self, conn, oid_list):
def answerObjectHistoryFrom(self, conn, object_dict):
raise UnexpectedPacketError
def askPartitionList(self, conn, min_offset, max_offset, uuid):
......@@ -358,6 +352,21 @@ class EventHandler(object):
def answerPack(self, conn, status):
raise UnexpectedPacketError
def askCheckTIDRange(self, conn, min_tid, length, partition):
raise UnexpectedPacketError
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
raise UnexpectedPacketError
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
raise UnexpectedPacketError
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
raise UnexpectedPacketError
# Error packet handlers.
......@@ -450,8 +459,6 @@ class EventHandler(object):
d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
d[Packets.AskOIDs] = self.askOIDs
d[Packets.AnswerOIDs] = self.answerOIDs
d[Packets.AskPartitionList] = self.askPartitionList
d[Packets.AnswerPartitionList] = self.answerPartitionList
d[Packets.AskNodeList] = self.askNodeList
......@@ -476,6 +483,10 @@ class EventHandler(object):
d[Packets.AnswerBarrier] = self.answerBarrier
d[Packets.AskPack] = self.askPack
d[Packets.AnswerPack] = self.answerPack
d[Packets.AskCheckTIDRange] = self.askCheckTIDRange
d[Packets.AnswerCheckTIDRange] = self.answerCheckTIDRange
d[Packets.AskCheckSerialRange] = self.askCheckSerialRange
d[Packets.AnswerCheckSerialRange] = self.answerCheckSerialRange
return d
......
......@@ -113,6 +113,7 @@ INVALID_PARTITION = 0xffffffff
ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID)
UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S',
......@@ -1167,63 +1168,47 @@ class AnswerObjectHistory(Packet):
class AskObjectHistoryFrom(Packet):
"""
Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial. S -> S.
ascending, and starts at (or above) min_serial for min_oid. S -> S.
"""
_header_format = '!8s8sL'
_header_format = '!8s8sLL'
def _encode(self, oid, min_serial, length):
return pack(self._header_format, oid, min_serial, length)
def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body):
return unpack(self._header_format, body) # oid, min_serial, length
# min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerObjectHistoryFrom(AskFinishTransaction):
class AnswerObjectHistoryFrom(Packet):
"""
Answer the requested serials. S -> S.
"""
# This is similar to AskFinishTransaction as TID size is identical to OID
# size:
# - we have a single OID (TID in AskFinishTransaction)
# - we have a list of TIDs (OIDs in AskFinishTransaction)
pass
class AskOIDs(Packet):
"""
Ask for length OIDs starting at min_oid. S -> S.
"""
_header_format = '!8sLL'
def _encode(self, min_oid, length, partition):
return pack(self._header_format, min_oid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_oid, length, partition
class AnswerOIDs(Packet):
"""
Answer the requested OIDs. S -> S.
"""
_header_format = '!L'
_list_entry_format = '8s'
_list_entry_format = '!8sL'
_list_entry_len = calcsize(_list_entry_format)
def _encode(self, oid_list):
body = [pack(self._header_format, len(oid_list))]
body.extend(oid_list)
def _encode(self, object_dict):
body = [pack(self._header_format, len(object_dict))]
append = body.append
extend = body.extend
list_entry_format = self._list_entry_format
for oid, serial_list in object_dict.iteritems():
append(pack(list_entry_format, oid, len(serial_list)))
extend(serial_list)
return ''.join(body)
def _decode(self, body):
offset = self._header_len
(n,) = unpack(self._header_format, body[:offset])
oid_list = []
body = StringIO(body)
read = body.read
list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len
for _ in xrange(n):
next_offset = offset + list_entry_len
oid = unpack(list_entry_format, body[offset:next_offset])[0]
offset = next_offset
oid_list.append(oid)
return (oid_list,)
object_dict = {}
dict_len = unpack(self._header_format, read(self._header_len))[0]
for _ in xrange(dict_len):
oid, serial_len = unpack(list_entry_format, read(list_entry_len))
object_dict[oid] = [read(TID_LEN) for _ in xrange(serial_len)]
return (object_dict, )
class AskPartitionList(Packet):
"""
......@@ -1660,6 +1645,73 @@ class AnswerPack(Packet):
def _decode(self, body):
return (bool(unpack(self._header_format, body)[0]), )
class AskCheckTIDRange(Packet):
"""
Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLL'
def _encode(self, min_tid, length, partition):
return pack(self._header_format, min_tid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition
class AnswerCheckTIDRange(Packet):
"""
Stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLLQ8s'
def _encode(self, min_tid, length, count, tid_checksum, max_tid):
return pack(self._header_format, min_tid, length, count, tid_checksum,
max_tid)
def _decode(self, body):
# min_tid, length, partition, count, tid_checksum, max_tid
return unpack(self._header_format, body)
class AskCheckSerialRange(Packet):
"""
Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLL'
def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body):
# min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerCheckSerialRange(Packet):
"""
Stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLLQ8sQ8s'
def _encode(self, min_oid, min_serial, length, count, oid_checksum,
max_oid, serial_checksum, max_serial):
return pack(self._header_format, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial)
def _decode(self, body):
# min_oid, min_serial, length, count, oid_checksum, max_oid,
# serial_checksum, max_serial
return unpack(self._header_format, body)
class Error(Packet):
"""
Error is a special type of message, because this can be sent against
......@@ -1844,10 +1896,6 @@ class PacketRegistry(dict):
0x001F,
AskObjectHistory,
AnswerObjectHistory)
AskOIDs, AnswerOIDs = register(
0x0020,
AskOIDs,
AnswerOIDs)
AskPartitionList, AnswerPartitionList = register(
0x0021,
AskPartitionList,
......@@ -1903,6 +1951,16 @@ class PacketRegistry(dict):
0x0038,
AskPack,
AnswerPack)
AskCheckTIDRange, AnswerCheckTIDRange = register(
0x0039,
AskCheckTIDRange,
AnswerCheckTIDRange,
)
AskCheckSerialRange, AnswerCheckSerialRange = register(
0x003A,
AskCheckSerialRange,
AnswerCheckSerialRange,
)
# build a "singleton"
Packets = PacketRegistry()
......
......@@ -288,6 +288,12 @@ class Application(object):
while True:
em.poll(1)
if self.replicator.pending():
# Call processDelayedTasks before act, so tasks added in the
# act call are executed after one poll call, so that sent
# packets are already on the network and delayed task
# processing happens in parallel with the same task on the
# other storage node.
self.replicator.processDelayedTasks()
self.replicator.act()
def wait(self):
......
......@@ -274,6 +274,11 @@ class DatabaseManager(object):
area."""
raise NotImplementedError
def deleteObject(self, oid, serial=None):
"""Delete given object. If serial is given, only delete that serial for
given oid."""
raise NotImplementedError
def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction
......@@ -282,12 +287,6 @@ class DatabaseManager(object):
area as well."""
raise NotImplementedError
def getOIDList(self, min_oid, length, num_partitions, partition_list):
"""Return a list of OIDs in ascending order from a minimal oid,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getObjectHistory(self, oid, offset = 0, length = 1):
"""Return a list of serials and sizes for a given object ID.
The length specifies the maximum size of such a list. Result starts
......@@ -295,9 +294,11 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None."""
raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, length):
"""Return a list of length serials for a given object ID at (or above)
min_serial, sorted in ascending order."""
def getObjectHistoryFrom(self, oid, min_serial, length, num_partitions,
partition):
"""Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial, for given partition, sorted in ascending
order."""
raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list):
......@@ -307,20 +308,10 @@ class DatabaseManager(object):
raise NotImplementedError
def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list):
partition):
"""Return a list of TIDs in ascending order from an initial tid value,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getTIDListPresent(self, tid_list):
"""Return a list of TIDs which are present in a database among
the given list."""
raise NotImplementedError
def getSerialListPresent(self, oid, serial_list):
"""Return a list of serials which are present in a database among
the given list."""
at most the specified length. The partition number is passed to filter
out non-applicable TIDs."""
raise NotImplementedError
def pack(self, tid, updateObjectDataForPack):
......
......@@ -24,7 +24,7 @@ import string
from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure
from neo.protocol import CellStates
from neo.protocol import CellStates, ZERO_OID, ZERO_TID
from neo import util
LOG_QUERIES = False
......@@ -576,6 +576,23 @@ class MySQLDatabaseManager(DatabaseManager):
raise
self.commit()
def deleteObject(self, oid, serial=None):
u64 = util.u64
query_param_dict = {
'oid': u64(oid),
}
query_fmt = 'DELETE FROM obj WHERE oid = %(oid)d'
if serial is not None:
query_param_dict['serial'] = u64(serial)
query_fmt = query_fmt + ' AND serial = %(serial)d'
self.begin()
try:
self.query(query_fmt % query_param_dict)
except:
self.rollback()
raise
self.commit()
def getTransaction(self, tid, all = False):
q = self.query
tid = util.u64(tid)
......@@ -594,20 +611,6 @@ class MySQLDatabaseManager(DatabaseManager):
return oid_list, user, desc, ext, bool(packed)
return None
def getOIDList(self, min_oid, length, num_partitions,
partition_list):
q = self.query
r = q("""SELECT DISTINCT oid FROM obj WHERE
MOD(oid, %(num_partitions)d) in (%(partitions)s)
AND oid >= %(min_oid)d
ORDER BY oid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_oid': util.u64(min_oid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
......@@ -646,18 +649,32 @@ class MySQLDatabaseManager(DatabaseManager):
return result
return None
def getObjectHistoryFrom(self, oid, min_serial, length):
def getObjectHistoryFrom(self, min_oid, min_serial, length, num_partitions,
partition):
q = self.query
oid = util.u64(oid)
u64 = util.u64
p64 = util.p64
r = q("""SELECT serial FROM obj
WHERE oid = %(oid)d AND serial >= %(min_serial)d
ORDER BY serial ASC LIMIT %(length)d""" % {
'oid': oid,
'min_serial': util.u64(min_serial),
min_oid = u64(min_oid)
min_serial = u64(min_serial)
r = q('SELECT oid, serial FROM obj '
'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR '
'oid > %(min_oid)d) AND '
'MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid,
'min_serial': min_serial,
'length': length,
'num_partitions': num_partitions,
'partition': partition,
})
return [p64(t[0]) for t in r]
result = {}
for oid, serial in r:
try:
serial_list = result[oid]
except KeyError:
serial_list = result[oid] = []
serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query
......@@ -669,32 +686,19 @@ class MySQLDatabaseManager(DatabaseManager):
return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list):
partition):
q = self.query
r = q("""SELECT tid FROM trans WHERE
MOD(tid, %(num_partitions)d) in (%(partitions)s)
MOD(tid, %(num_partitions)d) = %(partition)d
AND tid >= %(min_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'partition': partition,
'min_tid': util.u64(min_tid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def getTIDListPresent(self, tid_list):
q = self.query
r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \
% ','.join([str(util.u64(tid)) for tid in tid_list]))
return [util.p64(t[0]) for t in r]
def getSerialListPresent(self, oid, serial_list):
q = self.query
oid = util.u64(oid)
r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \
% (oid, ','.join([str(util.u64(serial)) for serial in serial_list])))
return [util.p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack):
q = self.query
......@@ -783,4 +787,54 @@ class MySQLDatabaseManager(DatabaseManager):
self.rollback()
raise
self.commit()
def checkTIDRange(self, min_tid, length, num_partitions, partition):
# XXX: XOR is a lame checksum
count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
'BIT_XOR(tid), MAX(tid) FROM ('
'SELECT tid FROM trans '
'WHERE MOD(tid, %(num_partitions)d) = %(partition)s '
'AND tid >= %(min_tid)d '
'ORDER BY tid ASC LIMIT %(length)d'
') AS foo' % {
'num_partitions': num_partitions,
'partition': partition,
'min_tid': util.u64(min_tid),
'length': length,
})[0]
if count == 0:
tid_checksum = 0
max_tid = ZERO_TID
else:
max_tid = util.p64(max_tid)
return count, tid_checksum, max_tid
def checkSerialRange(self, min_oid, min_serial, length, num_partitions,
partition):
# XXX: XOR is a lame checksum
u64 = util.u64
p64 = util.p64
r = self.query('SELECT oid, serial FROM obj WHERE '
'(oid > %(min_oid)d OR '
'(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
'AND MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': u64(min_oid),
'min_serial': u64(min_serial),
'length': length,
'num_partitions': num_partitions,
'partition': partition,
})
count = len(r)
oid_checksum = serial_checksum = 0
if count == 0:
max_oid = ZERO_OID
max_serial = ZERO_TID
else:
for max_oid, max_serial in r:
oid_checksum ^= max_oid
serial_checksum ^= max_serial
max_oid = p64(max_oid)
max_serial = p64(max_serial)
return count, oid_checksum, max_oid, serial_checksum, max_serial
This diff is collapsed.
......@@ -30,34 +30,32 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askOIDs(self, conn, min_oid, length, partition):
# This method is complicated, because I must return OIDs only
# about usable partitions assigned to me.
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
oid_list = app.dm.getOIDList(min_oid, length,
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerOIDs(oid_list))
def askTIDsFrom(self, conn, min_tid, length, partition):
# This method is complicated, because I must return TIDs only
# about usable partitions assigned to me.
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
tid_list = app.dm.getReplicationTIDList(min_tid, length,
app.pt.getPartitions(), partition_list)
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, oid, min_serial, length):
def askObjectHistoryFrom(self, conn, min_oid, min_serial, length,
partition):
app = self.app
object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, length, partition):
app = self.app
count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, count,
tid_checksum, max_tid))
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
app = self.app
history_list = app.dm.getObjectHistoryFrom(oid, min_serial, length)
conn.answer(Packets.AnswerObjectHistoryFrom(oid, history_list))
count, oid_checksum, max_oid, serial_checksum, max_serial = \
app.dm.checkSerialRange(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial))
......@@ -46,6 +46,46 @@ class Partition(object):
return tid is not None and (
min_pending_tid is None or tid < min_pending_tid)
class Task(object):
"""
A Task is a callable to execute at another time, with given parameters.
Execution result is kept and can be retrieved later.
"""
_func = None
_args = None
_kw = None
_result = None
_processed = False
def __init__(self, func, args=(), kw=None):
self._func = func
self._args = args
if kw is None:
kw = {}
self._kw = kw
def process(self):
if self._processed:
raise ValueError, 'You cannot process a single Task twice'
self._processed = True
self._result = self._func(*self._args, **self._kw)
def getResult(self):
# Should we instead execute immediately rather than raising ?
if not self._processed:
raise ValueError, 'You cannot get a result until task is executed'
return self._result
def __repr__(self):
fmt = '<%s at %x %r(*%r, **%r)%%s>' % (self.__class__.__name__,
id(self), self._func, self._args, self._kw)
if self._processed:
extra = ' => %r' % (self._result, )
else:
extra = ''
return fmt % (extra, )
class Replicator(object):
"""This class handles replications of objects and transactions.
......@@ -98,21 +138,23 @@ class Replicator(object):
# didn't answer yet.
# unfinished_tid_list
# The list of unfinished TIDs known by master node.
# oid_list
# List of OIDs to replicate. Doesn't contains currently-replicated
# object.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# next_oid
# Next OID to ask when oid_list is empty.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# replication_done
# False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if
# there is something to replicate
# XXX: accessed (w) directly by ReplicationHandler
new_partition_dict = None
critical_tid_dict = None
partition_dict = None
task_list = None
task_dict = None
current_partition = None
current_connection = None
waiting_for_unfinished_tids = None
unfinished_tid_list = None
replication_done = None
def __init__(self, app):
self.app = app
......@@ -129,6 +171,8 @@ class Replicator(object):
def reset(self):
"""Reset attributes to restart replicating."""
self.task_list = []
self.task_dict = {}
self.current_partition = None
self.current_connection = None
self.waiting_for_unfinished_tids = False
......@@ -213,15 +257,12 @@ class Replicator(object):
p = Packets.RequestIdentification(NodeTypes.STORAGE,
app.uuid, app.server, app.name)
self.current_connection.ask(p)
p = Packets.AskTIDsFrom(ZERO_TID, 1000,
self.current_partition.getRID())
self.current_connection.ask(p, timeout=300)
else:
self.current_connection.getHandler().startReplication(
self.current_connection)
self.replication_done = False
def _finishReplication(self):
app = self.app
# TODO: remove try..except: pass
try:
self.partition_dict.pop(self.current_partition.getRID())
......@@ -243,7 +284,11 @@ class Replicator(object):
self._askCriticalTID()
if self.current_partition is not None:
if self.replication_done:
# Don't end replication until we have received all expected
# answers, as we might have asked object data just before the last
# AnswerCheckSerialRange.
if self.replication_done and \
not self.current_connection.isPending():
# finish a replication
logging.info('replication is done for %s' %
(self.current_partition.getRID(), ))
......@@ -289,3 +334,57 @@ class Replicator(object):
and not self.new_partition_dict.has_key(rid):
self.new_partition_dict[rid] = Partition(rid)
def _addTask(self, key, func, args=(), kw=None):
task = Task(func, args, kw)
task_dict = self.task_dict
if key in task_dict:
raise ValueError, 'Task with key %r already exists (%r), cannot ' \
'add %r' % (key, task_dict[key], task)
task_dict[key] = task
self.task_list.append(task)
def processDelayedTasks(self):
task_list = self.task_list
if task_list:
for task in task_list:
task.process()
self.task_list = []
def checkTIDRange(self, min_tid, length, partition):
app = self.app
self._addTask(('TID', min_tid, length), app.dm.checkTIDRange,
(min_tid, length, app.pt.getPartitions(), partition))
def checkSerialRange(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask(('Serial', min_oid, min_serial, length),
app.dm.checkSerialRange, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def getTIDsFrom(self, min_tid, length, partition):
app = self.app
self._addTask('TIDsFrom',
app.dm.getReplicationTIDList, (min_tid, length,
app.pt.getPartitions(), partition))
def getObjectHistoryFrom(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask('ObjectHistoryFrom',
app.dm.getObjectHistoryFrom, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult()
def getTIDCheckResult(self, min_tid, length):
return self._getCheckResult(('TID', min_tid, length))
def getSerialCheckResult(self, min_oid, min_serial, length):
return self._getCheckResult(('Serial', min_oid, min_serial, length))
def getTIDsFromResult(self):
return self._getCheckResult('TIDsFrom')
def getObjectHistoryFromResult(self):
return self._getCheckResult('ObjectHistoryFrom')
This diff is collapsed.
This diff is collapsed.
......@@ -21,7 +21,7 @@ from collections import deque
from neo.tests import NeoTestBase
from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler
from neo.protocol import INVALID_PARTITION
from neo.protocol import INVALID_PARTITION, Packets
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
class StorageStorageHandlerTests(NeoTestBase):
......@@ -113,7 +113,7 @@ class StorageStorageHandlerTests(NeoTestBase):
self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
def test_25_askTIDsFrom1(self):
def test_25_askTIDsFrom(self):
# well case => answer
conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
......@@ -122,69 +122,85 @@ class StorageStorageHandlerTests(NeoTestBase):
self.operation.askTIDsFrom(conn, tid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [1, ])
self.checkAnswerTidsFrom(conn)
def test_25_askTIDsFrom2(self):
# invalid partition => answer usable partitions
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [0, ])
calls[0].checkArgs(tid, 2, 1, 1)
self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self):
oid = self.getOID(2)
min_tid = self.getNextTID()
min_oid = self.getOID(2)
min_serial = self.getNextTID()
length = 4
partition = 8
num_partitions = 16
tid = self.getNextTID()
conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistoryFrom': [tid]})
self.operation.askObjectHistoryFrom(conn, oid, min_tid, 2)
self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},})
self.app.pt = Mock({
'getPartitions': num_partitions,
})
self.operation.askObjectHistoryFrom(conn, min_oid, min_serial, length,
partition)
self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, min_tid, 2)
calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
partition)
def test_25_askOIDs1(self):
# well case > answer OIDs
def test_askCheckTIDRange(self):
count = 1
tid_checksum = 2
min_tid = self.getNextTID()
num_partitions = 4
length = 5
partition = 6
max_tid = self.getNextTID()
self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, 2, 1, [1, ])
self.checkAnswerOids(conn)
def test_25_askOIDs2(self):
# invalid partition => answer usable partitions
self.operation.askCheckTIDRange(conn, min_tid, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, length, num_partitions, partition)
pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
decode=True)
self.assertEqual(min_tid, pmin_tid)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(tid_checksum, ptid_checksum)
self.assertEqual(max_tid, pmax_tid)
def test_askCheckSerialRange(self):
count = 1
oid_checksum = 2
min_oid = self.getOID(1)
num_partitions = 4
length = 5
partition = 6
serial_checksum = 7
min_serial = self.getNextTID()
max_serial = self.getNextTID()
max_oid = self.getOID(2)
self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid,
serial_checksum, max_serial)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, 2, 1, [0])
self.checkAnswerOids(conn)
self.operation.askCheckSerialRange(conn, min_oid, min_serial, length,
partition)
calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
partition)
pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
Packets.AnswerCheckSerialRange, decode=True)
self.assertEqual(min_oid, pmin_oid)
self.assertEqual(min_serial, pmin_serial)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(oid_checksum, poid_checksum)
self.assertEqual(max_oid, pmax_oid)
self.assertEqual(serial_checksum, pserial_checksum)
self.assertEqual(max_serial, pmax_serial)
if __name__ == "__main__":
unittest.main()
......@@ -19,7 +19,7 @@ import unittest
import MySQLdb
from mock import Mock
from neo.util import dump, p64, u64
from neo.protocol import CellStates, INVALID_PTID
from neo.protocol import CellStates, INVALID_PTID, ZERO_OID, ZERO_TID
from neo.tests import NeoTestBase
from neo.exception import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager
......@@ -441,6 +441,23 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1)
self.assertEqual(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \
objs2[1][1:])
def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
......@@ -459,30 +476,6 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False))
self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getOIDList(self):
# store four objects
oid1, oid2, oid3, oid4 = self.getOIDs(4)
tid = self.getNextTID()
txn, objs = self.getTransaction([oid1, oid2, oid3, oid4])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
# get oids
result = self.db.getOIDList(oid1, 4, 1, [0])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 2, [0])
self.checkSet(result, [oid1, oid3])
result = self.db.getOIDList(oid1, 4, 2, [0, 1])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 3, [0])
self.checkSet(result, [oid1, oid4])
# get a subset of oids
result = self.db.getOIDList(oid1, 2, 1, [0])
self.checkSet(result, [oid1, oid2])
result = self.db.getOIDList(oid3, 2, 1, [0])
self.checkSet(result, [oid3, oid4])
result = self.db.getOIDList(oid2, 1, 3, [0])
self.checkSet(result, [oid4])
def test_getObjectHistory(self):
oid = self.getOID(1)
tid1, tid2, tid3 = self.getTIDs(3)
......@@ -506,6 +499,50 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None)
def test_getObjectHistoryFrom(self):
oid1 = self.getOID(0)
oid2 = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
txn3, objs3 = self.getTransaction([oid1])
txn4, objs4 = self.getTransaction([oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.storeTransaction(tid3, objs3, txn3)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
# Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 3, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
})
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 1)
self.assertEqual(result, {
oid2: [tid2, tid4],
})
def _storeTransactions(self, count):
# use OID generator to know result of tid % N
tid_list = self.getOIDs(count)
......@@ -538,59 +575,20 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
result = self.db.getReplicationTIDList(tid1, 4, 1, [0])
result = self.db.getReplicationTIDList(tid1, 4, 1, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0])
result = self.db.getReplicationTIDList(tid1, 4, 2, 0)
self.checkSet(result, [tid1, tid3])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0, 1])
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 3, [0])
result = self.db.getReplicationTIDList(tid1, 4, 3, 0)
self.checkSet(result, [tid1, tid4])
# get a subset of tids
result = self.db.getReplicationTIDList(tid3, 4, 1, [0])
result = self.db.getReplicationTIDList(tid3, 4, 1, 0)
self.checkSet(result, [tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 2, 1, [0])
result = self.db.getReplicationTIDList(tid1, 2, 1, 0)
self.checkSet(result, [tid1, tid2])
result = self.db.getReplicationTIDList(tid1, 1, 3, [1])
result = self.db.getReplicationTIDList(tid1, 1, 3, 1)
self.checkSet(result, [tid2])
def test_getTIDListPresent(self):
oid = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid])
txn4, objs4 = self.getTransaction([oid])
# four tids, two missing
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getTIDListPresent([tid1, tid2, tid3, tid4])
self.checkSet(result, [tid1, tid4])
result = self.db.getTIDListPresent([tid1, tid2])
self.checkSet(result, [tid1])
self.assertEqual(self.db.getTIDListPresent([tid2, tid3]), [])
def test_getSerialListPresent(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
txn3, objs3 = self.getTransaction([oid2])
txn4, objs4 = self.getTransaction([oid2])
# four object, one revision each
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getSerialListPresent(oid1, [tid1, tid2])
self.checkSet(result, [tid1])
result = self.db.getSerialListPresent(oid2, [tid3, tid4])
self.checkSet(result, [tid4])
result = self.db.getSerialListPresent(oid1, [tid2])
self.assertEqual(result, [])
result = self.db.getSerialListPresent(oid2, [tid3])
self.assertEqual(result, [])
def test__getObjectData(self):
db = self.db
db.setup(reset=True)
......
......@@ -458,24 +458,6 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(p_hist_list, hist_list)
self.assertEqual(oid, poid)
def test_55_askOIDs(self):
oid = self.getOID(1)
p = Packets.AskOIDs(oid, 1000, 5)
min_oid, length, partition = p.decode()
self.assertEqual(min_oid, oid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
def test_56_answerOIDs(self):
oid1 = self.getNextTID()
oid2 = self.getNextTID()
oid3 = self.getNextTID()
oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4]
p = Packets.AnswerOIDs(oid_list)
p_oid_list = p.decode()[0]
self.assertEqual(p_oid_list, oid_list)
def test_57_notifyReplicationDone(self):
offset = 10
p = Packets.NotifyReplicationDone(offset)
......@@ -626,14 +608,82 @@ class ProtocolTests(NeoTestBase):
oid = self.getOID(1)
min_serial = self.getNextTID()
length = 5
p = Packets.AskObjectHistoryFrom(oid, min_serial, length)
p_oid, p_min_serial, p_length = p.decode()
partition = 4
p = Packets.AskObjectHistoryFrom(oid, min_serial, length, partition)
p_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerObjectHistoryFrom(self):
self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom)
object_dict = {}
for int_oid in xrange(4):
object_dict[self.getOID(int_oid)] = [self.getNextTID() \
for _ in xrange(5)]
p = Packets.AnswerObjectHistoryFrom(object_dict)
p_object_dict = p.decode()[0]
self.assertEqual(object_dict, p_object_dict)
def test_AskCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckTIDRange(min_tid, length, partition)
p_min_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
count = 1
tid_checksum = 42
max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid)
p_min_tid, p_length, p_count, p_tid_checksum, p_max_tid = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_tid, max_tid)
def test_AskCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckSerialRange(min_oid, min_serial, length, partition)
p_min_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
count = 1
oid_checksum = 24
max_oid = self.getOID(5)
tid_checksum = 42
max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial)
p_min_oid, p_min_serial, p_length, p_count, p_oid_checksum, \
p_max_oid, p_tid_checksum, p_max_serial = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_oid_checksum, oid_checksum)
self.assertEqual(p_max_oid, max_oid)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_serial, max_serial)
def test_AskPack(self):
tid = self.getNextTID()
......
......@@ -56,6 +56,8 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions',
'neo.tests.storage.testReplicationHandler',
'neo.tests.storage.testReplicator',
# client application
'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment