Commit 2f33456e authored by Benjamin Peterson's avatar Benjamin Peterson

fix sslwrap_simple (closes #22523)

Thanks Alex Gaynor.
parent 0d377b37
...@@ -969,16 +969,16 @@ def get_protocol_name(protocol_code): ...@@ -969,16 +969,16 @@ def get_protocol_name(protocol_code):
# a replacement for the old socket.ssl function # a replacement for the old socket.ssl function
def sslwrap_simple(sock, keyfile=None, certfile=None): def sslwrap_simple(sock, keyfile=None, certfile=None):
"""A replacement for the old socket.ssl function. Designed """A replacement for the old socket.ssl function. Designed
for compability with Python 2.5 and earlier. Will disappear in for compability with Python 2.5 and earlier. Will disappear in
Python 3.0.""" Python 3.0."""
if hasattr(sock, "_sock"): if hasattr(sock, "_sock"):
sock = sock._sock sock = sock._sock
ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE, ctx = SSLContext(PROTOCOL_SSLv23)
PROTOCOL_SSLv23, None) if keyfile or certfile:
ctx.load_cert_chain(certfile, keyfile)
ssl_sock = ctx._wrap_socket(sock, server_side=False)
try: try:
sock.getpeername() sock.getpeername()
except socket_error: except socket_error:
......
...@@ -94,6 +94,8 @@ class BasicTests(unittest.TestCase): ...@@ -94,6 +94,8 @@ class BasicTests(unittest.TestCase):
pass pass
else: else:
raise raise
def can_clear_options(): def can_clear_options():
# 0.9.8m or higher # 0.9.8m or higher
return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15) return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)
...@@ -2944,7 +2946,7 @@ def test_main(verbose=False): ...@@ -2944,7 +2946,7 @@ def test_main(verbose=False):
if not os.path.exists(filename): if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename) raise support.TestFailed("Can't read certificate file %r" % filename)
tests = [ContextTests, BasicSocketTests, SSLErrorTests] tests = [ContextTests, BasicTests, BasicSocketTests, SSLErrorTests]
if support.is_resource_enabled('network'): if support.is_resource_enabled('network'):
tests.append(NetworkedTests) tests.append(NetworkedTests)
......
...@@ -517,10 +517,12 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, ...@@ -517,10 +517,12 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
self->socket_type = socket_type; self->socket_type = socket_type;
self->Socket = sock; self->Socket = sock;
Py_INCREF(self->Socket); Py_INCREF(self->Socket);
self->ssl_sock = PyWeakref_NewRef(ssl_sock, NULL); if (ssl_sock != Py_None) {
if (self->ssl_sock == NULL) { self->ssl_sock = PyWeakref_NewRef(ssl_sock, NULL);
Py_DECREF(self); if (self->ssl_sock == NULL) {
return NULL; Py_DECREF(self);
return NULL;
}
} }
return self; return self;
} }
...@@ -2931,8 +2933,12 @@ _servername_callback(SSL *s, int *al, void *args) ...@@ -2931,8 +2933,12 @@ _servername_callback(SSL *s, int *al, void *args)
ssl = SSL_get_app_data(s); ssl = SSL_get_app_data(s);
assert(PySSLSocket_Check(ssl)); assert(PySSLSocket_Check(ssl));
ssl_socket = PyWeakref_GetObject(ssl->ssl_sock); if (ssl->ssl_sock == NULL) {
Py_INCREF(ssl_socket); ssl_socket = Py_None;
} else {
ssl_socket = PyWeakref_GetObject(ssl->ssl_sock);
Py_INCREF(ssl_socket);
}
if (ssl_socket == Py_None) { if (ssl_socket == Py_None) {
goto error; goto error;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment