Commit 4e402dda authored by Vincent Pelletier's avatar Vincent Pelletier

Implement MVCC.

Remove round-trip to master upon "load" call.
Move load/loadBefore/loadSerial/loadEx from app.py to Storage.py.
This is required to get rid of master node round-trip upon each "load"
call.
Get rid of no-op-ish "sync" implementation.
Separate "undoing transaction ID" from "undoing transaction database
snapshot" when undoing.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2532 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 71a0de50
...@@ -20,6 +20,8 @@ from zope.interface import implements ...@@ -20,6 +20,8 @@ from zope.interface import implements
import ZODB.interfaces import ZODB.interfaces
from neo import setupLog from neo import setupLog
from neo.util import add64
from neo.protocol import ZERO_TID
from neo.client.app import Application from neo.client.app import Application
from neo.client.exception import NEOStorageNotFoundError from neo.client.exception import NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError from neo.client.exception import NEOStorageDoesNotExistError
...@@ -42,6 +44,8 @@ class Storage(BaseStorage.BaseStorage, ...@@ -42,6 +44,8 @@ class Storage(BaseStorage.BaseStorage,
ConflictResolution.ConflictResolvingStorage): ConflictResolution.ConflictResolvingStorage):
"""Wrapper class for neoclient.""" """Wrapper class for neoclient."""
_snapshot_tid = None
implements( implements(
ZODB.interfaces.IStorage, ZODB.interfaces.IStorage,
# "restore" missing for the moment, but "store" implements this # "restore" missing for the moment, but "store" implements this
...@@ -54,19 +58,67 @@ class Storage(BaseStorage.BaseStorage, ...@@ -54,19 +58,67 @@ class Storage(BaseStorage.BaseStorage,
ZODB.interfaces.IStorageUndoable, ZODB.interfaces.IStorageUndoable,
ZODB.interfaces.IExternalGC, ZODB.interfaces.IExternalGC,
ZODB.interfaces.ReadVerifyingStorage, ZODB.interfaces.ReadVerifyingStorage,
ZODB.interfaces.IMVCCStorage,
) )
def __init__(self, master_nodes, name, connector=None, read_only=False, def __init__(self, master_nodes, name, connector=None, read_only=False,
compress=None, logfile=None, verbose=False, **kw): compress=None, logfile=None, verbose=False,
_app=None, _cache=None,
**kw):
"""
Do not pass those parameters (used internally):
_app
_cache
"""
if compress is None: if compress is None:
compress = True compress = True
setupLog('CLIENT', filename=logfile, verbose=verbose) setupLog('CLIENT', filename=logfile, verbose=verbose)
BaseStorage.BaseStorage.__init__(self, 'NEOStorage(%s)' % (name, )) BaseStorage.BaseStorage.__init__(self, 'NEOStorage(%s)' % (name, ))
# Warning: _is_read_only is used in BaseStorage, do not rename it. # Warning: _is_read_only is used in BaseStorage, do not rename it.
self._is_read_only = read_only self._is_read_only = read_only
self.app = Application(master_nodes, name, connector, if _app is None:
_app = Application(master_nodes, name, connector,
compress=compress) compress=compress)
self._cache = DummyCache(self.app) assert _cache is None
_cache = DummyCache(_app)
self.app = _app
assert _cache is not None
self._cache = _cache
# Used to clone self (see new_instance & IMVCCStorage definition).
self._init_args = (master_nodes, name)
self._init_kw = {
'connector': connector,
'read_only': read_only,
'compress': compress,
'logfile': logfile,
'verbose': verbose,
'_app': _app,
'_cache': _cache,
}
def _getSnapshotTID(self):
"""
Get the highest TID visible for current transaction.
First call sets this snapshot by asking master node most recent
committed TID.
As a (positive) side-effect, this forces us to handle all pending
invalidations, so we get a very recent view of the database (which is
good when multiple databases are used in the same program with some
amount of referential integrity).
"""
tid = self._snapshot_tid
if tid is None:
tid = self.lastTransaction()
if tid is ZERO_TID:
raise NEOStorageDoesNotExistError('No transaction in storage')
# Increment by one, as we will use this as an excluded upper
# bound (loadBefore).
tid = add64(tid, 1)
self._snapshot_tid = tid
return tid
def _load(self, *args, **kw):
return self.app.load(self._getSnapshotTID(), *args, **kw)
def load(self, oid, version=''): def load(self, oid, version=''):
# XXX: interface deifinition states that version parameter is # XXX: interface deifinition states that version parameter is
...@@ -74,7 +126,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -74,7 +126,7 @@ class Storage(BaseStorage.BaseStorage,
# it optional. # it optional.
assert version == '', 'Versions are not supported' assert version == '', 'Versions are not supported'
try: try:
return self.app.load(oid=oid) return self._load(oid)[:2]
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
raise POSException.POSKeyError(oid) raise POSException.POSKeyError(oid)
...@@ -97,11 +149,14 @@ class Storage(BaseStorage.BaseStorage, ...@@ -97,11 +149,14 @@ class Storage(BaseStorage.BaseStorage,
@check_read_only @check_read_only
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
self.sync()
return self.app.tpc_abort(transaction=transaction) return self.app.tpc_abort(transaction=transaction)
def tpc_finish(self, transaction, f=None): def tpc_finish(self, transaction, f=None):
return self.app.tpc_finish(transaction=transaction, result = self.app.tpc_finish(transaction=transaction,
tryToResolveConflict=self.tryToResolveConflict, f=f) tryToResolveConflict=self.tryToResolveConflict, f=f)
self.sync()
return result
@check_read_only @check_read_only
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
...@@ -117,13 +172,13 @@ class Storage(BaseStorage.BaseStorage, ...@@ -117,13 +172,13 @@ class Storage(BaseStorage.BaseStorage,
# mutliple revisions # mutliple revisions
def loadSerial(self, oid, serial): def loadSerial(self, oid, serial):
try: try:
return self.app.loadSerial(oid=oid, serial=serial) return self._load(oid, serial=serial)[0]
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
raise POSException.POSKeyError(oid) raise POSException.POSKeyError(oid)
def loadBefore(self, oid, tid): def loadBefore(self, oid, tid):
try: try:
return self.app.loadBefore(oid=oid, tid=tid) return self._load(oid, tid=tid)
except NEOStorageDoesNotExistError: except NEOStorageDoesNotExistError:
raise POSException.POSKeyError(oid) raise POSException.POSKeyError(oid)
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
...@@ -135,8 +190,8 @@ class Storage(BaseStorage.BaseStorage, ...@@ -135,8 +190,8 @@ class Storage(BaseStorage.BaseStorage,
# undo # undo
@check_read_only @check_read_only
def undo(self, transaction_id, txn): def undo(self, transaction_id, txn):
return self.app.undo(undone_tid=transaction_id, txn=txn, return self.app.undo(self._getSnapshotTID(), undone_tid=transaction_id,
tryToResolveConflict=self.tryToResolveConflict) txn=txn, tryToResolveConflict=self.tryToResolveConflict)
@check_read_only @check_read_only
...@@ -159,9 +214,10 @@ class Storage(BaseStorage.BaseStorage, ...@@ -159,9 +214,10 @@ class Storage(BaseStorage.BaseStorage,
def loadEx(self, oid, version): def loadEx(self, oid, version):
try: try:
return self.app.loadEx(oid=oid, version=version) data, serial, _ = self._load(oid)
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
raise POSException.POSKeyError(oid) raise POSException.POSKeyError(oid)
return data, serial, ''
def __len__(self): def __len__(self):
return self.app.getStorageSize() return self.app.getStorageSize()
...@@ -172,8 +228,8 @@ class Storage(BaseStorage.BaseStorage, ...@@ -172,8 +228,8 @@ class Storage(BaseStorage.BaseStorage,
def history(self, oid, version=None, size=1, filter=None): def history(self, oid, version=None, size=1, filter=None):
return self.app.history(oid, version, size, filter) return self.app.history(oid, version, size, filter)
def sync(self): def sync(self, force=True):
self.app.sync() self._snapshot_tid = None
def copyTransactionsFrom(self, source, verbose=False): def copyTransactionsFrom(self, source, verbose=False):
""" Zope compliant API """ """ Zope compliant API """
...@@ -217,9 +273,23 @@ class Storage(BaseStorage.BaseStorage, ...@@ -217,9 +273,23 @@ class Storage(BaseStorage.BaseStorage,
def close(self): def close(self):
self.app.close() self.app.close()
def getTID(self, oid): def getTid(self, oid):
try:
return self.app.getLastTID(oid) return self.app.getLastTID(oid)
except NEOStorageNotFoundError:
raise KeyError
def checkCurrentSerialInTransaction(self, oid, serial, transaction): def checkCurrentSerialInTransaction(self, oid, serial, transaction):
self.app.checkCurrentSerialInTransaction(oid, serial, transaction) self.app.checkCurrentSerialInTransaction(oid, serial, transaction)
def new_instance(self):
return Storage(*self._init_args, **self._init_kw)
def poll_invalidations(self):
"""
Nothing to do, NEO doesn't need any polling.
"""
pass
release = sync
...@@ -115,7 +115,6 @@ class ThreadContext(object): ...@@ -115,7 +115,6 @@ class ThreadContext(object):
'asked_object': 0, 'asked_object': 0,
'undo_object_tid_dict': {}, 'undo_object_tid_dict': {},
'involved_nodes': set(), 'involved_nodes': set(),
'barrier_done': False,
'last_transaction': None, 'last_transaction': None,
} }
...@@ -564,10 +563,13 @@ class Application(object): ...@@ -564,10 +563,13 @@ class Application(object):
return int(u64(self.last_oid)) return int(u64(self.last_oid))
@profiler_decorator @profiler_decorator
def _load(self, oid, serial=None, tid=None): def load(self, snapshot_tid, oid, serial=None, tid=None):
""" """
Internal method which manage load, loadSerial and loadBefore. Internal method which manage load, loadSerial and loadBefore.
OID and TID (serial) parameters are expected packed. OID and TID (serial) parameters are expected packed.
snapshot_tid
First TID not visible to current transaction.
Set to None for no limit.
oid oid
OID of object to get. OID of object to get.
serial serial
...@@ -595,15 +597,19 @@ class Application(object): ...@@ -595,15 +597,19 @@ class Application(object):
""" """
# TODO: # TODO:
# - rename parameters (here and in handlers & packet definitions) # - rename parameters (here and in handlers & packet definitions)
if snapshot_tid is not None:
if serial is None:
if tid is None:
tid = snapshot_tid
else:
tid = min(tid, snapshot_tid)
# XXX: we must not clamp serial with snapshot_tid, as loadSerial is
# used during conflict resolution to load object's current version,
# which is not visible to us normaly (it was committed after our
# snapshot was taken).
self._load_lock_acquire() self._load_lock_acquire()
try: try:
# Once per transaction, upon first load, trigger a barrier so we
# handle all pending invalidations, so the snapshot of the database
# is as up-to-date as possible.
if not self.local_var.barrier_done:
self.invalidationBarrier()
self.local_var.barrier_done = True
try: try:
result = self._loadFromCache(oid, serial, tid) result = self._loadFromCache(oid, serial, tid)
except KeyError: except KeyError:
...@@ -701,34 +707,6 @@ class Application(object): ...@@ -701,34 +707,6 @@ class Application(object):
finally: finally:
self._cache_lock_release() self._cache_lock_release()
@profiler_decorator
def load(self, oid, version=None):
"""Load an object for a given oid."""
result = self._load(oid)[:2]
# Start a network barrier, so we get all invalidations *after* we
# received data. This ensures we get any invalidation message that
# would have been about the version we loaded.
# Those invalidations are checked at ZODB level, so it decides if
# loaded data can be handed to current transaction or if a separate
# loadBefore call is required.
# XXX: A better implementation is required to improve performances
self.invalidationBarrier()
return result
@profiler_decorator
def loadSerial(self, oid, serial):
"""Load an object for a given oid and serial."""
neo.logging.debug('loading %s at %s', dump(oid), dump(serial))
return self._load(oid, serial=serial)[0]
@profiler_decorator
def loadBefore(self, oid, tid):
"""Load an object for a given oid before tid committed."""
neo.logging.debug('loading %s before %s', dump(oid), dump(tid))
return self._load(oid, tid=tid)
@profiler_decorator @profiler_decorator
def tpc_begin(self, transaction, tid=None, status=' '): def tpc_begin(self, transaction, tid=None, status=' '):
"""Begin a new transaction.""" """Begin a new transaction."""
...@@ -1047,7 +1025,7 @@ class Application(object): ...@@ -1047,7 +1025,7 @@ class Application(object):
finally: finally:
self._load_lock_release() self._load_lock_release()
def undo(self, undone_tid, txn, tryToResolveConflict): def undo(self, snapshot_tid, undone_tid, txn, tryToResolveConflict):
if txn is not self.local_var.txn: if txn is not self.local_var.txn:
raise StorageTransactionError(self, undone_tid) raise StorageTransactionError(self, undone_tid)
...@@ -1106,7 +1084,7 @@ class Application(object): ...@@ -1106,7 +1084,7 @@ class Application(object):
cell_list.sort(key=getCellSortKey) cell_list.sort(key=getCellSortKey)
storage_conn = getConnForCell(cell_list[0]) storage_conn = getConnForCell(cell_list[0])
storage_conn.ask(Packets.AskObjectUndoSerial(self.local_var.tid, storage_conn.ask(Packets.AskObjectUndoSerial(self.local_var.tid,
undone_tid, oid_list), queue=queue) snapshot_tid, undone_tid, oid_list), queue=queue)
# Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError, # Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError,
# meaning that objects in transaction's oid_list do not exist any # meaning that objects in transaction's oid_list do not exist any
...@@ -1133,9 +1111,9 @@ class Application(object): ...@@ -1133,9 +1111,9 @@ class Application(object):
# object. This is an undo conflict, try to resolve it. # object. This is an undo conflict, try to resolve it.
try: try:
# Load the latest version we are supposed to see # Load the latest version we are supposed to see
data = self.loadSerial(oid, current_serial) data = self.load(snapshot_tid, oid, serial=current_serial)[0]
# Load the version we were undoing to # Load the version we were undoing to
undo_data = self.loadSerial(oid, undo_serial) undo_data = self.load(snapshot_tid, oid, serial=undo_serial)[0]
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
raise UndoError('Object not found while resolving undo ' raise UndoError('Object not found while resolving undo '
'conflict') 'conflict')
...@@ -1346,10 +1324,6 @@ class Application(object): ...@@ -1346,10 +1324,6 @@ class Application(object):
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
return '', [] return '', []
def loadEx(self, oid, version):
data, serial = self.load(oid=oid)
return data, serial, ''
def __del__(self): def __del__(self):
"""Clear all connection.""" """Clear all connection."""
# Due to bug in ZODB, close is not always called when shutting # Due to bug in ZODB, close is not always called when shutting
...@@ -1367,9 +1341,6 @@ class Application(object): ...@@ -1367,9 +1341,6 @@ class Application(object):
def invalidationBarrier(self): def invalidationBarrier(self):
self._askPrimary(Packets.AskBarrier()) self._askPrimary(Packets.AskBarrier())
def sync(self):
self._waitAnyMessage(False)
def setNodeReady(self): def setNodeReady(self):
self.local_var.node_ready = True self.local_var.node_ready = True
...@@ -1401,7 +1372,7 @@ class Application(object): ...@@ -1401,7 +1372,7 @@ class Application(object):
self._cache_lock_release() self._cache_lock_release()
def getLastTID(self, oid): def getLastTID(self, oid):
return self._load(oid)[1] return self.load(None, oid)[1]
def checkCurrentSerialInTransaction(self, oid, serial, transaction): def checkCurrentSerialInTransaction(self, oid, serial, transaction):
local_var = self.local_var local_var = self.local_var
......
...@@ -61,7 +61,7 @@ class Transaction(BaseStorage.TransactionRecord): ...@@ -61,7 +61,7 @@ class Transaction(BaseStorage.TransactionRecord):
while oid_index < oid_len: while oid_index < oid_len:
oid = oid_list[oid_index] oid = oid_list[oid_index]
try: try:
data, _, next_tid = app._load(oid, serial=self.tid) data, _, next_tid = app.load(None, oid, serial=self.tid)
except NEOStorageCreationUndoneError: except NEOStorageCreationUndoneError:
data = next_tid = None data = next_tid = None
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
......
...@@ -1507,12 +1507,12 @@ class AskObjectUndoSerial(Packet): ...@@ -1507,12 +1507,12 @@ class AskObjectUndoSerial(Packet):
for a list of OIDs. for a list of OIDs.
C -> S C -> S
""" """
_header_format = '!8s8sL' _header_format = '!8s8s8sL'
def _encode(self, tid, undone_tid, oid_list): def _encode(self, tid, ltid, undone_tid, oid_list):
body = StringIO() body = StringIO()
write = body.write write = body.write
write(pack(self._header_format, tid, undone_tid, len(oid_list))) write(pack(self._header_format, tid, ltid, undone_tid, len(oid_list)))
for oid in oid_list: for oid in oid_list:
write(oid) write(oid)
return body.getvalue() return body.getvalue()
...@@ -1520,10 +1520,10 @@ class AskObjectUndoSerial(Packet): ...@@ -1520,10 +1520,10 @@ class AskObjectUndoSerial(Packet):
def _decode(self, body): def _decode(self, body):
body = StringIO(body) body = StringIO(body)
read = body.read read = body.read
tid, undone_tid, oid_list_len = unpack(self._header_format, tid, ltid, undone_tid, oid_list_len = unpack(self._header_format,
read(self._header_len)) read(self._header_len))
oid_list = [read(8) for _ in xrange(oid_list_len)] oid_list = [read(8) for _ in xrange(oid_list_len)]
return tid, undone_tid, oid_list return tid, ltid, undone_tid, oid_list
class AnswerObjectUndoSerial(Packet): class AnswerObjectUndoSerial(Packet):
""" """
......
...@@ -324,12 +324,15 @@ class DatabaseManager(object): ...@@ -324,12 +324,15 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def findUndoTID(self, oid, tid, undone_tid, transaction_object): def findUndoTID(self, oid, tid, ltid, undone_tid, transaction_object):
""" """
oid oid
Object OID Object OID
tid tid
Transation doing the undo Transation doing the undo
ltid
Upper (exclued) bound of transactions visible to transaction doing
the undo.
undone_tid undone_tid
Transaction to undo Transaction to undo
transaction_object transaction_object
...@@ -355,16 +358,17 @@ class DatabaseManager(object): ...@@ -355,16 +358,17 @@ class DatabaseManager(object):
p64 = util.p64 p64 = util.p64
oid = u64(oid) oid = u64(oid)
tid = u64(tid) tid = u64(tid)
ltid = u64(ltid)
undone_tid = u64(undone_tid) undone_tid = u64(undone_tid)
_getDataTID = self._getDataTID _getDataTID = self._getDataTID
if transaction_object is not None: if transaction_object is not None:
toid, tcompression, tchecksum, tdata, tvalue_serial = \ toid, tcompression, tchecksum, tdata, tvalue_serial = \
transaction_object transaction_object
current_tid, current_data_tid = self._getDataTIDFromData(oid, current_tid, current_data_tid = self._getDataTIDFromData(oid,
(tid, None, tcompression, tchecksum, tdata, (ltid, None, tcompression, tchecksum, tdata,
u64(tvalue_serial))) u64(tvalue_serial)))
else: else:
current_tid, current_data_tid = _getDataTID(oid, before_tid=tid) current_tid, current_data_tid = _getDataTID(oid, before_tid=ltid)
if current_tid is None: if current_tid is None:
return (None, None, False) return (None, None, False)
found_undone_tid, undone_data_tid = _getDataTID(oid, tid=undone_tid) found_undone_tid, undone_data_tid = _getDataTID(oid, tid=undone_tid)
......
...@@ -102,14 +102,14 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -102,14 +102,14 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
app.pt.getPartitions(), partition_list) app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerTIDs(tid_list)) conn.answer(Packets.AnswerTIDs(tid_list))
def askObjectUndoSerial(self, conn, tid, undone_tid, oid_list): def askObjectUndoSerial(self, conn, tid, ltid, undone_tid, oid_list):
app = self.app app = self.app
findUndoTID = app.dm.findUndoTID findUndoTID = app.dm.findUndoTID
getObjectFromTransaction = app.tm.getObjectFromTransaction getObjectFromTransaction = app.tm.getObjectFromTransaction
object_tid_dict = {} object_tid_dict = {}
for oid in oid_list: for oid in oid_list:
current_serial, undo_serial, is_current = findUndoTID(oid, tid, current_serial, undo_serial, is_current = findUndoTID(oid, tid,
undone_tid, getObjectFromTransaction(tid, oid)) ltid, undone_tid, getObjectFromTransaction(tid, oid))
if current_serial is None: if current_serial is None:
p = Errors.OidNotFound(dump(oid)) p = Errors.OidNotFound(dump(oid))
break break
......
...@@ -206,6 +206,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -206,6 +206,7 @@ class ClientApplicationTests(NeoUnitTestBase):
oid = self.makeOID() oid = self.makeOID()
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
snapshot_tid = self.makeTID(3)
an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None) an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None)
# connection to SN close # connection to SN close
self.assertTrue((oid, tid1) not in mq) self.assertTrue((oid, tid1) not in mq)
...@@ -221,7 +222,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -221,7 +222,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.pt = Mock({ 'getCellListForOID': [cell, ], }) app.pt = Mock({ 'getCellListForOID': [cell, ], })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
Application._waitMessage = self._waitMessage Application._waitMessage = self._waitMessage
self.assertRaises(NEOStorageError, app.load, oid) self.assertRaises(NEOStorageError, app.load, snapshot_tid, oid)
self.checkAskObject(conn) self.checkAskObject(conn)
Application._waitMessage = _waitMessage Application._waitMessage = _waitMessage
# object not found in NEO -> NEOStorageNotFoundError # object not found in NEO -> NEOStorageNotFoundError
...@@ -236,7 +237,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -236,7 +237,7 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
app.pt = Mock({ 'getCellListForOID': [cell, ], }) app.pt = Mock({ 'getCellListForOID': [cell, ], })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
self.assertRaises(NEOStorageNotFoundError, app.load, oid) self.assertRaises(NEOStorageNotFoundError, app.load, snapshot_tid, oid)
self.checkAskObject(conn) self.checkAskObject(conn)
# object found on storage nodes and put in cache # object found on storage nodes and put in cache
packet = Packets.AnswerObject(*an_object[1:]) packet = Packets.AnswerObject(*an_object[1:])
...@@ -253,7 +254,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -253,7 +254,7 @@ class ClientApplicationTests(NeoUnitTestBase):
'getNextId': 1, 'getNextId': 1,
'fakeReceived': answer_barrier, 'fakeReceived': answer_barrier,
}) })
result = app.load(oid) result = app.load(snapshot_tid, oid)[:2]
self.assertEquals(result, ('OBJ', tid1)) self.assertEquals(result, ('OBJ', tid1))
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue((oid, tid1) in mq) self.assertTrue((oid, tid1) in mq)
...@@ -262,7 +263,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -262,7 +263,7 @@ class ClientApplicationTests(NeoUnitTestBase):
'getAddress': ('127.0.0.1', 0), 'getAddress': ('127.0.0.1', 0),
}) })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
result = app.load(oid) result = app.load(snapshot_tid, oid)[:2]
self.assertEquals(result, ('OBJ', tid1)) self.assertEquals(result, ('OBJ', tid1))
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
...@@ -273,6 +274,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -273,6 +274,9 @@ class ClientApplicationTests(NeoUnitTestBase):
oid = self.makeOID() oid = self.makeOID()
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
snapshot_tid = self.makeTID(3)
def loadSerial(oid, serial):
return app.load(snapshot_tid, oid, serial=serial)[0]
# object not found in NEO -> NEOStorageNotFoundError # object not found in NEO -> NEOStorageNotFoundError
self.assertTrue((oid, tid1) not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq) self.assertTrue((oid, tid2) not in mq)
...@@ -285,7 +289,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -285,7 +289,7 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
app.pt = Mock({ 'getCellListForOID': [cell, ], }) app.pt = Mock({ 'getCellListForOID': [cell, ], })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2) self.assertRaises(NEOStorageNotFoundError, loadSerial, oid, tid2)
self.checkAskObject(conn) self.checkAskObject(conn)
# object should not have been cached # object should not have been cached
self.assertFalse((oid, tid2) in mq) self.assertFalse((oid, tid2) in mq)
...@@ -302,7 +306,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -302,7 +306,7 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = another_object[:-1] app.local_var.asked_object = another_object[:-1]
result = app.loadSerial(oid, tid1) result = loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT') self.assertEquals(result, 'RIGHT')
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue((oid, tid2) in mq) self.assertTrue((oid, tid2) in mq)
...@@ -315,6 +319,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -315,6 +319,9 @@ class ClientApplicationTests(NeoUnitTestBase):
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
tid3 = self.makeTID(3) tid3 = self.makeTID(3)
snapshot_tid = self.makeTID(4)
def loadBefore(oid, tid):
return app.load(snapshot_tid, oid, tid=tid)
# object not found in NEO -> NEOStorageDoesNotExistError # object not found in NEO -> NEOStorageDoesNotExistError
self.assertTrue((oid, tid1) not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq) self.assertTrue((oid, tid2) not in mq)
...@@ -327,7 +334,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -327,7 +334,7 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
app.pt = Mock({ 'getCellListForOID': [cell, ], }) app.pt = Mock({ 'getCellListForOID': [cell, ], })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
self.assertRaises(NEOStorageDoesNotExistError, app.loadBefore, oid, tid2) self.assertRaises(NEOStorageDoesNotExistError, loadBefore, oid, tid2)
self.checkAskObject(conn) self.checkAskObject(conn)
# no visible version -> NEOStorageNotFoundError # no visible version -> NEOStorageNotFoundError
an_object = (1, oid, INVALID_SERIAL, None, 0, 0, '', None) an_object = (1, oid, INVALID_SERIAL, None, 0, 0, '', None)
...@@ -339,7 +346,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -339,7 +346,7 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = an_object[:-1] app.local_var.asked_object = an_object[:-1]
self.assertRaises(NEOStorageError, app.loadBefore, oid, tid1) self.assertRaises(NEOStorageError, loadBefore, oid, tid1)
# object should not have been cached # object should not have been cached
self.assertFalse((oid, tid1) in mq) self.assertFalse((oid, tid1) in mq)
# as for loadSerial, the object is cached but should be loaded from db # as for loadSerial, the object is cached but should be loaded from db
...@@ -356,7 +363,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -356,7 +363,7 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = another_object app.local_var.asked_object = another_object
result = app.loadBefore(oid, tid3) result = loadBefore(oid, tid3)
self.assertEquals(result, ('RIGHT', tid2, tid3)) self.assertEquals(result, ('RIGHT', tid2, tid3))
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue((oid, tid1) in mq) self.assertTrue((oid, tid1) in mq)
...@@ -765,6 +772,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -765,6 +772,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# invalid transaction # invalid transaction
app = self.getApp() app = self.getApp()
tid = self.makeTID() tid = self.makeTID()
snapshot_tid = self.getNextTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
marker = [] marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data): def tryToResolveConflict(oid, conflict_serial, serial, data):
...@@ -774,8 +782,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -774,8 +782,8 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertFalse(app.local_var.txn is txn) self.assertFalse(app.local_var.txn is txn)
conn = Mock() conn = Mock()
cell = Mock() cell = Mock()
self.assertRaises(StorageTransactionError, app.undo, tid, txn, self.assertRaises(StorageTransactionError, app.undo, snapshot_tid, tid,
tryToResolveConflict) txn, tryToResolveConflict)
# no packet sent # no packet sent
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
...@@ -808,10 +816,10 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -808,10 +816,10 @@ class ClientApplicationTests(NeoUnitTestBase):
def pending(self, queue): def pending(self, queue):
return not queue.empty() return not queue.empty()
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
def loadSerial(oid, tid): def load(snapshot_tid, oid, serial):
self.assertEqual(oid, oid0) self.assertEqual(oid, oid0)
return {tid0: 'dummy', tid2: 'cdummy'}[tid] return ({tid0: 'dummy', tid2: 'cdummy'}[serial], None, None)
app.loadSerial = loadSerial app.load = load
store_marker = [] store_marker = []
def _store(oid, serial, data, data_serial=None): def _store(oid, serial, data, data_serial=None):
store_marker.append((oid, serial, data, data_serial)) store_marker.append((oid, serial, data, data_serial))
...@@ -832,6 +840,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -832,6 +840,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
snapshot_tid = self.getNextTID()
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1, app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2) tid2)
undo_serial = Packets.AnswerObjectUndoSerial({ undo_serial = Packets.AnswerObjectUndoSerial({
...@@ -845,7 +854,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -845,7 +854,7 @@ class ClientApplicationTests(NeoUnitTestBase):
return 'solved' return 'solved'
# The undo # The undo
txn = self.beginTransaction(app, tid=tid3) txn = self.beginTransaction(app, tid=tid3)
app.undo(tid1, txn, tryToResolveConflict) app.undo(snapshot_tid, tid1, txn, tryToResolveConflict)
# Checking what happened # Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0] moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0) self.assertEqual(moid, oid0)
...@@ -872,6 +881,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -872,6 +881,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
snapshot_tid = self.getNextTID()
undo_serial = Packets.AnswerObjectUndoSerial({ undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)}) oid0: (tid2, tid0, False)})
undo_serial.setId(2) undo_serial.setId(2)
...@@ -885,7 +895,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -885,7 +895,8 @@ class ClientApplicationTests(NeoUnitTestBase):
return None return None
# The undo # The undo
txn = self.beginTransaction(app, tid=tid3) txn = self.beginTransaction(app, tid=tid3)
self.assertRaises(UndoError, app.undo, tid1, txn, tryToResolveConflict) self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict)
# Checking what happened # Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0] moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0) self.assertEqual(moid, oid0)
...@@ -903,7 +914,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -903,7 +914,8 @@ class ClientApplicationTests(NeoUnitTestBase):
raise ConflictError raise ConflictError
# The undo # The undo
app.local_var.queue.put((conn, undo_serial)) app.local_var.queue.put((conn, undo_serial))
self.assertRaises(UndoError, app.undo, tid1, txn, tryToResolveConflict) self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict)
# Checking what happened # Checking what happened
moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0] moid, mconflict_serial, mserial, mdata, mcommittedData = marker[0]
self.assertEqual(moid, oid0) self.assertEqual(moid, oid0)
...@@ -925,6 +937,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -925,6 +937,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
snapshot_tid = self.getNextTID()
transaction_info = Packets.AnswerTransactionInformation(tid1, '', '', transaction_info = Packets.AnswerTransactionInformation(tid1, '', '',
'', False, (oid0, )) '', False, (oid0, ))
transaction_info.setId(1) transaction_info.setId(1)
...@@ -940,7 +953,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -940,7 +953,7 @@ class ClientApplicationTests(NeoUnitTestBase):
'is no conflict in this test !' 'is no conflict in this test !'
# The undo # The undo
txn = self.beginTransaction(app, tid=tid3) txn = self.beginTransaction(app, tid=tid3)
app.undo(tid1, txn, tryToResolveConflict) app.undo(snapshot_tid, tid1, txn, tryToResolveConflict)
# Checking what happened # Checking what happened
moid, mserial, mdata, mdata_serial = store_marker[0] moid, mserial, mdata, mdata_serial = store_marker[0]
self.assertEqual(moid, oid0) self.assertEqual(moid, oid0)
......
...@@ -145,10 +145,14 @@ class ClientTests(NEOFunctionalTest): ...@@ -145,10 +145,14 @@ class ClientTests(NEOFunctionalTest):
""" Check transaction isolation within zope connection """ """ Check transaction isolation within zope connection """
self.__setup() self.__setup()
t, c = self.makeTransaction() t, c = self.makeTransaction()
c.root()['item'] = 0 root = c.root()
root['item'] = 0
root['other'] = 'bla'
t.commit() t.commit()
t1, c1 = self.makeTransaction() t1, c1 = self.makeTransaction()
t2, c2 = self.makeTransaction() t2, c2 = self.makeTransaction()
# Makes c2 take a snapshot of database state
c2.root()['other']
c1.root()['item'] = 1 c1.root()['item'] = 1
t1.commit() t1.commit()
# load objet from zope cache # load objet from zope cache
...@@ -159,10 +163,14 @@ class ClientTests(NEOFunctionalTest): ...@@ -159,10 +163,14 @@ class ClientTests(NEOFunctionalTest):
""" Check isolation with zope cache cleared """ """ Check isolation with zope cache cleared """
self.__setup() self.__setup()
t, c = self.makeTransaction() t, c = self.makeTransaction()
c.root()['item'] = 0 root = c.root()
root['item'] = 0
root['other'] = 'bla'
t.commit() t.commit()
t1, c1 = self.makeTransaction() t1, c1 = self.makeTransaction()
t2, c2 = self.makeTransaction() t2, c2 = self.makeTransaction()
# Makes c2 take a snapshot of database state
c2.root()['other']
c1.root()['item'] = 1 c1.root()['item'] = 1
t1.commit() t1.commit()
# clear zope cache to force re-ask NEO # clear zope cache to force re-ask NEO
......
...@@ -271,6 +271,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -271,6 +271,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
uuid = self.getNewUUID() uuid = self.getNewUUID()
conn = self._getConnection(uuid=uuid) conn = self._getConnection(uuid=uuid)
tid = self.getNextTID() tid = self.getNextTID()
ltid = self.getNextTID()
undone_tid = self.getNextTID() undone_tid = self.getNextTID()
# Keep 2 entries here, so we check findUndoTID is called only once. # Keep 2 entries here, so we check findUndoTID is called only once.
oid_list = [self.getOID(1), self.getOID(2)] oid_list = [self.getOID(1), self.getOID(2)]
...@@ -281,7 +282,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -281,7 +282,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.app.dm = Mock({ self.app.dm = Mock({
'findUndoTID': ReturnValues((None, None, False), ) 'findUndoTID': ReturnValues((None, None, False), )
}) })
self.operation.askObjectUndoSerial(conn, tid, undone_tid, oid_list) self.operation.askObjectUndoSerial(conn, tid, ltid, undone_tid, oid_list)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_askHasLock(self): def test_askHasLock(self):
......
...@@ -689,6 +689,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -689,6 +689,7 @@ class StorageDBTests(NeoUnitTestBase):
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid5 = self.getNextTID()
oid1 = self.getOID(1) oid1 = self.getOID(1)
db.storeTransaction( db.storeTransaction(
tid1, ( tid1, (
...@@ -699,7 +700,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -699,7 +700,7 @@ class StorageDBTests(NeoUnitTestBase):
# Result: current tid is tid1, data_tid is None (undoing object # Result: current tid is tid1, data_tid is None (undoing object
# creation) # creation)
self.assertEqual( self.assertEqual(
db.findUndoTID(oid1, tid4, tid1, None), db.findUndoTID(oid1, tid5, tid4, tid1, None),
(tid1, None, True)) (tid1, None, True))
# Store a new transaction # Store a new transaction
...@@ -711,13 +712,13 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -711,13 +712,13 @@ class StorageDBTests(NeoUnitTestBase):
# Undoing oid1 tid2, OK: tid2 is latest # Undoing oid1 tid2, OK: tid2 is latest
# Result: current tid is tid2, data_tid is tid1 # Result: current tid is tid2, data_tid is tid1
self.assertEqual( self.assertEqual(
db.findUndoTID(oid1, tid4, tid2, None), db.findUndoTID(oid1, tid5, tid4, tid2, None),
(tid2, tid1, True)) (tid2, tid1, True))
# Undoing oid1 tid1, Error: tid2 is latest # Undoing oid1 tid1, Error: tid2 is latest
# Result: current tid is tid2, data_tid is -1 # Result: current tid is tid2, data_tid is -1
self.assertEqual( self.assertEqual(
db.findUndoTID(oid1, tid4, tid1, None), db.findUndoTID(oid1, tid5, tid4, tid1, None),
(tid2, None, False)) (tid2, None, False))
# Undoing oid1 tid1 with tid2 being undone in same transaction, # Undoing oid1 tid1 with tid2 being undone in same transaction,
...@@ -727,7 +728,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -727,7 +728,7 @@ class StorageDBTests(NeoUnitTestBase):
# Explanation of transaction_object: oid1, no data but a data serial # Explanation of transaction_object: oid1, no data but a data serial
# to tid1 # to tid1
self.assertEqual( self.assertEqual(
db.findUndoTID(oid1, tid4, tid1, db.findUndoTID(oid1, tid5, tid4, tid1,
(u64(oid1), None, None, None, tid1)), (u64(oid1), None, None, None, tid1)),
(tid4, None, True)) (tid4, None, True))
...@@ -741,7 +742,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -741,7 +742,7 @@ class StorageDBTests(NeoUnitTestBase):
# Result: current tid is tid2, data_tid is None (undoing object # Result: current tid is tid2, data_tid is None (undoing object
# creation) # creation)
self.assertEqual( self.assertEqual(
db.findUndoTID(oid1, tid4, tid1, None), db.findUndoTID(oid1, tid5, tid4, tid1, None),
(tid3, None, True)) (tid3, None, True))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -466,11 +466,13 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -466,11 +466,13 @@ class ProtocolTests(NeoUnitTestBase):
def test_askObjectUndoSerial(self): def test_askObjectUndoSerial(self):
tid = self.getNextTID() tid = self.getNextTID()
ltid = self.getNextTID()
undone_tid = self.getNextTID() undone_tid = self.getNextTID()
oid_list = [self.getOID(x) for x in xrange(4)] oid_list = [self.getOID(x) for x in xrange(4)]
p = Packets.AskObjectUndoSerial(tid, undone_tid, oid_list) p = Packets.AskObjectUndoSerial(tid, ltid, undone_tid, oid_list)
ptid, pundone_tid, poid_list = p.decode() ptid, pltid, pundone_tid, poid_list = p.decode()
self.assertEqual(tid, ptid) self.assertEqual(tid, ptid)
self.assertEqual(ltid, pltid)
self.assertEqual(undone_tid, pundone_tid) self.assertEqual(undone_tid, pundone_tid)
self.assertEqual(oid_list, poid_list) self.assertEqual(oid_list, poid_list)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment