Commit 83c02447 authored by Vincent Pelletier's avatar Vincent Pelletier

Use value_serial for undo support.

This mimics what FileStorage uses (file offsets) but in a relational manner.
This offloads decision of the ability to undo a transaction to storages,
avoiding 3 data loads for each object in the transaction at client side.
This also makes Neo refuse to undo transactions where object data would happen
to be equal between current value and undone value.
Finally, this is required to make database pack work properly (namely, it
prevents loosing objects which are orphans at pack TID, but are reachable
after it thanks to a transactional undo).

Also, extend storage's transaction manager so database adapter can fetch data
already sent by client in the same transaction, so it can undo multiple
transactions at once. Requires making object lock re-entrant (done in this
commit).

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1978 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 987351fb
...@@ -112,6 +112,8 @@ class ThreadContext(object): ...@@ -112,6 +112,8 @@ class ThreadContext(object):
'node_tids': {}, 'node_tids': {},
'node_ready': False, 'node_ready': False,
'asked_object': 0, 'asked_object': 0,
'undo_conflict_oid_list': [],
'undo_error_oid_list': [],
} }
...@@ -805,31 +807,54 @@ class Application(object): ...@@ -805,31 +807,54 @@ class Application(object):
else: else:
raise NEOStorageError('undo failed') raise NEOStorageError('undo failed')
if self.local_var.txn_info['packed']:
UndoError('non-undoable transaction')
tid = self.local_var.tid
undo_conflict_oid_list = self.local_var.undo_conflict_oid_list = []
undo_error_oid_list = self.local_var.undo_error_oid_list = []
ask_undo_transaction = Packets.AskUndoTransaction(tid, undone_tid)
getConnForNode = self.cp.getConnForNode
for storage_node in self.nm.getStorageList():
storage_conn = getConnForNode(storage_node)
storage_conn.ask(ask_undo_transaction)
# Wait for all AnswerUndoTransaction.
self.waitResponses()
# Don't do any handling for "live" conflicts, raise
if undo_conflict_oid_list:
raise ConflictError(oid=undo_conflict_oid_list[0], serials=(tid,
undone_tid), data=None)
# Try to resolve undo conflicts
for oid in undo_error_oid_list:
def loadBefore(oid, tid):
try:
result = self._load(oid, tid=tid)
except NEOStorageNotFoundError:
raise UndoError("Object not found while resolving undo " \
"conflict")
return result[:2]
# Load the latest version we are supposed to see
data, data_tid = loadBefore(oid, tid)
# Load the version we were undoing to
undo_data, _ = loadBefore(oid, undone_tid)
# Resolve conflict
new_data = tryToResolveConflict(oid, data_tid, undone_tid, undo_data,
data)
if new_data is None:
raise UndoError('Some data were modified by a later ' \
'transaction', oid)
else:
self.store(oid, data_tid, new_data, '', self.local_var.txn)
oid_list = self.local_var.txn_info['oids'] oid_list = self.local_var.txn_info['oids']
# Second get object data from storage node using loadBefore # Consistency checking: all oids of the transaction must have been
data_dict = {} # reported as undone
# XXX: this way causes each object to be loaded 3 times from storage, data_dict = self.local_var.data_dict
# this work should rather be offloaded to it.
for oid in oid_list: for oid in oid_list:
current_data = self.load(oid)[0] assert oid in data_dict, repr(oid)
after_data = self.loadSerial(oid, undone_tid)
if current_data != after_data:
raise UndoError("non-undoable transaction", oid)
try:
data = self.loadBefore(oid, undone_tid)[0]
except NEOStorageNotFoundError:
if oid == '\x00' * 8:
# Refuse undoing root object creation.
raise UndoError("no previous record", oid)
else:
# Undo object creation
data = ''
data_dict[oid] = data
# Third do transaction with old data
for oid, data in data_dict.iteritems():
self.store(oid, undone_tid, data, None, txn)
self.waitStoreResponses(tryToResolveConflict)
return self.local_var.tid, oid_list return self.local_var.tid, oid_list
def _insertMetadata(self, txn_info, extension): def _insertMetadata(self, txn_info, extension):
......
...@@ -19,6 +19,7 @@ from ZODB.TimeStamp import TimeStamp ...@@ -19,6 +19,7 @@ from ZODB.TimeStamp import TimeStamp
from neo.client.handlers import BaseHandler, AnswerBaseHandler from neo.client.handlers import BaseHandler, AnswerBaseHandler
from neo.protocol import NodeTypes, ProtocolError from neo.protocol import NodeTypes, ProtocolError
from neo.util import dump
class StorageEventHandler(BaseHandler): class StorageEventHandler(BaseHandler):
...@@ -58,7 +59,10 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -58,7 +59,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
""" Handle all messages related to ZODB operations """ """ Handle all messages related to ZODB operations """
def answerObject(self, conn, oid, start_serial, end_serial, def answerObject(self, conn, oid, start_serial, end_serial,
compression, checksum, data): compression, checksum, data, data_serial):
if data_serial is not None:
raise ValueError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), )
self.app.local_var.asked_object = (oid, start_serial, end_serial, self.app.local_var.asked_object = (oid, start_serial, end_serial,
compression, checksum, data) compression, checksum, data)
...@@ -112,3 +116,12 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -112,3 +116,12 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerTIDs(self, conn, tid_list): def answerTIDs(self, conn, tid_list):
self.app.local_var.node_tids[conn.getUUID()] = tid_list self.app.local_var.node_tids[conn.getUUID()] = tid_list
def answerUndoTransaction(self, conn, oid_list, error_oid_list,
conflict_oid_list):
local_var = self.app.local_var
local_var.undo_conflict_oid_list.extend(conflict_oid_list)
local_var.undo_error_oid_list.extend(error_oid_list)
data_dict = local_var.data_dict
for oid in oid_list:
data_dict[oid] = ''
...@@ -247,7 +247,7 @@ class EventHandler(object): ...@@ -247,7 +247,7 @@ class EventHandler(object):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerObject(self, conn, oid, serial_start, def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data): serial_end, compression, checksum, data, data_serial):
raise UnexpectedPacketError raise UnexpectedPacketError
def askTIDs(self, conn, first, last, partition): def askTIDs(self, conn, first, last, partition):
...@@ -323,6 +323,11 @@ class EventHandler(object): ...@@ -323,6 +323,11 @@ class EventHandler(object):
def notifyReplicationDone(self, conn, offset): def notifyReplicationDone(self, conn, offset):
raise UnexpectedPacketError raise UnexpectedPacketError
def askUndoTransaction(self, conn, tid, undone_tid):
raise UnexpectedPacketError
def answerUndoTransaction(self, conn, oid_list, error_oid_list, conflict_oid_list):
raise UnexpectedPacketError
# Error packet handlers. # Error packet handlers.
...@@ -427,6 +432,8 @@ class EventHandler(object): ...@@ -427,6 +432,8 @@ class EventHandler(object):
d[Packets.NotifyClusterInformation] = self.notifyClusterInformation d[Packets.NotifyClusterInformation] = self.notifyClusterInformation
d[Packets.NotifyLastOID] = self.notifyLastOID d[Packets.NotifyLastOID] = self.notifyLastOID
d[Packets.NotifyReplicationDone] = self.notifyReplicationDone d[Packets.NotifyReplicationDone] = self.notifyReplicationDone
d[Packets.AskUndoTransaction] = self.askUndoTransaction
d[Packets.AnswerUndoTransaction] = self.answerUndoTransaction
return d return d
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from struct import pack, unpack, error, calcsize from struct import pack, unpack, error, calcsize
from socket import inet_ntoa, inet_aton from socket import inet_ntoa, inet_aton
from neo.profiling import profiler_decorator from neo.profiling import profiler_decorator
from cStringIO import StringIO
from neo.util import Enum from neo.util import Enum
...@@ -102,6 +103,7 @@ INVALID_OID = '\xff' * 8 ...@@ -102,6 +103,7 @@ INVALID_OID = '\xff' * 8
INVALID_PTID = '\0' * 8 INVALID_PTID = '\0' * 8
INVALID_SERIAL = INVALID_TID INVALID_SERIAL = INVALID_TID
INVALID_PARTITION = 0xffffffff INVALID_PARTITION = 0xffffffff
OID_LEN = len(INVALID_OID)
UUID_NAMESPACES = { UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S', NodeTypes.STORAGE: 'S',
...@@ -988,25 +990,30 @@ class AnswerObject(Packet): ...@@ -988,25 +990,30 @@ class AnswerObject(Packet):
""" """
Answer the requested object. S -> C. Answer the requested object. S -> C.
""" """
_header_format = '!8s8s8sBL' _header_format = '!8s8s8s8sBL'
def _encode(self, oid, serial_start, serial_end, compression, def _encode(self, oid, serial_start, serial_end, compression,
checksum, data): checksum, data, data_serial):
if serial_start is None: if serial_start is None:
serial_start = INVALID_TID serial_start = INVALID_TID
if serial_end is None: if serial_end is None:
serial_end = INVALID_TID serial_end = INVALID_TID
if data_serial is None:
data_serial = INVALID_TID
return pack(self._header_format, oid, serial_start, serial_end, return pack(self._header_format, oid, serial_start, serial_end,
compression, checksum) + _encodeString(data) data_serial, compression, checksum) + _encodeString(data)
def _decode(self, body): def _decode(self, body):
header_len = self._header_len header_len = self._header_len
r = unpack(self._header_format, body[:header_len]) r = unpack(self._header_format, body[:header_len])
oid, serial_start, serial_end, compression, checksum = r oid, serial_start, serial_end, data_serial, compression, checksum = r
if serial_end == INVALID_TID: if serial_end == INVALID_TID:
serial_end = None serial_end = None
if data_serial == INVALID_TID:
data_serial = None
(data, _) = _decodeString(body, 'data', offset=header_len) (data, _) = _decodeString(body, 'data', offset=header_len)
return (oid, serial_start, serial_end, compression, checksum, data) return (oid, serial_start, serial_end, compression, checksum, data,
data_serial)
class AskTIDs(Packet): class AskTIDs(Packet):
""" """
...@@ -1354,7 +1361,8 @@ class AnswerNewNodes(Packet): ...@@ -1354,7 +1361,8 @@ class AnswerNewNodes(Packet):
def _encode(self, uuid_list): def _encode(self, uuid_list):
list_header_format = self._list_header_format list_header_format = self._list_header_format
# an empty list means no new nodes # an empty list means no new nodes
uuid_list = [pack(list_header_format, _encodeUUID(uuid)) for uuid in uuid_list] uuid_list = [pack(list_header_format, _encodeUUID(uuid)) for \
uuid in uuid_list]
return pack(self._header_format, len(uuid_list)) + ''.join(uuid_list) return pack(self._header_format, len(uuid_list)) + ''.join(uuid_list)
def _decode(self, body): def _decode(self, body):
...@@ -1472,6 +1480,56 @@ class NotifyLastOID(Packet): ...@@ -1472,6 +1480,56 @@ class NotifyLastOID(Packet):
(loid, ) = unpack('8s', body) (loid, ) = unpack('8s', body)
return (loid, ) return (loid, )
class AskUndoTransaction(Packet):
"""
Ask storage to undo given transaction
C -> S
"""
def _encode(self, tid, undone_tid):
return _encodeTID(tid) + _encodeTID(undone_tid)
def _decode(self, body):
tid = _decodeTID(body[:8])
undone_tid = _decodeTID(body[8:])
return (tid, undone_tid)
class AnswerUndoTransaction(Packet):
"""
Answer an undo request, telling if undo could be done, with an oid list.
If undo failed, the list contains oid(s) causing problems.
If undo succeeded; the list contains all undone oids for given storage.
S -> C
"""
_header_format = '!LLL'
def _encode(self, oid_list, error_oid_list, conflict_oid_list):
body = StringIO()
write = body.write
oid_list_list = [oid_list, error_oid_list, conflict_oid_list]
write(pack(self._header_format, *[len(x) for x in oid_list_list]))
for oid_list in oid_list_list:
for oid in oid_list:
write(oid)
return body.getvalue()
def _decode(self, body):
body = StringIO(body)
read = body.read
oid_list_len, error_oid_list_len, conflict_oid_list_len = unpack(
self._header_format, read(self._header_len))
oid_list = []
error_oid_list = []
conflict_oid_list = []
for some_list, some_list_len in (
(oid_list, oid_list_len),
(error_oid_list, error_oid_list_len),
(conflict_oid_list, conflict_oid_list_len),
):
append = some_list.append
for _ in xrange(some_list_len):
append(read(OID_LEN))
return (oid_list, error_oid_list, conflict_oid_list)
class Error(Packet): class Error(Packet):
""" """
Error is a special type of message, because this can be sent against Error is a special type of message, because this can be sent against
...@@ -1671,6 +1729,10 @@ class PacketRegistry(dict): ...@@ -1671,6 +1729,10 @@ class PacketRegistry(dict):
AnswerClusterState) AnswerClusterState)
NotifyLastOID = register(0x0030, NotifyLastOID) NotifyLastOID = register(0x0030, NotifyLastOID)
NotifyReplicationDone = register(0x0031, NotifyReplicationDone) NotifyReplicationDone = register(0x0031, NotifyReplicationDone)
AskUndoTransaction, AnswerUndoTransaction = register(
0x0033,
AskUndoTransaction,
AnswerUndoTransaction)
# build a "singleton" # build a "singleton"
Packets = PacketRegistry() Packets = PacketRegistry()
......
...@@ -236,6 +236,16 @@ class DatabaseManager(object): ...@@ -236,6 +236,16 @@ class DatabaseManager(object):
pack state (True for packed).""" pack state (True for packed)."""
raise NotImplementedError raise NotImplementedError
def getTransactionUndoData(self, tid, undone_tid,
getObjectFromTransaction):
"""Undo transaction with "undone_tid" tid. "tid" is the tid of the
transaction in which the undo happens.
getObjectFromTransaction is a callback allowing to find object data
stored to this storage in the same transaction (it is useful for
example when undoing twice in the same transaction).
"""
raise NotImplementedError
def finishTransaction(self, tid): def finishTransaction(self, tid):
"""Finish a transaction specified by a given ID, by moving """Finish a transaction specified by a given ID, by moving
temporarily data to a finished area.""" temporarily data to a finished area."""
......
...@@ -29,6 +29,9 @@ from neo import util ...@@ -29,6 +29,9 @@ from neo import util
LOG_QUERIES = False LOG_QUERIES = False
class CreationUndone(Exception):
pass
def splitOIDField(tid, oids): def splitOIDField(tid, oids):
if (len(oids) % 8) != 0 or len(oids) == 0: if (len(oids) % 8) != 0 or len(oids) == 0:
raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid, raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid,
...@@ -157,9 +160,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -157,9 +160,10 @@ class MySQLDatabaseManager(DatabaseManager):
q("""CREATE TABLE IF NOT EXISTS obj ( q("""CREATE TABLE IF NOT EXISTS obj (
oid BIGINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NULL,
checksum INT UNSIGNED NOT NULL, checksum INT UNSIGNED NULL,
value LONGBLOB NOT NULL, value LONGBLOB NULL,
value_serial BIGINT UNSIGNED NULL,
PRIMARY KEY (oid, serial) PRIMARY KEY (oid, serial)
) ENGINE = InnoDB""") ) ENGINE = InnoDB""")
...@@ -177,9 +181,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -177,9 +181,10 @@ class MySQLDatabaseManager(DatabaseManager):
q("""CREATE TABLE IF NOT EXISTS tobj ( q("""CREATE TABLE IF NOT EXISTS tobj (
oid BIGINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NULL,
checksum INT UNSIGNED NOT NULL, checksum INT UNSIGNED NULL,
value LONGBLOB NOT NULL value LONGBLOB NULL,
value_serial BIGINT UNSIGNED NULL
) ENGINE = InnoDB""") ) ENGINE = InnoDB""")
def getConfiguration(self, key): def getConfiguration(self, key):
...@@ -259,27 +264,44 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -259,27 +264,44 @@ class MySQLDatabaseManager(DatabaseManager):
return True return True
return False return False
def getObject(self, oid, tid = None, before_tid = None): def _getObjectData(self, oid, value_serial, tid):
if value_serial is None:
raise CreationUndone
if value_serial >= tid:
raise ValueError, "Incorrect value reference found for " \
"oid %d at tid %d: reference = %d" % (oid, value_serial, tid)
r = self.query("""SELECT compression, checksum, value, """ \
"""value_serial FROM obj WHERE oid = %d AND serial = %d""" % (
oid, value_serial))
compression, checksum, value, next_value_serial = r[0]
if value is None:
logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
value_serial, compression, checksum, value = self._getObjectData(
oid, next_value_serial, value_serial)
return value_serial, compression, checksum, value
def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
oid = util.u64(oid)
if tid is not None: if tid is not None:
tid = util.u64(tid) r = q("""SELECT serial, compression, checksum, value, value_serial
r = q("""SELECT serial, compression, checksum, value FROM obj FROM obj
WHERE oid = %d AND serial = %d""" \ WHERE oid = %d AND serial = %d""" \
% (oid, tid)) % (oid, tid))
try: try:
serial, compression, checksum, data = r[0] serial, compression, checksum, data, value_serial = r[0]
next_serial = None next_serial = None
except IndexError: except IndexError:
return None return None
elif before_tid is not None: elif before_tid is not None:
before_tid = util.u64(before_tid) r = q("""SELECT serial, compression, checksum, value, value_serial
r = q("""SELECT serial, compression, checksum, value FROM obj FROM obj
WHERE oid = %d AND serial < %d WHERE oid = %d AND serial < %d
ORDER BY serial DESC LIMIT 1""" \ ORDER BY serial DESC LIMIT 1""" \
% (oid, before_tid)) % (oid, before_tid))
try: try:
serial, compression, checksum, data = r[0] serial, compression, checksum, data, value_serial = r[0]
except IndexError: except IndexError:
return None return None
r = q("""SELECT serial FROM obj r = q("""SELECT serial FROM obj
...@@ -293,20 +315,52 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -293,20 +315,52 @@ class MySQLDatabaseManager(DatabaseManager):
else: else:
# XXX I want to express "HAVING serial = MAX(serial)", but # XXX I want to express "HAVING serial = MAX(serial)", but
# MySQL does not use an index for a HAVING clause! # MySQL does not use an index for a HAVING clause!
r = q("""SELECT serial, compression, checksum, value FROM obj r = q("""SELECT serial, compression, checksum, value, value_serial
FROM obj
WHERE oid = %d ORDER BY serial DESC LIMIT 1""" \ WHERE oid = %d ORDER BY serial DESC LIMIT 1""" \
% oid) % oid)
try: try:
serial, compression, checksum, data = r[0] serial, compression, checksum, data, value_serial = r[0]
next_serial = None next_serial = None
except IndexError: except IndexError:
return None return None
if serial is not None: return serial, next_serial, compression, checksum, data, value_serial
serial = util.p64(serial)
if next_serial is not None: def getObject(self, oid, tid=None, before_tid=None, resolve_data=True):
next_serial = util.p64(next_serial) # TODO: resolve_data must be unit-tested
return serial, next_serial, compression, checksum, data u64 = util.u64
p64 = util.p64
oid = u64(oid)
if tid is not None:
tid = u64(tid)
if before_tid is not None:
before_tid = u64(before_tid)
result = self._getObject(oid, tid, before_tid)
if result is not None:
serial, next_serial, compression, checksum, data, data_serial = \
result
if data is None and resolve_data:
try:
_, compression, checksum, data = self._getObjectData(oid,
data_serial, serial)
except CreationUndone:
# XXX: why is a special case needed here ?
if tid is None:
return None
compression = 0
# XXX: this is the valid checksum for empty string
checksum = 1
data = ''
data_serial = None
if serial is not None:
serial = p64(serial)
if next_serial is not None:
next_serial = p64(next_serial)
if data_serial is not None:
data_serial = p64(data_serial)
result = serial, next_serial, compression, checksum, data, data_serial
return result
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
q = self.query q = self.query
...@@ -376,11 +430,25 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -376,11 +430,25 @@ class MySQLDatabaseManager(DatabaseManager):
self.begin() self.begin()
try: try:
for oid, compression, checksum, data in object_list: for oid, compression, checksum, data, value_serial in object_list:
oid = util.u64(oid) oid = util.u64(oid)
data = e(data) if data is None:
q("""REPLACE INTO %s VALUES (%d, %d, %d, %d, '%s')""" \ compression = checksum = data = 'NULL'
% (obj_table, oid, tid, compression, checksum, data)) else:
# TODO: unit-test this raise
if value_serial is not None:
raise ValueError, 'Either data or value_serial ' \
'must be None (oid %d, tid %d)' % (oid, tid)
compression = '%d' % (compression, )
checksum = '%d' % (checksum, )
data = "'%s'" % (e(data), )
if value_serial is None:
value_serial = 'NULL'
else:
value_serial = '%d' % (value_serial, )
q("""REPLACE INTO %s VALUES (%d, %d, %s, %s, %s, %s)""" \
% (obj_table, oid, tid, compression, checksum, data,
value_serial))
if transaction is not None: if transaction is not None:
oid_list, user, desc, ext, packed = transaction oid_list, user, desc, ext, packed = transaction
packed = packed and 1 or 0 packed = packed and 1 or 0
...@@ -395,6 +463,107 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -395,6 +463,107 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def _getDataTIDFromData(self, oid, result):
tid, next_serial, compression, checksum, data, value_serial = result
if data is None:
try:
data_serial = self._getObjectData(oid, value_serial, tid)[0]
except CreationUndone:
data_serial = None
else:
data_serial = tid
return tid, data_serial
def _getDataTID(self, oid, tid=None, before_tid=None):
"""
Return a 2-tuple:
tid (int)
tid corresponding to received parameters
serial
tid at which actual object data is located
If 'tid is None', requested object and transaction could
not be found.
If 'serial is None', requested object exist but has no data (its creation
has been undone).
If 'tid == serial', it means that requested transaction
contains object data.
Otherwise, it's an undo transaction which did not involve conflict
resolution.
"""
result = self._getObject(oid, tid=tid, before_tid=before_tid)
if result is None:
result = (None, None)
else:
result = self._getDataTIDFromData(oid, result)
return result
def _findUndoTID(self, oid, tid, undone_tid, transaction_object):
"""
oid, undone_tid (ints)
Object to undo for given transaction
tid (int)
Client's transaction (he can't see objects past this value).
Return a 2-tuple:
current_tid (p64)
TID of most recent version of the object client's transaction can
see. This is used later to detect current conflicts (eg, another
client modifying the same object in parallel)
data_tid (int)
TID containing (without indirection) the data prior to undone
transaction.
-1 if object was modified by later transaction.
None if object doesn't exist prior to transaction being undone
(its creation is being undone).
"""
_getDataTID = self._getDataTID
if transaction_object is not None:
# transaction_object:
# oid, compression, ...
# Expected value:
# serial, next_serial, compression, ...
current_tid, current_data_tid = self._getDataTIDFromData(oid,
(tid, None) + transaction_object[1:])
else:
current_tid, current_data_tid = _getDataTID(oid, before_tid=tid)
assert current_tid is not None, (oid, tid, transaction_object)
found_undone_tid, undone_data_tid = _getDataTID(oid, tid=undone_tid)
assert found_undone_tid is not None, (oid, undone_tid)
if undone_data_tid not in (current_data_tid, tid):
# data from the transaction we want to undo is modified by a later
# transaction. It is up to the client node to decide what to do
# (undo error of conflict resolution).
data_tid = -1
else:
# Load object data as it was before given transaction.
# It can be None, in which case it means we are undoing object
# creation.
_, data_tid = _getDataTID(oid, before_tid=undone_tid)
return util.p64(current_tid), data_tid
def getTransactionUndoData(self, tid, undone_tid,
getObjectFromTransaction):
q = self.query
p64 = util.p64
u64 = util.u64
_findUndoTID = self._findUndoTID
p_tid = tid
tid = u64(tid)
undone_tid = u64(undone_tid)
if undone_tid > tid:
# Replace with an exception reaching client (TIDNotFound)
raise ValueError, 'Can\'t undo in future: %d > %d' % (
undone_tid, tid)
result = {}
for (oid, ) in q("""SELECT oid FROM obj WHERE serial = %d""" % (
undone_tid, )):
p_oid = p64(oid)
result[p_oid] = _findUndoTID(oid, tid, undone_tid,
getObjectFromTransaction(p_tid, p_oid))
return result
def finishTransaction(self, tid): def finishTransaction(self, tid):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
...@@ -453,14 +622,40 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -453,14 +622,40 @@ class MySQLDatabaseManager(DatabaseManager):
offset, length)) offset, length))
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
r = self.query("""SELECT LENGTH(value), value_serial FROM obj """ \
"""WHERE oid = %d AND serial = %d""" % (oid, value_serial))
length, value_serial = r[0]
if length is None:
logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
length = self._getObjectLength(oid, value_serial)
return length
def getObjectHistory(self, oid, offset = 0, length = 1): def getObjectHistory(self, oid, offset = 0, length = 1):
# FIXME: This method doesn't take client's current ransaction id as
# parameter, which means it can return transactions in the future of
# client's transaction.
q = self.query q = self.query
oid = util.u64(oid) oid = util.u64(oid)
r = q("""SELECT serial, LENGTH(value) FROM obj WHERE oid = %d p64 = util.p64
ORDER BY serial DESC LIMIT %d, %d""" \ r = q("""SELECT serial, LENGTH(value), value_serial FROM obj
WHERE oid = %d ORDER BY serial DESC LIMIT %d, %d""" \
% (oid, offset, length)) % (oid, offset, length))
if r: if r:
return [(util.p64(serial), length) for serial, length in r] result = []
append = result.append
for serial, length, value_serial in r:
if length is None:
try:
length = self._getObjectLength(oid, value_serial)
except CreationUndone:
length = 0
append((p64(serial), length))
return result
return None return None
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
......
...@@ -82,19 +82,22 @@ class BaseClientAndStorageOperationHandler(EventHandler): ...@@ -82,19 +82,22 @@ class BaseClientAndStorageOperationHandler(EventHandler):
t[4], t[0]) t[4], t[0])
conn.answer(p) conn.answer(p)
def _askObject(self, oid, serial, tid):
raise NotImplementedError
def askObject(self, conn, oid, serial, tid): def askObject(self, conn, oid, serial, tid):
app = self.app app = self.app
if self.app.tm.loadLocked(oid): if self.app.tm.loadLocked(oid):
# Delay the response. # Delay the response.
app.queueEvent(self.askObject, conn, oid, serial, tid) app.queueEvent(self.askObject, conn, oid, serial, tid)
return return
o = app.dm.getObject(oid, serial, tid) o = self._askObject(oid, serial, tid)
if o is not None: if o is not None:
serial, next_serial, compression, checksum, data = o serial, next_serial, compression, checksum, data, data_serial = o
logging.debug('oid = %s, serial = %s, next_serial = %s', logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial)) dump(oid), dump(serial), dump(next_serial))
p = Packets.AnswerObject(oid, serial, next_serial, p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data) compression, checksum, data, data_serial)
else: else:
logging.debug('oid = %s not found', dump(oid)) logging.debug('oid = %s not found', dump(oid))
p = Errors.OidNotFound('%s does not exist' % dump(oid)) p = Errors.OidNotFound('%s does not exist' % dump(oid))
......
...@@ -22,6 +22,9 @@ from neo.storage.transactions import ConflictError, DelayedError ...@@ -22,6 +22,9 @@ from neo.storage.transactions import ConflictError, DelayedError
class ClientOperationHandler(BaseClientAndStorageOperationHandler): class ClientOperationHandler(BaseClientAndStorageOperationHandler):
def _askObject(self, oid, serial, tid):
return self.app.dm.getObject(oid, serial, tid)
def timeoutExpired(self, conn): def timeoutExpired(self, conn):
self.app.tm.abortFor(conn.getUUID()) self.app.tm.abortFor(conn.getUUID())
BaseClientAndStorageOperationHandler.timeoutExpired(self, conn) BaseClientAndStorageOperationHandler.timeoutExpired(self, conn)
...@@ -49,7 +52,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -49,7 +52,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
uuid = conn.getUUID() uuid = conn.getUUID()
try: try:
self.app.tm.storeObject(uuid, tid, serial, oid, compression, self.app.tm.storeObject(uuid, tid, serial, oid, compression,
checksum, data) checksum, data, None)
conn.answer(Packets.AnswerStoreObject(0, oid, serial)) conn.answer(Packets.AnswerStoreObject(0, oid, serial))
except ConflictError, err: except ConflictError, err:
# resolvable or not # resolvable or not
...@@ -76,3 +79,38 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -76,3 +79,38 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
app.pt.getPartitions(), partition_list) app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerTIDs(tid_list)) conn.answer(Packets.AnswerTIDs(tid_list))
def askUndoTransaction(self, conn, tid, undone_tid):
app = self.app
tm = app.tm
storeObject = tm.storeObject
uuid = conn.getUUID()
oid_list = []
error_oid_list = []
conflict_oid_list = []
undo_tid_dict = app.dm.getTransactionUndoData(tid, undone_tid,
tm.getObjectFromTransaction)
for oid, (current_serial, undone_value_serial) in \
undo_tid_dict.iteritems():
if undone_value_serial == -1:
# Some data were modified by a later transaction
# This must be propagated to client, who will
# attempt a conflict resolution, and store resolved
# data.
to_append_list = error_oid_list
else:
try:
storeObject(uuid, tid, current_serial, oid, None,
None, None, undone_value_serial)
except ConflictError:
to_append_list = conflict_oid_list
except DelayedError:
app.queueEvent(self.askUndoTransaction, conn, tid,
undone_tid)
return
else:
to_append_list = oid_list
to_append_list.append(oid)
conn.answer(Packets.AnswerUndoTransaction(oid_list, error_oid_list,
conflict_oid_list))
...@@ -129,13 +129,13 @@ class ReplicationHandler(EventHandler): ...@@ -129,13 +129,13 @@ class ReplicationHandler(EventHandler):
conn.ask(p, timeout=300) conn.ask(p, timeout=300)
def answerObject(self, conn, oid, serial_start, def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data): serial_end, compression, checksum, data, data_serial):
app = self.app app = self.app
if app.replicator.current_connection is not conn: if app.replicator.current_connection is not conn:
return return
# Directly store the transaction. # Directly store the transaction.
obj = (oid, compression, checksum, data) obj = (oid, compression, checksum, data, data_serial)
app.dm.storeTransaction(serial_start, [obj], None, False) app.dm.storeTransaction(serial_start, [obj], None, False)
del obj del obj
del data del data
......
...@@ -21,6 +21,9 @@ from neo.protocol import Packets ...@@ -21,6 +21,9 @@ from neo.protocol import Packets
class StorageOperationHandler(BaseClientAndStorageOperationHandler): class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def _askObject(self, oid, serial, tid):
return self.app.dm.getObject(oid, serial, tid, resolve_data=False)
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
oid = app.dm.getLastOID() oid = app.dm.getLastOID()
......
...@@ -71,11 +71,15 @@ class Transaction(object): ...@@ -71,11 +71,15 @@ class Transaction(object):
# assert self._transaction is not None # assert self._transaction is not None
self._transaction = (oid_list, user, desc, ext, packed) self._transaction = (oid_list, user, desc, ext, packed)
def addObject(self, oid, compression, checksum, data): def addObject(self, oid, compression, checksum, data, value_serial):
""" """
Add an object to the transaction Add an object to the transaction
""" """
self._object_dict[oid] = (oid, compression, checksum, data) self._object_dict[oid] = (oid, compression, checksum, data,
value_serial)
def getObject(self, oid):
return self._object_dict.get(oid)
def getObjectList(self): def getObjectList(self):
return self._object_dict.values() return self._object_dict.values()
...@@ -118,6 +122,16 @@ class TransactionManager(object): ...@@ -118,6 +122,16 @@ class TransactionManager(object):
self._transaction_dict[tid] = transaction self._transaction_dict[tid] = transaction
return transaction return transaction
def getObjectFromTransaction(self, tid, oid):
"""
Return object data for given running transaction.
Return None if not found.
"""
result = self._transaction_dict.get(tid)
if result is not None:
result = result.getObject(oid)
return result
def setLastOID(self, oid): def setLastOID(self, oid):
assert oid >= self._loid assert oid >= self._loid
self._loid = oid self._loid = oid
...@@ -168,30 +182,35 @@ class TransactionManager(object): ...@@ -168,30 +182,35 @@ class TransactionManager(object):
transaction = self._getTransaction(tid, uuid) transaction = self._getTransaction(tid, uuid)
transaction.prepare(oid_list, user, desc, ext, packed) transaction.prepare(oid_list, user, desc, ext, packed)
def storeObject(self, uuid, tid, serial, oid, compression, checksum, data): def storeObject(self, uuid, tid, serial, oid, compression, checksum, data,
value_serial):
""" """
Store an object received from client node Store an object received from client node
""" """
# check if the object if locked # check if the object if locked
locking_tid = self._store_lock_dict.get(oid, None) locking_tid = self._store_lock_dict.get(oid, None)
if locking_tid is not None: if locking_tid == tid:
if locking_tid < tid: logging.info('Transaction %s storing %s more than once', dump(tid),
# a previous transaction lock this object, retry later dump(oid))
raise DelayedError else:
# If a newer transaction already locks this object, if locking_tid is not None:
# do not try to resolve a conflict, so return immediately. if locking_tid < tid:
logging.info('unresolvable conflict in %s', dump(oid)) # a previous transaction lock this object, retry later
raise ConflictError(locking_tid) raise DelayedError
# If a newer transaction already locks this object,
# check if this is generated from the latest revision. # do not try to resolve a conflict, so return immediately.
history_list = self._app.dm.getObjectHistory(oid) logging.info('unresolvable conflict in %s', dump(oid))
if history_list and history_list[0][0] != serial: raise ConflictError(locking_tid)
logging.info('resolvable conflict in %s', dump(oid))
raise ConflictError(history_list[0][0]) # check if this is generated from the latest revision.
history_list = self._app.dm.getObjectHistory(oid)
if history_list and history_list[0][0] != serial:
logging.info('resolvable conflict in %s', dump(oid))
raise ConflictError(history_list[0][0])
# store object # store object
transaction = self._getTransaction(tid, uuid) transaction = self._getTransaction(tid, uuid)
transaction.addObject(oid, compression, checksum, data) transaction.addObject(oid, compression, checksum, data, value_serial)
self._store_lock_dict[oid] = tid self._store_lock_dict[oid] = tid
# update loid # update loid
......
...@@ -220,7 +220,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -220,7 +220,7 @@ class ClientApplicationTests(NeoTestBase):
oid = self.makeOID() oid = self.makeOID()
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ') an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None)
# connection to SN close # connection to SN close
self.assertTrue(oid not in mq) self.assertTrue(oid not in mq)
packet = Errors.OidNotFound('') packet = Errors.OidNotFound('')
...@@ -260,7 +260,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -260,7 +260,7 @@ class ClientApplicationTests(NeoTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = an_object app.local_var.asked_object = an_object[:-1]
result = app.load(oid) result = app.load(oid)
self.assertEquals(result, ('OBJ', tid1)) self.assertEquals(result, ('OBJ', tid1))
self.checkAskObject(conn) self.checkAskObject(conn)
...@@ -299,7 +299,8 @@ class ClientApplicationTests(NeoTestBase): ...@@ -299,7 +299,8 @@ class ClientApplicationTests(NeoTestBase):
# now a cached version ewxists but should not be hit # now a cached version ewxists but should not be hit
mq.store(oid, (tid1, 'WRONG')) mq.store(oid, (tid1, 'WRONG'))
self.assertTrue(oid in mq) self.assertTrue(oid in mq)
another_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum('RIGHT'), 'RIGHT') another_object = (1, oid, tid2, INVALID_SERIAL, 0,
makeChecksum('RIGHT'), 'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
packet.setId(0) packet.setId(0)
conn = Mock({ conn = Mock({
...@@ -307,7 +308,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -307,7 +308,7 @@ class ClientApplicationTests(NeoTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = another_object app.local_var.asked_object = another_object[:-1]
result = app.loadSerial(oid, tid1) result = app.loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT') self.assertEquals(result, 'RIGHT')
self.checkAskObject(conn) self.checkAskObject(conn)
...@@ -334,7 +335,8 @@ class ClientApplicationTests(NeoTestBase): ...@@ -334,7 +335,8 @@ class ClientApplicationTests(NeoTestBase):
self.assertRaises(NEOStorageNotFoundError, app.loadBefore, oid, tid2) self.assertRaises(NEOStorageNotFoundError, app.loadBefore, oid, tid2)
self.checkAskObject(conn) self.checkAskObject(conn)
# no previous versions -> return None # no previous versions -> return None
an_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum(''), '') an_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum(''), '',
None)
packet = Packets.AnswerObject(*an_object[1:]) packet = Packets.AnswerObject(*an_object[1:])
packet.setId(0) packet.setId(0)
conn = Mock({ conn = Mock({
...@@ -342,7 +344,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -342,7 +344,7 @@ class ClientApplicationTests(NeoTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = an_object app.local_var.asked_object = an_object[:-1]
result = app.loadBefore(oid, tid1) result = app.loadBefore(oid, tid1)
self.assertEquals(result, None) self.assertEquals(result, None)
# object should not have been cached # object should not have been cached
...@@ -350,7 +352,8 @@ class ClientApplicationTests(NeoTestBase): ...@@ -350,7 +352,8 @@ class ClientApplicationTests(NeoTestBase):
# as for loadSerial, the object is cached but should be loaded from db # as for loadSerial, the object is cached but should be loaded from db
mq.store(oid, (tid1, 'WRONG')) mq.store(oid, (tid1, 'WRONG'))
self.assertTrue(oid in mq) self.assertTrue(oid in mq)
another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'), 'RIGHT') another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'),
'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
packet.setId(0) packet.setId(0)
conn = Mock({ conn = Mock({
...@@ -731,65 +734,97 @@ class ClientApplicationTests(NeoTestBase): ...@@ -731,65 +734,97 @@ class ClientApplicationTests(NeoTestBase):
self.storeObject(app, oid=oid2, data='O2V2') self.storeObject(app, oid=oid2, data='O2V2')
self.voteTransaction(app) self.voteTransaction(app)
self.askFinishTransaction(app) self.askFinishTransaction(app)
# undo 2 -> not end tid # undo 1 -> undoing non-last TID, and conflict resolution succeeded
u1p1 = Packets.AnswerTransactionInformation(tid1, '', '', '',
False, (oid2, ))
u1p2 = Packets.AnswerUndoTransaction([], [oid2], [])
# undo 2 -> undoing non-last TID, and conflict resolution failed
u2p1 = Packets.AnswerTransactionInformation(tid2, '', '', '', u2p1 = Packets.AnswerTransactionInformation(tid2, '', '', '',
False, (oid2, )) False, (oid2, ))
u2p2 = Packets.AnswerObject(oid2, tid1, tid2, 0, makeChecksum('O2V1'), 'O2V1') u2p2 = Packets.AnswerUndoTransaction([], [oid2], [])
u2p3 = Packets.AnswerObject(oid2, tid2, tid3, 0, makeChecksum('O2V2'), 'O2V2') # undo 3 -> "live" conflict (another transaction modifying the object
# undo 3 -> conflict # we want to undo)
u3p1 = Packets.AnswerTransactionInformation(tid3, '', '', '', u3p1 = Packets.AnswerTransactionInformation(tid3, '', '', '',
False, (oid3, )) False, (oid3, ))
u3p2 = Packets.AnswerObject(oid3, tid3, tid3, 0, makeChecksum('O3V1'), 'O3V1') u3p2 = Packets.AnswerUndoTransaction([], [], [oid3])
u3p3 = Packets.AnswerObject(oid3, tid3, tid3, 0, makeChecksum('O3V1'), 'O3V1') # undo 4 -> undoing last tid
u3p4 = Packets.AnswerObject(oid3, tid3, tid3, 0, makeChecksum('O3V1'), 'O3V1')
u3p5 = Packets.AnswerStoreObject(conflicting=1, oid=oid3, serial=tid2)
# undo 4 -> ok
u4p1 = Packets.AnswerTransactionInformation(tid3, '', '', '', u4p1 = Packets.AnswerTransactionInformation(tid3, '', '', '',
False, (oid1, )) False, (oid1, ))
u4p2 = Packets.AnswerObject(oid1, tid3, tid3, 0, makeChecksum('O1V1'), 'O1V1') u4p2 = Packets.AnswerUndoTransaction([oid1], [], [])
u4p3 = Packets.AnswerObject(oid1, tid3, tid3, 0, makeChecksum('O1V1'), 'O1V1')
u4p4 = Packets.AnswerObject(oid1, tid3, tid3, 0, makeChecksum('O1V1'), 'O1V1')
u4p5 = Packets.AnswerStoreObject(conflicting=0, oid=oid1, serial=tid2)
# test logic # test logic
packets = (u2p1, u2p2, u2p3, u3p1, u3p2, u3p3, u3p4, u3p5, u4p1, u4p2, packets = (u1p1, u1p2, u2p1, u2p2, u3p1, u3p2, u4p1, u4p2)
u4p3, u4p4, u4p5)
for i, p in enumerate(packets): for i, p in enumerate(packets):
p.setId(i) p.setId(i)
storage_address = ('127.0.0.1', 10010) storage_address = ('127.0.0.1', 10010)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': ReturnValues( 'fakeReceived': ReturnValues(
u2p1, u2p2, u2p3, u1p1,
u4p1, u4p2, u4p3, u4p4, u2p1,
u3p1, u3p2, u3p3, u3p4, u4p1,
u3p1,
), ),
'getAddress': storage_address, 'getAddress': storage_address,
}) })
cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', }) cell = Mock({
'getAddress': 'FakeServer',
'getState': 'FakeState',
})
app.pt = Mock({ app.pt = Mock({
'getCellListForTID': (cell, ), 'getCellListForTID': (cell, ),
'getCellListForOID': (cell, ), 'getCellListForOID': (cell, ),
}) })
app.cp = Mock({ 'getConnForCell': conn}) app.cp = Mock({'getConnForCell': conn, 'getConnForNode': conn})
marker = [] def tryToResolveConflict(oid, conflict_serial, serial, data,
def tryToResolveConflict(oid, conflict_serial, serial, data): committedData=''):
marker.append(1) marker.append(1)
return resolution_result
class Dispatcher(object): class Dispatcher(object):
def pending(self, queue): def pending(self, queue):
return not queue.empty() return not queue.empty()
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
def _load(oid, tid=None, serial=None):
assert tid is not None
assert serial is None, serial
return ('dummy', oid, tid)
app._load = _load
app.nm.createStorage(address=storage_address) app.nm.createStorage(address=storage_address)
txn4 = self.beginTransaction(app, tid=tid4)
# all start here # all start here
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = 'solved'
app.local_var.queue.put((conn, u1p2))
app.undo(tid1, txn4, tryToResolveConflict)
self.assertEquals(marker, [1])
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = None
app.local_var.queue.put((conn, u2p2))
self.assertRaises(UndoError, app.undo, tid2, txn4, self.assertRaises(UndoError, app.undo, tid2, txn4,
tryToResolveConflict) tryToResolveConflict)
app.local_var.queue.put((conn, u4p5)) self.assertEquals(marker, [1])
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = None
app.local_var.queue.put((conn, u4p2))
self.assertEquals(app.undo(tid3, txn4, tryToResolveConflict), self.assertEquals(app.undo(tid3, txn4, tryToResolveConflict),
(tid4, [oid1, ])) (tid4, [oid1, ]))
app.local_var.queue.put((conn, u3p5)) self.assertEquals(marker, [])
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = None
app.local_var.queue.put((conn, u3p2))
self.assertRaises(ConflictError, app.undo, tid3, txn4, self.assertRaises(ConflictError, app.undo, tid3, txn4,
tryToResolveConflict) tryToResolveConflict)
self.assertEquals(marker, [1]) self.assertEquals(marker, [])
self.askFinishTransaction(app) self.askFinishTransaction(app)
def test_undoLog(self): def test_undoLog(self):
......
...@@ -76,10 +76,14 @@ class StorageAnswerHandlerTests(NeoTestBase): ...@@ -76,10 +76,14 @@ class StorageAnswerHandlerTests(NeoTestBase):
oid = self.getOID(0) oid = self.getOID(0)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1) tid2 = self.getNextTID(tid1)
the_object = (oid, tid1, tid2, 0, '', 'DATA') the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.app.local_var.asked_object = None self.app.local_var.asked_object = None
self.handler.answerObject(conn, *the_object) self.handler.answerObject(conn, *the_object)
self.assertEqual(self.app.local_var.asked_object, the_object) self.assertEqual(self.app.local_var.asked_object, the_object[:-1])
# Check handler raises on non-None data_serial.
the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID())
self.app.local_var.asked_object = None
self.assertRaises(ValueError, self.handler.answerObject, conn, *the_object)
def test_answerStoreObject(self): def test_answerStoreObject(self):
conn = self.getConnection() conn = self.getConnection()
...@@ -162,6 +166,28 @@ class StorageAnswerHandlerTests(NeoTestBase): ...@@ -162,6 +166,28 @@ class StorageAnswerHandlerTests(NeoTestBase):
self.assertTrue(uuid in self.app.local_var.node_tids) self.assertTrue(uuid in self.app.local_var.node_tids)
self.assertEqual(self.app.local_var.node_tids[uuid], tid_list) self.assertEqual(self.app.local_var.node_tids[uuid], tid_list)
def test_answerUndoTransaction(self):
local_var = self.app.local_var
undo_conflict_oid_list = local_var.undo_conflict_oid_list = []
undo_error_oid_list = local_var.undo_error_oid_list = []
data_dict = local_var.data_dict = {}
conn = None # Nothing is done on connection in this handler
# Nothing undone, check nothing changed
self.handler.answerUndoTransaction(conn, [], [], [])
self.assertEqual(undo_conflict_oid_list, [])
self.assertEqual(undo_error_oid_list, [])
self.assertEqual(data_dict, {})
# One OID for each case, check they are inserted in expected local_var
# entries.
oid_1 = self.getOID(0)
oid_2 = self.getOID(1)
oid_3 = self.getOID(2)
self.handler.answerUndoTransaction(conn, [oid_1], [oid_2], [oid_3])
self.assertEqual(undo_conflict_oid_list, [oid_3])
self.assertEqual(undo_error_oid_list, [oid_2])
self.assertEqual(data_dict, {oid_1: ''})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -20,10 +20,11 @@ from mock import Mock ...@@ -20,10 +20,11 @@ from mock import Mock
from collections import deque from collections import deque
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.transactions import ConflictError from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.protocol import INVALID_PARTITION from neo.protocol import INVALID_PARTITION
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
from neo.protocol import Packets
class StorageClientHandlerTests(NeoTestBase): class StorageClientHandlerTests(NeoTestBase):
...@@ -126,7 +127,7 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -126,7 +127,7 @@ class StorageClientHandlerTests(NeoTestBase):
def test_24_askObject3(self): def test_24_askObject3(self):
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )}) self.app.dm = Mock({'getObject': ('', '', 0, 0, '', None)})
conn = Mock({}) conn = Mock({})
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID, self.operation.askObject(conn, oid=INVALID_OID,
...@@ -225,7 +226,7 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -225,7 +226,7 @@ class StorageClientHandlerTests(NeoTestBase):
self.operation.askStoreObject(conn, oid, serial, comp, checksum, self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, tid) data, tid)
self._checkStoreObjectCalled(uuid, tid, serial, oid, comp, self._checkStoreObjectCalled(uuid, tid, serial, oid, comp,
checksum, data) checksum, data, None)
self.checkAnswerStoreObject(conn) self.checkAnswerStoreObject(conn)
def test_askStoreObject2(self): def test_askStoreObject2(self):
...@@ -250,5 +251,44 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -250,5 +251,44 @@ class StorageClientHandlerTests(NeoTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid) calls[0].checkArgs(tid)
def test_askUndoTransaction(self):
conn = self._getConnection()
tid = self.getNextTID()
undone_tid = self.getNextTID()
oid_1 = self.getNextTID()
oid_2 = self.getNextTID()
oid_3 = self.getNextTID()
oid_4 = self.getNextTID()
def getTransactionUndoData(tid, undone_tid, getObjectFromTransaction):
return {
oid_1: (1, 1),
oid_2: (1, -1),
oid_3: (1, 2),
oid_4: (1, 3),
}
self.app.dm.getTransactionUndoData = getTransactionUndoData
original_storeObject = self.app.tm.storeObject
def storeObject(uuid, tid, serial, oid, *args, **kw):
if oid == oid_3:
raise ConflictError(0)
elif oid == oid_4 and delay_store:
raise DelayedError
return original_storeObject(uuid, tid, serial, oid, *args, **kw)
self.app.tm.storeObject = storeObject
# Check if delaying a store (of oid_4) is supported
delay_store = True
self.operation.askUndoTransaction(conn, tid, undone_tid)
self.checkNoPacketSent(conn)
delay_store = False
self.operation.askUndoTransaction(conn, tid, undone_tid)
oid_list_1, oid_list_2, oid_list_3 = self.checkAnswerPacket(conn,
Packets.AnswerUndoTransaction, decode=True)
# Compare sets as order doens't matter here.
self.assertEqual(set(oid_list_1), set([oid_1, oid_4]))
self.assertEqual(oid_list_2, [oid_2])
self.assertEqual(oid_list_3, [oid_3])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -93,12 +93,13 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -93,12 +93,13 @@ class StorageStorageHandlerTests(NeoTestBase):
calls = self.app.dm.mockGetNamedCalls('getObject') calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID) calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID,
resolve_data=False)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_24_askObject3(self): def test_24_askObject3(self):
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )}) self.app.dm = Mock({'getObject': ('', '', 0, 0, '', None)})
conn = Mock({}) conn = Mock({})
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID, self.operation.askObject(conn, oid=INVALID_OID,
......
...@@ -286,25 +286,25 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -286,25 +286,25 @@ class StorageMySQSLdbTests(NeoTestBase):
checksum, value) values (%d, %d, 0, 0, '')""" % checksum, value) values (%d, %d, 0, 0, '')""" %
(u64(oid1), u64(tid1))) (u64(oid1), u64(tid1)))
result = self.db.getObject(oid1, tid1) result = self.db.getObject(oid1, tid1)
self.assertEquals(result, (tid1, None, 0, 0, '')) self.assertEquals(result, (tid1, None, 0, 0, '', None))
# before_tid specified, object not present # before_tid specified, object not present
result = self.db.getObject(oid2, before_tid=tid2) result = self.db.getObject(oid2, before_tid=tid2)
self.assertEquals(result, None) self.assertEquals(result, None)
# before_tid specified, object present, no next serial # before_tid specified, object present, no next serial
result = self.db.getObject(oid1, before_tid=tid2) result = self.db.getObject(oid1, before_tid=tid2)
self.assertEquals(result, (tid1, None, 0, 0, '')) self.assertEquals(result, (tid1, None, 0, 0, '', None))
# before_tid specified, object present, next serial exists # before_tid specified, object present, next serial exists
self.db.query("""insert into obj (oid, serial, compression, self.db.query("""insert into obj (oid, serial, compression,
checksum, value) values (%d, %d, 0, 0, '')""" % checksum, value) values (%d, %d, 0, 0, '')""" %
(u64(oid1), u64(tid2))) (u64(oid1), u64(tid2)))
result = self.db.getObject(oid1, before_tid=tid2) result = self.db.getObject(oid1, before_tid=tid2)
self.assertEquals(result, (tid1, tid2, 0, 0, '')) self.assertEquals(result, (tid1, tid2, 0, 0, '', None))
# no tid specified, retreive last object transaction, object unknown # no tid specified, retreive last object transaction, object unknown
result = self.db.getObject(oid2) result = self.db.getObject(oid2)
self.assertEquals(result, None) self.assertEquals(result, None)
# same but object found # same but object found
result = self.db.getObject(oid1) result = self.db.getObject(oid1)
self.assertEquals(result, (tid2, None, 0, 0, '')) self.assertEquals(result, (tid2, None, 0, 0, '', None))
def test_23_changePartitionTable(self): def test_23_changePartitionTable(self):
# two sn, two partitions # two sn, two partitions
...@@ -401,7 +401,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -401,7 +401,7 @@ class StorageMySQSLdbTests(NeoTestBase):
# data set # data set
tid = '\x00' * 7 + '\x01' tid = '\x00' * 7 + '\x01'
oid1, oid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02' oid1, oid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02'
object_list = ( (oid1, 0, 0, ''), (oid2, 0, 0, ''),) object_list = ( (oid1, 0, 0, '', None), (oid2, 0, 0, '', None),)
transaction = ((oid1, oid2), 'user', 'desc', 'ext', False) transaction = ((oid1, oid2), 'user', 'desc', 'ext', False)
# store objects in temporary table # store objects in temporary table
self.db.setup() self.db.setup()
...@@ -442,7 +442,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -442,7 +442,7 @@ class StorageMySQSLdbTests(NeoTestBase):
# data set # data set
tid1, tid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02' tid1, tid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02'
oid1, oid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02' oid1, oid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02'
object_list = ( (oid1, 0, 0, ''), (oid2, 0, 0, ''),) object_list = ( (oid1, 0, 0, '', None), (oid2, 0, 0, '', None),)
transaction = ((oid1, oid2), 'u', 'd', 'e', False) transaction = ((oid1, oid2), 'u', 'd', 'e', False)
self.db.setup(reset=True) self.db.setup(reset=True)
# store two temporary transactions # store two temporary transactions
...@@ -457,16 +457,16 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -457,16 +457,16 @@ class StorageMySQSLdbTests(NeoTestBase):
# t1 should be finished # t1 should be finished
result = self.db.query('select * from obj order by oid asc') result = self.db.query('select * from obj order by oid asc')
self.assertEquals(len(result), 2) self.assertEquals(len(result), 2)
self.assertEquals(result[0], (1L, 1L, 0, 0, '')) self.assertEquals(result[0], (1L, 1L, 0, 0, '', None))
self.assertEquals(result[1], (2L, 1L, 0, 0, '')) self.assertEquals(result[1], (2L, 1L, 0, 0, '', None))
result = self.db.query('select * from trans') result = self.db.query('select * from trans')
self.assertEquals(len(result), 1) self.assertEquals(len(result), 1)
self.assertEquals(result[0], (1L, 0, oid1 + oid2, 'u', 'd', 'e',)) self.assertEquals(result[0], (1L, 0, oid1 + oid2, 'u', 'd', 'e',))
# t2 should stay in temporary tables # t2 should stay in temporary tables
result = self.db.query('select * from tobj order by oid asc') result = self.db.query('select * from tobj order by oid asc')
self.assertEquals(len(result), 2) self.assertEquals(len(result), 2)
self.assertEquals(result[0], (1L, 2L, 0, 0, '')) self.assertEquals(result[0], (1L, 2L, 0, 0, '', None))
self.assertEquals(result[1], (2L, 2L, 0, 0, '')) self.assertEquals(result[1], (2L, 2L, 0, 0, '', None))
result = self.db.query('select * from ttrans') result = self.db.query('select * from ttrans')
self.assertEquals(len(result), 1) self.assertEquals(len(result), 1)
self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',)) self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',))
...@@ -475,7 +475,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -475,7 +475,7 @@ class StorageMySQSLdbTests(NeoTestBase):
# data set # data set
tid1, tid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02' tid1, tid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02'
oid1, oid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02' oid1, oid2 = '\x00' * 7 + '\x01', '\x00' * 7 + '\x02'
object_list = ( (oid1, 0, 0, ''), (oid2, 0, 0, ''),) object_list = ( (oid1, 0, 0, '', None), (oid2, 0, 0, '', None),)
transaction = ((oid1, oid2), 'u', 'd', 'e', False) transaction = ((oid1, oid2), 'u', 'd', 'e', False)
self.db.setup(reset=True) self.db.setup(reset=True)
# store two transactions in both state # store two transactions in both state
...@@ -496,15 +496,15 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -496,15 +496,15 @@ class StorageMySQSLdbTests(NeoTestBase):
# t2 not altered # t2 not altered
result = self.db.query('select * from tobj order by oid asc') result = self.db.query('select * from tobj order by oid asc')
self.assertEquals(len(result), 2) self.assertEquals(len(result), 2)
self.assertEquals(result[0], (1L, 2L, 0, 0, '')) self.assertEquals(result[0], (1L, 2L, 0, 0, '', None))
self.assertEquals(result[1], (2L, 2L, 0, 0, '')) self.assertEquals(result[1], (2L, 2L, 0, 0, '', None))
result = self.db.query('select * from ttrans') result = self.db.query('select * from ttrans')
self.assertEquals(len(result), 1) self.assertEquals(len(result), 1)
self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',)) self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',))
result = self.db.query('select * from obj order by oid asc') result = self.db.query('select * from obj order by oid asc')
self.assertEquals(len(result), 2) self.assertEquals(len(result), 2)
self.assertEquals(result[0], (1L, 2L, 0, 0, '')) self.assertEquals(result[0], (1L, 2L, 0, 0, '', None))
self.assertEquals(result[1], (2L, 2L, 0, 0, '')) self.assertEquals(result[1], (2L, 2L, 0, 0, '', None))
result = self.db.query('select * from trans') result = self.db.query('select * from trans')
self.assertEquals(len(result), 1) self.assertEquals(len(result), 1)
self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',)) self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',))
...@@ -516,17 +516,17 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -516,17 +516,17 @@ class StorageMySQSLdbTests(NeoTestBase):
# t2 not altered and t1 stay in obj/trans tables # t2 not altered and t1 stay in obj/trans tables
result = self.db.query('select * from tobj order by oid asc') result = self.db.query('select * from tobj order by oid asc')
self.assertEquals(len(result), 2) self.assertEquals(len(result), 2)
self.assertEquals(result[0], (1L, 2L, 0, 0, '')) self.assertEquals(result[0], (1L, 2L, 0, 0, '', None))
self.assertEquals(result[1], (2L, 2L, 0, 0, '')) self.assertEquals(result[1], (2L, 2L, 0, 0, '', None))
result = self.db.query('select * from ttrans') result = self.db.query('select * from ttrans')
self.assertEquals(len(result), 1) self.assertEquals(len(result), 1)
self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',)) self.assertEquals(result[0], (2L, 0, oid1 + oid2, 'u', 'd', 'e',))
result = self.db.query('select * from obj order by oid, serial asc') result = self.db.query('select * from obj order by oid, serial asc')
self.assertEquals(len(result), 4) self.assertEquals(len(result), 4)
self.assertEquals(result[0], (1L, 1L, 0, 0, '')) self.assertEquals(result[0], (1L, 1L, 0, 0, '', None))
self.assertEquals(result[1], (1L, 2L, 0, 0, '')) self.assertEquals(result[1], (1L, 2L, 0, 0, '', None))
self.assertEquals(result[2], (2L, 1L, 0, 0, '')) self.assertEquals(result[2], (2L, 1L, 0, 0, '', None))
self.assertEquals(result[3], (2L, 2L, 0, 0, '')) self.assertEquals(result[3], (2L, 2L, 0, 0, '', None))
result = self.db.query('select * from trans order by tid asc') result = self.db.query('select * from trans order by tid asc')
self.assertEquals(len(result), 2) self.assertEquals(len(result), 2)
self.assertEquals(result[0], (1L, 0, oid1 + oid2, 'u', 'd', 'e',)) self.assertEquals(result[0], (1L, 0, oid1 + oid2, 'u', 'd', 'e',))
...@@ -566,7 +566,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -566,7 +566,7 @@ class StorageMySQSLdbTests(NeoTestBase):
tid = '\x00' * 7 + '\x01' tid = '\x00' * 7 + '\x01'
oid1, oid2, oid3, oid4 = ['\x00' * 7 + chr(i) for i in xrange(4)] oid1, oid2, oid3, oid4 = ['\x00' * 7 + chr(i) for i in xrange(4)]
for oid in (oid1, oid2, oid3, oid4): for oid in (oid1, oid2, oid3, oid4):
self.db.query("replace into obj values (%d, %d, 0, 0, '')" % self.db.query("replace into obj values (%d, %d, 0, 0, '', NULL)" %
(u64(oid), u64(tid))) (u64(oid), u64(tid)))
# get all oids for all partitions # get all oids for all partitions
result = self.db.getOIDList(0, 4, 2, (0, 1)) result = self.db.getOIDList(0, 4, 2, (0, 1))
...@@ -593,7 +593,7 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -593,7 +593,7 @@ class StorageMySQSLdbTests(NeoTestBase):
tids = ['\x00' * 7 + chr(i) for i in xrange(4)] tids = ['\x00' * 7 + chr(i) for i in xrange(4)]
oid = '\x00' * 8 oid = '\x00' * 8
for tid in tids: for tid in tids:
self.db.query("replace into obj values (%d, %d, 0, 0, '')" % self.db.query("replace into obj values (%d, %d, 0, 0, '', NULL)" %
(u64(oid), u64(tid))) (u64(oid), u64(tid)))
# unkwown object # unkwown object
result = self.db.getObjectHistory(oid='\x01' * 8) result = self.db.getObjectHistory(oid='\x01' * 8)
...@@ -663,8 +663,8 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -663,8 +663,8 @@ class StorageMySQSLdbTests(NeoTestBase):
tid1, tid2, tid3, tid4 = tids tid1, tid2, tid3, tid4 = tids
oid = '\x00' * 8 oid = '\x00' * 8
for tid in tids: for tid in tids:
self.db.query("replace into obj values (%d, %d, 0, 0, '')" % (u64(oid), self.db.query("replace into obj values (%d, %d, 0, 0, '', NULL)" \
u64(tid))) % (u64(oid), u64(tid)))
# all match # all match
result = self.db.getSerialListPresent(oid, tids) result = self.db.getSerialListPresent(oid, tids)
expected = list(tids) expected = list(tids)
...@@ -676,6 +676,248 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -676,6 +676,248 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getSerialListPresent(oid, (tid1, tid3)) result = self.db.getSerialListPresent(oid, (tid1, tid3))
self.assertEquals(sorted(result), [tid1, tid3]) self.assertEquals(sorted(result), [tid1, tid3])
def test__getObjectData(self):
db = self.db
db.setup(reset=True)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
assert tid0 < tid1 < tid2 < tid3
oid1 = self.getOID(1)
oid2 = self.getOID(2)
oid3 = self.getOID(3)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
(oid2, None, None, None, u64(tid0)),
(oid3, None, None, None, u64(tid2)),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, u64(tid1)),
(oid2, None, None, None, u64(tid1)),
(oid3, 0, 0, 'bar', None),
), None, temporary=False)
original_getObjectData = db._getObjectData
def _getObjectData(*args, **kw):
call_counter.append(1)
return original_getObjectData(*args, **kw)
db._getObjectData = _getObjectData
# NOTE: all tests are done as if values were fetched by _getObject, so
# there is already one indirection level.
# oid1 at tid1: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid1), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 1)
# oid2 at tid1: missing data in table, raise IndexError on next
# recursive call
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# oid3 at tid1: data_serial grater than row's tid, raise ValueError
# on next recursive call - even if data does exist at that tid (see
# "oid3 at tid2" case below)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# Same with wrong parameters (tid0 < tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid0))
self.assertEqual(sum(call_counter), 1)
# Same with wrong parameters (tid1 == tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid1))
self.assertEqual(sum(call_counter), 1)
# oid1 at tid2: data is found after ons recursive call
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid2), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 2)
# oid2 at tid2: missing data in table, raise IndexError after two
# recursive calls
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid2),
u64(tid3))
self.assertEqual(sum(call_counter), 3)
# oid3 at tid2: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid3), u64(tid2), u64(tid3)),
(u64(tid2), 0, 0, 'bar'))
self.assertEqual(sum(call_counter), 1)
def test__getDataTIDFromData(self):
db = self.db
db.setup(reset=True)
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, u64(tid1)),
), None, temporary=False)
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid1))),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid2))),
(u64(tid2), u64(tid1)))
def test__getDataTID(self):
db = self.db
db.setup(reset=True)
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, u64(tid1)),
), None, temporary=False)
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid1)),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid2)),
(u64(tid2), u64(tid1)))
def test__findUndoTID(self):
db = self.db
db.setup(reset=True)
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
tid4 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
# Undoing oid1 tid1, OK: tid1 is latest
# Result: current tid is tid1, data_tid is None (undoing object
# creation)
self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1), None),
(tid1, None))
# Store a new transaction
db.storeTransaction(
tid2, (
(oid1, 0, 0, 'bar', None),
), None, temporary=False)
# Undoing oid1 tid2, OK: tid2 is latest
# Result: current tid is tid2, data_tid is tid1
self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid2), None),
(tid2, u64(tid1)))
# Undoing oid1 tid1, Error: tid2 is latest
# Result: current tid is tid2, data_tid is -1
self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1), None),
(tid2, -1))
# Undoing oid1 tid1 with tid2 being undone in same transaction,
# OK: tid1 is latest
# Result: current tid is tid2, data_tid is None (undoing object
# creation)
# Explanation of transaction_object: oid1, no data but a data serial
# to tid1
self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1),
(u64(oid1), None, None, None, u64(tid1))),
(tid4, None))
# Store a new transaction
db.storeTransaction(
tid3, (
(oid1, None, None, None, u64(tid1)),
), None, temporary=False)
# Undoing oid1 tid1, OK: tid3 is latest with tid1 data
# Result: current tid is tid2, data_tid is None (undoing object
# creation)
self.assertEqual(
db._findUndoTID(u64(oid1), u64(tid4), u64(tid1), None),
(tid3, None))
def test_getTransactionUndoData(self):
db = self.db
db.setup(reset=True)
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
tid4 = self.getNextTID()
tid5 = self.getNextTID()
assert tid1 < tid2 < tid3 < tid4 < tid5
oid1 = self.getOID(1)
oid2 = self.getOID(2)
oid3 = self.getOID(3)
oid4 = self.getOID(4)
oid5 = self.getOID(5)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo1', None),
(oid2, 0, 0, 'foo2', None),
(oid3, 0, 0, 'foo3', None),
(oid4, 0, 0, 'foo5', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, 0, 0, 'bar1', None),
(oid2, None, None, None, None),
(oid3, 0, 0, 'bar3', None),
), None, temporary=False)
db.storeTransaction(
tid3, (
(oid3, 0, 0, 'baz3', None),
(oid5, 0, 0, 'foo6', None),
), None, temporary=False)
def getObjectFromTransaction(tid, oid):
return None
self.assertEqual(
db.getTransactionUndoData(tid4, tid2, getObjectFromTransaction),
{
oid1: (tid2, u64(tid1)), # can be undone
oid2: (tid2, u64(tid1)), # can be undone (creation redo)
oid3: (tid3, -1), # cannot be undone
# oid4 & oid5: not present because not ins undone transaction
})
# Cannot undo future transaction
self.assertRaises(ValueError, db.getTransactionUndoData, tid4, tid5,
getObjectFromTransaction)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -52,8 +52,8 @@ class TransactionTests(NeoTestBase): ...@@ -52,8 +52,8 @@ class TransactionTests(NeoTestBase):
def testObjects(self): def testObjects(self):
txn = Transaction(self.getNewUUID(), self.getNextTID()) txn = Transaction(self.getNewUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2 = self.getOID(1), self.getOID(2)
object1 = (oid1, 1, '1', 'O1') object1 = (oid1, 1, '1', 'O1', None)
object2 = (oid2, 1, '2', 'O2') object2 = (oid2, 1, '2', 'O2', None)
self.assertEqual(txn.getObjectList(), []) self.assertEqual(txn.getObjectList(), [])
self.assertEqual(txn.getOIDList(), []) self.assertEqual(txn.getOIDList(), [])
txn.addObject(*object1) txn.addObject(*object1)
...@@ -63,6 +63,14 @@ class TransactionTests(NeoTestBase): ...@@ -63,6 +63,14 @@ class TransactionTests(NeoTestBase):
self.assertEqual(txn.getObjectList(), [object1, object2]) self.assertEqual(txn.getObjectList(), [object1, object2])
self.assertEqual(txn.getOIDList(), [oid1, oid2]) self.assertEqual(txn.getOIDList(), [oid1, oid2])
def test_getObject(self):
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
txn = Transaction(self.getNewUUID(), self.getNextTID())
object_info = (oid_1, None, None, None, None)
txn.addObject(*object_info)
self.assertEqual(txn.getObject(oid_2), None)
self.assertEqual(txn.getObject(oid_1), object_info)
class TransactionManagerTests(NeoTestBase): class TransactionManagerTests(NeoTestBase):
...@@ -81,7 +89,7 @@ class TransactionManagerTests(NeoTestBase): ...@@ -81,7 +89,7 @@ class TransactionManagerTests(NeoTestBase):
def _getObject(self, value): def _getObject(self, value):
oid = self.getOID(value) oid = self.getOID(value)
serial = self.getNextTID() serial = self.getNextTID()
return (serial, (oid, 1, str(value), 'O' + str(value))) return (serial, (oid, 1, str(value), 'O' + str(value), None))
def _checkTransactionStored(self, *args): def _checkTransactionStored(self, *args):
calls = self.app.dm.mockGetNamedCalls('storeTransaction') calls = self.app.dm.mockGetNamedCalls('storeTransaction')
...@@ -254,6 +262,19 @@ class TransactionManagerTests(NeoTestBase): ...@@ -254,6 +262,19 @@ class TransactionManagerTests(NeoTestBase):
self.assertFalse(tid in self.manager) self.assertFalse(tid in self.manager)
self.assertFalse(self.manager.loadLocked(obj[0])) self.assertFalse(self.manager.loadLocked(obj[0]))
def test_getObjectFromTransaction(self):
uuid = self.getNewUUID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
self.manager.storeObject(uuid, tid1, serial1, *obj1)
self.assertEqual(self.manager.getObjectFromTransaction(tid2, obj1[0]),
None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]),
None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]),
obj1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -382,14 +382,18 @@ class ProtocolTests(NeoTestBase): ...@@ -382,14 +382,18 @@ class ProtocolTests(NeoTestBase):
oid = self.getNextTID() oid = self.getNextTID()
serial_start = self.getNextTID() serial_start = self.getNextTID()
serial_end = self.getNextTID() serial_end = self.getNextTID()
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to",) data_serial = self.getNextTID()
poid, pserial_start, pserial_end, compression, checksum, data= p.decode() p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to",
data_serial)
poid, pserial_start, pserial_end, compression, checksum, data, \
pdata_serial = p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
self.assertEqual(serial_start, pserial_start) self.assertEqual(serial_start, pserial_start)
self.assertEqual(serial_end, pserial_end) self.assertEqual(serial_end, pserial_end)
self.assertEqual(compression, 1) self.assertEqual(compression, 1)
self.assertEqual(checksum, 55) self.assertEqual(checksum, 55)
self.assertEqual(data, "to") self.assertEqual(data, "to")
self.assertEqual(pdata_serial, data_serial)
def test_49_askTIDs(self): def test_49_askTIDs(self):
p = Packets.AskTIDs(1, 10, 5) p = Packets.AskTIDs(1, 10, 5)
...@@ -474,6 +478,24 @@ class ProtocolTests(NeoTestBase): ...@@ -474,6 +478,24 @@ class ProtocolTests(NeoTestBase):
p_offset = p.decode()[0] p_offset = p.decode()[0]
self.assertEqual(p_offset, offset) self.assertEqual(p_offset, offset)
def test_askUndoTransaction(self):
tid = self.getNextTID()
undone_tid = self.getNextTID()
p = Packets.AskUndoTransaction(tid, undone_tid)
p_tid, p_undone_tid = p.decode()
self.assertEqual(p_tid, tid)
self.assertEqual(p_undone_tid, undone_tid)
def test_answerUndoTransaction(self):
oid_list_1 = [self.getNextTID()]
oid_list_2 = [self.getNextTID(), self.getNextTID()]
oid_list_3 = [self.getNextTID(), self.getNextTID(), self.getNextTID()]
p = Packets.AnswerUndoTransaction(oid_list_1, oid_list_2, oid_list_3)
p_oid_list_1, p_oid_list_2, p_oid_list_3 = p.decode()
self.assertEqual(p_oid_list_1, oid_list_1)
self.assertEqual(p_oid_list_2, oid_list_2)
self.assertEqual(p_oid_list_3, oid_list_3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment