Commit 4720ba33 authored by Julien Muchembled's avatar Julien Muchembled

Extend Connection.ask so that private data can be forwarded to response handler

Keyword arguments are kept internally by the handler switcher, which passes
them to the handler when the peer answers.

This will be required by new replicator, so that 'setUnfinishedTIDList'
receives the list of outdated partitions at the time the
'AskUnfinishedTransactions' packet was sent.
parent c810920c
...@@ -129,14 +129,14 @@ class MasterEventHandler(EventHandler): ...@@ -129,14 +129,14 @@ class MasterEventHandler(EventHandler):
def connectionClosed(self, conn): def connectionClosed(self, conn):
self._connectionLost(conn) self._connectionLost(conn)
def dispatch(self, conn, packet): def dispatch(self, conn, packet, kw={}):
if packet.isResponse() and \ if packet.isResponse() and \
self.app.dispatcher.registered(packet.getId()): self.app.dispatcher.registered(packet.getId()):
# expected answer # expected answer
self.app.request_handler.dispatch(conn, packet) self.app.request_handler.dispatch(conn, packet, kw)
else: else:
# unexpectexd answers and notifications # unexpectexd answers and notifications
super(MasterEventHandler, self).dispatch(conn, packet) super(MasterEventHandler, self).dispatch(conn, packet, kw)
def answerNodeInformation(self, conn): def answerNodeInformation(self, conn):
# XXX: This will no more exists when the initialization module will be # XXX: This will no more exists when the initialization module will be
......
...@@ -149,7 +149,7 @@ class Application(object): ...@@ -149,7 +149,7 @@ class Application(object):
self.pt.log() self.pt.log()
@profiler_decorator @profiler_decorator
def _handlePacket(self, conn, packet, handler=None): def _handlePacket(self, conn, packet, kw={}, handler=None):
""" """
conn conn
The connection which received the packet (forwarded to handler). The connection which received the packet (forwarded to handler).
...@@ -174,7 +174,7 @@ class Application(object): ...@@ -174,7 +174,7 @@ class Application(object):
raise ValueError, 'Unknown node type: %r' % (node.__class__, ) raise ValueError, 'Unknown node type: %r' % (node.__class__, )
conn.lock() conn.lock()
try: try:
handler.dispatch(conn, packet) handler.dispatch(conn, packet, kw)
finally: finally:
conn.unlock() conn.unlock()
...@@ -191,7 +191,7 @@ class Application(object): ...@@ -191,7 +191,7 @@ class Application(object):
_handlePacket = self._handlePacket _handlePacket = self._handlePacket
while pending(queue): while pending(queue):
try: try:
conn, packet = get(block) conn, packet, kw = get(block)
except Empty: except Empty:
break break
if packet is None or isinstance(packet, ForgottenPacket): if packet is None or isinstance(packet, ForgottenPacket):
...@@ -199,7 +199,7 @@ class Application(object): ...@@ -199,7 +199,7 @@ class Application(object):
continue continue
block = False block = False
try: try:
_handlePacket(conn, packet) _handlePacket(conn, packet, kw)
except ConnectionClosed: except ConnectionClosed:
pass pass
...@@ -224,7 +224,7 @@ class Application(object): ...@@ -224,7 +224,7 @@ class Application(object):
get = queue.get get = queue.get
_handlePacket = self._handlePacket _handlePacket = self._handlePacket
while True: while True:
qconn, qpacket = get(True) qconn, qpacket, kw = get(True)
is_forgotten = isinstance(qpacket, ForgottenPacket) is_forgotten = isinstance(qpacket, ForgottenPacket)
if conn is qconn: if conn is qconn:
# check fake packet # check fake packet
...@@ -234,7 +234,7 @@ class Application(object): ...@@ -234,7 +234,7 @@ class Application(object):
if is_forgotten: if is_forgotten:
raise ValueError, 'ForgottenPacket for an ' \ raise ValueError, 'ForgottenPacket for an ' \
'explicitely expected packet.' 'explicitely expected packet.'
_handlePacket(qconn, qpacket, handler=handler) _handlePacket(qconn, qpacket, kw, handler)
break break
if not is_forgotten and qpacket is not None: if not is_forgotten and qpacket is not None:
_handlePacket(qconn, qpacket) _handlePacket(qconn, qpacket)
......
...@@ -25,24 +25,24 @@ class BaseHandler(EventHandler): ...@@ -25,24 +25,24 @@ class BaseHandler(EventHandler):
super(BaseHandler, self).__init__(app) super(BaseHandler, self).__init__(app)
self.dispatcher = app.dispatcher self.dispatcher = app.dispatcher
def dispatch(self, conn, packet): def dispatch(self, conn, packet, kw={}):
# Before calling superclass's dispatch method, lock the connection. # Before calling superclass's dispatch method, lock the connection.
# This covers the case where handler sends a response to received # This covers the case where handler sends a response to received
# packet. # packet.
conn.lock() conn.lock()
try: try:
super(BaseHandler, self).dispatch(conn, packet) super(BaseHandler, self).dispatch(conn, packet, kw)
finally: finally:
conn.release() conn.release()
def packetReceived(self, conn, packet): def packetReceived(self, conn, packet, kw={}):
"""Redirect all received packet to dispatcher thread.""" """Redirect all received packet to dispatcher thread."""
if packet.isResponse() and type(packet) is not Packets.Pong: if packet.isResponse() and type(packet) is not Packets.Pong:
if not self.dispatcher.dispatch(conn, packet.getId(), packet): if not self.dispatcher.dispatch(conn, packet.getId(), packet, kw):
raise ProtocolError('Unexpected response packet from %r: %r' raise ProtocolError('Unexpected response packet from %r: %r'
% (conn, packet)) % (conn, packet))
else: else:
self.dispatch(conn, packet) self.dispatch(conn, packet, kw)
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
......
...@@ -107,7 +107,7 @@ class HandlerSwitcher(object): ...@@ -107,7 +107,7 @@ class HandlerSwitcher(object):
return self._pending[-1][1] return self._pending[-1][1]
@profiler_decorator @profiler_decorator
def emit(self, request, timeout, on_timeout): def emit(self, request, timeout, on_timeout, kw={}):
# register the request in the current handler # register the request in the current handler
_pending = self._pending _pending = self._pending
if self._is_handling: if self._is_handling:
...@@ -127,7 +127,7 @@ class HandlerSwitcher(object): ...@@ -127,7 +127,7 @@ class HandlerSwitcher(object):
self._next_timeout = timeout self._next_timeout = timeout
self._next_timeout_msg_id = msg_id self._next_timeout_msg_id = msg_id
self._next_on_timeout = on_timeout self._next_on_timeout = on_timeout
request_dict[msg_id] = (answer_class, timeout, on_timeout) request_dict[msg_id] = answer_class, timeout, on_timeout, kw
def getNextTimeout(self): def getNextTimeout(self):
return self._next_timeout return self._next_timeout
...@@ -166,9 +166,12 @@ class HandlerSwitcher(object): ...@@ -166,9 +166,12 @@ class HandlerSwitcher(object):
handler.packetReceived(connection, packet) handler.packetReceived(connection, packet)
return return
# checkout the expected answer class # checkout the expected answer class
(klass, timeout, _) = request_dict.pop(msg_id, (None, None, None)) try:
klass, timeout, _, kw = request_dict.pop(msg_id)
except KeyError:
klass = None
if klass and isinstance(packet, klass) or packet.isError(): if klass and isinstance(packet, klass) or packet.isError():
handler.packetReceived(connection, packet) handler.packetReceived(connection, packet, kw)
else: else:
neo.lib.logging.error( neo.lib.logging.error(
'Unexpected answer %r in %r', packet, connection) 'Unexpected answer %r in %r', packet, connection)
...@@ -190,7 +193,7 @@ class HandlerSwitcher(object): ...@@ -190,7 +193,7 @@ class HandlerSwitcher(object):
# Find next timeout and its msg_id # Find next timeout and its msg_id
next_timeout = None next_timeout = None
for pending in self._pending: for pending in self._pending:
for msg_id, (_, timeout, on_timeout) in pending[0].iteritems(): for msg_id, (_, timeout, on_timeout, _) in pending[0].iteritems():
if not next_timeout or timeout < next_timeout[0]: if not next_timeout or timeout < next_timeout[0]:
next_timeout = timeout, msg_id, on_timeout next_timeout = timeout, msg_id, on_timeout
self._next_timeout, self._next_timeout_msg_id, self._next_on_timeout = \ self._next_timeout, self._next_timeout_msg_id, self._next_on_timeout = \
...@@ -598,7 +601,7 @@ class Connection(BaseConnection): ...@@ -598,7 +601,7 @@ class Connection(BaseConnection):
@profiler_decorator @profiler_decorator
@not_closed @not_closed
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None): def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, **kw):
""" """
Send a packet with a new ID and register the expectation of an answer Send a packet with a new ID and register the expectation of an answer
""" """
...@@ -607,7 +610,7 @@ class Connection(BaseConnection): ...@@ -607,7 +610,7 @@ class Connection(BaseConnection):
self._addPacket(packet) self._addPacket(packet)
handlers = self._handlers handlers = self._handlers
t = not handlers.isPending() and time() or None t = not handlers.isPending() and time() or None
handlers.emit(packet, timeout, on_timeout) handlers.emit(packet, timeout, on_timeout, kw)
self.updateTimeout(t) self.updateTimeout(t)
return msg_id return msg_id
...@@ -728,7 +731,7 @@ class MTClientConnection(ClientConnection): ...@@ -728,7 +731,7 @@ class MTClientConnection(ClientConnection):
@profiler_decorator @profiler_decorator
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None,
queue=None): queue=None, **kw):
self.lock() self.lock()
try: try:
if self.isClosed(): if self.isClosed():
...@@ -747,7 +750,7 @@ class MTClientConnection(ClientConnection): ...@@ -747,7 +750,7 @@ class MTClientConnection(ClientConnection):
self._addPacket(packet) self._addPacket(packet)
handlers = self._handlers handlers = self._handlers
t = not handlers.isPending() and time() or None t = not handlers.isPending() and time() or None
handlers.emit(packet, timeout, on_timeout) handlers.emit(packet, timeout, on_timeout, kw)
self.updateTimeout(t) self.updateTimeout(t)
return msg_id return msg_id
finally: finally:
......
...@@ -55,7 +55,7 @@ class Dispatcher: ...@@ -55,7 +55,7 @@ class Dispatcher:
@giant_lock @giant_lock
@profiler_decorator @profiler_decorator
def dispatch(self, conn, msg_id, packet): def dispatch(self, conn, msg_id, packet, kw):
""" """
Retrieve register-time provided queue, and put conn and packet in it. Retrieve register-time provided queue, and put conn and packet in it.
""" """
...@@ -65,7 +65,7 @@ class Dispatcher: ...@@ -65,7 +65,7 @@ class Dispatcher:
elif queue is NOBODY: elif queue is NOBODY:
return True return True
self._decrefQueue(queue) self._decrefQueue(queue)
queue.put((conn, packet)) queue.put((conn, packet, kw))
return True return True
def _decrefQueue(self, queue): def _decrefQueue(self, queue):
...@@ -112,7 +112,7 @@ class Dispatcher: ...@@ -112,7 +112,7 @@ class Dispatcher:
continue continue
queue_id = id(queue) queue_id = id(queue)
if queue_id not in notified_set: if queue_id not in notified_set:
queue.put((conn, None)) queue.put((conn, None, None))
notified_set.add(queue_id) notified_set.add(queue_id)
_decrefQueue(queue) _decrefQueue(queue)
...@@ -127,7 +127,7 @@ class Dispatcher: ...@@ -127,7 +127,7 @@ class Dispatcher:
if queue is NOBODY: if queue is NOBODY:
raise KeyError, 'Already expected by NOBODY: %r, %r' % ( raise KeyError, 'Already expected by NOBODY: %r, %r' % (
conn, msg_id) conn, msg_id)
queue.put((conn, ForgottenPacket(msg_id))) queue.put((conn, ForgottenPacket(msg_id), None))
self.queue_dict[id(queue)] -= 1 self.queue_dict[id(queue)] -= 1
message_table[msg_id] = NOBODY message_table[msg_id] = NOBODY
return queue return queue
......
...@@ -43,7 +43,7 @@ class EventHandler(object): ...@@ -43,7 +43,7 @@ class EventHandler(object):
conn.abort() conn.abort()
# self.peerBroken(conn) # self.peerBroken(conn)
def dispatch(self, conn, packet): def dispatch(self, conn, packet, kw={}):
"""This is a helper method to handle various packet types.""" """This is a helper method to handle various packet types."""
try: try:
try: try:
...@@ -52,7 +52,7 @@ class EventHandler(object): ...@@ -52,7 +52,7 @@ class EventHandler(object):
raise UnexpectedPacketError('no handler found') raise UnexpectedPacketError('no handler found')
args = packet.decode() or () args = packet.decode() or ()
conn.setPeerId(packet.getId()) conn.setPeerId(packet.getId())
method(conn, *args) method(conn, *args, **kw)
except UnexpectedPacketError, e: except UnexpectedPacketError, e:
self.__unexpectedPacket(conn, packet, *e.args) self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError: except PacketMalformedError:
...@@ -83,9 +83,9 @@ class EventHandler(object): ...@@ -83,9 +83,9 @@ class EventHandler(object):
# Network level handlers # Network level handlers
def packetReceived(self, conn, packet): def packetReceived(self, *args):
"""Called when a packet is received.""" """Called when a packet is received."""
self.dispatch(conn, packet) self.dispatch(*args)
def connectionStarted(self, conn): def connectionStarted(self, conn):
"""Called when a connection is started.""" """Called when a connection is started."""
......
...@@ -37,12 +37,12 @@ class SecondaryMasterHandler(MasterHandler): ...@@ -37,12 +37,12 @@ class SecondaryMasterHandler(MasterHandler):
class PrimaryHandler(MasterHandler): class PrimaryHandler(MasterHandler):
""" Handler used by secondaries to handle primary master""" """ Handler used by secondaries to handle primary master"""
def packetReceived(self, conn, packet): def packetReceived(self, conn, packet, kw):
if not conn.isServer(): if not conn.isServer():
node = self.app.nm.getByAddress(conn.getAddress()) node = self.app.nm.getByAddress(conn.getAddress())
if not node.isBroken(): if not node.isBroken():
node.setRunning() node.setRunning()
MasterHandler.packetReceived(self, conn, packet) super(PrimaryHandler, self).packetReceived(conn, packet, kw)
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
self.app.primary_master_node.setDown() self.app.primary_master_node.setDown()
......
...@@ -315,7 +315,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -315,7 +315,7 @@ class ClientApplicationTests(NeoUnitTestBase):
data_dict[oid] = 'BEFORE' data_dict[oid] = 'BEFORE'
txn_context['data_list'].append(oid) txn_context['data_list'].append(oid)
app.store(oid, tid, '', None, txn) app.store(oid, tid, '', None, txn)
txn_context['queue'].put((conn, packet)) txn_context['queue'].put((conn, packet, {}))
self.assertRaises(ConflictError, app.waitStoreResponses, txn_context, self.assertRaises(ConflictError, app.waitStoreResponses, txn_context,
failing_tryToResolveConflict) failing_tryToResolveConflict)
self.assertTrue(oid not in data_dict) self.assertTrue(oid not in data_dict)
...@@ -344,7 +344,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -344,7 +344,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.nm.createStorage(address=storage_address) app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn) app.store(oid, tid, 'DATA', None, txn)
self.checkAskStoreObject(conn) self.checkAskStoreObject(conn)
txn_context['queue'].put((conn, packet)) txn_context['queue'].put((conn, packet, {}))
app.waitStoreResponses(txn_context, resolving_tryToResolveConflict) app.waitStoreResponses(txn_context, resolving_tryToResolveConflict)
self.assertEqual(txn_context['object_stored_counter_dict'][oid], self.assertEqual(txn_context['object_stored_counter_dict'][oid],
{tid: set([uuid])}) {tid: set([uuid])})
...@@ -481,8 +481,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -481,8 +481,8 @@ class ClientApplicationTests(NeoUnitTestBase):
app.store(oid1, tid, 'DATA', None, txn) app.store(oid1, tid, 'DATA', None, txn)
app.store(oid2, tid, 'DATA', None, txn) app.store(oid2, tid, 'DATA', None, txn)
queue = txn_context['queue'] queue = txn_context['queue']
queue.put((conn2, packet2)) queue.put((conn2, packet2, {}))
queue.put((conn3, packet3)) queue.put((conn3, packet3, {}))
# vote fails as the conflict is not resolved, nothing is sent to storage 3 # vote fails as the conflict is not resolved, nothing is sent to storage 3
self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict) self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict)
# abort must be sent to storage 1 and 2 # abort must be sent to storage 1 and 2
...@@ -600,7 +600,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -600,7 +600,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial = Packets.AnswerObjectUndoSerial({ undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)}) oid0: (tid2, tid0, False)})
undo_serial.setId(2) undo_serial.setId(2)
app._getThreadQueue().put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial, {}))
marker = [] marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
...@@ -641,7 +641,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -641,7 +641,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2) undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1, app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2) tid2)
app._getThreadQueue().put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial, {}))
marker = [] marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
...@@ -667,7 +667,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -667,7 +667,7 @@ class ClientApplicationTests(NeoUnitTestBase):
marker.append((oid, conflict_serial, serial, data, committedData)) marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError raise ConflictError
# The undo # The undo
app._getThreadQueue().put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial, {}))
self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn, self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict) tryToResolveConflict)
# Checking what happened # Checking what happened
...@@ -700,7 +700,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -700,7 +700,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2) undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1, app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2) tid2)
app._getThreadQueue().put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial, {}))
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
raise Exception, 'Test called conflict resolution, but there ' \ raise Exception, 'Test called conflict resolution, but there ' \
......
...@@ -911,7 +911,7 @@ class HandlerSwitcherTests(NeoUnitTestBase): ...@@ -911,7 +911,7 @@ class HandlerSwitcherTests(NeoUnitTestBase):
self.assertFalse(applied) self.assertFalse(applied)
self._checkCurrentHandler(self._handler) self._checkCurrentHandler(self._handler)
call_tracker = [] call_tracker = []
def packetReceived(conn, packet): def packetReceived(conn, packet, kw):
self._handlers.emit(self._makeRequest(2), 0, None) self._handlers.emit(self._makeRequest(2), 0, None)
call_tracker.append(True) call_tracker.append(True)
self._handler.packetReceived = packetReceived self._handler.packetReceived = packetReceived
......
...@@ -34,11 +34,11 @@ class DispatcherTests(NeoTestBase): ...@@ -34,11 +34,11 @@ class DispatcherTests(NeoTestBase):
MARKER = object() MARKER = object()
self.dispatcher.register(conn, 1, queue) self.dispatcher.register(conn, 1, queue)
self.assertTrue(queue.empty()) self.assertTrue(queue.empty())
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER)) self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER, {}))
self.assertFalse(queue.empty()) self.assertFalse(queue.empty())
self.assertEqual(queue.get(block=False), (conn, MARKER)) self.assertEqual(queue.get(block=False), (conn, MARKER, {}))
self.assertTrue(queue.empty()) self.assertTrue(queue.empty())
self.assertFalse(self.dispatcher.dispatch(conn, 2, None)) self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
self.assertEqual(len(self.fake_thread.mockGetNamedCalls('start')), 1) self.assertEqual(len(self.fake_thread.mockGetNamedCalls('start')), 1)
def testUnregister(self): def testUnregister(self):
...@@ -47,7 +47,7 @@ class DispatcherTests(NeoTestBase): ...@@ -47,7 +47,7 @@ class DispatcherTests(NeoTestBase):
self.dispatcher.register(conn, 2, queue) self.dispatcher.register(conn, 2, queue)
self.dispatcher.unregister(conn) self.dispatcher.unregister(conn)
self.assertEqual(len(queue.mockGetNamedCalls('put')), 1) self.assertEqual(len(queue.mockGetNamedCalls('put')), 1)
self.assertFalse(self.dispatcher.dispatch(conn, 2, None)) self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
def testRegistered(self): def testRegistered(self):
conn1 = object() conn1 = object()
...@@ -88,10 +88,10 @@ class DispatcherTests(NeoTestBase): ...@@ -88,10 +88,10 @@ class DispatcherTests(NeoTestBase):
self.assertTrue(self.dispatcher.pending(queue1)) self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2)) self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn1, 1, None) self.dispatcher.dispatch(conn1, 1, None, {})
self.assertTrue(self.dispatcher.pending(queue1)) self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2)) self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn2, 2, None) self.dispatcher.dispatch(conn2, 2, None, {})
self.assertFalse(self.dispatcher.pending(queue1)) self.assertFalse(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2)) self.assertTrue(self.dispatcher.pending(queue2))
...@@ -121,7 +121,7 @@ class DispatcherTests(NeoTestBase): ...@@ -121,7 +121,7 @@ class DispatcherTests(NeoTestBase):
forgotten_queue = self.dispatcher.forget(conn, 1) forgotten_queue = self.dispatcher.forget(conn, 1)
self.assertTrue(queue is forgotten_queue, (queue, forgotten_queue)) self.assertTrue(queue is forgotten_queue, (queue, forgotten_queue))
# A ForgottenPacket must have been put in the queue # A ForgottenPacket must have been put in the queue
queue_conn, packet = queue.get(block=False) queue_conn, packet, kw = queue.get(block=False)
self.assertTrue(isinstance(packet, ForgottenPacket), packet) self.assertTrue(isinstance(packet, ForgottenPacket), packet)
# ...with appropriate packet id # ...with appropriate packet id
self.assertEqual(packet.getId(), 1) self.assertEqual(packet.getId(), 1)
...@@ -130,7 +130,7 @@ class DispatcherTests(NeoTestBase): ...@@ -130,7 +130,7 @@ class DispatcherTests(NeoTestBase):
# If forgotten twice, it must raise a KeyError # If forgotten twice, it must raise a KeyError
self.assertRaises(KeyError, self.dispatcher.forget, conn, 1) self.assertRaises(KeyError, self.dispatcher.forget, conn, 1)
# Event arrives, return value must be True (it was expected) # Event arrives, return value must be True (it was expected)
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER)) self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER, {}))
# ...but must not have reached the queue # ...but must not have reached the queue
self.assertTrue(queue.empty()) self.assertTrue(queue.empty())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment