Commit 61dbc7b6 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Per-connection timeout support (instead of per-packet).

- Rename IdleEvent to IdleTimeout from event.py to connection.py
- Move connection-related logic in Connection itself and keep only
time-related logic in IdleTimeout
- Clarify differences between hard and soft timeouts.
- Remove (unused) 'additional_timeout' from ask()
- Remove (now useless) event_dict attribute from Connection.
- Remove external ping support, as the answer can not be handled at
application level.
- Expectation after a new incoming connection moved from Handler to
Connection.
- Fix (and clean) related tests.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1895 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 63f11ead
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from time import time
from neo import logging from neo import logging
from neo.locking import RLock from neo.locking import RLock
from neo.protocol import PacketMalformedError, Packets from neo.protocol import PacketMalformedError, Packets
from neo.event import IdleEvent
from neo.connector import ConnectorException, ConnectorTryAgainException, \ from neo.connector import ConnectorException, ConnectorTryAgainException, \
ConnectorInProgressException, ConnectorConnectionRefusedException, \ ConnectorInProgressException, ConnectorConnectionRefusedException, \
ConnectorConnectionClosedException ConnectorConnectionClosedException
...@@ -28,6 +29,11 @@ from neo.logger import PACKET_LOGGER ...@@ -28,6 +29,11 @@ from neo.logger import PACKET_LOGGER
from neo import attributeTracker from neo import attributeTracker
PING_DELAY = 5
PING_TIMEOUT = 5
INCOMING_TIMEOUT = 10
CRITICAL_TIMEOUT = 30
APPLY_HANDLER = object() APPLY_HANDLER = object()
def not_closed(func): def not_closed(func):
...@@ -120,6 +126,34 @@ class HandlerSwitcher(object): ...@@ -120,6 +126,34 @@ class HandlerSwitcher(object):
self._pending.append([{}, handler]) self._pending.append([{}, handler])
class Timeout(object):
""" Keep track of current timeouts """
def __init__(self):
self._ping_time = None
self._critical_time = None
def update(self, t, timeout=CRITICAL_TIMEOUT):
""" Update the new critical time """
self._ping_time = t + PING_TIMEOUT
critical_time = self._ping_time + timeout
self._critical_time = max(critical_time, self._critical_time)
def refresh(self, t):
""" Refresh timeout after something received """
self._ping_time = t + PING_DELAY
def softExpired(self, t):
""" Indicate if the soft timeout (ping delay) is reached """
# hard timeout takes precedences
return self._ping_time < t < self._critical_time
def hardExpired(self, t):
""" Indicate if hard (or pong) timeout is reached """
# should be called if softExpired if False
return self._critical_time < t or self._ping_time < t
class BaseConnection(object): class BaseConnection(object):
"""A base connection.""" """A base connection."""
...@@ -129,8 +163,20 @@ class BaseConnection(object): ...@@ -129,8 +163,20 @@ class BaseConnection(object):
self.connector = connector self.connector = connector
self.addr = addr self.addr = addr
self._handlers = HandlerSwitcher(self, handler) self._handlers = HandlerSwitcher(self, handler)
self._timeout = Timeout()
event_manager.register(self) event_manager.register(self)
def checkTimeout(self, t):
if self._handlers.isPending():
if self._timeout.softExpired(t):
self._timeout.refresh(t)
self.ping()
elif self._timeout.hardExpired(t):
# critical time reach or pong not received, abort
logging.info('timeout with %s:%d', *(self.getAddress()))
self.close()
self.getHandler().timeoutExpired(self)
def lock(self): def lock(self):
return 1 return 1
...@@ -215,6 +261,8 @@ class ListeningConnection(BaseConnection): ...@@ -215,6 +261,8 @@ class ListeningConnection(BaseConnection):
handler = self.getHandler() handler = self.getHandler()
new_conn = ServerConnection(self.getEventManager(), handler, new_conn = ServerConnection(self.getEventManager(), handler,
connector=new_s, addr=addr) connector=new_s, addr=addr)
# A request for a node identification should arrive.
self._timeout.update(time(), timeout=INCOMING_TIMEOUT)
handler.connectionAccepted(new_conn) handler.connectionAccepted(new_conn)
except ConnectorTryAgainException: except ConnectorTryAgainException:
pass pass
...@@ -236,7 +284,6 @@ class Connection(BaseConnection): ...@@ -236,7 +284,6 @@ class Connection(BaseConnection):
self.write_buf = [] self.write_buf = []
self.cur_id = 0 self.cur_id = 0
self.peer_id = 0 self.peer_id = 0
self.event_dict = {}
self.aborted = False self.aborted = False
self.uuid = None self.uuid = None
self._queue = [] self._queue = []
...@@ -271,12 +318,9 @@ class Connection(BaseConnection): ...@@ -271,12 +318,9 @@ class Connection(BaseConnection):
logging.debug('closing a connector for %s (%s:%d)', logging.debug('closing a connector for %s (%s:%d)',
dump(self.uuid), *(self.addr)) dump(self.uuid), *(self.addr))
BaseConnection.close(self) BaseConnection.close(self)
for event in self.event_dict.itervalues():
self.em.removeIdleEvent(event)
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
self.event_dict.clear()
del self.write_buf[:] del self.write_buf[:]
del self.read_buf[:] del self.read_buf[:]
self._handlers.clear() self._handlers.clear()
...@@ -320,24 +364,14 @@ class Connection(BaseConnection): ...@@ -320,24 +364,14 @@ class Connection(BaseConnection):
except PacketMalformedError, msg: except PacketMalformedError, msg:
self.getHandler()._packetMalformed(self, msg) self.getHandler()._packetMalformed(self, msg)
return return
self._timeout.refresh(time())
msg = msg[len(packet):] msg = msg[len(packet):]
packet_type = packet.getType() packet_type = packet.getType()
# Remove idle events, if appropriate packets were received.
for msg_id in (None, packet.getId()):
event = self.event_dict.pop(msg_id, None)
if event is not None:
if packet_type == Packets.Pong:
self.em.refreshIdleEvent(event)
self.event_dict[msg_id] = event
else:
self.em.removeIdleEvent(event)
if packet_type == Packets.Ping: if packet_type == Packets.Ping:
# Send a pong notification # Send a pong notification
self.answer(Packets.Pong(), packet.getId()) self.answer(Packets.Pong(), packet.getId())
elif packet_type != Packets.Pong: elif packet_type != Packets.Pong:
# Skip PONG packets, its only purpose is to drop IdleEvent # Skip PONG packets, its only purpose is refresh the timeout
# generated upong ping. # generated upong ping.
self._queue.append(packet) self._queue.append(packet)
self.read_buf = [msg] self.read_buf = [msg]
...@@ -434,33 +468,6 @@ class Connection(BaseConnection): ...@@ -434,33 +468,6 @@ class Connection(BaseConnection):
# enable polling for writing. # enable polling for writing.
self.em.addWriter(self) self.em.addWriter(self)
def expectMessage(self, msg_id=None, timeout=5, additional_timeout=30):
"""Expect a message for a reply to a given message ID or any message.
The purpose of this method is to define how much amount of time is
acceptable to wait for a message, thus to detect a down or broken
peer. This is important, because one error may halt a whole cluster
otherwise. Although TCP defines a keep-alive feature, the timeout
is too long generally, and it does not detect a certain type of reply,
thus it is better to probe problems at the application level.
The message ID specifies what ID is expected. Usually, this should
be identical with an ID for a request message. If it is None, any
message is acceptable, so it can be used to check idle time.
The timeout is the amount of time to wait until keep-alive messages start.
Once the timeout is expired, the connection starts to ping the peer.
The additional timeout defines the amount of time after the timeout
to invoke a timeoutExpired callback. If it is zero, no ping is sent, and
the callback is executed immediately."""
if self.connector is None:
return
event = IdleEvent(self, msg_id, timeout, additional_timeout)
self.event_dict[msg_id] = event
self.em.addIdleEvent(event)
@not_closed @not_closed
def notify(self, packet): def notify(self, packet):
""" Then a packet with a new ID """ """ Then a packet with a new ID """
...@@ -470,15 +477,15 @@ class Connection(BaseConnection): ...@@ -470,15 +477,15 @@ class Connection(BaseConnection):
return msg_id return msg_id
@not_closed @not_closed
def ask(self, packet, timeout=5, additional_timeout=30): def ask(self, packet, timeout=CRITICAL_TIMEOUT):
""" """
Send a packet with a new ID and register the expectation of an answer Send a packet with a new ID and register the expectation of an answer
""" """
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
self.expectMessage(msg_id, timeout=timeout,
additional_timeout=additional_timeout)
self._addPacket(packet) self._addPacket(packet)
if not self._handlers.isPending():
self._timeout.update(time(), timeout=timeout)
self._handlers.emit(packet) self._handlers.emit(packet)
return msg_id return msg_id
...@@ -491,13 +498,10 @@ class Connection(BaseConnection): ...@@ -491,13 +498,10 @@ class Connection(BaseConnection):
assert packet.isResponse(), packet assert packet.isResponse(), packet
self._addPacket(packet) self._addPacket(packet)
def ping(self, timeout=5, msg_id=None): @not_closed
""" Send a ping and expect to receive a pong notification """ def ping(self):
packet = Packets.Ping() packet = Packets.Ping()
if msg_id is None: packet.setId(self._getNextId())
msg_id = self._getNextId()
self.expectMessage(msg_id, timeout, 0)
packet.setId(msg_id)
self._addPacket(packet) self._addPacket(packet)
...@@ -582,21 +586,18 @@ class MTClientConnection(ClientConnection): ...@@ -582,21 +586,18 @@ class MTClientConnection(ClientConnection):
def analyse(self, *args, **kw): def analyse(self, *args, **kw):
return super(MTClientConnection, self).analyse(*args, **kw) return super(MTClientConnection, self).analyse(*args, **kw)
@lockCheckWrapper
def expectMessage(self, *args, **kw):
return super(MTClientConnection, self).expectMessage(*args, **kw)
@lockCheckWrapper @lockCheckWrapper
def notify(self, *args, **kw): def notify(self, *args, **kw):
return super(MTClientConnection, self).notify(*args, **kw) return super(MTClientConnection, self).notify(*args, **kw)
@lockCheckWrapper @lockCheckWrapper
def ask(self, queue, packet, timeout=5, additional_timeout=30): def ask(self, queue, packet, timeout=CRITICAL_TIMEOUT):
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
self.dispatcher.register(self, msg_id, queue) self.dispatcher.register(self, msg_id, queue)
self.expectMessage(msg_id)
self._addPacket(packet) self._addPacket(packet)
if not self._handlers.isPending():
self._timeout.update(time(), timeout=timeout)
self._handlers.emit(packet) self._handlers.emit(packet)
return msg_id return msg_id
...@@ -604,6 +605,10 @@ class MTClientConnection(ClientConnection): ...@@ -604,6 +605,10 @@ class MTClientConnection(ClientConnection):
def answer(self, *args, **kw): def answer(self, *args, **kw):
return super(MTClientConnection, self).answer(*args, **kw) return super(MTClientConnection, self).answer(*args, **kw)
@lockCheckWrapper
def checkTimeout(self, *args, **kw):
return super(MTClientConnection, self).checkTimeout(*args, **kw)
def close(self): def close(self):
self.lock() self.lock()
try: try:
...@@ -644,10 +649,6 @@ class MTServerConnection(ServerConnection): ...@@ -644,10 +649,6 @@ class MTServerConnection(ServerConnection):
def analyse(self, *args, **kw): def analyse(self, *args, **kw):
return super(MTServerConnection, self).analyse(*args, **kw) return super(MTServerConnection, self).analyse(*args, **kw)
@lockCheckWrapper
def expectMessage(self, *args, **kw):
return super(MTServerConnection, self).expectMessage(*args, **kw)
@lockCheckWrapper @lockCheckWrapper
def notify(self, *args, **kw): def notify(self, *args, **kw):
return super(MTServerConnection, self).notify(*args, **kw) return super(MTServerConnection, self).notify(*args, **kw)
...@@ -660,3 +661,7 @@ class MTServerConnection(ServerConnection): ...@@ -660,3 +661,7 @@ class MTServerConnection(ServerConnection):
def answer(self, *args, **kw): def answer(self, *args, **kw):
return super(MTServerConnection, self).answer(*args, **kw) return super(MTServerConnection, self).answer(*args, **kw)
@lockCheckWrapper
def checkTimeout(self, *args, **kw):
return super(MTServerConnection, self).checkTimeout(*args, **kw)
...@@ -15,74 +15,10 @@ ...@@ -15,74 +15,10 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from neo import logging
from time import time from time import time
from neo.epoll import Epoll from neo.epoll import Epoll
PING_DELAY = 5
PING_TIMEOUT = 5
class IdleEvent(object):
"""
This class represents an event called when a connection is waiting for
a message too long.
"""
def __init__(self, conn, msg_id, timeout, additional_timeout):
self._conn = conn
self._id = msg_id
t = time()
self._time = t + timeout
self._critical_time = t + timeout + additional_timeout
self.refresh()
def getId(self):
return self._id
def getTime(self):
return self._time
def getCriticalTime(self):
return self._critical_time
def refresh(self):
self._next_critical_time = self._critical_time
def __call__(self, t):
conn = self._conn
if t > self._next_critical_time:
# No answer after _critical_time, close connection.
# This means that remote peer is processing the request for too
# long, although being responsive at network level.
logging.info('timeout for %r with %s:%d',
self._id, *(conn.getAddress()))
conn.lock()
try:
conn.close()
conn.getHandler().timeoutExpired(conn)
finally:
conn.unlock()
return True
elif t > self._time:
# Still no answer after _time, send a ping to see if connection is
# broken.
# XXX: This code has no meaning if the remote peer is single-
# threaded. Nevertheless, it should be kept in case it gets
# multithreaded, someday (master & storage are the only candidates
# for using this code, as other don't receive requests).
conn.lock()
try:
conn.ping(msg_id=self._id)
finally:
conn.unlock()
# Don't retry pinging after at least PING_DELAY seconds have
# passed.
self._time = t + PING_DELAY
self._next_critical_time = min(self._critical_time,
t + PING_TIMEOUT)
return False
class EpollEventManager(object): class EpollEventManager(object):
"""This class manages connections and events based on epoll(5).""" """This class manages connections and events based on epoll(5)."""
...@@ -90,7 +26,6 @@ class EpollEventManager(object): ...@@ -90,7 +26,6 @@ class EpollEventManager(object):
self.connection_dict = {} self.connection_dict = {}
self.reader_set = set([]) self.reader_set = set([])
self.writer_set = set([]) self.writer_set = set([])
self.event_list = []
self.prev_time = time() self.prev_time = time()
self.epoll = Epoll() self.epoll = Epoll()
self._pending_processing = [] self._pending_processing = []
...@@ -164,6 +99,7 @@ class EpollEventManager(object): ...@@ -164,6 +99,7 @@ class EpollEventManager(object):
self._addPendingConnection(to_process) self._addPendingConnection(to_process)
def _poll(self, timeout = 1): def _poll(self, timeout = 1):
assert timeout >= 0
rlist, wlist = self.epoll.poll(timeout) rlist, wlist = self.epoll.poll(timeout)
r_done_set = set() r_done_set = set()
for fd in rlist: for fd in rlist:
...@@ -196,32 +132,13 @@ class EpollEventManager(object): ...@@ -196,32 +132,13 @@ class EpollEventManager(object):
finally: finally:
conn.unlock() conn.unlock()
# Check idle events. Do not check them out too often, because this t = time()
# is somehow heavy. for conn in self.connection_dict.values():
event_list = self.event_list conn.lock()
if event_list: try:
t = time() conn.checkTimeout(t)
if t - self.prev_time >= 1: finally:
self.prev_time = t conn.unlock()
event_list.sort(key = lambda event: event.getTime())
while event_list:
event = event_list[0]
if event(t):
self.removeIdleEvent(event)
else:
break
def addIdleEvent(self, event):
self.event_list.append(event)
def removeIdleEvent(self, event):
try:
self.event_list.remove(event)
except ValueError:
pass
def refreshIdleEvent(self, event):
event.refresh()
def addReader(self, conn): def addReader(self, conn):
connector = conn.getConnector() connector = conn.getConnector()
......
...@@ -109,8 +109,6 @@ class EventHandler(object): ...@@ -109,8 +109,6 @@ class EventHandler(object):
def connectionAccepted(self, conn): def connectionAccepted(self, conn):
"""Called when a connection is accepted.""" """Called when a connection is accepted."""
# A request for a node identification should arrive.
conn.expectMessage(timeout = 10, additional_timeout = 0)
def timeoutExpired(self, conn): def timeoutExpired(self, conn):
"""Called when a timeout event occurs.""" """Called when a timeout event occurs."""
......
...@@ -30,10 +30,6 @@ from neo.connection import ClientConnection ...@@ -30,10 +30,6 @@ from neo.connection import ClientConnection
def _addPacket(self, packet): def _addPacket(self, packet):
if self.connector is not None: if self.connector is not None:
self.connector._addPacket(packet) self.connector._addPacket(packet)
def expectMessage(self, packet, timeout=5, additional_timeout=30):
if self.connector is not None:
self.connector.expectMessage(packet)
class MasterClientElectionTests(NeoTestBase): class MasterClientElectionTests(NeoTestBase):
...@@ -56,14 +52,11 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -56,14 +52,11 @@ class MasterClientElectionTests(NeoTestBase):
self.master_port = 10011 self.master_port = 10011
# apply monkey patches # apply monkey patches
self._addPacket = ClientConnection._addPacket self._addPacket = ClientConnection._addPacket
self.expectMessage = ClientConnection.expectMessage
ClientConnection._addPacket = _addPacket ClientConnection._addPacket = _addPacket
ClientConnection.expectMessage = expectMessage
def tearDown(self): def tearDown(self):
# restore patched methods # restore patched methods
ClientConnection._addPacket = self._addPacket ClientConnection._addPacket = self._addPacket
ClientConnection.expectMessage = self.expectMessage
NeoTestBase.tearDown(self) NeoTestBase.tearDown(self)
def identifyToMasterNode(self): def identifyToMasterNode(self):
...@@ -220,15 +213,12 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -220,15 +213,12 @@ class MasterServerElectionTests(NeoTestBase):
self.master_address = ('127.0.0.1', 3000) self.master_address = ('127.0.0.1', 3000)
# apply monkey patches # apply monkey patches
self._addPacket = ClientConnection._addPacket self._addPacket = ClientConnection._addPacket
self.expectMessage = ClientConnection.expectMessage
ClientConnection._addPacket = _addPacket ClientConnection._addPacket = _addPacket
ClientConnection.expectMessage = expectMessage
def tearDown(self): def tearDown(self):
NeoTestBase.tearDown(self) NeoTestBase.tearDown(self)
# restore environnement # restore environnement
ClientConnection._addPacket = self._addPacket ClientConnection._addPacket = self._addPacket
ClientConnection.expectMessage = self.expectMessage
def identifyToMasterNode(self, uuid=True): def identifyToMasterNode(self, uuid=True):
node = self.app.nm.getMasterList()[0] node = self.app.nm.getMasterList()[0]
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest import unittest
from time import time
from mock import Mock from mock import Mock
from neo.connection import ListeningConnection, Connection, \ from neo.connection import ListeningConnection, Connection, \
ClientConnection, ServerConnection, MTClientConnection, \ ClientConnection, ServerConnection, MTClientConnection, \
MTServerConnection, HandlerSwitcher MTServerConnection, HandlerSwitcher, Timeout
from neo.connector import getConnectorHandler, registerConnectorHandler from neo.connector import getConnectorHandler, registerConnectorHandler
from neo.tests import DoNothingConnector from neo.tests import DoNothingConnector
from neo.connector import ConnectorException, ConnectorTryAgainException, \ from neo.connector import ConnectorException, ConnectorTryAgainException, \
...@@ -122,12 +123,6 @@ class ConnectionTests(NeoTestBase): ...@@ -122,12 +123,6 @@ class ConnectionTests(NeoTestBase):
self.assertEqual(len(calls), n) self.assertEqual(len(calls), n)
self.assertEqual(calls[n-1].getParam(0), self.address) self.assertEqual(calls[n-1].getParam(0), self.address)
def _checkAddIdleEvent(self, n=1):
self.assertEquals(len(self.em.mockGetNamedCalls("addIdleEvent")), n)
def _checkRemoveIdleEvent(self, n=1):
self.assertEquals(len(self.em.mockGetNamedCalls("removeIdleEvent")), n)
def _checkPacketReceived(self, n=1): def _checkPacketReceived(self, n=1):
calls = self.handler.mockGetNamedCalls('packetReceived') calls = self.handler.mockGetNamedCalls('packetReceived')
self.assertEquals(len(calls), n) self.assertEquals(len(calls), n)
...@@ -192,7 +187,6 @@ class ConnectionTests(NeoTestBase): ...@@ -192,7 +187,6 @@ class ConnectionTests(NeoTestBase):
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
self.assertEqual(bc.cur_id, 0) self.assertEqual(bc.cur_id, 0)
self.assertEqual(bc.event_dict, {})
self.assertEqual(bc.aborted, False) self.assertEqual(bc.aborted, False)
# test uuid # test uuid
self.assertEqual(bc.uuid, None) self.assertEqual(bc.uuid, None)
...@@ -377,25 +371,14 @@ class ConnectionTests(NeoTestBase): ...@@ -377,25 +371,14 @@ class ConnectionTests(NeoTestBase):
self._checkWriteBuf(bc, 'testdata') self._checkWriteBuf(bc, 'testdata')
self._checkWriterAdded(1) self._checkWriterAdded(1)
def test_08_Connection_expectMessage(self):
# with a right connector -> event created
bc = self._makeConnection()
self.assertEqual(len(bc.event_dict), 0)
bc.expectMessage('1')
self.assertEqual(len(bc.event_dict), 1)
self._checkAddIdleEvent(1)
def test_Connection_analyse1(self): def test_Connection_analyse1(self):
# nothing to read, nothing is done # nothing to read, nothing is done
bc = self._makeConnection() bc = self._makeConnection()
bc._queue = Mock() bc._queue = Mock()
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self.assertEqual(len(bc.event_dict), 0)
bc.analyse() bc.analyse()
self._checkRemoveIdleEvent(0)
self._checkPacketReceived(0) self._checkPacketReceived(0)
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self.assertEqual(len(bc.event_dict), 0)
# give some data to analyse # give some data to analyse
master_list = ( master_list = (
...@@ -410,17 +393,14 @@ class ConnectionTests(NeoTestBase): ...@@ -410,17 +393,14 @@ class ConnectionTests(NeoTestBase):
p = Packets.AnswerPrimary(self.getNewUUID(), master_list) p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1) p.setId(1)
bc.read_buf += p.encode() bc.read_buf += p.encode()
self.assertEqual(len(bc.event_dict), 0)
bc.analyse() bc.analyse()
# check packet decoded # check packet decoded
self._checkRemoveIdleEvent(0)
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), p.getType()) self.assertEqual(data.getType(), p.getType())
self.assertEqual(data.getId(), p.getId()) self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode()) self.assertEqual(data.decode(), p.decode())
self.assertEqual(len(bc.event_dict), 0)
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
def test_Connection_analyse2(self): def test_Connection_analyse2(self):
...@@ -454,10 +434,8 @@ class ConnectionTests(NeoTestBase): ...@@ -454,10 +434,8 @@ class ConnectionTests(NeoTestBase):
p2.setId(2) p2.setId(2)
bc.read_buf += p2.encode() bc.read_buf += p2.encode()
self.assertEqual(len(''.join(bc.read_buf)), len(p1) + len(p2)) self.assertEqual(len(''.join(bc.read_buf)), len(p1) + len(p2))
self.assertEqual(len(bc.event_dict), 0)
bc.analyse() bc.analyse()
# check two packets decoded # check two packets decoded
self._checkRemoveIdleEvent(0)
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 2) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 2)
# packet 1 # packet 1
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
...@@ -471,7 +449,6 @@ class ConnectionTests(NeoTestBase): ...@@ -471,7 +449,6 @@ class ConnectionTests(NeoTestBase):
self.assertEqual(data.getType(), p2.getType()) self.assertEqual(data.getType(), p2.getType())
self.assertEqual(data.getId(), p2.getId()) self.assertEqual(data.getId(), p2.getId())
self.assertEqual(data.decode(), p2.decode()) self.assertEqual(data.decode(), p2.decode())
self.assertEqual(len(bc.event_dict), 0)
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
def test_Connection_analyse3(self): def test_Connection_analyse3(self):
...@@ -480,11 +457,9 @@ class ConnectionTests(NeoTestBase): ...@@ -480,11 +457,9 @@ class ConnectionTests(NeoTestBase):
bc._queue = Mock() bc._queue = Mock()
bc.read_buf += "datadatadatadata" bc.read_buf += "datadatadatadata"
self.assertEqual(len(bc.read_buf), 16) self.assertEqual(len(bc.read_buf), 16)
self.assertEqual(len(bc.event_dict), 0)
bc.analyse() bc.analyse()
self.assertEqual(len(bc.read_buf), 16) self.assertEqual(len(bc.read_buf), 16)
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 0) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 0)
self._checkRemoveIdleEvent(0)
def test_Connection_analyse4(self): def test_Connection_analyse4(self):
# give an expected packet # give an expected packet
...@@ -502,19 +477,14 @@ class ConnectionTests(NeoTestBase): ...@@ -502,19 +477,14 @@ class ConnectionTests(NeoTestBase):
p = Packets.AnswerPrimary(self.getNewUUID(), master_list) p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1) p.setId(1)
bc.read_buf += p.encode() bc.read_buf += p.encode()
self.assertEqual(len(bc.event_dict), 0)
bc.expectMessage(1)
self.assertEqual(len(bc.event_dict), 1)
bc.analyse() bc.analyse()
# check packet decoded # check packet decoded
self._checkRemoveIdleEvent(1)
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), p.getType()) self.assertEqual(data.getType(), p.getType())
self.assertEqual(data.getId(), p.getId()) self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode()) self.assertEqual(data.decode(), p.decode())
self.assertEqual(len(bc.event_dict), 0)
self.assertEqual(''.join(bc.read_buf), '') self.assertEqual(''.join(bc.read_buf), '')
def test_Connection_writable1(self): def test_Connection_writable1(self):
...@@ -614,13 +584,11 @@ class ConnectionTests(NeoTestBase): ...@@ -614,13 +584,11 @@ class ConnectionTests(NeoTestBase):
bc.readable() bc.readable()
# check packet decoded # check packet decoded
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self._checkRemoveIdleEvent(0)
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), Packets.AnswerPrimary) self.assertEqual(data.getType(), Packets.AnswerPrimary)
self.assertEqual(data.getId(), 1) self.assertEqual(data.getId(), 1)
self.assertEqual(len(bc.event_dict), 0)
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
# check not aborted # check not aborted
self.assertFalse(bc.aborted) self.assertFalse(bc.aborted)
...@@ -763,7 +731,6 @@ class ConnectionTests(NeoTestBase): ...@@ -763,7 +731,6 @@ class ConnectionTests(NeoTestBase):
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
self.assertEqual(bc.cur_id, 0) self.assertEqual(bc.cur_id, 0)
self.assertEqual(bc.event_dict, {})
self.assertEqual(bc.aborted, False) self.assertEqual(bc.aborted, False)
# test uuid # test uuid
self.assertEqual(bc.uuid, None) self.assertEqual(bc.uuid, None)
...@@ -941,5 +908,51 @@ class HandlerSwitcherTests(NeoTestBase): ...@@ -941,5 +908,51 @@ class HandlerSwitcherTests(NeoTestBase):
self.checkAborted(self._connection) self.checkAborted(self._connection)
class TestTimeout(NeoTestBase):
""" assume PING_DELAY=5 """
def setUp(self):
self.initial = time()
self.current = self.initial
self.timeout = Timeout()
def checkAfter(self, n, soft, hard):
at = self.current + n
self.assertEqual(soft, self.timeout.softExpired(at))
self.assertEqual(hard, self.timeout.hardExpired(at))
def refreshAfter(self, n):
self.current += n
self.timeout.refresh(self.current)
def testNoTimeout(self):
self.timeout.update(self.initial, 5)
self.checkAfter(1, False, False)
self.checkAfter(4, False, False)
self.refreshAfter(4) # answer received
self.checkAfter(1, False, False)
def testSoftTimeout(self):
self.timeout.update(self.initial, 5)
self.checkAfter(1, False, False)
self.checkAfter(4, False, False)
self.checkAfter(6, True, True) # ping
self.refreshAfter(8) # pong
self.checkAfter(1, False, False)
self.checkAfter(4, False, True)
def testHardTimeout(self):
self.timeout.update(self.initial, 5)
self.checkAfter(1, False, False)
self.checkAfter(4, False, False)
self.checkAfter(6, True, True) # ping
self.refreshAfter(6) # pong
self.checkAfter(1, False, False)
self.checkAfter(4, False, False)
self.checkAfter(6, False, True) # ping
self.refreshAfter(6) # pong
self.checkAfter(1, False, True) # too late
self.checkAfter(5, False, True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,7 +19,7 @@ from mock import Mock ...@@ -19,7 +19,7 @@ from mock import Mock
from time import time from time import time
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.epoll import Epoll from neo.epoll import Epoll
from neo.event import EpollEventManager, IdleEvent from neo.event import EpollEventManager
class EventTests(NeoTestBase): class EventTests(NeoTestBase):
...@@ -35,7 +35,6 @@ class EventTests(NeoTestBase): ...@@ -35,7 +35,6 @@ class EventTests(NeoTestBase):
self.assertEqual(len(em.connection_dict), 0) self.assertEqual(len(em.connection_dict), 0)
self.assertEqual(len(em.reader_set), 0) self.assertEqual(len(em.reader_set), 0)
self.assertEqual(len(em.writer_set), 0) self.assertEqual(len(em.writer_set), 0)
self.assertEqual(len(em.event_list), 0)
self.assertTrue(em.prev_time <time) self.assertTrue(em.prev_time <time)
self.assertTrue(isinstance(em.epoll, Epoll)) self.assertTrue(isinstance(em.epoll, Epoll))
# use a mock object instead of epoll # use a mock object instead of epoll
...@@ -63,16 +62,6 @@ class EventTests(NeoTestBase): ...@@ -63,16 +62,6 @@ class EventTests(NeoTestBase):
self.assertEqual(data, 1014) self.assertEqual(data, 1014)
self.assertEqual(len(em.getConnectionList()), 0) self.assertEqual(len(em.getConnectionList()), 0)
# add/removeIdleEvent
event = Mock()
self.assertEqual(len(em.event_list), 0)
em.addIdleEvent(event)
self.assertEqual(len(em.event_list), 1)
em.removeIdleEvent(event)
self.assertEqual(len(em.event_list), 0)
em.removeIdleEvent(event) # must not fail
self.assertEqual(len(em.event_list), 0)
# add/removeReader # add/removeReader
connector = Mock({"getDescriptor" : 1515}) connector = Mock({"getDescriptor" : 1515})
conn = Mock({'getConnector': connector}) conn = Mock({'getConnector': connector})
...@@ -136,102 +125,6 @@ class EventTests(NeoTestBase): ...@@ -136,102 +125,6 @@ class EventTests(NeoTestBase):
#self.assertEquals(len(w_conn.mockGetNamedCalls("readable")), 0) #self.assertEquals(len(w_conn.mockGetNamedCalls("readable")), 0)
#self.assertEquals(len(w_conn.mockGetNamedCalls("writable")), 1) #self.assertEquals(len(w_conn.mockGetNamedCalls("writable")), 1)
def test_02_IdleEvent(self):
# test init
handler = Mock()
conn = Mock({"getAddress" : ("127.9.9.9", 135),
"getHandler" : handler})
event = IdleEvent(conn, 1, 10, 20)
self.assertEqual(event.getId(), 1)
self.assertNotEqual(event.getTime(), None)
time = event.getTime()
self.assertNotEqual(event.getCriticalTime(), None)
critical_time = event.getCriticalTime()
self.assertEqual(critical_time, time+20)
# call with t < time < critical_time
t = time - 10
r = event(t)
self.assertFalse(r)
self.assertEquals(len(conn.mockGetNamedCalls("lock")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("getHandler")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("close")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("unlock")), 0)
self.checkNoPacketSent(conn)
self.assertEquals(len(handler.mockGetNamedCalls("timeoutExpired")), 0)
# call with time < t < critical_time
t = time + 5
self.assertTrue(t < critical_time)
r = event(t)
self.assertFalse(r)
self.assertEquals(len(conn.mockGetNamedCalls("lock")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("getHandler")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("close")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("unlock")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("ping")), 1)
self.assertEquals(len(handler.mockGetNamedCalls("timeoutExpired")), 0)
# call with time < critical_time < t
t = critical_time + 5
self.assertTrue(t > critical_time)
r = event(t)
self.assertTrue(r)
self.assertEquals(len(conn.mockGetNamedCalls("lock")), 2)
self.assertEquals(len(conn.mockGetNamedCalls("getHandler")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("close")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("unlock")), 2)
self.assertEquals(len(conn.mockGetNamedCalls("ping")), 1)
self.assertEquals(len(handler.mockGetNamedCalls("timeoutExpired")), 1)
# same test with additional time < 5
# test init
handler = Mock()
conn = Mock({"getAddress" : ("127.9.9.9", 135),
"getHandler" : handler})
event = IdleEvent(conn, 1, 10, 3)
self.assertEqual(event.getId(), 1)
self.assertNotEqual(event.getTime(), None)
time = event.getTime()
self.assertNotEqual(event.getCriticalTime(), None)
critical_time = event.getCriticalTime()
self.assertEqual(critical_time, time+3)
# call with t < time < critical_time
t = time - 10
r = event(t)
self.assertFalse(r)
self.assertEquals(len(conn.mockGetNamedCalls("lock")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("getHandler")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("close")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("unlock")), 0)
self.checkNoPacketSent(conn)
self.assertEquals(len(handler.mockGetNamedCalls("timeoutExpired")), 0)
# call with time < t < critical_time
t = time + 1
self.assertTrue(t < critical_time)
r = event(t)
self.assertFalse(r)
self.assertEquals(len(conn.mockGetNamedCalls("lock")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("getHandler")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("close")), 0)
self.assertEquals(len(conn.mockGetNamedCalls("unlock")), 1)
self.checkNoPacketSent(conn)
self.assertEquals(len(handler.mockGetNamedCalls("timeoutExpired")), 0)
# call with time < critical_time < t
t = critical_time + 5
self.assertTrue(t > critical_time)
r = event(t)
self.assertTrue(r)
self.assertEquals(len(conn.mockGetNamedCalls("lock")), 2)
self.assertEquals(len(conn.mockGetNamedCalls("getHandler")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("close")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("unlock")), 2)
self.checkNoPacketSent(conn)
self.assertEquals(len(handler.mockGetNamedCalls("timeoutExpired")), 1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment