Commit 51b206b9 authored by Julien Muchembled's avatar Julien Muchembled

client: fix conflicts between AskTIDs(From) requests and reconnections

This fixes testStorageReconnectDuring{Transaction,Undo}Log unit tests.

The change in testStorageReconnectDuringTransactionLog fixes a bug in the test
itself. 'c.root()._p_serial' returned ZERO_TID.

Application.undo is also updated so that 'waitResponses' do not use
'setHandlerData' anymore.
parent e48201dd
...@@ -649,12 +649,11 @@ class Application(object): ...@@ -649,12 +649,11 @@ class Application(object):
return result return result
@profiler_decorator @profiler_decorator
def waitResponses(self, queue, handler_data): def waitResponses(self, queue):
"""Wait for all requests to be answered (or their connection to be """Wait for all requests to be answered (or their connection to be
detected as closed)""" detected as closed)"""
pending = self.dispatcher.pending pending = self.dispatcher.pending
_waitAnyMessage = self._waitAnyMessage _waitAnyMessage = self._waitAnyMessage
self.setHandlerData(handler_data)
while pending(queue): while pending(queue):
_waitAnyMessage(queue) _waitAnyMessage(queue)
...@@ -829,6 +828,7 @@ class Application(object): ...@@ -829,6 +828,7 @@ class Application(object):
getConnForCell = self.cp.getConnForCell getConnForCell = self.cp.getConnForCell
queue = self._getThreadQueue() queue = self._getThreadQueue()
ttid = txn_context['ttid'] ttid = txn_context['ttid']
undo_object_tid_dict = {}
for partition, oid_list in partition_oid_dict.iteritems(): for partition, oid_list in partition_oid_dict.iteritems():
cell_list = getCellList(partition, readable=True) cell_list = getCellList(partition, readable=True)
# We do want to shuffle before getting one with the smallest # We do want to shuffle before getting one with the smallest
...@@ -837,15 +837,15 @@ class Application(object): ...@@ -837,15 +837,15 @@ class Application(object):
shuffle(cell_list) shuffle(cell_list)
storage_conn = getConnForCell(min(cell_list, key=getCellSortKey)) storage_conn = getConnForCell(min(cell_list, key=getCellSortKey))
storage_conn.ask(Packets.AskObjectUndoSerial(ttid, storage_conn.ask(Packets.AskObjectUndoSerial(ttid,
snapshot_tid, undone_tid, oid_list), queue=queue) snapshot_tid, undone_tid, oid_list),
queue=queue, undo_object_tid_dict=undo_object_tid_dict)
# Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError, # Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError,
# meaning that objects in transaction's oid_list do not exist any # meaning that objects in transaction's oid_list do not exist any
# longer. This is the symptom of a pack, so forbid undoing transaction # longer. This is the symptom of a pack, so forbid undoing transaction
# when it happens. # when it happens.
undo_object_tid_dict = {}
try: try:
self.waitResponses(queue, undo_object_tid_dict) self.waitResponses(queue)
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
self.dispatcher.forget_queue(queue) self.dispatcher.forget_queue(queue)
raise UndoError('non-undoable transaction') raise UndoError('non-undoable transaction')
...@@ -899,10 +899,6 @@ class Application(object): ...@@ -899,10 +899,6 @@ class Application(object):
raise NEOStorageError('Transaction %r not found' % (tid, )) raise NEOStorageError('Transaction %r not found' % (tid, ))
return (txn_info, txn_ext) return (txn_info, txn_ext)
# XXX: The following 2 methods fail when they reconnect to a storage after
# they already sent a request to a previous storage.
# See also testStorageReconnectDuringXxx
def undoLog(self, first, last, filter=None, block=0): def undoLog(self, first, last, filter=None, block=0):
# XXX: undoLog is broken # XXX: undoLog is broken
if last < 0: if last < 0:
...@@ -917,15 +913,15 @@ class Application(object): ...@@ -917,15 +913,15 @@ class Application(object):
queue = self._getThreadQueue() queue = self._getThreadQueue()
packet = Packets.AskTIDs(first, last, INVALID_PARTITION) packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
tid_set = set()
for storage_node in storage_node_list: for storage_node in storage_node_list:
conn = self.cp.getConnForNode(storage_node) conn = self.cp.getConnForNode(storage_node)
if conn is None: if conn is None:
continue continue
conn.ask(packet, queue=queue) conn.ask(packet, queue=queue, tid_set=tid_set)
# Wait for answers from all storages. # Wait for answers from all storages.
tid_set = set() self.waitResponses(queue)
self.waitResponses(queue, tid_set)
# Reorder tids # Reorder tids
ordered_tids = sorted(tid_set, reverse=True) ordered_tids = sorted(tid_set, reverse=True)
...@@ -955,6 +951,7 @@ class Application(object): ...@@ -955,6 +951,7 @@ class Application(object):
node_list.sort(key=self.cp.getCellSortKey) node_list.sort(key=self.cp.getCellSortKey)
partition_set = set(range(self.pt.getPartitions())) partition_set = set(range(self.pt.getPartitions()))
queue = self._getThreadQueue() queue = self._getThreadQueue()
tid_set = set()
# request a tid list for each partition # request a tid list for each partition
for node in node_list: for node in node_list:
conn = self.cp.getConnForNode(node) conn = self.cp.getConnForNode(node)
...@@ -963,12 +960,11 @@ class Application(object): ...@@ -963,12 +960,11 @@ class Application(object):
continue continue
partition_set -= set(request_set) partition_set -= set(request_set)
packet = Packets.AskTIDsFrom(start, stop, limit, request_set) packet = Packets.AskTIDsFrom(start, stop, limit, request_set)
conn.ask(packet, queue=queue) conn.ask(packet, queue=queue, tid_set=tid_set)
if not partition_set: if not partition_set:
break break
assert not partition_set assert not partition_set
tid_set = set() self.waitResponses(queue)
self.waitResponses(queue, tid_set)
# request transactions informations # request transactions informations
txn_list = [] txn_list = []
append = txn_list.append append = txn_list.append
......
...@@ -100,11 +100,10 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -100,11 +100,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerStoreTransaction(self, conn, _): def answerStoreTransaction(self, conn, _):
pass pass
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list, tid_set):
neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn) neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
tids_from = self.app.getHandlerData() assert not tid_set.intersection(tid_list)
assert not tids_from.intersection(set(tid_list)) tid_set.update(tid_list)
tids_from.update(tid_list)
def answerTransactionInformation(self, conn, tid, def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list): user, desc, ext, packed, oid_list):
...@@ -134,11 +133,12 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -134,11 +133,12 @@ class StorageAnswersHandler(AnswerBaseHandler):
# This can happen when requiring txn informations # This can happen when requiring txn informations
raise NEOStorageNotFoundError(message) raise NEOStorageNotFoundError(message)
def answerTIDs(self, conn, tid_list): def answerTIDs(self, conn, tid_list, tid_set):
self.app.getHandlerData().update(tid_list) tid_set.update(tid_list)
def answerObjectUndoSerial(self, conn, object_tid_dict): def answerObjectUndoSerial(self, conn, object_tid_dict,
self.app.getHandlerData().update(object_tid_dict) undo_object_tid_dict):
undo_object_tid_dict.update(object_tid_dict)
def answerHasLock(self, conn, oid, status): def answerHasLock(self, conn, oid, status):
store_msg_id = self.app.getHandlerData()['timeout_dict'].pop(oid) store_msg_id = self.app.getHandlerData()['timeout_dict'].pop(oid)
......
...@@ -28,6 +28,14 @@ from neo.lib.protocol import Packet, Packets, Errors, INVALID_TID, \ ...@@ -28,6 +28,14 @@ from neo.lib.protocol import Packet, Packets, Errors, INVALID_TID, \
from neo.lib.util import makeChecksum, SOCKET_CONNECTORS_DICT from neo.lib.util import makeChecksum, SOCKET_CONNECTORS_DICT
import time import time
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
def forget_queue(self, queue, flush_queue=True):
pass
def _getMasterConnection(self): def _getMasterConnection(self):
if self.master_conn is None: if self.master_conn is None:
self.uuid = 'C' * 16 self.uuid = 'C' * 16
...@@ -306,9 +314,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -306,9 +314,6 @@ class ClientApplicationTests(NeoUnitTestBase):
node, cell, conn = self.getNodeCellConn(address=storage_address) node, cell, conn = self.getNodeCellConn(address=storage_address)
app.pt = Mock({ 'getCellListForOID': (cell, cell)}) app.pt = Mock({ 'getCellListForOID': (cell, cell)})
app.cp = self.getConnectionPool([(node, conn)]) app.cp = self.getConnectionPool([(node, conn)])
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address) app.nm.createStorage(address=storage_address)
data_dict = txn_context['data_dict'] data_dict = txn_context['data_dict']
...@@ -337,9 +342,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -337,9 +342,6 @@ class ClientApplicationTests(NeoUnitTestBase):
uuid=uuid) uuid=uuid)
app.cp = self.getConnectionPool([(node, conn)]) app.cp = self.getConnectionPool([(node, conn)])
app.pt = Mock({ 'getCellListForOID': (cell, cell, ) }) app.pt = Mock({ 'getCellListForOID': (cell, cell, ) })
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address) app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn) app.store(oid, tid, 'DATA', None, txn)
...@@ -470,12 +472,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -470,12 +472,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.master_conn = Mock({'__hash__': 0}) app.master_conn = Mock({'__hash__': 0})
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
txn_context = self._begin(app, txn, tid) txn_context = self._begin(app, txn, tid)
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
def forget_queue(self, queue, flush_queue=True):
pass
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
# conflict occurs on storage 2 # conflict occurs on storage 2
app.store(oid1, tid, 'DATA', None, txn) app.store(oid1, tid, 'DATA', None, txn)
...@@ -566,9 +562,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -566,9 +562,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'iterateForObject': [(node, conn)], 'iterateForObject': [(node, conn)],
'getConnForCell': conn, 'getConnForCell': conn,
}) })
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
def load(oid, tid=None, before_tid=None): def load(oid, tid=None, before_tid=None):
self.assertEqual(oid, oid0) self.assertEqual(oid, oid0)
...@@ -599,8 +592,10 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -599,8 +592,10 @@ class ClientApplicationTests(NeoUnitTestBase):
tid2) tid2)
undo_serial = Packets.AnswerObjectUndoSerial({ undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)}) oid0: (tid2, tid0, False)})
conn.ask = lambda p, queue=None, **kw: \
isinstance(p, Packets.AskObjectUndoSerial) and \
queue.put((conn, undo_serial, kw))
undo_serial.setId(2) undo_serial.setId(2)
app._getThreadQueue().put((conn, undo_serial, {}))
marker = [] marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
...@@ -641,7 +636,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -641,7 +636,9 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2) undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1, app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2) tid2)
app._getThreadQueue().put((conn, undo_serial, {})) conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
marker = [] marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
...@@ -667,7 +664,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -667,7 +664,6 @@ class ClientApplicationTests(NeoUnitTestBase):
marker.append((oid, conflict_serial, serial, data, committedData)) marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError raise ConflictError
# The undo # The undo
app._getThreadQueue().put((conn, undo_serial, {}))
self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn, self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict) tryToResolveConflict)
# Checking what happened # Checking what happened
...@@ -700,7 +696,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -700,7 +696,9 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2) undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1, app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2) tid2)
app._getThreadQueue().put((conn, undo_serial, {})) conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
raise Exception, 'Test called conflict resolution, but there ' \ raise Exception, 'Test called conflict resolution, but there ' \
...@@ -720,8 +718,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -720,8 +718,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.num_partitions = 2 app.num_partitions = 2
uuid1, uuid2 = '\x00' * 15 + '\x01', '\x00' * 15 + '\x02' uuid1, uuid2 = '\x00' * 15 + '\x01', '\x00' * 15 + '\x02'
# two nodes, two partition, two transaction, two objects : # two nodes, two partition, two transaction, two objects :
node1, node2 = Mock({}), Mock({})
cell1, cell2 = Mock({}), Mock({})
tid1, tid2 = self.makeTID(1), self.makeTID(2) tid1, tid2 = self.makeTID(1), self.makeTID(2)
oid1, oid2 = self.makeOID(1), self.makeOID(2) oid1, oid2 = self.makeOID(1), self.makeOID(2)
# TIDs packets supplied by _ask hook # TIDs packets supplied by _ask hook
...@@ -744,38 +740,40 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -744,38 +740,40 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': ReturnValues(p3, p4), 'fakeReceived': ReturnValues(p3, p4),
'getAddress': ('127.0.0.1', 10021), 'getAddress': ('127.0.0.1', 10021),
}) })
storage_1_conn = Mock() asked = []
storage_2_conn = Mock() def answerTIDs(packet):
conn = Mock({'getAddress': packet})
app.nm.createStorage(address=conn.getAddress())
def ask(p, queue, **kw):
asked.append(p)
queue.put((conn, packet, kw))
conn.ask = ask
return conn
app.dispatcher = Dispatcher()
app.pt = Mock({ app.pt = Mock({
'getNodeList': (node1, node2, ), 'getNodeList': (Mock(), Mock()),
'getCellListForTID': ReturnValues([cell1], [cell2]), 'getCellListForTID': ReturnValues([Mock()], [Mock()]),
}) })
app.cp = Mock({ app.cp = Mock({
'getConnForNode': ReturnValues(storage_1_conn, storage_2_conn), 'getConnForNode': ReturnValues(answerTIDs(p1), answerTIDs(p2)),
'iterateForObject': [(Mock(), conn)] 'iterateForObject': [(Mock(), conn)]
}) })
def waitResponses(queue, handler_data):
app.setHandlerData(handler_data)
for p in (p1, p2):
app._handlePacket(Mock(), p, handler=app.storage_handler)
app.waitResponses = waitResponses
def txn_filter(info): def txn_filter(info):
return info['id'] > '\x00' * 8 return info['id'] > '\x00' * 8
first = 0 first = 0
last = 4 last = 4
result = app.undoLog(first, last, filter=txn_filter) result = app.undoLog(first, last, filter=txn_filter)
pfirst, plast, ppartition = self.checkAskPacket(storage_1_conn, pfirst, plast, ppartition = asked.pop().decode()
Packets.AskTIDs, decode=True)
self.assertEqual(pfirst, first) self.assertEqual(pfirst, first)
self.assertEqual(plast, last) self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION) self.assertEqual(ppartition, INVALID_PARTITION)
pfirst, plast, ppartition = self.checkAskPacket(storage_2_conn, pfirst, plast, ppartition = asked.pop().decode()
Packets.AskTIDs, decode=True)
self.assertEqual(pfirst, first) self.assertEqual(pfirst, first)
self.assertEqual(plast, last) self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION) self.assertEqual(ppartition, INVALID_PARTITION)
self.assertEqual(result[0]['id'], tid1) self.assertEqual(result[0]['id'], tid1)
self.assertEqual(result[1]['id'], tid2) self.assertEqual(result[1]['id'], tid2)
self.assertFalse(asked)
def test_connectToPrimaryNode(self): def test_connectToPrimaryNode(self):
# here we have three master nodes : # here we have three master nodes :
......
...@@ -223,12 +223,7 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -223,12 +223,7 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid_list = [tid1, tid2] tid_list = [tid1, tid2]
conn = self.getFakeConnection(uuid=uuid) conn = self.getFakeConnection(uuid=uuid)
tid_set = set() tid_set = set()
app = Mock({ StorageAnswersHandler(Mock()).answerTIDs(conn, tid_list, tid_set)
'getHandlerData': tid_set,
})
handler = StorageAnswersHandler(app)
handler.answerTIDs(conn, tid_list)
self.assertEqual(tid_set, set(tid_list)) self.assertEqual(tid_set, set(tid_list))
def test_answerObjectUndoSerial(self): def test_answerObjectUndoSerial(self):
...@@ -241,13 +236,10 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -241,13 +236,10 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
undo_dict = {} undo_dict = {}
app = Mock({ handler = StorageAnswersHandler(Mock())
'getHandlerData': undo_dict, handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]}, undo_dict)
})
handler = StorageAnswersHandler(app)
handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]})
self.assertEqual(undo_dict, {oid1: [tid0, tid1]}) self.assertEqual(undo_dict, {oid1: [tid0, tid1]})
handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]}) handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]}, undo_dict)
self.assertEqual(undo_dict, { self.assertEqual(undo_dict, {
oid1: [tid0, tid1], oid1: [tid0, tid1],
oid2: [tid2, tid3], oid2: [tid2, tid3],
......
...@@ -353,10 +353,6 @@ class Test(NEOThreadedTest): ...@@ -353,10 +353,6 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
# The following 2 tests fail because the same queue is used for
# AskTIDs(From) requests and reconnections. The same bug affected
# history() before df47e5b1df8eabbff1383348b6b8c476bca0c328
def testStorageReconnectDuringTransactionLog(self): def testStorageReconnectDuringTransactionLog(self):
cluster = NEOCluster(storage_count=2, partitions=2) cluster = NEOCluster(storage_count=2, partitions=2)
try: try:
...@@ -365,7 +361,7 @@ class Test(NEOThreadedTest): ...@@ -365,7 +361,7 @@ class Test(NEOThreadedTest):
while cluster.client.cp.connection_dict: while cluster.client.cp.connection_dict:
cluster.client.cp._dropConnections() cluster.client.cp._dropConnections()
tid, (t1,) = cluster.client.transactionLog( tid, (t1,) = cluster.client.transactionLog(
ZERO_TID, c.root()._p_serial, 10) ZERO_TID, c.db().lastTransaction(), 10)
finally: finally:
cluster.stop() cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment