Commit 5b95eb90 authored by Antoine Pitrou's avatar Antoine Pitrou

Use context managers in test_ssl to simplify test writing.

parent 17c07134
...@@ -532,6 +532,14 @@ else: ...@@ -532,6 +532,14 @@ else:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.daemon = True self.daemon = True
def __enter__(self):
self.start(threading.Event())
self.flag.wait()
def __exit__(self, *args):
self.stop()
self.join()
def start(self, flag=None): def start(self, flag=None):
self.flag = flag self.flag = flag
threading.Thread.start(self) threading.Thread.start(self)
...@@ -638,6 +646,20 @@ else: ...@@ -638,6 +646,20 @@ else:
def __str__(self): def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server) return "<%s %s>" % (self.__class__.__name__, self.server)
def __enter__(self):
self.start(threading.Event())
self.flag.wait()
def __exit__(self, *args):
if test_support.verbose:
sys.stdout.write(" cleanup: stopping server.\n")
self.stop()
if test_support.verbose:
sys.stdout.write(" cleanup: joining server thread.\n")
self.join()
if test_support.verbose:
sys.stdout.write(" cleanup: successfully joined.\n")
def start(self, flag=None): def start(self, flag=None):
self.flag = flag self.flag = flag
threading.Thread.start(self) threading.Thread.start(self)
...@@ -752,12 +774,7 @@ else: ...@@ -752,12 +774,7 @@ else:
server = ThreadedEchoServer(CERTFILE, server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_REQUIRED, certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False) cacerts=CERTFILE, chatty=False)
flag = threading.Event() with server:
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
try: try:
s = ssl.wrap_socket(socket.socket(), s = ssl.wrap_socket(socket.socket(),
certfile=certfile, certfile=certfile,
...@@ -771,9 +788,6 @@ else: ...@@ -771,9 +788,6 @@ else:
sys.stdout.write("\nsocket.error is %s\n" % x[1]) sys.stdout.write("\nsocket.error is %s\n" % x[1])
else: else:
raise AssertionError("Use of invalid cert should have failed!") raise AssertionError("Use of invalid cert should have failed!")
finally:
server.stop()
server.join()
def server_params_test(certfile, protocol, certreqs, cacertsfile, def server_params_test(certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, indata="FOO\n", client_certfile, client_protocol=None, indata="FOO\n",
...@@ -791,14 +805,10 @@ else: ...@@ -791,14 +805,10 @@ else:
chatty=chatty, chatty=chatty,
connectionchatty=connectionchatty, connectionchatty=connectionchatty,
wrap_accepting_socket=wrap_accepting_socket) wrap_accepting_socket=wrap_accepting_socket)
flag = threading.Event() with server:
server.start(flag) # try to connect
# wait for it to start if client_protocol is None:
flag.wait() client_protocol = protocol
# try to connect
if client_protocol is None:
client_protocol = protocol
try:
s = ssl.wrap_socket(socket.socket(), s = ssl.wrap_socket(socket.socket(),
certfile=client_certfile, certfile=client_certfile,
ca_certs=cacertsfile, ca_certs=cacertsfile,
...@@ -826,9 +836,6 @@ else: ...@@ -826,9 +836,6 @@ else:
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" client: closing connection.\n") sys.stdout.write(" client: closing connection.\n")
s.close() s.close()
finally:
server.stop()
server.join()
def try_protocol_combo(server_protocol, def try_protocol_combo(server_protocol,
client_protocol, client_protocol,
...@@ -930,12 +937,7 @@ else: ...@@ -930,12 +937,7 @@ else:
ssl_version=ssl.PROTOCOL_SSLv23, ssl_version=ssl.PROTOCOL_SSLv23,
cacerts=CERTFILE, cacerts=CERTFILE,
chatty=False) chatty=False)
flag = threading.Event() with server:
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
s = ssl.wrap_socket(socket.socket(), s = ssl.wrap_socket(socket.socket(),
certfile=CERTFILE, certfile=CERTFILE,
ca_certs=CERTFILE, ca_certs=CERTFILE,
...@@ -957,9 +959,6 @@ else: ...@@ -957,9 +959,6 @@ else:
"Missing or invalid 'organizationName' field in certificate subject; " "Missing or invalid 'organizationName' field in certificate subject; "
"should be 'Python Software Foundation'.") "should be 'Python Software Foundation'.")
s.close() s.close()
finally:
server.stop()
server.join()
def test_empty_cert(self): def test_empty_cert(self):
"""Connecting with an empty cert file""" """Connecting with an empty cert file"""
...@@ -1042,13 +1041,8 @@ else: ...@@ -1042,13 +1041,8 @@ else:
starttls_server=True, starttls_server=True,
chatty=True, chatty=True,
connectionchatty=True) connectionchatty=True)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
wrapped = False wrapped = False
try: with server:
s = socket.socket() s = socket.socket()
s.setblocking(1) s.setblocking(1)
s.connect((HOST, server.port)) s.connect((HOST, server.port))
...@@ -1093,9 +1087,6 @@ else: ...@@ -1093,9 +1087,6 @@ else:
else: else:
s.send("over\n") s.send("over\n")
s.close() s.close()
finally:
server.stop()
server.join()
def test_socketserver(self): def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections.""" """Using a SocketServer to create and manage SSL connections."""
...@@ -1145,12 +1136,7 @@ else: ...@@ -1145,12 +1136,7 @@ else:
if test_support.verbose: if test_support.verbose:
sys.stdout.write("\n") sys.stdout.write("\n")
server = AsyncoreEchoServer(CERTFILE) server = AsyncoreEchoServer(CERTFILE)
flag = threading.Event() with server:
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
s = ssl.wrap_socket(socket.socket()) s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', server.port)) s.connect(('127.0.0.1', server.port))
if test_support.verbose: if test_support.verbose:
...@@ -1169,10 +1155,6 @@ else: ...@@ -1169,10 +1155,6 @@ else:
if test_support.verbose: if test_support.verbose:
sys.stdout.write(" client: closing connection.\n") sys.stdout.write(" client: closing connection.\n")
s.close() s.close()
finally:
server.stop()
# wait for server thread to end
server.join()
def test_recv_send(self): def test_recv_send(self):
"""Test recv(), send() and friends.""" """Test recv(), send() and friends."""
...@@ -1185,19 +1167,14 @@ else: ...@@ -1185,19 +1167,14 @@ else:
cacerts=CERTFILE, cacerts=CERTFILE,
chatty=True, chatty=True,
connectionchatty=False) connectionchatty=False)
flag = threading.Event() with server:
server.start(flag) s = ssl.wrap_socket(socket.socket(),
# wait for it to start server_side=False,
flag.wait() certfile=CERTFILE,
# try to connect ca_certs=CERTFILE,
s = ssl.wrap_socket(socket.socket(), cert_reqs=ssl.CERT_NONE,
server_side=False, ssl_version=ssl.PROTOCOL_TLSv1)
certfile=CERTFILE, s.connect((HOST, server.port))
ca_certs=CERTFILE,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_TLSv1)
s.connect((HOST, server.port))
try:
# helper methods for standardising recv* method signatures # helper methods for standardising recv* method signatures
def _recv_into(): def _recv_into():
b = bytearray("\0"*100) b = bytearray("\0"*100)
...@@ -1285,9 +1262,6 @@ else: ...@@ -1285,9 +1262,6 @@ else:
s.write("over\n".encode("ASCII", "strict")) s.write("over\n".encode("ASCII", "strict"))
s.close() s.close()
finally:
server.stop()
server.join()
def test_handshake_timeout(self): def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout # Issue #5103: SSL handshake must respect the socket timeout
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment