Commit 36d83981 authored by Vincent Pelletier's avatar Vincent Pelletier

Make possible to call a function on timeout.

Such function is provided when message is queued for send when an answer
is expected, and is called when the answer is not arrived after expiration
of the timeout delay. Depending on the return value of this callback, the
timeout is ignored (True) or passed through (False).

Also, move code to refresh next timeout value to a separate function, for
reusability.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2106 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f5deb386
...@@ -66,10 +66,23 @@ def lockCheckWrapper(func): ...@@ -66,10 +66,23 @@ def lockCheckWrapper(func):
return func(self, *args, **kw) return func(self, *args, **kw)
return wrapper return wrapper
class OnTimeout(object):
"""
Simple helper class for on_timeout parameter used in HandlerSwitcher
class.
"""
def __init__(self, func, *args, **kw):
self.func = func
self.args = args
self.kw = kw
def __call__(self, conn, msg_id):
return self.func(conn, msg_id, *self.args, **self.kw)
class HandlerSwitcher(object): class HandlerSwitcher(object):
_next_timeout = None _next_timeout = None
_next_timeout_msg_id = None _next_timeout_msg_id = None
_next_on_timeout = None
def __init__(self, connection, handler): def __init__(self, connection, handler):
self._connection = connection self._connection = connection
...@@ -88,7 +101,7 @@ class HandlerSwitcher(object): ...@@ -88,7 +101,7 @@ class HandlerSwitcher(object):
return self._pending[0][1] return self._pending[0][1]
@profiler_decorator @profiler_decorator
def emit(self, request, timeout): def emit(self, request, timeout, on_timeout):
# register the request in the current handler # register the request in the current handler
_pending = self._pending _pending = self._pending
if self._is_handling: if self._is_handling:
...@@ -107,12 +120,26 @@ class HandlerSwitcher(object): ...@@ -107,12 +120,26 @@ class HandlerSwitcher(object):
if next_timeout is None or timeout < next_timeout: if next_timeout is None or timeout < next_timeout:
self._next_timeout = timeout self._next_timeout = timeout
self._next_timeout_msg_id = msg_id self._next_timeout_msg_id = msg_id
request_dict[msg_id] = (answer_class, timeout) self._next_on_timeout = on_timeout
request_dict[msg_id] = (answer_class, timeout, on_timeout)
def checkTimeout(self, t): def checkTimeout(self, t):
next_timeout = self._next_timeout next_timeout = self._next_timeout
if next_timeout is not None and next_timeout < t: if next_timeout is not None and next_timeout < t:
result = self._next_timeout_msg_id msg_id = self._next_timeout_msg_id
if self._next_on_timeout is None:
result = msg_id
else:
if self._next_on_timeout(self._connection, msg_id):
# Don't notify that a timeout occured, and forget about
# this answer.
for (request_dict, _) in self._pending:
request_dict.pop(msg_id, None)
self._updateNextTimeout()
result = None
else:
# Notify that a timeout occured
result = msg_id
else: else:
result = None result = None
return result return result
...@@ -136,7 +163,7 @@ class HandlerSwitcher(object): ...@@ -136,7 +163,7 @@ class HandlerSwitcher(object):
handler.packetReceived(self._connection, packet) handler.packetReceived(self._connection, packet)
return return
# checkout the expected answer class # checkout the expected answer class
(klass, timeout) = request_dict.pop(msg_id, (None, None)) (klass, timeout, _) = request_dict.pop(msg_id, (None, None, None))
if klass and isinstance(packet, klass) or packet.isError(): if klass and isinstance(packet, klass) or packet.isError():
handler.packetReceived(self._connection, packet) handler.packetReceived(self._connection, packet)
else: else:
...@@ -151,17 +178,23 @@ class HandlerSwitcher(object): ...@@ -151,17 +178,23 @@ class HandlerSwitcher(object):
logging.debug('Apply handler %r on %r', self._pending[0][1], logging.debug('Apply handler %r on %r', self._pending[0][1],
self._connection) self._connection)
if timeout == self._next_timeout: if timeout == self._next_timeout:
self._updateNextTimeout()
def _updateNextTimeout(self):
# Find next timeout and its msg_id # Find next timeout and its msg_id
timeout_list = [] timeout_list = []
extend = timeout_list.extend extend = timeout_list.extend
for (request_dict, handler) in self._pending: for (request_dict, handler) in self._pending:
extend(((timeout, msg_id) \ extend(((timeout, msg_id, on_timeout) \
for msg_id, (_, timeout) in request_dict.iteritems())) for msg_id, (_, timeout, on_timeout) in \
request_dict.iteritems()))
if timeout_list: if timeout_list:
timeout_list.sort(key=lambda x: x[0]) timeout_list.sort(key=lambda x: x[0])
self._next_timeout, self._next_timeout_msg_id = timeout_list[0] self._next_timeout, self._next_timeout_msg_id, \
self._next_on_timeout = timeout_list[0]
else: else:
self._next_timeout, self._next_timeout_msg_id = None, None self._next_timeout, self._next_timeout_msg_id, \
self._next_on_timeout = None, None, None
@profiler_decorator @profiler_decorator
def setHandler(self, handler): def setHandler(self, handler):
...@@ -562,7 +595,7 @@ class Connection(BaseConnection): ...@@ -562,7 +595,7 @@ class Connection(BaseConnection):
@profiler_decorator @profiler_decorator
@not_closed @not_closed
def ask(self, packet, timeout=CRITICAL_TIMEOUT): def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None):
""" """
Send a packet with a new ID and register the expectation of an answer Send a packet with a new ID and register the expectation of an answer
""" """
...@@ -573,7 +606,7 @@ class Connection(BaseConnection): ...@@ -573,7 +606,7 @@ class Connection(BaseConnection):
# If there is no pending request, initialise timeout values. # If there is no pending request, initialise timeout values.
if not self._handlers.isPending(): if not self._handlers.isPending():
self._timeout.update(t, force=True) self._timeout.update(t, force=True)
self._handlers.emit(packet, t + timeout) self._handlers.emit(packet, t + timeout, on_timeout)
return msg_id return msg_id
@not_closed @not_closed
...@@ -682,7 +715,7 @@ class MTClientConnection(ClientConnection): ...@@ -682,7 +715,7 @@ class MTClientConnection(ClientConnection):
self.unlock() self.unlock()
@profiler_decorator @profiler_decorator
def ask(self, packet, timeout=CRITICAL_TIMEOUT): def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None):
self.lock() self.lock()
try: try:
# XXX: Here, we duplicate Connection.ask because we need to call # XXX: Here, we duplicate Connection.ask because we need to call
...@@ -696,7 +729,7 @@ class MTClientConnection(ClientConnection): ...@@ -696,7 +729,7 @@ class MTClientConnection(ClientConnection):
# If there is no pending request, initialise timeout values. # If there is no pending request, initialise timeout values.
if not self._handlers.isPending(): if not self._handlers.isPending():
self._timeout.update(t) self._timeout.update(t)
self._handlers.emit(packet, t + timeout) self._handlers.emit(packet, t + timeout, on_timeout)
return msg_id return msg_id
finally: finally:
self.unlock() self.unlock()
......
...@@ -19,7 +19,7 @@ from time import time ...@@ -19,7 +19,7 @@ from time import time
from mock import Mock from mock import Mock
from neo.connection import ListeningConnection, Connection, \ from neo.connection import ListeningConnection, Connection, \
ClientConnection, ServerConnection, MTClientConnection, \ ClientConnection, ServerConnection, MTClientConnection, \
HandlerSwitcher, Timeout, PING_DELAY, PING_TIMEOUT HandlerSwitcher, Timeout, PING_DELAY, PING_TIMEOUT, OnTimeout
from neo.connector import getConnectorHandler, registerConnectorHandler from neo.connector import getConnectorHandler, registerConnectorHandler
from neo.tests import DoNothingConnector from neo.tests import DoNothingConnector
from neo.connector import ConnectorException, ConnectorTryAgainException, \ from neo.connector import ConnectorException, ConnectorTryAgainException, \
...@@ -854,7 +854,7 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -854,7 +854,7 @@ class HandlerSwitcherTests(NeoTestBase):
# First case, emit is called outside of a handler # First case, emit is called outside of a handler
self.assertFalse(self._handlers.isPending()) self.assertFalse(self._handlers.isPending())
request = self._makeRequest(1) request = self._makeRequest(1)
self._handlers.emit(request, 0) self._handlers.emit(request, 0, None)
self.assertTrue(self._handlers.isPending()) self.assertTrue(self._handlers.isPending())
# Second case, emit is called from inside a handler with a pending # Second case, emit is called from inside a handler with a pending
# handler change. # handler change.
...@@ -863,7 +863,7 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -863,7 +863,7 @@ class HandlerSwitcherTests(NeoTestBase):
self._checkCurrentHandler(self._handler) self._checkCurrentHandler(self._handler)
call_tracker = [] call_tracker = []
def packetReceived(conn, packet): def packetReceived(conn, packet):
self._handlers.emit(self._makeRequest(2), 0) self._handlers.emit(self._makeRequest(2), 0, None)
call_tracker.append(True) call_tracker.append(True)
self._handler.packetReceived = packetReceived self._handler.packetReceived = packetReceived
self._handlers.handle(self._makeAnswer(1)) self._handlers.handle(self._makeAnswer(1))
...@@ -883,7 +883,7 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -883,7 +883,7 @@ class HandlerSwitcherTests(NeoTestBase):
self._checkPacketReceived(self._handler, notif1) self._checkPacketReceived(self._handler, notif1)
# emit a request and delay an handler # emit a request and delay an handler
request = self._makeRequest(2) request = self._makeRequest(2)
self._handlers.emit(request, 0) self._handlers.emit(request, 0, None)
handler = self._makeHandler() handler = self._makeHandler()
self._handlers.setHandler(handler) self._handlers.setHandler(handler)
# next notification fall into the current handler # next notification fall into the current handler
...@@ -900,7 +900,7 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -900,7 +900,7 @@ class HandlerSwitcherTests(NeoTestBase):
def testHandleAnswer1(self): def testHandleAnswer1(self):
# handle with current handler # handle with current handler
request = self._makeRequest(1) request = self._makeRequest(1)
self._handlers.emit(request, 0) self._handlers.emit(request, 0, None)
answer = self._makeAnswer(1) answer = self._makeAnswer(1)
self._handlers.handle(answer) self._handlers.handle(answer)
self._checkPacketReceived(self._handler, answer) self._checkPacketReceived(self._handler, answer)
...@@ -908,7 +908,7 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -908,7 +908,7 @@ class HandlerSwitcherTests(NeoTestBase):
def testHandleAnswer2(self): def testHandleAnswer2(self):
# handle with blocking handler # handle with blocking handler
request = self._makeRequest(1) request = self._makeRequest(1)
self._handlers.emit(request, 0) self._handlers.emit(request, 0, None)
handler = self._makeHandler() handler = self._makeHandler()
self._handlers.setHandler(handler) self._handlers.setHandler(handler)
answer = self._makeAnswer(1) answer = self._makeAnswer(1)
...@@ -928,11 +928,11 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -928,11 +928,11 @@ class HandlerSwitcherTests(NeoTestBase):
h2 = self._makeHandler() h2 = self._makeHandler()
h3 = self._makeHandler() h3 = self._makeHandler()
# emit all requests and setHandleres # emit all requests and setHandleres
self._handlers.emit(r1, 0) self._handlers.emit(r1, 0, None)
self._handlers.setHandler(h1) self._handlers.setHandler(h1)
self._handlers.emit(r2, 0) self._handlers.emit(r2, 0, None)
self._handlers.setHandler(h2) self._handlers.setHandler(h2)
self._handlers.emit(r3, 0) self._handlers.emit(r3, 0, None)
self._handlers.setHandler(h3) self._handlers.setHandler(h3)
self._checkCurrentHandler(self._handler) self._checkCurrentHandler(self._handler)
self.assertTrue(self._handlers.isPending()) self.assertTrue(self._handlers.isPending())
...@@ -954,9 +954,9 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -954,9 +954,9 @@ class HandlerSwitcherTests(NeoTestBase):
a3 = self._makeAnswer(3) a3 = self._makeAnswer(3)
h = self._makeHandler() h = self._makeHandler()
# emit all requests # emit all requests
self._handlers.emit(r1, 0) self._handlers.emit(r1, 0, None)
self._handlers.emit(r2, 0) self._handlers.emit(r2, 0, None)
self._handlers.emit(r3, 0) self._handlers.emit(r3, 0, None)
self._handlers.setHandler(h) self._handlers.setHandler(h)
# process answers # process answers
self._handlers.handle(a1) self._handlers.handle(a1)
...@@ -973,9 +973,9 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -973,9 +973,9 @@ class HandlerSwitcherTests(NeoTestBase):
a2 = self._makeAnswer(2) a2 = self._makeAnswer(2)
h = self._makeHandler() h = self._makeHandler()
# emit requests aroung state setHandler # emit requests aroung state setHandler
self._handlers.emit(r1, 0) self._handlers.emit(r1, 0, None)
self._handlers.setHandler(h) self._handlers.setHandler(h)
self._handlers.emit(r2, 0) self._handlers.emit(r2, 0, None)
# process answer for next state # process answer for next state
self._handlers.handle(a2) self._handlers.handle(a2)
self.checkAborted(self._connection) self.checkAborted(self._connection)
...@@ -992,18 +992,29 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -992,18 +992,29 @@ class HandlerSwitcherTests(NeoTestBase):
msg_id_1 = 1 msg_id_1 = 1
msg_id_2 = 2 msg_id_2 = 2
msg_id_3 = 3 msg_id_3 = 3
msg_id_4 = 4
r1 = self._makeRequest(msg_id_1) r1 = self._makeRequest(msg_id_1)
a1 = self._makeAnswer(msg_id_1) a1 = self._makeAnswer(msg_id_1)
r2 = self._makeRequest(msg_id_2) r2 = self._makeRequest(msg_id_2)
a2 = self._makeAnswer(msg_id_2)
r3 = self._makeRequest(msg_id_3) r3 = self._makeRequest(msg_id_3)
r4 = self._makeRequest(msg_id_4)
msg_1_time = now + 5 msg_1_time = now + 5
msg_2_time = msg_1_time + 5 msg_2_time = msg_1_time + 5
msg_3_time = msg_2_time + 5 msg_3_time = msg_2_time + 5
msg_4_time = msg_3_time + 5
markers = []
def msg_3_on_timeout(conn, msg_id):
markers.append((3, conn, msg_id))
return True
def msg_4_on_timeout(conn, msg_id):
markers.append((4, conn, msg_id))
return False
# Emit r3 before all other, to test that it's time parameter value # Emit r3 before all other, to test that it's time parameter value
# which is used, not the registration order. # which is used, not the registration order.
self._handlers.emit(r3, msg_3_time) self._handlers.emit(r3, msg_3_time, OnTimeout(msg_3_on_timeout))
self._handlers.emit(r1, msg_1_time) self._handlers.emit(r1, msg_1_time, None)
self._handlers.emit(r2, msg_2_time) self._handlers.emit(r2, msg_2_time, None)
# No timeout before msg_1_time # No timeout before msg_1_time
self.assertEqual(self._handlers.checkTimeout(now), None) self.assertEqual(self._handlers.checkTimeout(now), None)
# Timeout for msg_1 after msg_1_time # Timeout for msg_1 after msg_1_time
...@@ -1014,6 +1025,28 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -1014,6 +1025,28 @@ class HandlerSwitcherTests(NeoTestBase):
self.assertEqual(self._handlers.checkTimeout(msg_1_time + 0.5), None) self.assertEqual(self._handlers.checkTimeout(msg_1_time + 0.5), None)
# Next timeout is after msg_2_time # Next timeout is after msg_2_time
self.assertEqual(self._handlers.checkTimeout(msg_2_time + 0.5), msg_id_2) self.assertEqual(self._handlers.checkTimeout(msg_2_time + 0.5), msg_id_2)
self._handlers.handle(a2)
# Sanity check
self.assertEqual(self._handlers.checkTimeout(msg_2_time + 0.5), None)
# msg_3 timeout will fire msg_3_on_timeout callback, which causes the
# timeout to be ignored (it returns True)
self.assertEqual(self._handlers.checkTimeout(msg_3_time + 0.5), None)
# ...check that callback actually fired
self.assertEqual(len(markers), 1)
# ...with expected parameters
self.assertEqual(markers[0], (3, self._connection, msg_id_3))
# answer to msg_3 must not be expected anymore (and it was the last
# expected message)
self.assertFalse(bool(self._handlers.isPending()))
del markers[:]
self._handlers.emit(r4, msg_4_time, OnTimeout(msg_4_on_timeout))
# msg_4 timeout will fire msg_4_on_timeout callback, which lets the
# timeout be detected (it returns False)
self.assertEqual(self._handlers.checkTimeout(msg_4_time + 0.5), msg_id_4)
# ...check that callback actually fired
self.assertEqual(len(markers), 1)
# ...with expected parameters
self.assertEqual(markers[0], (4, self._connection, msg_id_4))
class TestTimeout(NeoTestBase): class TestTimeout(NeoTestBase):
def setUp(self): def setUp(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment