Commit 36d83981 authored by Vincent Pelletier's avatar Vincent Pelletier

Make possible to call a function on timeout.

Such function is provided when message is queued for send when an answer
is expected, and is called when the answer is not arrived after expiration
of the timeout delay. Depending on the return value of this callback, the
timeout is ignored (True) or passed through (False).

Also, move code to refresh next timeout value to a separate function, for
reusability.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2106 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f5deb386
......@@ -66,10 +66,23 @@ def lockCheckWrapper(func):
return func(self, *args, **kw)
return wrapper
class OnTimeout(object):
"""
Simple helper class for on_timeout parameter used in HandlerSwitcher
class.
"""
def __init__(self, func, *args, **kw):
self.func = func
self.args = args
self.kw = kw
def __call__(self, conn, msg_id):
return self.func(conn, msg_id, *self.args, **self.kw)
class HandlerSwitcher(object):
_next_timeout = None
_next_timeout_msg_id = None
_next_on_timeout = None
def __init__(self, connection, handler):
self._connection = connection
......@@ -88,7 +101,7 @@ class HandlerSwitcher(object):
return self._pending[0][1]
@profiler_decorator
def emit(self, request, timeout):
def emit(self, request, timeout, on_timeout):
# register the request in the current handler
_pending = self._pending
if self._is_handling:
......@@ -107,12 +120,26 @@ class HandlerSwitcher(object):
if next_timeout is None or timeout < next_timeout:
self._next_timeout = timeout
self._next_timeout_msg_id = msg_id
request_dict[msg_id] = (answer_class, timeout)
self._next_on_timeout = on_timeout
request_dict[msg_id] = (answer_class, timeout, on_timeout)
def checkTimeout(self, t):
next_timeout = self._next_timeout
if next_timeout is not None and next_timeout < t:
result = self._next_timeout_msg_id
msg_id = self._next_timeout_msg_id
if self._next_on_timeout is None:
result = msg_id
else:
if self._next_on_timeout(self._connection, msg_id):
# Don't notify that a timeout occured, and forget about
# this answer.
for (request_dict, _) in self._pending:
request_dict.pop(msg_id, None)
self._updateNextTimeout()
result = None
else:
# Notify that a timeout occured
result = msg_id
else:
result = None
return result
......@@ -136,7 +163,7 @@ class HandlerSwitcher(object):
handler.packetReceived(self._connection, packet)
return
# checkout the expected answer class
(klass, timeout) = request_dict.pop(msg_id, (None, None))
(klass, timeout, _) = request_dict.pop(msg_id, (None, None, None))
if klass and isinstance(packet, klass) or packet.isError():
handler.packetReceived(self._connection, packet)
else:
......@@ -151,17 +178,23 @@ class HandlerSwitcher(object):
logging.debug('Apply handler %r on %r', self._pending[0][1],
self._connection)
if timeout == self._next_timeout:
# Find next timeout and its msg_id
timeout_list = []
extend = timeout_list.extend
for (request_dict, handler) in self._pending:
extend(((timeout, msg_id) \
for msg_id, (_, timeout) in request_dict.iteritems()))
if timeout_list:
timeout_list.sort(key=lambda x: x[0])
self._next_timeout, self._next_timeout_msg_id = timeout_list[0]
else:
self._next_timeout, self._next_timeout_msg_id = None, None
self._updateNextTimeout()
def _updateNextTimeout(self):
# Find next timeout and its msg_id
timeout_list = []
extend = timeout_list.extend
for (request_dict, handler) in self._pending:
extend(((timeout, msg_id, on_timeout) \
for msg_id, (_, timeout, on_timeout) in \
request_dict.iteritems()))
if timeout_list:
timeout_list.sort(key=lambda x: x[0])
self._next_timeout, self._next_timeout_msg_id, \
self._next_on_timeout = timeout_list[0]
else:
self._next_timeout, self._next_timeout_msg_id, \
self._next_on_timeout = None, None, None
@profiler_decorator
def setHandler(self, handler):
......@@ -562,7 +595,7 @@ class Connection(BaseConnection):
@profiler_decorator
@not_closed
def ask(self, packet, timeout=CRITICAL_TIMEOUT):
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None):
"""
Send a packet with a new ID and register the expectation of an answer
"""
......@@ -573,7 +606,7 @@ class Connection(BaseConnection):
# If there is no pending request, initialise timeout values.
if not self._handlers.isPending():
self._timeout.update(t, force=True)
self._handlers.emit(packet, t + timeout)
self._handlers.emit(packet, t + timeout, on_timeout)
return msg_id
@not_closed
......@@ -682,7 +715,7 @@ class MTClientConnection(ClientConnection):
self.unlock()
@profiler_decorator
def ask(self, packet, timeout=CRITICAL_TIMEOUT):
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None):
self.lock()
try:
# XXX: Here, we duplicate Connection.ask because we need to call
......@@ -696,7 +729,7 @@ class MTClientConnection(ClientConnection):
# If there is no pending request, initialise timeout values.
if not self._handlers.isPending():
self._timeout.update(t)
self._handlers.emit(packet, t + timeout)
self._handlers.emit(packet, t + timeout, on_timeout)
return msg_id
finally:
self.unlock()
......
......@@ -19,7 +19,7 @@ from time import time
from mock import Mock
from neo.connection import ListeningConnection, Connection, \
ClientConnection, ServerConnection, MTClientConnection, \
HandlerSwitcher, Timeout, PING_DELAY, PING_TIMEOUT
HandlerSwitcher, Timeout, PING_DELAY, PING_TIMEOUT, OnTimeout
from neo.connector import getConnectorHandler, registerConnectorHandler
from neo.tests import DoNothingConnector
from neo.connector import ConnectorException, ConnectorTryAgainException, \
......@@ -854,7 +854,7 @@ class HandlerSwitcherTests(NeoTestBase):
# First case, emit is called outside of a handler
self.assertFalse(self._handlers.isPending())
request = self._makeRequest(1)
self._handlers.emit(request, 0)
self._handlers.emit(request, 0, None)
self.assertTrue(self._handlers.isPending())
# Second case, emit is called from inside a handler with a pending
# handler change.
......@@ -863,7 +863,7 @@ class HandlerSwitcherTests(NeoTestBase):
self._checkCurrentHandler(self._handler)
call_tracker = []
def packetReceived(conn, packet):
self._handlers.emit(self._makeRequest(2), 0)
self._handlers.emit(self._makeRequest(2), 0, None)
call_tracker.append(True)
self._handler.packetReceived = packetReceived
self._handlers.handle(self._makeAnswer(1))
......@@ -883,7 +883,7 @@ class HandlerSwitcherTests(NeoTestBase):
self._checkPacketReceived(self._handler, notif1)
# emit a request and delay an handler
request = self._makeRequest(2)
self._handlers.emit(request, 0)
self._handlers.emit(request, 0, None)
handler = self._makeHandler()
self._handlers.setHandler(handler)
# next notification fall into the current handler
......@@ -900,7 +900,7 @@ class HandlerSwitcherTests(NeoTestBase):
def testHandleAnswer1(self):
# handle with current handler
request = self._makeRequest(1)
self._handlers.emit(request, 0)
self._handlers.emit(request, 0, None)
answer = self._makeAnswer(1)
self._handlers.handle(answer)
self._checkPacketReceived(self._handler, answer)
......@@ -908,7 +908,7 @@ class HandlerSwitcherTests(NeoTestBase):
def testHandleAnswer2(self):
# handle with blocking handler
request = self._makeRequest(1)
self._handlers.emit(request, 0)
self._handlers.emit(request, 0, None)
handler = self._makeHandler()
self._handlers.setHandler(handler)
answer = self._makeAnswer(1)
......@@ -928,11 +928,11 @@ class HandlerSwitcherTests(NeoTestBase):
h2 = self._makeHandler()
h3 = self._makeHandler()
# emit all requests and setHandleres
self._handlers.emit(r1, 0)
self._handlers.emit(r1, 0, None)
self._handlers.setHandler(h1)
self._handlers.emit(r2, 0)
self._handlers.emit(r2, 0, None)
self._handlers.setHandler(h2)
self._handlers.emit(r3, 0)
self._handlers.emit(r3, 0, None)
self._handlers.setHandler(h3)
self._checkCurrentHandler(self._handler)
self.assertTrue(self._handlers.isPending())
......@@ -954,9 +954,9 @@ class HandlerSwitcherTests(NeoTestBase):
a3 = self._makeAnswer(3)
h = self._makeHandler()
# emit all requests
self._handlers.emit(r1, 0)
self._handlers.emit(r2, 0)
self._handlers.emit(r3, 0)
self._handlers.emit(r1, 0, None)
self._handlers.emit(r2, 0, None)
self._handlers.emit(r3, 0, None)
self._handlers.setHandler(h)
# process answers
self._handlers.handle(a1)
......@@ -973,9 +973,9 @@ class HandlerSwitcherTests(NeoTestBase):
a2 = self._makeAnswer(2)
h = self._makeHandler()
# emit requests aroung state setHandler
self._handlers.emit(r1, 0)
self._handlers.emit(r1, 0, None)
self._handlers.setHandler(h)
self._handlers.emit(r2, 0)
self._handlers.emit(r2, 0, None)
# process answer for next state
self._handlers.handle(a2)
self.checkAborted(self._connection)
......@@ -992,18 +992,29 @@ class HandlerSwitcherTests(NeoTestBase):
msg_id_1 = 1
msg_id_2 = 2
msg_id_3 = 3
msg_id_4 = 4
r1 = self._makeRequest(msg_id_1)
a1 = self._makeAnswer(msg_id_1)
r2 = self._makeRequest(msg_id_2)
a2 = self._makeAnswer(msg_id_2)
r3 = self._makeRequest(msg_id_3)
r4 = self._makeRequest(msg_id_4)
msg_1_time = now + 5
msg_2_time = msg_1_time + 5
msg_3_time = msg_2_time + 5
msg_4_time = msg_3_time + 5
markers = []
def msg_3_on_timeout(conn, msg_id):
markers.append((3, conn, msg_id))
return True
def msg_4_on_timeout(conn, msg_id):
markers.append((4, conn, msg_id))
return False
# Emit r3 before all other, to test that it's time parameter value
# which is used, not the registration order.
self._handlers.emit(r3, msg_3_time)
self._handlers.emit(r1, msg_1_time)
self._handlers.emit(r2, msg_2_time)
self._handlers.emit(r3, msg_3_time, OnTimeout(msg_3_on_timeout))
self._handlers.emit(r1, msg_1_time, None)
self._handlers.emit(r2, msg_2_time, None)
# No timeout before msg_1_time
self.assertEqual(self._handlers.checkTimeout(now), None)
# Timeout for msg_1 after msg_1_time
......@@ -1014,6 +1025,28 @@ class HandlerSwitcherTests(NeoTestBase):
self.assertEqual(self._handlers.checkTimeout(msg_1_time + 0.5), None)
# Next timeout is after msg_2_time
self.assertEqual(self._handlers.checkTimeout(msg_2_time + 0.5), msg_id_2)
self._handlers.handle(a2)
# Sanity check
self.assertEqual(self._handlers.checkTimeout(msg_2_time + 0.5), None)
# msg_3 timeout will fire msg_3_on_timeout callback, which causes the
# timeout to be ignored (it returns True)
self.assertEqual(self._handlers.checkTimeout(msg_3_time + 0.5), None)
# ...check that callback actually fired
self.assertEqual(len(markers), 1)
# ...with expected parameters
self.assertEqual(markers[0], (3, self._connection, msg_id_3))
# answer to msg_3 must not be expected anymore (and it was the last
# expected message)
self.assertFalse(bool(self._handlers.isPending()))
del markers[:]
self._handlers.emit(r4, msg_4_time, OnTimeout(msg_4_on_timeout))
# msg_4 timeout will fire msg_4_on_timeout callback, which lets the
# timeout be detected (it returns False)
self.assertEqual(self._handlers.checkTimeout(msg_4_time + 0.5), msg_id_4)
# ...check that callback actually fired
self.assertEqual(len(markers), 1)
# ...with expected parameters
self.assertEqual(markers[0], (4, self._connection, msg_id_4))
class TestTimeout(NeoTestBase):
def setUp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment