Commit 90fe4c8a authored by Grégory Wisniewski's avatar Grégory Wisniewski

Storage node check if all objects are stored before set write locks.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2102 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 2dc8f28a
......@@ -217,7 +217,7 @@ class EventHandler(object):
def answerTransactionFinished(self, conn, tid):
raise UnexpectedPacketError
def askLockInformation(self, conn, tid):
def askLockInformation(self, conn, tid, oid_list):
raise UnexpectedPacketError
def answerInformationLocked(self, conn, tid):
......
......@@ -75,7 +75,7 @@ class ClientServiceHandler(MasterHandler):
# Request locking data.
# build a new set as we may not send the message to all nodes as some
# might be not reachable at that time
p = Packets.AskLockInformation(tid)
p = Packets.AskLockInformation(tid, oid_list)
used_uuid_set = set()
for node in self.app.nm.getIdentifiedList(pool_set=uuid_set):
node.ask(p, timeout=60)
......
......@@ -783,12 +783,28 @@ class AskLockInformation(Packet):
"""
Lock information on a transaction. PM -> S.
"""
def _encode(self, tid):
return _encodeTID(tid)
# XXX: Identical to InvalidateObjects and AskFinishTransaction
_header_format = '!8sL'
_list_entry_format = '8s'
_list_entry_len = calcsize(_list_entry_format)
def _encode(self, tid, oid_list):
body = [pack(self._header_format, tid, len(oid_list))]
body.extend(oid_list)
return ''.join(body)
def _decode(self, body):
(tid, ) = unpack('8s', body)
return (_decodeTID(tid), )
offset = self._header_len
(tid, n) = unpack(self._header_format, body[:offset])
oid_list = []
list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len
for _ in xrange(n):
next_offset = offset + list_entry_len
oid = unpack(list_entry_format, body[offset:next_offset])[0]
offset = next_offset
oid_list.append(oid)
return (tid, oid_list)
class AnswerInformationLocked(Packet):
"""
......
......@@ -143,6 +143,13 @@ class PartitionTable(object):
return self.getCellList(self._getPartitionFromIndex(u64(oid)),
readable, writable)
def isAssigned(self, oid, uuid):
""" Check if the oid is assigned to the given node """
for cell in self.partition_list[u64(oid) % self.np]:
if cell.getUUID() == uuid:
return True
return False
def _getPartitionFromIndex(self, index):
return index % self.np
......
......@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from neo import logging
from neo.util import dump
from neo.protocol import CellStates, Packets, ProtocolError
from neo.storage.handlers import BaseMasterHandler
......@@ -52,10 +52,10 @@ class MasterOperationHandler(BaseMasterHandler):
elif state == CellStates.OUT_OF_DATE:
app.replicator.addPartition(offset)
def askLockInformation(self, conn, tid):
def askLockInformation(self, conn, tid, oid_list):
if not tid in self.app.tm:
raise ProtocolError('Unknown transaction')
self.app.tm.lock(tid)
self.app.tm.lock(tid, oid_list)
conn.answer(Packets.AnswerInformationLocked(tid))
def notifyUnlockInformation(self, conn, tid):
......
......@@ -154,7 +154,7 @@ class TransactionManager(object):
self._load_lock_dict.clear()
self._uuid_dict.clear()
def lock(self, tid):
def lock(self, tid, oid_list):
"""
Lock a transaction
"""
......@@ -163,6 +163,12 @@ class TransactionManager(object):
transaction.lock()
for oid in transaction.getOIDList():
self._load_lock_dict[oid] = tid
# check every object that should be locked
uuid = transaction.getUUID()
is_assigned = self._app.pt.isAssigned
for oid in oid_list:
if is_assigned(oid, uuid) and self._load_lock_dict.get(oid) != tid:
raise ValueError, 'Some locks are not held'
object_list = transaction.getObjectList()
# txn_info is None is the transaction information is not stored on
# this storage.
......
......@@ -132,19 +132,22 @@ class StorageMasterHandlerTests(NeoTestBase):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.askLockInformation, conn, tid)
self.assertRaises(ProtocolError, handler.askLockInformation, conn, tid,
oid_list)
def test_askLockInformation2(self):
""" Lock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
self.operation.askLockInformation(conn, tid)
oid_list = [self.getOID(1), self.getOID(2)]
self.operation.askLockInformation(conn, tid, oid_list)
calls = self.app.tm.mockGetNamedCalls('lock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
calls[0].checkArgs(tid, oid_list)
self.checkAnswerInformationLocked(conn)
def test_notifyUnlockInformation1(self):
......@@ -153,7 +156,7 @@ class StorageMasterHandlerTests(NeoTestBase):
conn = self._getConnection()
tid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.notifyUnlockInformation,
self.assertRaises(ProtocolError, handler.notifyUnlockInformation,
conn, tid)
def test_notifyUnlockInformation2(self):
......
......@@ -78,6 +78,7 @@ class TransactionManagerTests(NeoTestBase):
self.app = Mock()
# no history
self.app.dm = Mock({'getObjectHistory': []})
self.app.pt = Mock({'isAssigned': True})
self.manager = TransactionManager(self.app)
self.ltid = None
......@@ -86,6 +87,11 @@ class TransactionManagerTests(NeoTestBase):
oid_list = [self.getOID(1), self.getOID(2)]
return (tid, (oid_list, 'USER', 'DESC', 'EXT', False))
def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]):
self.manager.storeObject(self.getNewUUID(), tid, None,
oid, 1, str(i), '0' + str(i), None)
def _getObject(self, value):
oid = self.getOID(value)
serial = self.getNextTID()
......@@ -115,7 +121,7 @@ class TransactionManagerTests(NeoTestBase):
self.manager.storeObject(uuid, tid, serial1, *object1)
self.manager.storeObject(uuid, tid, serial2, *object2)
self.assertTrue(tid in self.manager)
self.manager.lock(tid)
self.manager.lock(tid, txn[0])
self._checkTransactionStored(tid, [object1, object2], txn)
self.manager.unlock(tid)
self.assertFalse(tid in self.manager)
......@@ -130,8 +136,8 @@ class TransactionManagerTests(NeoTestBase):
# first transaction lock the object
self.manager.storeTransaction(uuid, tid1, *txn1)
self.assertTrue(tid1 in self.manager)
self.manager.storeObject(uuid, tid1, serial, *obj)
self.manager.lock(tid1)
self._storeTransactionObjects(tid1, txn1)
self.manager.lock(tid1, txn1[0])
# the second is delayed
self.manager.storeTransaction(uuid, tid2, *txn2)
self.assertTrue(tid2 in self.manager)
......@@ -148,7 +154,8 @@ class TransactionManagerTests(NeoTestBase):
self.manager.storeTransaction(uuid, tid2, *txn2)
self.manager.storeObject(uuid, tid2, serial, *obj)
self.assertTrue(tid2 in self.manager)
self.manager.lock(tid2)
self._storeTransactionObjects(tid2, txn2)
self.manager.lock(tid2, txn2[0])
# the previous it's not using the latest version
self.manager.storeTransaction(uuid, tid1, *txn1)
self.assertTrue(tid1 in self.manager)
......@@ -167,8 +174,8 @@ class TransactionManagerTests(NeoTestBase):
self.assertRaises(ConflictError, self.manager.storeObject,
uuid, tid, serial, *obj)
def testConflictWithTwoNodes(self):
""" Ensure conflict/delaytion is working with different nodes"""
def testLockDelayed(self):
""" Check lock delaytion"""
uuid1 = self.getNewUUID()
uuid2 = self.getNewUUID()
self.assertNotEqual(uuid1, uuid2)
......@@ -176,25 +183,41 @@ class TransactionManagerTests(NeoTestBase):
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
# first transaction lock the object
# first transaction lock objects
self.manager.storeTransaction(uuid1, tid1, *txn1)
self.assertTrue(tid1 in self.manager)
self.manager.storeObject(uuid1, tid1, serial1, *obj1)
self.manager.lock(tid1)
self.manager.storeObject(uuid1, tid1, serial1, *obj2)
self.manager.lock(tid1, txn1[0])
# second transaction is delayed
self.manager.storeTransaction(uuid2, tid2, *txn2)
self.assertTrue(tid2 in self.manager)
self.assertRaises(DelayedError, self.manager.storeObject,
self.assertRaises(DelayedError, self.manager.storeObject,
uuid2, tid2, serial1, *obj1)
# the second transaction lock another object
self.assertRaises(DelayedError, self.manager.storeObject,
uuid2, tid2, serial2, *obj2)
def testLockConflict(self):
""" Check lock conflict """
uuid1 = self.getNewUUID()
uuid2 = self.getNewUUID()
self.assertNotEqual(uuid1, uuid2)
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
# the second transaction lock objects
self.manager.storeTransaction(uuid2, tid2, *txn2)
self.manager.storeObject(uuid2, tid2, serial1, *obj1)
self.manager.storeObject(uuid2, tid2, serial2, *obj2)
self.assertTrue(tid2 in self.manager)
self.manager.lock(tid2)
self.manager.lock(tid2, txn1[0])
# the first get a conflict
self.manager.storeTransaction(uuid1, tid1, *txn1)
self.assertTrue(tid1 in self.manager)
self.assertRaises(ConflictError, self.manager.storeObject,
self.assertRaises(ConflictError, self.manager.storeObject,
uuid1, tid1, serial1, *obj1)
self.assertRaises(ConflictError, self.manager.storeObject,
uuid1, tid1, serial2, *obj2)
def testAbortUnlocked(self):
......@@ -215,15 +238,15 @@ class TransactionManagerTests(NeoTestBase):
""" Try to abort a locked transaction """
uuid = self.getNewUUID()
tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
self.manager.storeTransaction(uuid, tid, *txn)
self.manager.storeObject(uuid, tid, serial, *obj)
self._storeTransactionObjects(tid, txn)
# lock transaction
self.manager.lock(tid)
self.manager.lock(tid, txn[0])
self.assertTrue(tid in self.manager)
self.manager.abort(tid, even_if_locked=False)
self.assertTrue(tid in self.manager)
self.assertTrue(self.manager.loadLocked(obj[0]))
for oid in txn[0]:
self.assertTrue(self.manager.loadLocked(oid))
self._checkQueuedEventExecuted(number=0)
def testAbortForNode(self):
......@@ -238,7 +261,8 @@ class TransactionManagerTests(NeoTestBase):
# node 2 owns tid2 & tid3 and lock tid2 only
self.manager.storeTransaction(uuid2, tid2, *txn2)
self.manager.storeTransaction(uuid2, tid3, *txn3)
self.manager.lock(tid2)
self._storeTransactionObjects(tid2, txn2)
self.manager.lock(tid2, txn2[0])
self.assertTrue(tid1 in self.manager)
self.assertTrue(tid2 in self.manager)
self.assertTrue(tid3 in self.manager)
......@@ -253,14 +277,14 @@ class TransactionManagerTests(NeoTestBase):
""" Reset the manager """
uuid = self.getNewUUID()
tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
self.manager.storeTransaction(uuid, tid, *txn)
self.manager.storeObject(uuid, tid, serial, *obj)
self.manager.lock(tid)
self._storeTransactionObjects(tid, txn)
self.manager.lock(tid, txn[0])
self.assertTrue(tid in self.manager)
self.manager.reset()
self.assertFalse(tid in self.manager)
self.assertFalse(self.manager.loadLocked(obj[0]))
for oid in txn[0]:
self.assertFalse(self.manager.loadLocked(oid))
def test_getObjectFromTransaction(self):
uuid = self.getNewUUID()
......
......@@ -290,10 +290,14 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(ptid, tid)
def test_38_askLockInformation(self):
oid1 = self.getNextTID()
oid2 = self.getNextTID()
oid_list = [oid1, oid2]
tid = self.getNextTID()
p = Packets.AskLockInformation(tid)
ptid = p.decode()[0]
p = Packets.AskLockInformation(tid, oid_list)
ptid, p_oid_list = p.decode()
self.assertEqual(ptid, tid)
self.assertEqual(oid_list, p_oid_list)
def test_39_answerInformationLocked(self):
tid = self.getNextTID()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment