Commit 51b206b9 authored by Julien Muchembled's avatar Julien Muchembled

client: fix conflicts between AskTIDs(From) requests and reconnections

This fixes testStorageReconnectDuring{Transaction,Undo}Log unit tests.

The change in testStorageReconnectDuringTransactionLog fixes a bug in the test
itself. 'c.root()._p_serial' returned ZERO_TID.

Application.undo is also updated so that 'waitResponses' do not use
'setHandlerData' anymore.
parent e48201dd
......@@ -649,12 +649,11 @@ class Application(object):
return result
@profiler_decorator
def waitResponses(self, queue, handler_data):
def waitResponses(self, queue):
"""Wait for all requests to be answered (or their connection to be
detected as closed)"""
pending = self.dispatcher.pending
_waitAnyMessage = self._waitAnyMessage
self.setHandlerData(handler_data)
while pending(queue):
_waitAnyMessage(queue)
......@@ -829,6 +828,7 @@ class Application(object):
getConnForCell = self.cp.getConnForCell
queue = self._getThreadQueue()
ttid = txn_context['ttid']
undo_object_tid_dict = {}
for partition, oid_list in partition_oid_dict.iteritems():
cell_list = getCellList(partition, readable=True)
# We do want to shuffle before getting one with the smallest
......@@ -837,15 +837,15 @@ class Application(object):
shuffle(cell_list)
storage_conn = getConnForCell(min(cell_list, key=getCellSortKey))
storage_conn.ask(Packets.AskObjectUndoSerial(ttid,
snapshot_tid, undone_tid, oid_list), queue=queue)
snapshot_tid, undone_tid, oid_list),
queue=queue, undo_object_tid_dict=undo_object_tid_dict)
# Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError,
# meaning that objects in transaction's oid_list do not exist any
# longer. This is the symptom of a pack, so forbid undoing transaction
# when it happens.
undo_object_tid_dict = {}
try:
self.waitResponses(queue, undo_object_tid_dict)
self.waitResponses(queue)
except NEOStorageNotFoundError:
self.dispatcher.forget_queue(queue)
raise UndoError('non-undoable transaction')
......@@ -899,10 +899,6 @@ class Application(object):
raise NEOStorageError('Transaction %r not found' % (tid, ))
return (txn_info, txn_ext)
# XXX: The following 2 methods fail when they reconnect to a storage after
# they already sent a request to a previous storage.
# See also testStorageReconnectDuringXxx
def undoLog(self, first, last, filter=None, block=0):
# XXX: undoLog is broken
if last < 0:
......@@ -917,15 +913,15 @@ class Application(object):
queue = self._getThreadQueue()
packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
tid_set = set()
for storage_node in storage_node_list:
conn = self.cp.getConnForNode(storage_node)
if conn is None:
continue
conn.ask(packet, queue=queue)
conn.ask(packet, queue=queue, tid_set=tid_set)
# Wait for answers from all storages.
tid_set = set()
self.waitResponses(queue, tid_set)
self.waitResponses(queue)
# Reorder tids
ordered_tids = sorted(tid_set, reverse=True)
......@@ -955,6 +951,7 @@ class Application(object):
node_list.sort(key=self.cp.getCellSortKey)
partition_set = set(range(self.pt.getPartitions()))
queue = self._getThreadQueue()
tid_set = set()
# request a tid list for each partition
for node in node_list:
conn = self.cp.getConnForNode(node)
......@@ -963,12 +960,11 @@ class Application(object):
continue
partition_set -= set(request_set)
packet = Packets.AskTIDsFrom(start, stop, limit, request_set)
conn.ask(packet, queue=queue)
conn.ask(packet, queue=queue, tid_set=tid_set)
if not partition_set:
break
assert not partition_set
tid_set = set()
self.waitResponses(queue, tid_set)
self.waitResponses(queue)
# request transactions informations
txn_list = []
append = txn_list.append
......
......@@ -100,11 +100,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
def answerStoreTransaction(self, conn, _):
pass
def answerTIDsFrom(self, conn, tid_list):
def answerTIDsFrom(self, conn, tid_list, tid_set):
neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
tids_from = self.app.getHandlerData()
assert not tids_from.intersection(set(tid_list))
tids_from.update(tid_list)
assert not tid_set.intersection(tid_list)
tid_set.update(tid_list)
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
......@@ -134,11 +133,12 @@ class StorageAnswersHandler(AnswerBaseHandler):
# This can happen when requiring txn informations
raise NEOStorageNotFoundError(message)
def answerTIDs(self, conn, tid_list):
self.app.getHandlerData().update(tid_list)
def answerTIDs(self, conn, tid_list, tid_set):
tid_set.update(tid_list)
def answerObjectUndoSerial(self, conn, object_tid_dict):
self.app.getHandlerData().update(object_tid_dict)
def answerObjectUndoSerial(self, conn, object_tid_dict,
undo_object_tid_dict):
undo_object_tid_dict.update(object_tid_dict)
def answerHasLock(self, conn, oid, status):
store_msg_id = self.app.getHandlerData()['timeout_dict'].pop(oid)
......
......@@ -28,6 +28,14 @@ from neo.lib.protocol import Packet, Packets, Errors, INVALID_TID, \
from neo.lib.util import makeChecksum, SOCKET_CONNECTORS_DICT
import time
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
def forget_queue(self, queue, flush_queue=True):
pass
def _getMasterConnection(self):
if self.master_conn is None:
self.uuid = 'C' * 16
......@@ -306,9 +314,6 @@ class ClientApplicationTests(NeoUnitTestBase):
node, cell, conn = self.getNodeCellConn(address=storage_address)
app.pt = Mock({ 'getCellListForOID': (cell, cell)})
app.cp = self.getConnectionPool([(node, conn)])
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
data_dict = txn_context['data_dict']
......@@ -337,9 +342,6 @@ class ClientApplicationTests(NeoUnitTestBase):
uuid=uuid)
app.cp = self.getConnectionPool([(node, conn)])
app.pt = Mock({ 'getCellListForOID': (cell, cell, ) })
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn)
......@@ -470,12 +472,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.master_conn = Mock({'__hash__': 0})
txn = self.makeTransactionObject()
txn_context = self._begin(app, txn, tid)
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
def forget_queue(self, queue, flush_queue=True):
pass
app.dispatcher = Dispatcher()
# conflict occurs on storage 2
app.store(oid1, tid, 'DATA', None, txn)
......@@ -566,9 +562,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'iterateForObject': [(node, conn)],
'getConnForCell': conn,
})
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
app.dispatcher = Dispatcher()
def load(oid, tid=None, before_tid=None):
self.assertEqual(oid, oid0)
......@@ -599,8 +592,10 @@ class ClientApplicationTests(NeoUnitTestBase):
tid2)
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)})
conn.ask = lambda p, queue=None, **kw: \
isinstance(p, Packets.AskObjectUndoSerial) and \
queue.put((conn, undo_serial, kw))
undo_serial.setId(2)
app._getThreadQueue().put((conn, undo_serial, {}))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
......@@ -641,7 +636,9 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
app._getThreadQueue().put((conn, undo_serial, {}))
conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
......@@ -667,7 +664,6 @@ class ClientApplicationTests(NeoUnitTestBase):
marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError
# The undo
app._getThreadQueue().put((conn, undo_serial, {}))
self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict)
# Checking what happened
......@@ -700,7 +696,9 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
app._getThreadQueue().put((conn, undo_serial, {}))
conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw))
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
raise Exception, 'Test called conflict resolution, but there ' \
......@@ -720,8 +718,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.num_partitions = 2
uuid1, uuid2 = '\x00' * 15 + '\x01', '\x00' * 15 + '\x02'
# two nodes, two partition, two transaction, two objects :
node1, node2 = Mock({}), Mock({})
cell1, cell2 = Mock({}), Mock({})
tid1, tid2 = self.makeTID(1), self.makeTID(2)
oid1, oid2 = self.makeOID(1), self.makeOID(2)
# TIDs packets supplied by _ask hook
......@@ -744,38 +740,40 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': ReturnValues(p3, p4),
'getAddress': ('127.0.0.1', 10021),
})
storage_1_conn = Mock()
storage_2_conn = Mock()
asked = []
def answerTIDs(packet):
conn = Mock({'getAddress': packet})
app.nm.createStorage(address=conn.getAddress())
def ask(p, queue, **kw):
asked.append(p)
queue.put((conn, packet, kw))
conn.ask = ask
return conn
app.dispatcher = Dispatcher()
app.pt = Mock({
'getNodeList': (node1, node2, ),
'getCellListForTID': ReturnValues([cell1], [cell2]),
'getNodeList': (Mock(), Mock()),
'getCellListForTID': ReturnValues([Mock()], [Mock()]),
})
app.cp = Mock({
'getConnForNode': ReturnValues(storage_1_conn, storage_2_conn),
'getConnForNode': ReturnValues(answerTIDs(p1), answerTIDs(p2)),
'iterateForObject': [(Mock(), conn)]
})
def waitResponses(queue, handler_data):
app.setHandlerData(handler_data)
for p in (p1, p2):
app._handlePacket(Mock(), p, handler=app.storage_handler)
app.waitResponses = waitResponses
def txn_filter(info):
return info['id'] > '\x00' * 8
first = 0
last = 4
result = app.undoLog(first, last, filter=txn_filter)
pfirst, plast, ppartition = self.checkAskPacket(storage_1_conn,
Packets.AskTIDs, decode=True)
pfirst, plast, ppartition = asked.pop().decode()
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
pfirst, plast, ppartition = self.checkAskPacket(storage_2_conn,
Packets.AskTIDs, decode=True)
pfirst, plast, ppartition = asked.pop().decode()
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
self.assertEqual(result[0]['id'], tid1)
self.assertEqual(result[1]['id'], tid2)
self.assertFalse(asked)
def test_connectToPrimaryNode(self):
# here we have three master nodes :
......
......@@ -223,12 +223,7 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid_list = [tid1, tid2]
conn = self.getFakeConnection(uuid=uuid)
tid_set = set()
app = Mock({
'getHandlerData': tid_set,
})
handler = StorageAnswersHandler(app)
handler.answerTIDs(conn, tid_list)
StorageAnswersHandler(Mock()).answerTIDs(conn, tid_list, tid_set)
self.assertEqual(tid_set, set(tid_list))
def test_answerObjectUndoSerial(self):
......@@ -241,13 +236,10 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid2 = self.getNextTID()
tid3 = self.getNextTID()
undo_dict = {}
app = Mock({
'getHandlerData': undo_dict,
})
handler = StorageAnswersHandler(app)
handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]})
handler = StorageAnswersHandler(Mock())
handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]}, undo_dict)
self.assertEqual(undo_dict, {oid1: [tid0, tid1]})
handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]})
handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]}, undo_dict)
self.assertEqual(undo_dict, {
oid1: [tid0, tid1],
oid2: [tid2, tid3],
......
......@@ -353,10 +353,6 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
# The following 2 tests fail because the same queue is used for
# AskTIDs(From) requests and reconnections. The same bug affected
# history() before df47e5b1df8eabbff1383348b6b8c476bca0c328
def testStorageReconnectDuringTransactionLog(self):
cluster = NEOCluster(storage_count=2, partitions=2)
try:
......@@ -365,7 +361,7 @@ class Test(NEOThreadedTest):
while cluster.client.cp.connection_dict:
cluster.client.cp._dropConnections()
tid, (t1,) = cluster.client.transactionLog(
ZERO_TID, c.root()._p_serial, 10)
ZERO_TID, c.db().lastTransaction(), 10)
finally:
cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment