Commit 83c02447 authored by Vincent Pelletier's avatar Vincent Pelletier

Use value_serial for undo support.

This mimics what FileStorage uses (file offsets) but in a relational manner.
This offloads decision of the ability to undo a transaction to storages,
avoiding 3 data loads for each object in the transaction at client side.
This also makes Neo refuse to undo transactions where object data would happen
to be equal between current value and undone value.
Finally, this is required to make database pack work properly (namely, it
prevents loosing objects which are orphans at pack TID, but are reachable
after it thanks to a transactional undo).

Also, extend storage's transaction manager so database adapter can fetch data
already sent by client in the same transaction, so it can undo multiple
transactions at once. Requires making object lock re-entrant (done in this
commit).

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1978 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 987351fb
......@@ -112,6 +112,8 @@ class ThreadContext(object):
'node_tids': {},
'node_ready': False,
'asked_object': 0,
'undo_conflict_oid_list': [],
'undo_error_oid_list': [],
}
......@@ -805,31 +807,54 @@ class Application(object):
else:
raise NEOStorageError('undo failed')
oid_list = self.local_var.txn_info['oids']
# Second get object data from storage node using loadBefore
data_dict = {}
# XXX: this way causes each object to be loaded 3 times from storage,
# this work should rather be offloaded to it.
for oid in oid_list:
current_data = self.load(oid)[0]
after_data = self.loadSerial(oid, undone_tid)
if current_data != after_data:
raise UndoError("non-undoable transaction", oid)
if self.local_var.txn_info['packed']:
UndoError('non-undoable transaction')
tid = self.local_var.tid
undo_conflict_oid_list = self.local_var.undo_conflict_oid_list = []
undo_error_oid_list = self.local_var.undo_error_oid_list = []
ask_undo_transaction = Packets.AskUndoTransaction(tid, undone_tid)
getConnForNode = self.cp.getConnForNode
for storage_node in self.nm.getStorageList():
storage_conn = getConnForNode(storage_node)
storage_conn.ask(ask_undo_transaction)
# Wait for all AnswerUndoTransaction.
self.waitResponses()
# Don't do any handling for "live" conflicts, raise
if undo_conflict_oid_list:
raise ConflictError(oid=undo_conflict_oid_list[0], serials=(tid,
undone_tid), data=None)
# Try to resolve undo conflicts
for oid in undo_error_oid_list:
def loadBefore(oid, tid):
try:
data = self.loadBefore(oid, undone_tid)[0]
result = self._load(oid, tid=tid)
except NEOStorageNotFoundError:
if oid == '\x00' * 8:
# Refuse undoing root object creation.
raise UndoError("no previous record", oid)
raise UndoError("Object not found while resolving undo " \
"conflict")
return result[:2]
# Load the latest version we are supposed to see
data, data_tid = loadBefore(oid, tid)
# Load the version we were undoing to
undo_data, _ = loadBefore(oid, undone_tid)
# Resolve conflict
new_data = tryToResolveConflict(oid, data_tid, undone_tid, undo_data,
data)
if new_data is None:
raise UndoError('Some data were modified by a later ' \
'transaction', oid)
else:
# Undo object creation
data = ''
data_dict[oid] = data
self.store(oid, data_tid, new_data, '', self.local_var.txn)
# Third do transaction with old data
for oid, data in data_dict.iteritems():
self.store(oid, undone_tid, data, None, txn)
self.waitStoreResponses(tryToResolveConflict)
oid_list = self.local_var.txn_info['oids']
# Consistency checking: all oids of the transaction must have been
# reported as undone
data_dict = self.local_var.data_dict
for oid in oid_list:
assert oid in data_dict, repr(oid)
return self.local_var.tid, oid_list
def _insertMetadata(self, txn_info, extension):
......
......@@ -19,6 +19,7 @@ from ZODB.TimeStamp import TimeStamp
from neo.client.handlers import BaseHandler, AnswerBaseHandler
from neo.protocol import NodeTypes, ProtocolError
from neo.util import dump
class StorageEventHandler(BaseHandler):
......@@ -58,7 +59,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
""" Handle all messages related to ZODB operations """
def answerObject(self, conn, oid, start_serial, end_serial,
compression, checksum, data):
compression, checksum, data, data_serial):
if data_serial is not None:
raise ValueError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), )
self.app.local_var.asked_object = (oid, start_serial, end_serial,
compression, checksum, data)
......@@ -112,3 +116,12 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerTIDs(self, conn, tid_list):
self.app.local_var.node_tids[conn.getUUID()] = tid_list
def answerUndoTransaction(self, conn, oid_list, error_oid_list,
conflict_oid_list):
local_var = self.app.local_var
local_var.undo_conflict_oid_list.extend(conflict_oid_list)
local_var.undo_error_oid_list.extend(error_oid_list)
data_dict = local_var.data_dict
for oid in oid_list:
data_dict[oid] = ''
......@@ -247,7 +247,7 @@ class EventHandler(object):
raise UnexpectedPacketError
def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data):
serial_end, compression, checksum, data, data_serial):
raise UnexpectedPacketError
def askTIDs(self, conn, first, last, partition):
......@@ -323,6 +323,11 @@ class EventHandler(object):
def notifyReplicationDone(self, conn, offset):
raise UnexpectedPacketError
def askUndoTransaction(self, conn, tid, undone_tid):
raise UnexpectedPacketError
def answerUndoTransaction(self, conn, oid_list, error_oid_list, conflict_oid_list):
raise UnexpectedPacketError
# Error packet handlers.
......@@ -427,6 +432,8 @@ class EventHandler(object):
d[Packets.NotifyClusterInformation] = self.notifyClusterInformation
d[Packets.NotifyLastOID] = self.notifyLastOID
d[Packets.NotifyReplicationDone] = self.notifyReplicationDone
d[Packets.AskUndoTransaction] = self.askUndoTransaction
d[Packets.AnswerUndoTransaction] = self.answerUndoTransaction
return d
......
......@@ -18,6 +18,7 @@
from struct import pack, unpack, error, calcsize
from socket import inet_ntoa, inet_aton
from neo.profiling import profiler_decorator
from cStringIO import StringIO
from neo.util import Enum
......@@ -102,6 +103,7 @@ INVALID_OID = '\xff' * 8
INVALID_PTID = '\0' * 8
INVALID_SERIAL = INVALID_TID
INVALID_PARTITION = 0xffffffff
OID_LEN = len(INVALID_OID)
UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S',
......@@ -988,25 +990,30 @@ class AnswerObject(Packet):
"""
Answer the requested object. S -> C.
"""
_header_format = '!8s8s8sBL'
_header_format = '!8s8s8s8sBL'
def _encode(self, oid, serial_start, serial_end, compression,
checksum, data):
checksum, data, data_serial):
if serial_start is None:
serial_start = INVALID_TID
if serial_end is None:
serial_end = INVALID_TID
if data_serial is None:
data_serial = INVALID_TID
return pack(self._header_format, oid, serial_start, serial_end,
compression, checksum) + _encodeString(data)
data_serial, compression, checksum) + _encodeString(data)
def _decode(self, body):
header_len = self._header_len
r = unpack(self._header_format, body[:header_len])
oid, serial_start, serial_end, compression, checksum = r
oid, serial_start, serial_end, data_serial, compression, checksum = r
if serial_end == INVALID_TID:
serial_end = None
if data_serial == INVALID_TID:
data_serial = None
(data, _) = _decodeString(body, 'data', offset=header_len)
return (oid, serial_start, serial_end, compression, checksum, data)
return (oid, serial_start, serial_end, compression, checksum, data,
data_serial)
class AskTIDs(Packet):
"""
......@@ -1354,7 +1361,8 @@ class AnswerNewNodes(Packet):
def _encode(self, uuid_list):
list_header_format = self._list_header_format
# an empty list means no new nodes
uuid_list = [pack(list_header_format, _encodeUUID(uuid)) for uuid in uuid_list]
uuid_list = [pack(list_header_format, _encodeUUID(uuid)) for \
uuid in uuid_list]
return pack(self._header_format, len(uuid_list)) + ''.join(uuid_list)
def _decode(self, body):
......@@ -1472,6 +1480,56 @@ class NotifyLastOID(Packet):
(loid, ) = unpack('8s', body)
return (loid, )
class AskUndoTransaction(Packet):
"""
Ask storage to undo given transaction
C -> S
"""
def _encode(self, tid, undone_tid):
return _encodeTID(tid) + _encodeTID(undone_tid)
def _decode(self, body):
tid = _decodeTID(body[:8])
undone_tid = _decodeTID(body[8:])
return (tid, undone_tid)
class AnswerUndoTransaction(Packet):
"""
Answer an undo request, telling if undo could be done, with an oid list.
If undo failed, the list contains oid(s) causing problems.
If undo succeeded; the list contains all undone oids for given storage.
S -> C
"""
_header_format = '!LLL'
def _encode(self, oid_list, error_oid_list, conflict_oid_list):
body = StringIO()
write = body.write
oid_list_list = [oid_list, error_oid_list, conflict_oid_list]
write(pack(self._header_format, *[len(x) for x in oid_list_list]))
for oid_list in oid_list_list:
for oid in oid_list:
write(oid)
return body.getvalue()
def _decode(self, body):
body = StringIO(body)
read = body.read
oid_list_len, error_oid_list_len, conflict_oid_list_len = unpack(
self._header_format, read(self._header_len))
oid_list = []
error_oid_list = []
conflict_oid_list = []
for some_list, some_list_len in (
(oid_list, oid_list_len),
(error_oid_list, error_oid_list_len),
(conflict_oid_list, conflict_oid_list_len),
):
append = some_list.append
for _ in xrange(some_list_len):
append(read(OID_LEN))
return (oid_list, error_oid_list, conflict_oid_list)
class Error(Packet):
"""
Error is a special type of message, because this can be sent against
......@@ -1671,6 +1729,10 @@ class PacketRegistry(dict):
AnswerClusterState)
NotifyLastOID = register(0x0030, NotifyLastOID)
NotifyReplicationDone = register(0x0031, NotifyReplicationDone)
AskUndoTransaction, AnswerUndoTransaction = register(
0x0033,
AskUndoTransaction,
AnswerUndoTransaction)
# build a "singleton"
Packets = PacketRegistry()
......
......@@ -236,6 +236,16 @@ class DatabaseManager(object):
pack state (True for packed)."""
raise NotImplementedError
def getTransactionUndoData(self, tid, undone_tid,
getObjectFromTransaction):
"""Undo transaction with "undone_tid" tid. "tid" is the tid of the
transaction in which the undo happens.
getObjectFromTransaction is a callback allowing to find object data
stored to this storage in the same transaction (it is useful for
example when undoing twice in the same transaction).
"""
raise NotImplementedError
def finishTransaction(self, tid):
"""Finish a transaction specified by a given ID, by moving
temporarily data to a finished area."""
......
This diff is collapsed.
......@@ -82,19 +82,22 @@ class BaseClientAndStorageOperationHandler(EventHandler):
t[4], t[0])
conn.answer(p)
def _askObject(self, oid, serial, tid):
raise NotImplementedError
def askObject(self, conn, oid, serial, tid):
app = self.app
if self.app.tm.loadLocked(oid):
# Delay the response.
app.queueEvent(self.askObject, conn, oid, serial, tid)
return
o = app.dm.getObject(oid, serial, tid)
o = self._askObject(oid, serial, tid)
if o is not None:
serial, next_serial, compression, checksum, data = o
serial, next_serial, compression, checksum, data, data_serial = o
logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial))
p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data)
compression, checksum, data, data_serial)
else:
logging.debug('oid = %s not found', dump(oid))
p = Errors.OidNotFound('%s does not exist' % dump(oid))
......
......@@ -22,6 +22,9 @@ from neo.storage.transactions import ConflictError, DelayedError
class ClientOperationHandler(BaseClientAndStorageOperationHandler):
def _askObject(self, oid, serial, tid):
return self.app.dm.getObject(oid, serial, tid)
def timeoutExpired(self, conn):
self.app.tm.abortFor(conn.getUUID())
BaseClientAndStorageOperationHandler.timeoutExpired(self, conn)
......@@ -49,7 +52,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
uuid = conn.getUUID()
try:
self.app.tm.storeObject(uuid, tid, serial, oid, compression,
checksum, data)
checksum, data, None)
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
except ConflictError, err:
# resolvable or not
......@@ -76,3 +79,38 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerTIDs(tid_list))
def askUndoTransaction(self, conn, tid, undone_tid):
app = self.app
tm = app.tm
storeObject = tm.storeObject
uuid = conn.getUUID()
oid_list = []
error_oid_list = []
conflict_oid_list = []
undo_tid_dict = app.dm.getTransactionUndoData(tid, undone_tid,
tm.getObjectFromTransaction)
for oid, (current_serial, undone_value_serial) in \
undo_tid_dict.iteritems():
if undone_value_serial == -1:
# Some data were modified by a later transaction
# This must be propagated to client, who will
# attempt a conflict resolution, and store resolved
# data.
to_append_list = error_oid_list
else:
try:
storeObject(uuid, tid, current_serial, oid, None,
None, None, undone_value_serial)
except ConflictError:
to_append_list = conflict_oid_list
except DelayedError:
app.queueEvent(self.askUndoTransaction, conn, tid,
undone_tid)
return
else:
to_append_list = oid_list
to_append_list.append(oid)
conn.answer(Packets.AnswerUndoTransaction(oid_list, error_oid_list,
conflict_oid_list))
......@@ -129,13 +129,13 @@ class ReplicationHandler(EventHandler):
conn.ask(p, timeout=300)
def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data):
serial_end, compression, checksum, data, data_serial):
app = self.app
if app.replicator.current_connection is not conn:
return
# Directly store the transaction.
obj = (oid, compression, checksum, data)
obj = (oid, compression, checksum, data, data_serial)
app.dm.storeTransaction(serial_start, [obj], None, False)
del obj
del data
......
......@@ -21,6 +21,9 @@ from neo.protocol import Packets
class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def _askObject(self, oid, serial, tid):
return self.app.dm.getObject(oid, serial, tid, resolve_data=False)
def askLastIDs(self, conn):
app = self.app
oid = app.dm.getLastOID()
......
......@@ -71,11 +71,15 @@ class Transaction(object):
# assert self._transaction is not None
self._transaction = (oid_list, user, desc, ext, packed)
def addObject(self, oid, compression, checksum, data):
def addObject(self, oid, compression, checksum, data, value_serial):
"""
Add an object to the transaction
"""
self._object_dict[oid] = (oid, compression, checksum, data)
self._object_dict[oid] = (oid, compression, checksum, data,
value_serial)
def getObject(self, oid):
return self._object_dict.get(oid)
def getObjectList(self):
return self._object_dict.values()
......@@ -118,6 +122,16 @@ class TransactionManager(object):
self._transaction_dict[tid] = transaction
return transaction
def getObjectFromTransaction(self, tid, oid):
"""
Return object data for given running transaction.
Return None if not found.
"""
result = self._transaction_dict.get(tid)
if result is not None:
result = result.getObject(oid)
return result
def setLastOID(self, oid):
assert oid >= self._loid
self._loid = oid
......@@ -168,12 +182,17 @@ class TransactionManager(object):
transaction = self._getTransaction(tid, uuid)
transaction.prepare(oid_list, user, desc, ext, packed)
def storeObject(self, uuid, tid, serial, oid, compression, checksum, data):
def storeObject(self, uuid, tid, serial, oid, compression, checksum, data,
value_serial):
"""
Store an object received from client node
"""
# check if the object if locked
locking_tid = self._store_lock_dict.get(oid, None)
if locking_tid == tid:
logging.info('Transaction %s storing %s more than once', dump(tid),
dump(oid))
else:
if locking_tid is not None:
if locking_tid < tid:
# a previous transaction lock this object, retry later
......@@ -191,7 +210,7 @@ class TransactionManager(object):
# store object
transaction = self._getTransaction(tid, uuid)
transaction.addObject(oid, compression, checksum, data)
transaction.addObject(oid, compression, checksum, data, value_serial)
self._store_lock_dict[oid] = tid
# update loid
......
......@@ -220,7 +220,7 @@ class ClientApplicationTests(NeoTestBase):
oid = self.makeOID()
tid1 = self.makeTID(1)
tid2 = self.makeTID(2)
an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ')
an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None)
# connection to SN close
self.assertTrue(oid not in mq)
packet = Errors.OidNotFound('')
......@@ -260,7 +260,7 @@ class ClientApplicationTests(NeoTestBase):
'fakeReceived': packet,
})
app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = an_object
app.local_var.asked_object = an_object[:-1]
result = app.load(oid)
self.assertEquals(result, ('OBJ', tid1))
self.checkAskObject(conn)
......@@ -299,7 +299,8 @@ class ClientApplicationTests(NeoTestBase):
# now a cached version ewxists but should not be hit
mq.store(oid, (tid1, 'WRONG'))
self.assertTrue(oid in mq)
another_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum('RIGHT'), 'RIGHT')
another_object = (1, oid, tid2, INVALID_SERIAL, 0,
makeChecksum('RIGHT'), 'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:])
packet.setId(0)
conn = Mock({
......@@ -307,7 +308,7 @@ class ClientApplicationTests(NeoTestBase):
'fakeReceived': packet,
})
app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = another_object
app.local_var.asked_object = another_object[:-1]
result = app.loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT')
self.checkAskObject(conn)
......@@ -334,7 +335,8 @@ class ClientApplicationTests(NeoTestBase):
self.assertRaises(NEOStorageNotFoundError, app.loadBefore, oid, tid2)
self.checkAskObject(conn)
# no previous versions -> return None
an_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum(''), '')
an_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum(''), '',
None)
packet = Packets.AnswerObject(*an_object[1:])
packet.setId(0)
conn = Mock({
......@@ -342,7 +344,7 @@ class ClientApplicationTests(NeoTestBase):
'fakeReceived': packet,
})
app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = an_object
app.local_var.asked_object = an_object[:-1]
result = app.loadBefore(oid, tid1)
self.assertEquals(result, None)
# object should not have been cached
......@@ -350,7 +352,8 @@ class ClientApplicationTests(NeoTestBase):
# as for loadSerial, the object is cached but should be loaded from db
mq.store(oid, (tid1, 'WRONG'))
self.assertTrue(oid in mq)
another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'), 'RIGHT')
another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'),
'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:])
packet.setId(0)
conn = Mock({
......@@ -731,65 +734,97 @@ class ClientApplicationTests(NeoTestBase):
self.storeObject(app, oid=oid2, data='O2V2')
self.voteTransaction(app)
self.askFinishTransaction(app)
# undo 2 -> not end tid
# undo 1 -> undoing non-last TID, and conflict resolution succeeded
u1p1 = Packets.AnswerTransactionInformation(tid1, '', '', '',
False, (oid2, ))
u1p2 = Packets.AnswerUndoTransaction([], [oid2], [])
# undo 2 -> undoing non-last TID, and conflict resolution failed
u2p1 = Packets.AnswerTransactionInformation(tid2, '', '', '',
False, (oid2, ))
u2p2 = Packets.AnswerObject(oid2, tid1, tid2, 0, makeChecksum('O2V1'), 'O2V1')
u2p3 = Packets.AnswerObject(oid2, tid2, tid3, 0, makeChecksum('O2V2'), 'O2V2')
# undo 3 -> conflict
u2p2 = Packets.AnswerUndoTransaction([], [oid2], [])
# undo 3 -> "live" conflict (another transaction modifying the object
# we want to undo)
u3p1 = Packets.AnswerTransactionInformation(tid3, '', '', '',
False, (oid3, ))
u3p2 = Packets.AnswerObject(oid3, tid3, tid3, 0, makeChecksum('O3V1'), 'O3V1')
u3p3 = Packets.AnswerObject(oid3, tid3, tid3, 0, makeChecksum('O3V1'), 'O3V1')
u3p4 = Packets.AnswerObject(oid3, tid3, tid3, 0, makeChecksum('O3V1'), 'O3V1')
u3p5 = Packets.AnswerStoreObject(conflicting=1, oid=oid3, serial=tid2)
# undo 4 -> ok
u3p2 = Packets.AnswerUndoTransaction([], [], [oid3])
# undo 4 -> undoing last tid
u4p1 = Packets.AnswerTransactionInformation(tid3, '', '', '',
False, (oid1, ))
u4p2 = Packets.AnswerObject(oid1, tid3, tid3, 0, makeChecksum('O1V1'), 'O1V1')
u4p3 = Packets.AnswerObject(oid1, tid3, tid3, 0, makeChecksum('O1V1'), 'O1V1')
u4p4 = Packets.AnswerObject(oid1, tid3, tid3, 0, makeChecksum('O1V1'), 'O1V1')
u4p5 = Packets.AnswerStoreObject(conflicting=0, oid=oid1, serial=tid2)
u4p2 = Packets.AnswerUndoTransaction([oid1], [], [])
# test logic
packets = (u2p1, u2p2, u2p3, u3p1, u3p2, u3p3, u3p4, u3p5, u4p1, u4p2,
u4p3, u4p4, u4p5)
packets = (u1p1, u1p2, u2p1, u2p2, u3p1, u3p2, u4p1, u4p2)
for i, p in enumerate(packets):
p.setId(i)
storage_address = ('127.0.0.1', 10010)
conn = Mock({
'getNextId': 1,
'fakeReceived': ReturnValues(
u2p1, u2p2, u2p3,
u4p1, u4p2, u4p3, u4p4,
u3p1, u3p2, u3p3, u3p4,
u1p1,
u2p1,
u4p1,
u3p1,
),
'getAddress': storage_address,
})
cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', })
cell = Mock({
'getAddress': 'FakeServer',
'getState': 'FakeState',
})
app.pt = Mock({
'getCellListForTID': (cell, ),
'getCellListForOID': (cell, ),
})
app.cp = Mock({ 'getConnForCell': conn})
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data):
app.cp = Mock({'getConnForCell': conn, 'getConnForNode': conn})
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
marker.append(1)
return resolution_result
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher()
def _load(oid, tid=None, serial=None):
assert tid is not None
assert serial is None, serial
return ('dummy', oid, tid)
app._load = _load
app.nm.createStorage(address=storage_address)
txn4 = self.beginTransaction(app, tid=tid4)
# all start here
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = 'solved'
app.local_var.queue.put((conn, u1p2))
app.undo(tid1, txn4, tryToResolveConflict)
self.assertEquals(marker, [1])
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = None
app.local_var.queue.put((conn, u2p2))
self.assertRaises(UndoError, app.undo, tid2, txn4,
tryToResolveConflict)
app.local_var.queue.put((conn, u4p5))
self.assertEquals(marker, [1])
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = None
app.local_var.queue.put((conn, u4p2))
self.assertEquals(app.undo(tid3, txn4, tryToResolveConflict),
(tid4, [oid1, ]))
app.local_var.queue.put((conn, u3p5))
self.assertEquals(marker, [])
app.local_var.clear()
txn4 = self.beginTransaction(app, tid=tid4)
marker = []
resolution_result = None
app.local_var.queue.put((conn, u3p2))
self.assertRaises(ConflictError, app.undo, tid3, txn4,
tryToResolveConflict)
self.assertEquals(marker, [1])
self.assertEquals(marker, [])
self.askFinishTransaction(app)
def test_undoLog(self):
......
......@@ -76,10 +76,14 @@ class StorageAnswerHandlerTests(NeoTestBase):
oid = self.getOID(0)
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
the_object = (oid, tid1, tid2, 0, '', 'DATA')
the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.app.local_var.asked_object = None
self.handler.answerObject(conn, *the_object)
self.assertEqual(self.app.local_var.asked_object, the_object)
self.assertEqual(self.app.local_var.asked_object, the_object[:-1])
# Check handler raises on non-None data_serial.
the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID())
self.app.local_var.asked_object = None
self.assertRaises(ValueError, self.handler.answerObject, conn, *the_object)
def test_answerStoreObject(self):
conn = self.getConnection()
......@@ -162,6 +166,28 @@ class StorageAnswerHandlerTests(NeoTestBase):
self.assertTrue(uuid in self.app.local_var.node_tids)
self.assertEqual(self.app.local_var.node_tids[uuid], tid_list)
def test_answerUndoTransaction(self):
local_var = self.app.local_var
undo_conflict_oid_list = local_var.undo_conflict_oid_list = []
undo_error_oid_list = local_var.undo_error_oid_list = []
data_dict = local_var.data_dict = {}
conn = None # Nothing is done on connection in this handler
# Nothing undone, check nothing changed
self.handler.answerUndoTransaction(conn, [], [], [])
self.assertEqual(undo_conflict_oid_list, [])
self.assertEqual(undo_error_oid_list, [])
self.assertEqual(data_dict, {})
# One OID for each case, check they are inserted in expected local_var
# entries.
oid_1 = self.getOID(0)
oid_2 = self.getOID(1)
oid_3 = self.getOID(2)
self.handler.answerUndoTransaction(conn, [oid_1], [oid_2], [oid_3])
self.assertEqual(undo_conflict_oid_list, [oid_3])
self.assertEqual(undo_error_oid_list, [oid_2])
self.assertEqual(data_dict, {oid_1: ''})
if __name__ == '__main__':
unittest.main()
......
......@@ -20,10 +20,11 @@ from mock import Mock
from collections import deque
from neo.tests import NeoTestBase
from neo.storage.app import Application
from neo.storage.transactions import ConflictError
from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.handlers.client import ClientOperationHandler
from neo.protocol import INVALID_PARTITION
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
from neo.protocol import Packets
class StorageClientHandlerTests(NeoTestBase):
......@@ -126,7 +127,7 @@ class StorageClientHandlerTests(NeoTestBase):
def test_24_askObject3(self):
# object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )})
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', None)})
conn = Mock({})
self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID,
......@@ -225,7 +226,7 @@ class StorageClientHandlerTests(NeoTestBase):
self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, tid)
self._checkStoreObjectCalled(uuid, tid, serial, oid, comp,
checksum, data)
checksum, data, None)
self.checkAnswerStoreObject(conn)
def test_askStoreObject2(self):
......@@ -250,5 +251,44 @@ class StorageClientHandlerTests(NeoTestBase):
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
def test_askUndoTransaction(self):
conn = self._getConnection()
tid = self.getNextTID()
undone_tid = self.getNextTID()
oid_1 = self.getNextTID()
oid_2 = self.getNextTID()
oid_3 = self.getNextTID()
oid_4 = self.getNextTID()
def getTransactionUndoData(tid, undone_tid, getObjectFromTransaction):
return {
oid_1: (1, 1),
oid_2: (1, -1),
oid_3: (1, 2),
oid_4: (1, 3),
}
self.app.dm.getTransactionUndoData = getTransactionUndoData
original_storeObject = self.app.tm.storeObject
def storeObject(uuid, tid, serial, oid, *args, **kw):
if oid == oid_3:
raise ConflictError(0)
elif oid == oid_4 and delay_store:
raise DelayedError
return original_storeObject(uuid, tid, serial, oid, *args, **kw)
self.app.tm.storeObject = storeObject
# Check if delaying a store (of oid_4) is supported
delay_store = True
self.operation.askUndoTransaction(conn, tid, undone_tid)
self.checkNoPacketSent(conn)
delay_store = False
self.operation.askUndoTransaction(conn, tid, undone_tid)
oid_list_1, oid_list_2, oid_list_3 = self.checkAnswerPacket(conn,
Packets.AnswerUndoTransaction, decode=True)
# Compare sets as order doens't matter here.
self.assertEqual(set(oid_list_1), set([oid_1, oid_4]))
self.assertEqual(oid_list_2, [oid_2])
self.assertEqual(oid_list_3, [oid_3])
if __name__ == "__main__":
unittest.main()
......@@ -93,12 +93,13 @@ class StorageStorageHandlerTests(NeoTestBase):
calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEquals(len(self.app.event_queue), 0)
self.assertEquals(len(calls), 1)
calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID)
calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID,
resolve_data=False)
self.checkErrorPacket(conn)
def test_24_askObject3(self):
# object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )})
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', None)})
conn = Mock({})
self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID,
......
This diff is collapsed.
......@@ -52,8 +52,8 @@ class TransactionTests(NeoTestBase):
def testObjects(self):
txn = Transaction(self.getNewUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2)
object1 = (oid1, 1, '1', 'O1')
object2 = (oid2, 1, '2', 'O2')
object1 = (oid1, 1, '1', 'O1', None)
object2 = (oid2, 1, '2', 'O2', None)
self.assertEqual(txn.getObjectList(), [])
self.assertEqual(txn.getOIDList(), [])
txn.addObject(*object1)
......@@ -63,6 +63,14 @@ class TransactionTests(NeoTestBase):
self.assertEqual(txn.getObjectList(), [object1, object2])
self.assertEqual(txn.getOIDList(), [oid1, oid2])
def test_getObject(self):
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
txn = Transaction(self.getNewUUID(), self.getNextTID())
object_info = (oid_1, None, None, None, None)
txn.addObject(*object_info)
self.assertEqual(txn.getObject(oid_2), None)
self.assertEqual(txn.getObject(oid_1), object_info)
class TransactionManagerTests(NeoTestBase):
......@@ -81,7 +89,7 @@ class TransactionManagerTests(NeoTestBase):
def _getObject(self, value):
oid = self.getOID(value)
serial = self.getNextTID()
return (serial, (oid, 1, str(value), 'O' + str(value)))
return (serial, (oid, 1, str(value), 'O' + str(value), None))
def _checkTransactionStored(self, *args):
calls = self.app.dm.mockGetNamedCalls('storeTransaction')
......@@ -254,6 +262,19 @@ class TransactionManagerTests(NeoTestBase):
self.assertFalse(tid in self.manager)
self.assertFalse(self.manager.loadLocked(obj[0]))
def test_getObjectFromTransaction(self):
uuid = self.getNewUUID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
self.manager.storeObject(uuid, tid1, serial1, *obj1)
self.assertEqual(self.manager.getObjectFromTransaction(tid2, obj1[0]),
None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]),
None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]),
obj1)
if __name__ == "__main__":
unittest.main()
......@@ -382,14 +382,18 @@ class ProtocolTests(NeoTestBase):
oid = self.getNextTID()
serial_start = self.getNextTID()
serial_end = self.getNextTID()
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to",)
poid, pserial_start, pserial_end, compression, checksum, data= p.decode()
data_serial = self.getNextTID()
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to",
data_serial)
poid, pserial_start, pserial_end, compression, checksum, data, \
pdata_serial = p.decode()
self.assertEqual(oid, poid)
self.assertEqual(serial_start, pserial_start)
self.assertEqual(serial_end, pserial_end)
self.assertEqual(compression, 1)
self.assertEqual(checksum, 55)
self.assertEqual(data, "to")
self.assertEqual(pdata_serial, data_serial)
def test_49_askTIDs(self):
p = Packets.AskTIDs(1, 10, 5)
......@@ -474,6 +478,24 @@ class ProtocolTests(NeoTestBase):
p_offset = p.decode()[0]
self.assertEqual(p_offset, offset)
def test_askUndoTransaction(self):
tid = self.getNextTID()
undone_tid = self.getNextTID()
p = Packets.AskUndoTransaction(tid, undone_tid)
p_tid, p_undone_tid = p.decode()
self.assertEqual(p_tid, tid)
self.assertEqual(p_undone_tid, undone_tid)
def test_answerUndoTransaction(self):
oid_list_1 = [self.getNextTID()]
oid_list_2 = [self.getNextTID(), self.getNextTID()]
oid_list_3 = [self.getNextTID(), self.getNextTID(), self.getNextTID()]
p = Packets.AnswerUndoTransaction(oid_list_1, oid_list_2, oid_list_3)
p_oid_list_1, p_oid_list_2, p_oid_list_3 = p.decode()
self.assertEqual(p_oid_list_1, oid_list_1)
self.assertEqual(p_oid_list_2, oid_list_2)
self.assertEqual(p_oid_list_3, oid_list_3)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment