Commit 4720ba33 authored by Julien Muchembled's avatar Julien Muchembled

Extend Connection.ask so that private data can be forwarded to response handler

Keyword arguments are kept internally by the handler switcher, which passes
them to the handler when the peer answers.

This will be required by new replicator, so that 'setUnfinishedTIDList'
receives the list of outdated partitions at the time the
'AskUnfinishedTransactions' packet was sent.
parent c810920c
......@@ -129,14 +129,14 @@ class MasterEventHandler(EventHandler):
def connectionClosed(self, conn):
self._connectionLost(conn)
def dispatch(self, conn, packet):
def dispatch(self, conn, packet, kw={}):
if packet.isResponse() and \
self.app.dispatcher.registered(packet.getId()):
# expected answer
self.app.request_handler.dispatch(conn, packet)
self.app.request_handler.dispatch(conn, packet, kw)
else:
# unexpectexd answers and notifications
super(MasterEventHandler, self).dispatch(conn, packet)
super(MasterEventHandler, self).dispatch(conn, packet, kw)
def answerNodeInformation(self, conn):
# XXX: This will no more exists when the initialization module will be
......
......@@ -149,7 +149,7 @@ class Application(object):
self.pt.log()
@profiler_decorator
def _handlePacket(self, conn, packet, handler=None):
def _handlePacket(self, conn, packet, kw={}, handler=None):
"""
conn
The connection which received the packet (forwarded to handler).
......@@ -174,7 +174,7 @@ class Application(object):
raise ValueError, 'Unknown node type: %r' % (node.__class__, )
conn.lock()
try:
handler.dispatch(conn, packet)
handler.dispatch(conn, packet, kw)
finally:
conn.unlock()
......@@ -191,7 +191,7 @@ class Application(object):
_handlePacket = self._handlePacket
while pending(queue):
try:
conn, packet = get(block)
conn, packet, kw = get(block)
except Empty:
break
if packet is None or isinstance(packet, ForgottenPacket):
......@@ -199,7 +199,7 @@ class Application(object):
continue
block = False
try:
_handlePacket(conn, packet)
_handlePacket(conn, packet, kw)
except ConnectionClosed:
pass
......@@ -224,7 +224,7 @@ class Application(object):
get = queue.get
_handlePacket = self._handlePacket
while True:
qconn, qpacket = get(True)
qconn, qpacket, kw = get(True)
is_forgotten = isinstance(qpacket, ForgottenPacket)
if conn is qconn:
# check fake packet
......@@ -234,7 +234,7 @@ class Application(object):
if is_forgotten:
raise ValueError, 'ForgottenPacket for an ' \
'explicitely expected packet.'
_handlePacket(qconn, qpacket, handler=handler)
_handlePacket(qconn, qpacket, kw, handler)
break
if not is_forgotten and qpacket is not None:
_handlePacket(qconn, qpacket)
......
......@@ -25,24 +25,24 @@ class BaseHandler(EventHandler):
super(BaseHandler, self).__init__(app)
self.dispatcher = app.dispatcher
def dispatch(self, conn, packet):
def dispatch(self, conn, packet, kw={}):
# Before calling superclass's dispatch method, lock the connection.
# This covers the case where handler sends a response to received
# packet.
conn.lock()
try:
super(BaseHandler, self).dispatch(conn, packet)
super(BaseHandler, self).dispatch(conn, packet, kw)
finally:
conn.release()
def packetReceived(self, conn, packet):
def packetReceived(self, conn, packet, kw={}):
"""Redirect all received packet to dispatcher thread."""
if packet.isResponse() and type(packet) is not Packets.Pong:
if not self.dispatcher.dispatch(conn, packet.getId(), packet):
if not self.dispatcher.dispatch(conn, packet.getId(), packet, kw):
raise ProtocolError('Unexpected response packet from %r: %r'
% (conn, packet))
else:
self.dispatch(conn, packet)
self.dispatch(conn, packet, kw)
def connectionLost(self, conn, new_state):
......
......@@ -107,7 +107,7 @@ class HandlerSwitcher(object):
return self._pending[-1][1]
@profiler_decorator
def emit(self, request, timeout, on_timeout):
def emit(self, request, timeout, on_timeout, kw={}):
# register the request in the current handler
_pending = self._pending
if self._is_handling:
......@@ -127,7 +127,7 @@ class HandlerSwitcher(object):
self._next_timeout = timeout
self._next_timeout_msg_id = msg_id
self._next_on_timeout = on_timeout
request_dict[msg_id] = (answer_class, timeout, on_timeout)
request_dict[msg_id] = answer_class, timeout, on_timeout, kw
def getNextTimeout(self):
return self._next_timeout
......@@ -166,9 +166,12 @@ class HandlerSwitcher(object):
handler.packetReceived(connection, packet)
return
# checkout the expected answer class
(klass, timeout, _) = request_dict.pop(msg_id, (None, None, None))
try:
klass, timeout, _, kw = request_dict.pop(msg_id)
except KeyError:
klass = None
if klass and isinstance(packet, klass) or packet.isError():
handler.packetReceived(connection, packet)
handler.packetReceived(connection, packet, kw)
else:
neo.lib.logging.error(
'Unexpected answer %r in %r', packet, connection)
......@@ -190,7 +193,7 @@ class HandlerSwitcher(object):
# Find next timeout and its msg_id
next_timeout = None
for pending in self._pending:
for msg_id, (_, timeout, on_timeout) in pending[0].iteritems():
for msg_id, (_, timeout, on_timeout, _) in pending[0].iteritems():
if not next_timeout or timeout < next_timeout[0]:
next_timeout = timeout, msg_id, on_timeout
self._next_timeout, self._next_timeout_msg_id, self._next_on_timeout = \
......@@ -598,7 +601,7 @@ class Connection(BaseConnection):
@profiler_decorator
@not_closed
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None):
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, **kw):
"""
Send a packet with a new ID and register the expectation of an answer
"""
......@@ -607,7 +610,7 @@ class Connection(BaseConnection):
self._addPacket(packet)
handlers = self._handlers
t = not handlers.isPending() and time() or None
handlers.emit(packet, timeout, on_timeout)
handlers.emit(packet, timeout, on_timeout, kw)
self.updateTimeout(t)
return msg_id
......@@ -728,7 +731,7 @@ class MTClientConnection(ClientConnection):
@profiler_decorator
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None,
queue=None):
queue=None, **kw):
self.lock()
try:
if self.isClosed():
......@@ -747,7 +750,7 @@ class MTClientConnection(ClientConnection):
self._addPacket(packet)
handlers = self._handlers
t = not handlers.isPending() and time() or None
handlers.emit(packet, timeout, on_timeout)
handlers.emit(packet, timeout, on_timeout, kw)
self.updateTimeout(t)
return msg_id
finally:
......
......@@ -55,7 +55,7 @@ class Dispatcher:
@giant_lock
@profiler_decorator
def dispatch(self, conn, msg_id, packet):
def dispatch(self, conn, msg_id, packet, kw):
"""
Retrieve register-time provided queue, and put conn and packet in it.
"""
......@@ -65,7 +65,7 @@ class Dispatcher:
elif queue is NOBODY:
return True
self._decrefQueue(queue)
queue.put((conn, packet))
queue.put((conn, packet, kw))
return True
def _decrefQueue(self, queue):
......@@ -112,7 +112,7 @@ class Dispatcher:
continue
queue_id = id(queue)
if queue_id not in notified_set:
queue.put((conn, None))
queue.put((conn, None, None))
notified_set.add(queue_id)
_decrefQueue(queue)
......@@ -127,7 +127,7 @@ class Dispatcher:
if queue is NOBODY:
raise KeyError, 'Already expected by NOBODY: %r, %r' % (
conn, msg_id)
queue.put((conn, ForgottenPacket(msg_id)))
queue.put((conn, ForgottenPacket(msg_id), None))
self.queue_dict[id(queue)] -= 1
message_table[msg_id] = NOBODY
return queue
......
......@@ -43,7 +43,7 @@ class EventHandler(object):
conn.abort()
# self.peerBroken(conn)
def dispatch(self, conn, packet):
def dispatch(self, conn, packet, kw={}):
"""This is a helper method to handle various packet types."""
try:
try:
......@@ -52,7 +52,7 @@ class EventHandler(object):
raise UnexpectedPacketError('no handler found')
args = packet.decode() or ()
conn.setPeerId(packet.getId())
method(conn, *args)
method(conn, *args, **kw)
except UnexpectedPacketError, e:
self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError:
......@@ -83,9 +83,9 @@ class EventHandler(object):
# Network level handlers
def packetReceived(self, conn, packet):
def packetReceived(self, *args):
"""Called when a packet is received."""
self.dispatch(conn, packet)
self.dispatch(*args)
def connectionStarted(self, conn):
"""Called when a connection is started."""
......
......@@ -37,12 +37,12 @@ class SecondaryMasterHandler(MasterHandler):
class PrimaryHandler(MasterHandler):
""" Handler used by secondaries to handle primary master"""
def packetReceived(self, conn, packet):
def packetReceived(self, conn, packet, kw):
if not conn.isServer():
node = self.app.nm.getByAddress(conn.getAddress())
if not node.isBroken():
node.setRunning()
MasterHandler.packetReceived(self, conn, packet)
super(PrimaryHandler, self).packetReceived(conn, packet, kw)
def connectionLost(self, conn, new_state):
self.app.primary_master_node.setDown()
......
......@@ -315,7 +315,7 @@ class ClientApplicationTests(NeoUnitTestBase):
data_dict[oid] = 'BEFORE'
txn_context['data_list'].append(oid)
app.store(oid, tid, '', None, txn)
txn_context['queue'].put((conn, packet))
txn_context['queue'].put((conn, packet, {}))
self.assertRaises(ConflictError, app.waitStoreResponses, txn_context,
failing_tryToResolveConflict)
self.assertTrue(oid not in data_dict)
......@@ -344,7 +344,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn)
self.checkAskStoreObject(conn)
txn_context['queue'].put((conn, packet))
txn_context['queue'].put((conn, packet, {}))
app.waitStoreResponses(txn_context, resolving_tryToResolveConflict)
self.assertEqual(txn_context['object_stored_counter_dict'][oid],
{tid: set([uuid])})
......@@ -481,8 +481,8 @@ class ClientApplicationTests(NeoUnitTestBase):
app.store(oid1, tid, 'DATA', None, txn)
app.store(oid2, tid, 'DATA', None, txn)
queue = txn_context['queue']
queue.put((conn2, packet2))
queue.put((conn3, packet3))
queue.put((conn2, packet2, {}))
queue.put((conn3, packet3, {}))
# vote fails as the conflict is not resolved, nothing is sent to storage 3
self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict)
# abort must be sent to storage 1 and 2
......@@ -600,7 +600,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)})
undo_serial.setId(2)
app._getThreadQueue().put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial, {}))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
......@@ -641,7 +641,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
app._getThreadQueue().put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial, {}))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
......@@ -667,7 +667,7 @@ class ClientApplicationTests(NeoUnitTestBase):
marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError
# The undo
app._getThreadQueue().put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial, {}))
self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict)
# Checking what happened
......@@ -700,7 +700,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
app._getThreadQueue().put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial, {}))
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
raise Exception, 'Test called conflict resolution, but there ' \
......
......@@ -911,7 +911,7 @@ class HandlerSwitcherTests(NeoUnitTestBase):
self.assertFalse(applied)
self._checkCurrentHandler(self._handler)
call_tracker = []
def packetReceived(conn, packet):
def packetReceived(conn, packet, kw):
self._handlers.emit(self._makeRequest(2), 0, None)
call_tracker.append(True)
self._handler.packetReceived = packetReceived
......
......@@ -34,11 +34,11 @@ class DispatcherTests(NeoTestBase):
MARKER = object()
self.dispatcher.register(conn, 1, queue)
self.assertTrue(queue.empty())
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER))
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER, {}))
self.assertFalse(queue.empty())
self.assertEqual(queue.get(block=False), (conn, MARKER))
self.assertEqual(queue.get(block=False), (conn, MARKER, {}))
self.assertTrue(queue.empty())
self.assertFalse(self.dispatcher.dispatch(conn, 2, None))
self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
self.assertEqual(len(self.fake_thread.mockGetNamedCalls('start')), 1)
def testUnregister(self):
......@@ -47,7 +47,7 @@ class DispatcherTests(NeoTestBase):
self.dispatcher.register(conn, 2, queue)
self.dispatcher.unregister(conn)
self.assertEqual(len(queue.mockGetNamedCalls('put')), 1)
self.assertFalse(self.dispatcher.dispatch(conn, 2, None))
self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
def testRegistered(self):
conn1 = object()
......@@ -88,10 +88,10 @@ class DispatcherTests(NeoTestBase):
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn1, 1, None)
self.dispatcher.dispatch(conn1, 1, None, {})
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn2, 2, None)
self.dispatcher.dispatch(conn2, 2, None, {})
self.assertFalse(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
......@@ -121,7 +121,7 @@ class DispatcherTests(NeoTestBase):
forgotten_queue = self.dispatcher.forget(conn, 1)
self.assertTrue(queue is forgotten_queue, (queue, forgotten_queue))
# A ForgottenPacket must have been put in the queue
queue_conn, packet = queue.get(block=False)
queue_conn, packet, kw = queue.get(block=False)
self.assertTrue(isinstance(packet, ForgottenPacket), packet)
# ...with appropriate packet id
self.assertEqual(packet.getId(), 1)
......@@ -130,7 +130,7 @@ class DispatcherTests(NeoTestBase):
# If forgotten twice, it must raise a KeyError
self.assertRaises(KeyError, self.dispatcher.forget, conn, 1)
# Event arrives, return value must be True (it was expected)
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER))
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER, {}))
# ...but must not have reached the queue
self.assertTrue(queue.empty())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment