Commit fe487c07 authored by Julien Muchembled's avatar Julien Muchembled

ssl: fix handshaking connections being stuck when they're aborted

parent aaefaf8b
......@@ -433,9 +433,12 @@ class Connection(BaseConnection):
def abort(self):
"""Abort dealing with this connection."""
assert self.pending()
if self.connecting:
self.close()
return
logging.debug('aborting a connector for %r', self)
self.aborted = True
assert self.pending()
if self._on_close is not None:
self._on_close()
self._on_close = None
......
......@@ -278,8 +278,8 @@ class ServerNode(Node):
if not address:
address = self.newAddress()
if cluster is None:
master_nodes = kw['master_nodes']
name = kw['name']
master_nodes = kw.get('master_nodes', ())
name = kw.get('name', 'test')
else:
master_nodes = kw.get('master_nodes', cluster.master_nodes)
name = kw.get('name', cluster.name)
......@@ -292,7 +292,7 @@ class ServerNode(Node):
self.daemon = True
self.node_name = '%s_%u' % (self.node_type, port)
kw.update(getCluster=name, getBind=address,
getMasters=parseMasterList(master_nodes, address))
getMasters=master_nodes and parseMasterList(master_nodes, address))
super(ServerNode, self).__init__(Mock(kw))
def getVirtualAddress(self):
......
......@@ -15,8 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from neo.lib.connection import ClientConnection, ListeningConnection
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets
from .. import SSL
from . import NEOCluster, test, testReplication
from . import MasterApplication, NEOCluster, test, testReplication
class SSLMixin:
......@@ -34,6 +37,30 @@ class SSLTests(SSLMixin, test.Test):
# exclude expected failures
testDeadlockAvoidance = testStorageFailureDuringTpcFinish = None
def testAbortConnection(self):
app = MasterApplication(getSSL=SSL, getReplicas=0, getPartitions=1)
handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(),
uuid=app.uuid)
for after_handshake in 1, 0:
conn = ClientConnection(app, handler, node)
conn.ask(Packets.Ping())
connector = conn.getConnector()
del connector.connect_limit[connector.addr]
app.em.poll(1)
self.assertTrue(isinstance(connector,
connector.SSLHandshakeConnectorClass))
self.assertNotIn(connector.getDescriptor(), app.em.writer_set)
if after_handshake:
while not isinstance(connector, connector.SSLConnectorClass):
app.em.poll(1)
conn.abort()
fd = connector.getDescriptor()
while fd in app.em.reader_set:
app.em.poll(1)
self.assertIs(conn.getConnector(), None)
class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests):
# do not repeat slowest tests with SSL
testBackupNodeLost = testBackupNormalCase = None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment