Commit 06a64d80 authored by Julien Muchembled's avatar Julien Muchembled

client: fix spurious connection timeouts

This fixes a regression caused by
commit eef52c27
parent f180b00e
...@@ -422,12 +422,18 @@ class Connection(BaseConnection): ...@@ -422,12 +422,18 @@ class Connection(BaseConnection):
def onTimeout(self): def onTimeout(self):
handlers = self._handlers handlers = self._handlers
if handlers.isPending(): if handlers.isPending():
msg_id = handlers.timeout(self) # It is possible that another thread used ask() while getting a
if msg_id is None: # timeout from epoll, so we must check again the value of
self._next_timeout = time() + self._timeout # _next_timeout (we know that _queue is still empty).
else: # Although this test is only useful for MTClientConnection,
logging.info('timeout for #0x%08x with %r', msg_id, self) # it's not worth complicating the code more.
self.close() if self._next_timeout <= time():
msg_id = handlers.timeout(self)
if msg_id is None:
self._next_timeout = time() + self._timeout
else:
logging.info('timeout for #0x%08x with %r', msg_id, self)
self.close()
else: else:
self.idle() self.idle()
......
...@@ -30,8 +30,10 @@ import neo.admin.app, neo.master.app, neo.storage.app ...@@ -30,8 +30,10 @@ import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app import neo.client.app, neo.neoctl.app
from neo.client import Storage from neo.client import Storage
from neo.lib import logging from neo.lib import logging
from neo.lib.connection import BaseConnection, Connection from neo.lib.connection import BaseConnection, \
ClientConnection, Connection, ListeningConnection
from neo.lib.connector import SocketConnector, ConnectorException from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
...@@ -829,6 +831,21 @@ class NEOThreadedTest(NeoTestBase): ...@@ -829,6 +831,21 @@ class NEOThreadedTest(NeoTestBase):
tic = Serialized.tic tic = Serialized.tic
def getLoopbackConnection(self):
app = MasterApplication(getSSL=NEOCluster.SSL,
getReplicas=0, getPartitions=1)
handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(),
uuid=app.uuid)
conn = ClientConnection.__new__(ClientConnection)
def reset():
conn.__dict__.clear()
conn.__init__(app, handler, node)
conn.reset = reset
reset()
return conn
def getUnpickler(self, conn): def getUnpickler(self, conn):
reader = conn._reader reader = conn._reader
def unpickler(data, compression=False): def unpickler(data, compression=False):
......
...@@ -1176,6 +1176,19 @@ class Test(NEOThreadedTest): ...@@ -1176,6 +1176,19 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testConnectionTimeout(self):
conn = self.getLoopbackConnection()
conn.KEEP_ALIVE
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
def onTimeout(orig):
conn.idle()
orig()
with Patch(conn, onTimeout=onTimeout):
conn.em.poll(1)
self.assertFalse(conn.isClosed())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -253,6 +253,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -253,6 +253,7 @@ class ReplicationTests(NEOThreadedTest):
def _poll(orig, self, blocking): def _poll(orig, self, blocking):
if backup.master.em is self: if backup.master.em is self:
p.revert() p.revert()
conn._next_timeout = 0
conn.onTimeout() conn.onTimeout()
else: else:
orig(self, blocking) orig(self, blocking)
......
...@@ -15,11 +15,9 @@ ...@@ -15,11 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from neo.lib.connection import ClientConnection, ListeningConnection
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets from neo.lib.protocol import Packets
from .. import SSL from .. import SSL
from . import MasterApplication, NEOCluster, test, testReplication from . import NEOCluster, test, testReplication
class SSLMixin: class SSLMixin:
...@@ -38,27 +36,25 @@ class SSLTests(SSLMixin, test.Test): ...@@ -38,27 +36,25 @@ class SSLTests(SSLMixin, test.Test):
testDeadlockAvoidance = testStorageFailureDuringTpcFinish = None testDeadlockAvoidance = testStorageFailureDuringTpcFinish = None
def testAbortConnection(self): def testAbortConnection(self):
app = MasterApplication(getSSL=SSL, getReplicas=0, getPartitions=1)
handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(),
uuid=app.uuid)
for after_handshake in 1, 0: for after_handshake in 1, 0:
conn = ClientConnection(app, handler, node) try:
conn.reset()
except UnboundLocalError:
conn = self.getLoopbackConnection()
conn.ask(Packets.Ping()) conn.ask(Packets.Ping())
connector = conn.getConnector() connector = conn.getConnector()
del connector.connect_limit[connector.addr] del connector.connect_limit[connector.addr]
app.em.poll(1) conn.em.poll(1)
self.assertTrue(isinstance(connector, self.assertTrue(isinstance(connector,
connector.SSLHandshakeConnectorClass)) connector.SSLHandshakeConnectorClass))
self.assertNotIn(connector.getDescriptor(), app.em.writer_set) self.assertNotIn(connector.getDescriptor(), conn.em.writer_set)
if after_handshake: if after_handshake:
while not isinstance(connector, connector.SSLConnectorClass): while not isinstance(connector, connector.SSLConnectorClass):
app.em.poll(1) conn.em.poll(1)
conn.abort() conn.abort()
fd = connector.getDescriptor() fd = connector.getDescriptor()
while fd in app.em.reader_set: while fd in conn.em.reader_set:
app.em.poll(1) conn.em.poll(1)
self.assertIs(conn.getConnector(), None) self.assertIs(conn.getConnector(), None)
class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests): class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment