Commit d5c469be authored by Julien Muchembled's avatar Julien Muchembled

Fix protocol and DB schema so that storages can handle transactions of any size

- Change protocol to use SHA1 for all checksums:
  - Use SHA1 instead of CRC32 for data checksums.
  - Use SHA1 instead of MD5 for replication.

- Change DatabaseManager API so that backends can store raw data separately from
  object metadata:
  - When processing AskStoreObject, call the backend to store the data
    immediately, instead of keeping it in RAM or in the temporary object table.
    Data is then referenced only by its checksum.
    Without such change, the storage could fail to store the transaction due to
    lack of RAM, or it could make tpc_finish step very slow.
  - Backends have to store data in a separate space, and remove entries as soon
    as they get unreferenced. So they must have an index of checksums in object
    metadata space. A new '_uncommitted_data' backend attribute keeps references
    of uncommitted data.
  - New methods: _pruneData, _storeData, storeData, unlockData
  - MySQL: change vertical partitioning of 'obj' by having data in a separate
    'data' table instead of using a shortened 'obj_short' table.
  - BTree: data is moved from '_obj' to a new '_data' btree.

- Undo is optimized so that backpointers are not required anymore to fetch data:
  - The checksum of an object is None only when creation is undone.
  - Removed DatabaseManager methods: _getObjectData, _getDataTIDFromData
  - DatabaseManager: move some code from _getDataTID to findUndoTID so that
    _getDataTID only has what's specific to backend.

- Removed because already covered by ZODB tests:
  - neo.tests.storage.testStorageDBTests.StorageDBTests.test__getDataTID
  - neo.tests.storage.testStorageDBTests.StorageDBTests.test__getDataTIDFromData
parent d90c5b83
...@@ -4,6 +4,8 @@ Change History ...@@ -4,6 +4,8 @@ Change History
0.10 (unreleased) 0.10 (unreleased)
----------------- -----------------
- Storage was unable or slow to process large-sized transactions.
This required to change protocol and MySQL tables format.
- NEO learned to store empty values (although it's useless when managed by - NEO learned to store empty values (although it's useless when managed by
a ZODB Connection). a ZODB Connection).
......
...@@ -28,7 +28,8 @@ from ZODB.ConflictResolution import ResolvedSerial ...@@ -28,7 +28,8 @@ from ZODB.ConflictResolution import ResolvedSerial
from persistent.TimeStamp import TimeStamp from persistent.TimeStamp import TimeStamp
import neo.lib import neo.lib
from neo.lib.protocol import NodeTypes, Packets, INVALID_PARTITION, ZERO_TID from neo.lib.protocol import NodeTypes, Packets, \
INVALID_PARTITION, ZERO_HASH, ZERO_TID
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.util import makeChecksum as real_makeChecksum, dump from neo.lib.util import makeChecksum as real_makeChecksum, dump
from neo.lib.locking import Lock from neo.lib.locking import Lock
...@@ -444,7 +445,7 @@ class Application(object): ...@@ -444,7 +445,7 @@ class Application(object):
except ConnectionClosed: except ConnectionClosed:
continue continue
if data or checksum: if data or checksum != ZERO_HASH:
if checksum != makeChecksum(data): if checksum != makeChecksum(data):
neo.lib.logging.error('wrong checksum from %s for oid %s', neo.lib.logging.error('wrong checksum from %s for oid %s',
conn, dump(oid)) conn, dump(oid))
...@@ -509,7 +510,7 @@ class Application(object): ...@@ -509,7 +510,7 @@ class Application(object):
# an older object revision). # an older object revision).
compressed_data = '' compressed_data = ''
compression = 0 compression = 0
checksum = 0 checksum = ZERO_HASH
else: else:
assert data_serial is None assert data_serial is None
compression = self.compress compression = self.compress
......
...@@ -66,9 +66,6 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -66,9 +66,6 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerObject(self, conn, oid, start_serial, end_serial, def answerObject(self, conn, oid, start_serial, end_serial,
compression, checksum, data, data_serial): compression, checksum, data, data_serial):
if data_serial is not None:
raise NEOStorageError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), )
self.app.setHandlerData((oid, start_serial, end_serial, self.app.setHandlerData((oid, start_serial, end_serial,
compression, checksum, data)) compression, checksum, data))
......
...@@ -112,6 +112,7 @@ INVALID_TID = '\xff' * 8 ...@@ -112,6 +112,7 @@ INVALID_TID = '\xff' * 8
INVALID_OID = '\xff' * 8 INVALID_OID = '\xff' * 8
INVALID_PARTITION = 0xffffffff INVALID_PARTITION = 0xffffffff
INVALID_ADDRESS_TYPE = socket.AF_UNSPEC INVALID_ADDRESS_TYPE = socket.AF_UNSPEC
ZERO_HASH = '\0' * 20
ZERO_TID = '\0' * 8 ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8 ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID) OID_LEN = len(INVALID_OID)
...@@ -527,6 +528,17 @@ class PProtocol(PStructItem): ...@@ -527,6 +528,17 @@ class PProtocol(PStructItem):
raise ProtocolError('protocol version mismatch') raise ProtocolError('protocol version mismatch')
return (major, minor) return (major, minor)
class PChecksum(PItem):
"""
A hash (SHA1)
"""
def _encode(self, writer, checksum):
assert len(checksum) == 20, (len(checksum), checksum)
writer(checksum)
def _decode(self, reader):
return reader(20)
class PUUID(PItem): class PUUID(PItem):
""" """
An UUID (node identifier) An UUID (node identifier)
...@@ -561,7 +573,6 @@ class PTID(PItem): ...@@ -561,7 +573,6 @@ class PTID(PItem):
# same definition, for now # same definition, for now
POID = PTID POID = PTID
PChecksum = PUUID # (md5 is same length as uuid)
# common definitions # common definitions
...@@ -908,7 +919,7 @@ class StoreObject(Packet): ...@@ -908,7 +919,7 @@ class StoreObject(Packet):
POID('oid'), POID('oid'),
PTID('serial'), PTID('serial'),
PBoolean('compression'), PBoolean('compression'),
PNumber('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
PTID('tid'), PTID('tid'),
...@@ -964,7 +975,7 @@ class GetObject(Packet): ...@@ -964,7 +975,7 @@ class GetObject(Packet):
PTID('serial_start'), PTID('serial_start'),
PTID('serial_end'), PTID('serial_end'),
PBoolean('compression'), PBoolean('compression'),
PNumber('checksum'), PChecksum('checksum'),
PString('data'), PString('data'),
PTID('data_serial'), PTID('data_serial'),
) )
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import re import re
import socket import socket
from zlib import adler32 from hashlib import sha1
from Queue import deque from Queue import deque
from struct import pack, unpack from struct import pack, unpack
...@@ -62,8 +62,8 @@ def bin(s): ...@@ -62,8 +62,8 @@ def bin(s):
def makeChecksum(s): def makeChecksum(s):
"""Return a 4-byte integer checksum against a string.""" """Return a 20-byte checksum against a string."""
return adler32(s) & 0xffffffff return sha1(s).digest()
def resolve(hostname): def resolve(hostname):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -21,7 +21,7 @@ from neo.lib.handler import EventHandler ...@@ -21,7 +21,7 @@ from neo.lib.handler import EventHandler
from neo.lib import protocol from neo.lib import protocol
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.exception import PrimaryFailure, OperationFailure from neo.lib.exception import PrimaryFailure, OperationFailure
from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors, ZERO_HASH
class BaseMasterHandler(EventHandler): class BaseMasterHandler(EventHandler):
...@@ -97,7 +97,7 @@ class BaseClientAndStorageOperationHandler(EventHandler): ...@@ -97,7 +97,7 @@ class BaseClientAndStorageOperationHandler(EventHandler):
neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s', neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial)) dump(oid), dump(serial), dump(next_serial))
if checksum is None: if checksum is None:
checksum = 0 checksum = ZERO_HASH
data = '' data = ''
p = Packets.AnswerObject(oid, serial, next_serial, p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data, data_serial) compression, checksum, data, data_serial)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import neo.lib import neo.lib
from neo.lib import protocol from neo.lib import protocol
from neo.lib.util import dump, makeChecksum from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors from neo.lib.protocol import Packets, LockState, Errors, ZERO_HASH
from neo.storage.handlers import BaseClientAndStorageOperationHandler from neo.storage.handlers import BaseClientAndStorageOperationHandler
from neo.storage.transactions import ConflictError, DelayedError from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.exception import AlreadyPendingError from neo.storage.exception import AlreadyPendingError
...@@ -88,7 +88,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -88,7 +88,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
compression, checksum, data, data_serial, ttid, unlock): compression, checksum, data, data_serial, ttid, unlock):
# register the transaction # register the transaction
self.app.tm.register(conn.getUUID(), ttid) self.app.tm.register(conn.getUUID(), ttid)
if data or checksum: if data or checksum != ZERO_HASH:
# TODO: return an appropriate error packet # TODO: return an appropriate error packet
assert makeChecksum(data) == checksum assert makeChecksum(data) == checksum
assert data_serial is None assert data_serial is None
......
...@@ -20,7 +20,7 @@ from functools import wraps ...@@ -20,7 +20,7 @@ from functools import wraps
import neo.lib import neo.lib
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, ZERO_TID, ZERO_OID from neo.lib.protocol import Packets, ZERO_HASH, ZERO_TID, ZERO_OID
from neo.lib.util import add64, u64 from neo.lib.util import add64, u64
# TODO: benchmark how different values behave # TODO: benchmark how different values behave
...@@ -173,12 +173,14 @@ class ReplicationHandler(EventHandler): ...@@ -173,12 +173,14 @@ class ReplicationHandler(EventHandler):
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start, def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial): serial_end, compression, checksum, data, data_serial):
app = self.app dm = self.app.dm
if data or checksum != ZERO_HASH:
dm.storeData(checksum, data, compression)
else:
checksum = None
# Directly store the transaction. # Directly store the transaction.
obj = (oid, compression, checksum, data, data_serial) obj = oid, checksum, data_serial
app.dm.storeTransaction(serial_start, [obj], None, False) dm.storeTransaction(serial_start, [obj], None, False)
del obj
del data
def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid, def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid,
length=RANGE_LENGTH): length=RANGE_LENGTH):
......
...@@ -21,7 +21,10 @@ from neo.lib.protocol import Packets ...@@ -21,7 +21,10 @@ from neo.lib.protocol import Packets
class StorageOperationHandler(BaseClientAndStorageOperationHandler): class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def _askObject(self, oid, serial, tid): def _askObject(self, oid, serial, tid):
return self.app.dm.getObject(oid, serial, tid, resolve_data=False) result = self.app.dm.getObject(oid, serial, tid)
if result and result[5]:
return result[:2] + (None, None, None) + result[4:]
return result
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
......
...@@ -98,22 +98,21 @@ class Transaction(object): ...@@ -98,22 +98,21 @@ class Transaction(object):
# assert self._transaction is not None # assert self._transaction is not None
self._transaction = (oid_list, user, desc, ext, packed) self._transaction = (oid_list, user, desc, ext, packed)
def addObject(self, oid, compression, checksum, data, value_serial): def addObject(self, oid, checksum, value_serial):
""" """
Add an object to the transaction Add an object to the transaction
""" """
assert oid not in self._checked_set, dump(oid) assert oid not in self._checked_set, dump(oid)
self._object_dict[oid] = (oid, compression, checksum, data, self._object_dict[oid] = oid, checksum, value_serial
value_serial)
def delObject(self, oid): def delObject(self, oid):
try: try:
del self._object_dict[oid] return self._object_dict.pop(oid)[1]
except KeyError: except KeyError:
self._checked_set.remove(oid) self._checked_set.remove(oid)
def getObject(self, oid): def getObject(self, oid):
return self._object_dict.get(oid) return self._object_dict[oid]
def getObjectList(self): def getObjectList(self):
return self._object_dict.values() return self._object_dict.values()
...@@ -163,10 +162,10 @@ class TransactionManager(object): ...@@ -163,10 +162,10 @@ class TransactionManager(object):
Return object data for given running transaction. Return object data for given running transaction.
Return None if not found. Return None if not found.
""" """
result = self._transaction_dict.get(ttid) try:
if result is not None: return self._transaction_dict[ttid].getObject(oid)
result = result.getObject(oid) except KeyError:
return result return None
def reset(self): def reset(self):
""" """
...@@ -242,7 +241,9 @@ class TransactionManager(object): ...@@ -242,7 +241,9 @@ class TransactionManager(object):
# drop the lock it held on this object, and drop object data for # drop the lock it held on this object, and drop object data for
# consistency. # consistency.
del self._store_lock_dict[oid] del self._store_lock_dict[oid]
self._transaction_dict[ttid].delObject(oid) checksum = self._transaction_dict[ttid].delObject(oid)
if checksum:
self._app.dm.pruneData((checksum,))
# Give a chance to pending events to take that lock now. # Give a chance to pending events to take that lock now.
self._app.executeQueuedEvents() self._app.executeQueuedEvents()
# Attemp to acquire lock again. # Attemp to acquire lock again.
...@@ -252,7 +253,7 @@ class TransactionManager(object): ...@@ -252,7 +253,7 @@ class TransactionManager(object):
elif locking_tid == ttid: elif locking_tid == ttid:
# If previous store was an undo, next store must be based on # If previous store was an undo, next store must be based on
# undo target. # undo target.
previous_serial = self._transaction_dict[ttid].getObject(oid)[4] previous_serial = self._transaction_dict[ttid].getObject(oid)[2]
if previous_serial is None: if previous_serial is None:
# XXX: use some special serial when previous store was not # XXX: use some special serial when previous store was not
# an undo ? Maybe it should just not happen. # an undo ? Maybe it should just not happen.
...@@ -301,8 +302,11 @@ class TransactionManager(object): ...@@ -301,8 +302,11 @@ class TransactionManager(object):
self.lockObject(ttid, serial, oid, unlock=unlock) self.lockObject(ttid, serial, oid, unlock=unlock)
# store object # store object
assert ttid in self, "Transaction not registered" assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid] if data is None:
transaction.addObject(oid, compression, checksum, data, value_serial) checksum = None
else:
self._app.dm.storeData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, checksum, value_serial)
def abort(self, ttid, even_if_locked=False): def abort(self, ttid, even_if_locked=False):
""" """
...@@ -320,8 +324,13 @@ class TransactionManager(object): ...@@ -320,8 +324,13 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
has_load_lock = transaction.isLocked() has_load_lock = transaction.isLocked()
# if the transaction is locked, ensure we can drop it # if the transaction is locked, ensure we can drop it
if not even_if_locked and has_load_lock: if has_load_lock:
return if not even_if_locked:
return
else:
self._app.dm.unlockData([checksum
for oid, checksum, value_serial in transaction.getObjectList()
if checksum], True)
# unlock any object # unlock any object
for oid in transaction.getLockedOIDList(): for oid in transaction.getLockedOIDList():
if has_load_lock: if has_load_lock:
...@@ -370,19 +379,13 @@ class TransactionManager(object): ...@@ -370,19 +379,13 @@ class TransactionManager(object):
for oid, ttid in self._store_lock_dict.items(): for oid, ttid in self._store_lock_dict.items():
neo.lib.logging.info(' %r by %r', dump(oid), dump(ttid)) neo.lib.logging.info(' %r by %r', dump(oid), dump(ttid))
def updateObjectDataForPack(self, oid, orig_serial, new_serial, def updateObjectDataForPack(self, oid, orig_serial, new_serial, checksum):
getObjectData):
lock_tid = self.getLockingTID(oid) lock_tid = self.getLockingTID(oid)
if lock_tid is not None: if lock_tid is not None:
transaction = self._transaction_dict[lock_tid] transaction = self._transaction_dict[lock_tid]
oid, compression, checksum, data, value_serial = \ if transaction.getObject(oid)[2] == orig_serial:
transaction.getObject(oid)
if value_serial == orig_serial:
if new_serial: if new_serial:
value_serial = new_serial checksum = None
else: else:
compression, checksum, data = getObjectData() self._app.dm.storeData(checksum)
value_serial = None transaction.addObject(oid, checksum, new_serial)
transaction.addObject(oid, compression, checksum, data,
value_serial)
...@@ -88,10 +88,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -88,10 +88,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
the_object = (oid, tid1, tid2, 0, '', 'DATA', None) the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.handler.answerObject(conn, *the_object) self.handler.answerObject(conn, *the_object)
self._checkHandlerData(the_object[:-1]) self._checkHandlerData(the_object[:-1])
# Check handler raises on non-None data_serial.
the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID())
self.assertRaises(NEOStorageError, self.handler.answerObject, conn,
*the_object)
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict, def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict): conflict_serial_dict, resolved_conflict_serial_dict):
......
...@@ -23,9 +23,8 @@ from neo.tests import NeoUnitTestBase ...@@ -23,9 +23,8 @@ from neo.tests import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.transactions import ConflictError, DelayedError from neo.storage.transactions import ConflictError, DelayedError
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_PARTITION from neo.lib.protocol import INVALID_PARTITION, INVALID_TID, INVALID_OID
from neo.lib.protocol import INVALID_TID, INVALID_OID from neo.lib.protocol import Packets, LockState, ZERO_HASH
from neo.lib.protocol import Packets, LockState
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
...@@ -124,7 +123,8 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -124,7 +123,8 @@ class StorageClientHandlerTests(NeoUnitTestBase):
next_serial = self.getNextTID() next_serial = self.getNextTID()
oid = self.getOID(1) oid = self.getOID(1)
tid = self.getNextTID() tid = self.getNextTID()
self.app.dm = Mock({'getObject': (serial, next_serial, 0, 0, '', None)}) H = "0" * 20
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self._getConnection() conn = self._getConnection()
self.assertEqual(len(self.app.event_queue), 0) self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid) self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
...@@ -239,7 +239,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -239,7 +239,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject() oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID() data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, 0, self.operation.askStoreObject(conn, oid, serial, comp, ZERO_HASH,
'', data_tid, tid, False) '', data_tid, tid, False)
self._checkStoreObjectCalled(tid, serial, oid, comp, self._checkStoreObjectCalled(tid, serial, oid, comp,
None, None, data_tid, False) None, None, data_tid, False)
......
...@@ -128,8 +128,11 @@ class ReplicationTests(NeoUnitTestBase): ...@@ -128,8 +128,11 @@ class ReplicationTests(NeoUnitTestBase):
transaction = ([ZERO_OID], 'user', 'desc', '', False) transaction = ([ZERO_OID], 'user', 'desc', '', False)
storage.storeTransaction(makeid(tid), [], transaction, False) storage.storeTransaction(makeid(tid), [], transaction, False)
# store object history # store object history
H = "0" * 20
storage.storeData(H, '', 0)
storage.unlockData((H,))
for tid, oid_list in objects.iteritems(): for tid, oid_list in objects.iteritems():
object_list = [(makeid(oid), False, 0, '', None) for oid in oid_list] object_list = [(makeid(oid), H, None) for oid in oid_list]
storage.storeTransaction(makeid(tid), object_list, None, False) storage.storeTransaction(makeid(tid), object_list, None, False)
return storage return storage
......
...@@ -268,15 +268,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -268,15 +268,15 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
serial_start = self.getNextTID() serial_start = self.getNextTID()
serial_end = self.getNextTID() serial_end = self.getNextTID()
compression = 1 compression = 1
checksum = 2 checksum = "0" * 20
data = 'foo' data = 'foo'
data_serial = None data_serial = None
ReplicationHandler(app).answerObject(conn, oid, serial_start, ReplicationHandler(app).answerObject(conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial) serial_end, compression, checksum, data, data_serial)
calls = app.dm.mockGetNamedCalls('storeTransaction') calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(serial_start, [(oid, compression, checksum, data, calls[0].checkArgs(serial_start, [(oid, checksum, data_serial)],
data_serial)], None, False) None, False)
# CheckTIDRange # CheckTIDRange
def test_answerCheckTIDFullRangeIdenticalChunkWithNext(self): def test_answerCheckTIDFullRangeIdenticalChunkWithNext(self):
......
...@@ -121,7 +121,10 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -121,7 +121,10 @@ class StorageDBTests(NeoUnitTestBase):
def getTransaction(self, oid_list): def getTransaction(self, oid_list):
transaction = (oid_list, 'user', 'desc', 'ext', False) transaction = (oid_list, 'user', 'desc', 'ext', False)
object_list = [(oid, 1, 0, '', None) for oid in oid_list] H = "0" * 20
for _ in oid_list:
self.db.storeData(H, '', 1)
object_list = [(oid, H, None) for oid in oid_list]
return (transaction, object_list) return (transaction, object_list)
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
...@@ -180,9 +183,9 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -180,9 +183,9 @@ class StorageDBTests(NeoUnitTestBase):
oid1, = self.getOIDs(1) oid1, = self.getOIDs(1)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
FOUND_BUT_NOT_VISIBLE = False FOUND_BUT_NOT_VISIBLE = False
OBJECT_T1_NO_NEXT = (tid1, None, 1, 0, '', None) OBJECT_T1_NO_NEXT = (tid1, None, 1, "0"*20, '', None)
OBJECT_T1_NEXT = (tid1, tid2, 1, 0, '', None) OBJECT_T1_NEXT = (tid1, tid2, 1, "0"*20, '', None)
OBJECT_T2 = (tid2, None, 1, 0, '', None) OBJECT_T2 = (tid2, None, 1, "0"*20, '', None)
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1]) txn2, objs2 = self.getTransaction([oid1])
# non-present # non-present
...@@ -277,14 +280,14 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -277,14 +280,14 @@ class StorageDBTests(NeoUnitTestBase):
self.db.storeTransaction(tid2, objs2, txn2) self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1) self.db.finishTransaction(tid1)
result = self.db.getObject(oid1) result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, 0, '', None)) self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [tid2]) self.assertEqual(self.db.getUnfinishedTIDList(), [tid2])
# drop it # drop it
self.db.dropUnfinishedData() self.db.dropUnfinishedData()
self.assertEqual(self.db.getUnfinishedTIDList(), []) self.assertEqual(self.db.getUnfinishedTIDList(), [])
result = self.db.getObject(oid1) result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, 0, '', None)) self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
def test_storeTransaction(self): def test_storeTransaction(self):
...@@ -393,8 +396,8 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -393,8 +396,8 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, tid=tid2), None) self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1) self.db.deleteObject(oid2, serial=tid1)
self.assertFalse(self.db.getObject(oid2, tid=tid1)) self.assertFalse(self.db.getObject(oid2, tid=tid1))
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \ self.assertEqual(self.db.getObject(oid2, tid=tid2),
objs2[1][1:]) (tid2, None, 1, "0" * 20, '', None))
def test_deleteObjectsAbove(self): def test_deleteObjectsAbove(self):
self.setNumPartitions(2) self.setNumPartitions(2)
...@@ -574,138 +577,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -574,138 +577,6 @@ class StorageDBTests(NeoUnitTestBase):
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 2, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 2, 0)
self.checkSet(result, [tid1]) self.checkSet(result, [tid1])
def test__getObjectData(self):
self.setNumPartitions(4, True)
db = self.db
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
assert tid0 < tid1 < tid2 < tid3
oid1 = self.getOID(1)
oid2 = self.getOID(2)
oid3 = self.getOID(3)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
(oid2, None, None, None, tid0),
(oid3, None, None, None, tid2),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
(oid2, None, None, None, tid1),
(oid3, 0, 0, 'bar', None),
), None, temporary=False)
original_getObjectData = db._getObjectData
def _getObjectData(*args, **kw):
call_counter.append(1)
return original_getObjectData(*args, **kw)
db._getObjectData = _getObjectData
# NOTE: all tests are done as if values were fetched by _getObject, so
# there is already one indirection level.
# oid1 at tid1: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid1), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 1)
# oid2 at tid1: missing data in table, raise IndexError on next
# recursive call
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# oid3 at tid1: data_serial grater than row's tid, raise ValueError
# on next recursive call - even if data does exist at that tid (see
# "oid3 at tid2" case below)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid3))
self.assertEqual(sum(call_counter), 2)
# Same with wrong parameters (tid0 < tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid0))
self.assertEqual(sum(call_counter), 1)
# Same with wrong parameters (tid1 == tid1)
call_counter = []
self.assertRaises(ValueError, db._getObjectData, u64(oid3), u64(tid1),
u64(tid1))
self.assertEqual(sum(call_counter), 1)
# oid1 at tid2: data is found after ons recursive call
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid1), u64(tid2), u64(tid3)),
(u64(tid1), 0, 0, 'foo'))
self.assertEqual(sum(call_counter), 2)
# oid2 at tid2: missing data in table, raise IndexError after two
# recursive calls
call_counter = []
self.assertRaises(IndexError, db._getObjectData, u64(oid2), u64(tid2),
u64(tid3))
self.assertEqual(sum(call_counter), 3)
# oid3 at tid2: data is immediately found
call_counter = []
self.assertEqual(
db._getObjectData(u64(oid3), u64(tid2), u64(tid3)),
(u64(tid2), 0, 0, 'bar'))
self.assertEqual(sum(call_counter), 1)
def test__getDataTIDFromData(self):
self.setNumPartitions(4, True)
db = self.db
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
), None, temporary=False)
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid1))),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTIDFromData(u64(oid1),
db._getObject(u64(oid1), tid=u64(tid2))),
(u64(tid2), u64(tid1)))
def test__getDataTID(self):
self.setNumPartitions(4, True)
db = self.db
tid1 = self.getNextTID()
tid2 = self.getNextTID()
oid1 = self.getOID(1)
db.storeTransaction(
tid1, (
(oid1, 0, 0, 'foo', None),
), None, temporary=False)
db.storeTransaction(
tid2, (
(oid1, None, None, None, tid1),
), None, temporary=False)
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid1)),
(u64(tid1), u64(tid1)))
self.assertEqual(
db._getDataTID(u64(oid1), tid=u64(tid2)),
(u64(tid2), u64(tid1)))
def test_findUndoTID(self): def test_findUndoTID(self):
self.setNumPartitions(4, True) self.setNumPartitions(4, True)
db = self.db db = self.db
...@@ -715,9 +586,14 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -715,9 +586,14 @@ class StorageDBTests(NeoUnitTestBase):
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid5 = self.getNextTID() tid5 = self.getNextTID()
oid1 = self.getOID(1) oid1 = self.getOID(1)
foo = "3" * 20
bar = "4" * 20
db.storeData(foo, 'foo', 0)
db.storeData(bar, 'bar', 0)
db.unlockData((foo, bar))
db.storeTransaction( db.storeTransaction(
tid1, ( tid1, (
(oid1, 0, 0, 'foo', None), (oid1, foo, None),
), None, temporary=False) ), None, temporary=False)
# Undoing oid1 tid1, OK: tid1 is latest # Undoing oid1 tid1, OK: tid1 is latest
...@@ -730,7 +606,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -730,7 +606,7 @@ class StorageDBTests(NeoUnitTestBase):
# Store a new transaction # Store a new transaction
db.storeTransaction( db.storeTransaction(
tid2, ( tid2, (
(oid1, 0, 0, 'bar', None), (oid1, bar, None),
), None, temporary=False) ), None, temporary=False)
# Undoing oid1 tid2, OK: tid2 is latest # Undoing oid1 tid2, OK: tid2 is latest
...@@ -753,13 +629,13 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -753,13 +629,13 @@ class StorageDBTests(NeoUnitTestBase):
# to tid1 # to tid1
self.assertEqual( self.assertEqual(
db.findUndoTID(oid1, tid5, tid4, tid1, db.findUndoTID(oid1, tid5, tid4, tid1,
(u64(oid1), None, None, None, tid1)), (u64(oid1), None, tid1)),
(tid1, None, True)) (tid1, None, True))
# Store a new transaction # Store a new transaction
db.storeTransaction( db.storeTransaction(
tid3, ( tid3, (
(oid1, None, None, None, tid1), (oid1, None, tid1),
), None, temporary=False) ), None, temporary=False)
# Undoing oid1 tid1, OK: tid3 is latest with tid1 data # Undoing oid1 tid1, OK: tid3 is latest with tid1 data
......
...@@ -97,7 +97,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -97,7 +97,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
calls = self.app.dm.mockGetNamedCalls('getObject') calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEqual(len(self.app.event_queue), 0) self.assertEqual(len(self.app.event_queue), 0)
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(oid, serial, tid, resolve_data=False) calls[0].checkArgs(oid, serial, tid)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_24_askObject3(self): def test_24_askObject3(self):
...@@ -105,8 +105,9 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -105,8 +105,9 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
serial = self.getNextTID() serial = self.getNextTID()
next_serial = self.getNextTID() next_serial = self.getNextTID()
H = "0" * 20
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': (serial, next_serial, 0, 0, '', None)}) self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.assertEqual(len(self.app.event_queue), 0) self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid) self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
...@@ -149,7 +150,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -149,7 +150,7 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
def test_askCheckTIDRange(self): def test_askCheckTIDRange(self):
count = 1 count = 1
tid_checksum = self.getNewUUID() tid_checksum = "1" * 20
min_tid = self.getNextTID() min_tid = self.getNextTID()
num_partitions = 4 num_partitions = 4
length = 5 length = 5
...@@ -173,12 +174,12 @@ class StorageStorageHandlerTests(NeoUnitTestBase): ...@@ -173,12 +174,12 @@ class StorageStorageHandlerTests(NeoUnitTestBase):
def test_askCheckSerialRange(self): def test_askCheckSerialRange(self):
count = 1 count = 1
oid_checksum = self.getNewUUID() oid_checksum = "2" * 20
min_oid = self.getOID(1) min_oid = self.getOID(1)
num_partitions = 4 num_partitions = 4
length = 5 length = 5
partition = 6 partition = 6
serial_checksum = self.getNewUUID() serial_checksum = "3" * 20
min_serial = self.getNextTID() min_serial = self.getNextTID()
max_serial = self.getNextTID() max_serial = self.getNextTID()
max_oid = self.getOID(2) max_oid = self.getOID(2)
......
...@@ -125,23 +125,6 @@ class StorageMySQSLdbTests(StorageDBTests): ...@@ -125,23 +125,6 @@ class StorageMySQSLdbTests(StorageDBTests):
self.assertEqual(self.db.escape('a"b'), 'a\\"b') self.assertEqual(self.db.escape('a"b'), 'a\\"b')
self.assertEqual(self.db.escape("a'b"), "a\\'b") self.assertEqual(self.db.escape("a'b"), "a\\'b")
def test_setup(self):
# XXX: this test verifies irrelevant symptoms. It should instead check that
# - setup, store, setup, load -> data still there
# - setup, store, setup(reset=True), load -> data not found
# Then, it should be moved to generic test class.
# create all tables
self.db.conn = Mock()
self.db.setup()
calls = self.db.conn.mockGetNamedCalls('query')
self.assertEqual(len(calls), 7)
# create all tables but drop them first
self.db.conn = Mock()
self.db.setup(reset=True)
calls = self.db.conn.mockGetNamedCalls('query')
self.assertEqual(len(calls), 8)
del StorageDBTests del StorageDBTests
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -63,8 +63,8 @@ class TransactionTests(NeoUnitTestBase): ...@@ -63,8 +63,8 @@ class TransactionTests(NeoUnitTestBase):
def testObjects(self): def testObjects(self):
txn = Transaction(self.getNewUUID(), self.getNextTID()) txn = Transaction(self.getNewUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2 = self.getOID(1), self.getOID(2)
object1 = (oid1, 1, '1', 'O1', None) object1 = oid1, "0" * 20, None
object2 = (oid2, 1, '2', 'O2', None) object2 = oid2, "1" * 20, None
self.assertEqual(txn.getObjectList(), []) self.assertEqual(txn.getObjectList(), [])
self.assertEqual(txn.getOIDList(), []) self.assertEqual(txn.getOIDList(), [])
txn.addObject(*object1) txn.addObject(*object1)
...@@ -78,9 +78,9 @@ class TransactionTests(NeoUnitTestBase): ...@@ -78,9 +78,9 @@ class TransactionTests(NeoUnitTestBase):
oid_1 = self.getOID(1) oid_1 = self.getOID(1)
oid_2 = self.getOID(2) oid_2 = self.getOID(2)
txn = Transaction(self.getNewUUID(), self.getNextTID()) txn = Transaction(self.getNewUUID(), self.getNextTID())
object_info = (oid_1, None, None, None, None) object_info = oid_1, None, None
txn.addObject(*object_info) txn.addObject(*object_info)
self.assertEqual(txn.getObject(oid_2), None) self.assertRaises(KeyError, txn.getObject, oid_2)
self.assertEqual(txn.getObject(oid_1), object_info) self.assertEqual(txn.getObject(oid_1), object_info)
class TransactionManagerTests(NeoUnitTestBase): class TransactionManagerTests(NeoUnitTestBase):
...@@ -102,12 +102,12 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -102,12 +102,12 @@ class TransactionManagerTests(NeoUnitTestBase):
def _storeTransactionObjects(self, tid, txn): def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]): for i, oid in enumerate(txn[0]):
self.manager.storeObject(tid, None, self.manager.storeObject(tid, None,
oid, 1, str(i), '0' + str(i), None) oid, 1, '%020d' % i, '0' + str(i), None)
def _getObject(self, value): def _getObject(self, value):
oid = self.getOID(value) oid = self.getOID(value)
serial = self.getNextTID() serial = self.getNextTID()
return (serial, (oid, 1, str(value), 'O' + str(value), None)) return (serial, (oid, 1, '%020d' % value, 'O' + str(value), None))
def _checkTransactionStored(self, *args): def _checkTransactionStored(self, *args):
calls = self.app.dm.mockGetNamedCalls('storeTransaction') calls = self.app.dm.mockGetNamedCalls('storeTransaction')
...@@ -136,7 +136,10 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -136,7 +136,10 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.storeObject(ttid, serial2, *object2) self.manager.storeObject(ttid, serial2, *object2)
self.assertTrue(ttid in self.manager) self.assertTrue(ttid in self.manager)
self.manager.lock(ttid, tid, txn[0]) self.manager.lock(ttid, tid, txn[0])
self._checkTransactionStored(tid, [object1, object2], txn) self._checkTransactionStored(tid, [
(object1[0], object1[2], object1[4]),
(object2[0], object2[2], object2[4]),
], txn)
self.manager.unlock(ttid) self.manager.unlock(ttid)
self.assertFalse(ttid in self.manager) self.assertFalse(ttid in self.manager)
self._checkTransactionFinished(tid) self._checkTransactionFinished(tid)
...@@ -340,7 +343,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -340,7 +343,7 @@ class TransactionManagerTests(NeoUnitTestBase):
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]), self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]),
None) None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]), self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]),
obj1) (obj1[0], obj1[2], obj1[4]))
def test_getLockingTID(self): def test_getLockingTID(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
...@@ -360,26 +363,24 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -360,26 +363,24 @@ class TransactionManagerTests(NeoUnitTestBase):
locking_serial = self.getNextTID() locking_serial = self.getNextTID()
other_serial = self.getNextTID() other_serial = self.getNextTID()
new_serial = self.getNextTID() new_serial = self.getNextTID()
compression = 1 checksum = "2" * 20
checksum = 42
value = 'foo'
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
def getObjectData():
return (compression, checksum, value)
# Object not known, nothing happens # Object not known, nothing happens
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None) oid), None)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), None) oid), None)
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
# Object known, but doesn't point at orig_serial, it is not updated # Object known, but doesn't point at orig_serial, it is not updated
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, 0, 512, self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20,
'bar', None) 'bar', None)
storeData = self.app.dm.mockGetNamedCalls('storeData')
self.assertEqual(storeData.pop(0).params, ("3" * 20, 'bar', 0))
orig_object = self.manager.getObjectFromTransaction(locking_serial, orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid) oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object) oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
...@@ -389,29 +390,29 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -389,29 +390,29 @@ class TransactionManagerTests(NeoUnitTestBase):
None, other_serial) None, other_serial)
orig_object = self.manager.getObjectFromTransaction(locking_serial, orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid) oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, None) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), orig_object) oid), orig_object)
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
# Object known and points at undone data it gets updated # Object known and points at undone data it gets updated
# ...with data_serial: getObjectData must not be called
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None, self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial) None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, new_serial, self.manager.updateObjectDataForPack(oid, orig_serial, new_serial,
None) checksum)
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, None, None, None, new_serial)) oid), (oid, None, new_serial))
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
# with data
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, None, None, self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial) None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None, self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
getObjectData) self.assertEqual(storeData.pop(0).params, (checksum,))
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, compression, checksum, value, None)) oid), (oid, checksum, None))
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
self.assertFalse(storeData)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -387,7 +387,8 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -387,7 +387,8 @@ class ProtocolTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
unlock = False unlock = False
p = Packets.AskStoreObject(oid, serial, 1, 55, "to", tid2, tid, unlock) H = "1" * 20
p = Packets.AskStoreObject(oid, serial, 1, H, "to", tid2, tid, unlock)
poid, pserial, compression, checksum, data, ptid2, ptid, punlock = \ poid, pserial, compression, checksum, data, ptid2, ptid, punlock = \
p.decode() p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
...@@ -395,7 +396,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -395,7 +396,7 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(tid, ptid) self.assertEqual(tid, ptid)
self.assertEqual(tid2, ptid2) self.assertEqual(tid2, ptid2)
self.assertEqual(compression, 1) self.assertEqual(compression, 1)
self.assertEqual(checksum, 55) self.assertEqual(checksum, H)
self.assertEqual(data, "to") self.assertEqual(data, "to")
self.assertEqual(unlock, punlock) self.assertEqual(unlock, punlock)
...@@ -423,7 +424,8 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -423,7 +424,8 @@ class ProtocolTests(NeoUnitTestBase):
serial_start = self.getNextTID() serial_start = self.getNextTID()
serial_end = self.getNextTID() serial_end = self.getNextTID()
data_serial = self.getNextTID() data_serial = self.getNextTID()
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to", H = "2" * 20
p = Packets.AnswerObject(oid, serial_start, serial_end, 1, H, "to",
data_serial) data_serial)
poid, pserial_start, pserial_end, compression, checksum, data, \ poid, pserial_start, pserial_end, compression, checksum, data, \
pdata_serial = p.decode() pdata_serial = p.decode()
...@@ -431,7 +433,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -431,7 +433,7 @@ class ProtocolTests(NeoUnitTestBase):
self.assertEqual(serial_start, pserial_start) self.assertEqual(serial_start, pserial_start)
self.assertEqual(serial_end, pserial_end) self.assertEqual(serial_end, pserial_end)
self.assertEqual(compression, 1) self.assertEqual(compression, 1)
self.assertEqual(checksum, 55) self.assertEqual(checksum, H)
self.assertEqual(data, "to") self.assertEqual(data, "to")
self.assertEqual(pdata_serial, data_serial) self.assertEqual(pdata_serial, data_serial)
...@@ -686,7 +688,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -686,7 +688,7 @@ class ProtocolTests(NeoUnitTestBase):
min_tid = self.getNextTID() min_tid = self.getNextTID()
length = 2 length = 2
count = 1 count = 1
tid_checksum = self.getNewUUID() tid_checksum = "3" * 20
max_tid = self.getNextTID() max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum, p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid) max_tid)
...@@ -717,9 +719,9 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -717,9 +719,9 @@ class ProtocolTests(NeoUnitTestBase):
min_serial = self.getNextTID() min_serial = self.getNextTID()
length = 2 length = 2
count = 1 count = 1
oid_checksum = self.getNewUUID() oid_checksum = "4" * 20
max_oid = self.getOID(5) max_oid = self.getOID(5)
tid_checksum = self.getNewUUID() tid_checksum = "5" * 20
max_serial = self.getNextTID() max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count, p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial) oid_checksum, max_oid, tid_checksum, max_serial)
......
...@@ -259,6 +259,10 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -259,6 +259,10 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
if adapter == 'BTree': if adapter == 'BTree':
dm._obj, dm._tobj = dm._tobj, dm._obj dm._obj, dm._tobj = dm._tobj, dm._obj
dm._trans, dm._ttrans = dm._ttrans, dm._trans dm._trans, dm._ttrans = dm._ttrans, dm._trans
uncommitted_data = dm._uncommitted_data
for checksum, (_, _, index) in dm._data.iteritems():
uncommitted_data[checksum] = len(index)
index.clear()
elif adapter == 'MySQL': elif adapter == 'MySQL':
q = dm.query q = dm.query
dm.begin() dm.begin()
...@@ -266,11 +270,22 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -266,11 +270,22 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
q('RENAME TABLE %s to tmp' % table) q('RENAME TABLE %s to tmp' % table)
q('RENAME TABLE t%s to %s' % (table, table)) q('RENAME TABLE t%s to %s' % (table, table))
q('RENAME TABLE tmp to t%s' % table) q('RENAME TABLE tmp to t%s' % table)
q('TRUNCATE obj_short')
dm.commit() dm.commit()
else: else:
assert False assert False
def getDataLockInfo(self):
adapter = self._init_args[1]['getAdapter']
dm = self.dm
if adapter == 'BTree':
checksum_list = dm._data
elif adapter == 'MySQL':
checksum_list = [x for x, in dm.query("SELECT hash FROM data")]
else:
assert False
assert set(dm._uncommitted_data).issubset(checksum_list)
return dict((x, dm._uncommitted_data.get(x, 0)) for x in checksum_list)
class ClientApplication(Node, neo.client.app.Application): class ClientApplication(Node, neo.client.app.Application):
@SerializedEventManager.decorate @SerializedEventManager.decorate
......
...@@ -26,6 +26,7 @@ from neo.lib.connection import MTClientConnection ...@@ -26,6 +26,7 @@ from neo.lib.connection import MTClientConnection
from neo.lib.protocol import NodeStates, Packets, ZERO_TID from neo.lib.protocol import NodeStates, Packets, ZERO_TID
from neo.tests.threaded import NEOCluster, NEOThreadedTest, \ from neo.tests.threaded import NEOCluster, NEOThreadedTest, \
Patch, ConnectionFilter Patch, ConnectionFilter
from neo.lib.util import makeChecksum
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
class PCounter(Persistent): class PCounter(Persistent):
...@@ -43,13 +44,19 @@ class Test(NEOThreadedTest): ...@@ -43,13 +44,19 @@ class Test(NEOThreadedTest):
try: try:
cluster.start() cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
for data in 'foo', '': data_info = {}
for data in 'foo', '', 'foo':
checksum = makeChecksum(data)
oid = storage.new_oid() oid = storage.new_oid()
txn = transaction.Transaction() txn = transaction.Transaction()
storage.tpc_begin(txn) storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn) r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn) r2 = storage.tpc_vote(txn)
data_info[checksum] = 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
serial = storage.tpc_finish(txn) serial = storage.tpc_finish(txn)
data_info[checksum] = 0
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
self.assertEqual((data, serial), storage.load(oid, '')) self.assertEqual((data, serial), storage.load(oid, ''))
storage._cache.clear() storage._cache.clear()
self.assertEqual((data, serial), storage.load(oid, '')) self.assertEqual((data, serial), storage.load(oid, ''))
...@@ -57,6 +64,51 @@ class Test(NEOThreadedTest): ...@@ -57,6 +64,51 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testStorageDataLock(self):
cluster = NEOCluster()
try:
cluster.start()
storage = cluster.getZODBStorage()
data_info = {}
data = 'foo'
checksum = makeChecksum(data)
oid = storage.new_oid()
txn = transaction.Transaction()
storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn)
tid = storage.tpc_finish(txn)
data_info[checksum] = 0
storage.sync()
txn = [transaction.Transaction() for x in xrange(3)]
for t in txn:
storage.tpc_begin(t)
storage.store(tid and oid or storage.new_oid(),
tid, data, '', t)
tid = None
for t in txn:
storage.tpc_vote(t)
data_info[checksum] = 3
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[1])
storage.sync()
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
tid1 = storage.tpc_finish(txn[2])
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[0])
storage.sync()
data_info[checksum] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally:
cluster.stop()
def testDelayedUnlockInformation(self): def testDelayedUnlockInformation(self):
except_list = [] except_list = []
def delayUnlockInformation(conn, packet): def delayUnlockInformation(conn, packet):
...@@ -273,16 +325,21 @@ class Test(NEOThreadedTest): ...@@ -273,16 +325,21 @@ class Test(NEOThreadedTest):
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = 'ok' c.root()[0] = 'ok'
t.commit() t.commit()
data_info = cluster.storage.getDataLockInfo()
self.assertEqual(data_info.values(), [0, 0])
# (obj|trans) become t(obj|trans)
cluster.storage.switchTables()
finally: finally:
cluster.stop() cluster.stop()
cluster.reset() cluster.reset()
# XXX: (obj|trans) become t(obj|trans) self.assertEqual(dict.fromkeys(data_info, 1),
cluster.storage.switchTables() cluster.storage.getDataLockInfo())
try: try:
cluster.start(fast_startup=fast_startup) cluster.start(fast_startup=fast_startup)
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
# transaction should be verified and commited # transaction should be verified and commited
self.assertEqual(c.root()[0], 'ok') self.assertEqual(c.root()[0], 'ok')
self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally: finally:
cluster.stop() cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment