Commit 23cdfd26 authored by Bill Janssen's avatar Bill Janssen

This contains a number of things:

1) Improve the documentation of the SSL module, with a fuller
   explanation of certificate usage, another reference, proper
   formatting of this and that.

2) Fix Windows bug in ssl.py, and general bug in sslsocket.close().
   Remove some unused code from ssl.py.  Allow accept() to be called on
   sslsocket sockets.

3) Use try-except-else in import of ssl in socket.py.  Deprecate use of
   socket.ssl().

4) Remove use of socket.ssl() in every library module, except for
   test_socket_ssl.py and test_ssl.py.
parent 690ebdd8
This diff is collapsed.
...@@ -940,205 +940,6 @@ class HTTPConnection: ...@@ -940,205 +940,6 @@ class HTTPConnection:
return response return response
# The next several classes are used to define FakeSocket, a socket-like
# interface to an SSL connection.
# The primary complexity comes from faking a makefile() method. The
# standard socket makefile() implementation calls dup() on the socket
# file descriptor. As a consequence, clients can call close() on the
# parent socket and its makefile children in any order. The underlying
# socket isn't closed until they are all closed.
# The implementation uses reference counting to keep the socket open
# until the last client calls close(). SharedSocket keeps track of
# the reference counting and SharedSocketClient provides an constructor
# and close() method that call incref() and decref() correctly.
class SharedSocket:
def __init__(self, sock):
self.sock = sock
self._refcnt = 0
def incref(self):
self._refcnt += 1
def decref(self):
self._refcnt -= 1
assert self._refcnt >= 0
if self._refcnt == 0:
self.sock.close()
def __del__(self):
self.sock.close()
class SharedSocketClient:
def __init__(self, shared):
self._closed = 0
self._shared = shared
self._shared.incref()
self._sock = shared.sock
def close(self):
if not self._closed:
self._shared.decref()
self._closed = 1
self._shared = None
class SSLFile(SharedSocketClient):
"""File-like object wrapping an SSL socket."""
BUFSIZE = 8192
def __init__(self, sock, ssl, bufsize=None):
SharedSocketClient.__init__(self, sock)
self._ssl = ssl
self._buf = ''
self._bufsize = bufsize or self.__class__.BUFSIZE
def _read(self):
buf = ''
# put in a loop so that we retry on transient errors
while True:
try:
buf = self._ssl.read(self._bufsize)
except socket.sslerror, err:
if (err[0] == socket.SSL_ERROR_WANT_READ
or err[0] == socket.SSL_ERROR_WANT_WRITE):
continue
if (err[0] == socket.SSL_ERROR_ZERO_RETURN
or err[0] == socket.SSL_ERROR_EOF):
break
raise
except socket.error, err:
if err[0] == errno.EINTR:
continue
if err[0] == errno.EBADF:
# XXX socket was closed?
break
raise
else:
break
return buf
def read(self, size=None):
L = [self._buf]
avail = len(self._buf)
while size is None or avail < size:
s = self._read()
if s == '':
break
L.append(s)
avail += len(s)
all = "".join(L)
if size is None:
self._buf = ''
return all
else:
self._buf = all[size:]
return all[:size]
def readline(self):
L = [self._buf]
self._buf = ''
while 1:
i = L[-1].find("\n")
if i >= 0:
break
s = self._read()
if s == '':
break
L.append(s)
if i == -1:
# loop exited because there is no more data
return "".join(L)
else:
all = "".join(L)
# XXX could do enough bookkeeping not to do a 2nd search
i = all.find("\n") + 1
line = all[:i]
self._buf = all[i:]
return line
def readlines(self, sizehint=0):
total = 0
list = []
while True:
line = self.readline()
if not line:
break
list.append(line)
total += len(line)
if sizehint and total >= sizehint:
break
return list
def fileno(self):
return self._sock.fileno()
def __iter__(self):
return self
def next(self):
line = self.readline()
if not line:
raise StopIteration
return line
class FakeSocket(SharedSocketClient):
class _closedsocket:
def __getattr__(self, name):
raise error(9, 'Bad file descriptor')
def __init__(self, sock, ssl):
sock = SharedSocket(sock)
SharedSocketClient.__init__(self, sock)
self._ssl = ssl
def close(self):
SharedSocketClient.close(self)
self._sock = self.__class__._closedsocket()
def makefile(self, mode, bufsize=None):
if mode != 'r' and mode != 'rb':
raise UnimplementedFileMode()
return SSLFile(self._shared, self._ssl, bufsize)
def send(self, stuff, flags = 0):
return self._ssl.write(stuff)
sendall = send
def recv(self, len = 1024, flags = 0):
return self._ssl.read(len)
def __getattr__(self, attr):
return getattr(self._sock, attr)
def close(self):
SharedSocketClient.close(self)
self._ssl = None
class HTTPSConnection(HTTPConnection):
"This class allows communication via SSL."
default_port = HTTPS_PORT
def __init__(self, host, port=None, key_file=None, cert_file=None,
strict=None, timeout=None):
HTTPConnection.__init__(self, host, port, strict, timeout)
self.key_file = key_file
self.cert_file = cert_file
def connect(self):
"Connect to a host on a given (SSL) port."
sock = socket.create_connection((self.host, self.port), self.timeout)
ssl = socket.ssl(sock, self.key_file, self.cert_file)
self.sock = FakeSocket(sock, ssl)
class HTTP: class HTTP:
"Compatibility class with httplib.py from 1.5." "Compatibility class with httplib.py from 1.5."
...@@ -1229,7 +1030,29 @@ class HTTP: ...@@ -1229,7 +1030,29 @@ class HTTP:
### do it ### do it
self.file = None self.file = None
if hasattr(socket, 'ssl'): try:
import ssl
except ImportError:
pass
else:
class HTTPSConnection(HTTPConnection):
"This class allows communication via SSL."
default_port = HTTPS_PORT
def __init__(self, host, port=None, key_file=None, cert_file=None,
strict=None, timeout=None):
HTTPConnection.__init__(self, host, port, strict, timeout)
self.key_file = key_file
self.cert_file = cert_file
def connect(self):
"Connect to a host on a given (SSL) port."
sock = socket.create_connection((self.host, self.port), self.timeout)
self.sock = ssl.sslsocket(sock, self.key_file, self.cert_file)
class HTTPS(HTTP): class HTTPS(HTTP):
"""Compatibility with 1.5 httplib interface """Compatibility with 1.5 httplib interface
...@@ -1256,6 +1079,10 @@ if hasattr(socket, 'ssl'): ...@@ -1256,6 +1079,10 @@ if hasattr(socket, 'ssl'):
self.cert_file = cert_file self.cert_file = cert_file
def FakeSocket (sock, sslobj):
return sslobj
class HTTPException(Exception): class HTTPException(Exception):
# Subclasses that define an __init__ must call Exception.__init__ # Subclasses that define an __init__ must call Exception.__init__
# or define self.args. Otherwise, str() will fail. # or define self.args. Otherwise, str() will fail.
...@@ -1413,7 +1240,11 @@ def test(): ...@@ -1413,7 +1240,11 @@ def test():
h.getreply() h.getreply()
h.close() h.close()
if hasattr(socket, 'ssl'): try:
import ssl
except ImportError:
pass
else:
for host, selector in (('sourceforge.net', '/projects/python'), for host, selector in (('sourceforge.net', '/projects/python'),
): ):
......
...@@ -1111,94 +1111,99 @@ class IMAP4: ...@@ -1111,94 +1111,99 @@ class IMAP4:
class IMAP4_SSL(IMAP4): try:
import ssl
except ImportError:
pass
else:
class IMAP4_SSL(IMAP4):
"""IMAP4 client class over SSL connection """IMAP4 client class over SSL connection
Instantiate with: IMAP4_SSL([host[, port[, keyfile[, certfile]]]]) Instantiate with: IMAP4_SSL([host[, port[, keyfile[, certfile]]]])
host - host's name (default: localhost); host - host's name (default: localhost);
port - port number (default: standard IMAP4 SSL port). port - port number (default: standard IMAP4 SSL port).
keyfile - PEM formatted file that contains your private key (default: None); keyfile - PEM formatted file that contains your private key (default: None);
certfile - PEM formatted certificate chain file (default: None); certfile - PEM formatted certificate chain file (default: None);
for more documentation see the docstring of the parent class IMAP4. for more documentation see the docstring of the parent class IMAP4.
""" """
def __init__(self, host = '', port = IMAP4_SSL_PORT, keyfile = None, certfile = None): def __init__(self, host = '', port = IMAP4_SSL_PORT, keyfile = None, certfile = None):
self.keyfile = keyfile self.keyfile = keyfile
self.certfile = certfile self.certfile = certfile
IMAP4.__init__(self, host, port) IMAP4.__init__(self, host, port)
def open(self, host = '', port = IMAP4_SSL_PORT): def open(self, host = '', port = IMAP4_SSL_PORT):
"""Setup connection to remote server on "host:port". """Setup connection to remote server on "host:port".
(default: localhost:standard IMAP4 SSL port). (default: localhost:standard IMAP4 SSL port).
This connection will be used by the routines: This connection will be used by the routines:
read, readline, send, shutdown. read, readline, send, shutdown.
""" """
self.host = host self.host = host
self.port = port self.port = port
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((host, port)) self.sock.connect((host, port))
self.sslobj = socket.ssl(self.sock, self.keyfile, self.certfile) self.sslobj = ssl.sslsocket(self.sock, self.keyfile, self.certfile)
def read(self, size): def read(self, size):
"""Read 'size' bytes from remote.""" """Read 'size' bytes from remote."""
# sslobj.read() sometimes returns < size bytes # sslobj.read() sometimes returns < size bytes
chunks = [] chunks = []
read = 0 read = 0
while read < size: while read < size:
data = self.sslobj.read(size-read) data = self.sslobj.read(size-read)
read += len(data) read += len(data)
chunks.append(data) chunks.append(data)
return ''.join(chunks) return ''.join(chunks)
def readline(self): def readline(self):
"""Read line from remote.""" """Read line from remote."""
# NB: socket.ssl needs a "readline" method, or perhaps a "makefile" method. # NB: socket.ssl needs a "readline" method, or perhaps a "makefile" method.
line = [] line = []
while 1: while 1:
char = self.sslobj.read(1) char = self.sslobj.read(1)
line.append(char) line.append(char)
if char == "\n": return ''.join(line) if char == "\n": return ''.join(line)
def send(self, data): def send(self, data):
"""Send data to remote.""" """Send data to remote."""
# NB: socket.ssl needs a "sendall" method to match socket objects. # NB: socket.ssl needs a "sendall" method to match socket objects.
bytes = len(data) bytes = len(data)
while bytes > 0: while bytes > 0:
sent = self.sslobj.write(data) sent = self.sslobj.write(data)
if sent == bytes: if sent == bytes:
break # avoid copy break # avoid copy
data = data[sent:] data = data[sent:]
bytes = bytes - sent bytes = bytes - sent
def shutdown(self): def shutdown(self):
"""Close I/O established in "open".""" """Close I/O established in "open"."""
self.sock.close() self.sock.close()
def socket(self): def socket(self):
"""Return socket instance used to connect to IMAP4 server. """Return socket instance used to connect to IMAP4 server.
socket = <instance>.socket() socket = <instance>.socket()
""" """
return self.sock return self.sock
def ssl(self): def ssl(self):
"""Return SSLObject instance used to communicate with the IMAP4 server. """Return SSLObject instance used to communicate with the IMAP4 server.
ssl = <instance>.socket.ssl() ssl = <instance>.socket.ssl()
""" """
return self.sslobj return self.sslobj
......
...@@ -307,89 +307,95 @@ class POP3: ...@@ -307,89 +307,95 @@ class POP3:
return self._shortcmd('UIDL %s' % which) return self._shortcmd('UIDL %s' % which)
return self._longcmd('UIDL') return self._longcmd('UIDL')
class POP3_SSL(POP3): try:
"""POP3 client class over SSL connection import ssl
except ImportError:
pass
else:
Instantiate with: POP3_SSL(hostname, port=995, keyfile=None, certfile=None) class POP3_SSL(POP3):
"""POP3 client class over SSL connection
hostname - the hostname of the pop3 over ssl server Instantiate with: POP3_SSL(hostname, port=995, keyfile=None, certfile=None)
port - port number
keyfile - PEM formatted file that countains your private key
certfile - PEM formatted certificate chain file
See the methods of the parent class POP3 for more documentation. hostname - the hostname of the pop3 over ssl server
""" port - port number
keyfile - PEM formatted file that countains your private key
def __init__(self, host, port = POP3_SSL_PORT, keyfile = None, certfile = None): certfile - PEM formatted certificate chain file
self.host = host
self.port = port
self.keyfile = keyfile
self.certfile = certfile
self.buffer = ""
msg = "getaddrinfo returns an empty list"
self.sock = None
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
try:
self.sock = socket.socket(af, socktype, proto)
self.sock.connect(sa)
except socket.error, msg:
if self.sock:
self.sock.close()
self.sock = None
continue
break
if not self.sock:
raise socket.error, msg
self.file = self.sock.makefile('rb')
self.sslobj = socket.ssl(self.sock, self.keyfile, self.certfile)
self._debugging = 0
self.welcome = self._getresp()
def _fillBuffer(self): See the methods of the parent class POP3 for more documentation.
localbuf = self.sslobj.read() """
if len(localbuf) == 0:
raise error_proto('-ERR EOF')
self.buffer += localbuf
def _getline(self): def __init__(self, host, port = POP3_SSL_PORT, keyfile = None, certfile = None):
line = "" self.host = host
renewline = re.compile(r'.*?\n') self.port = port
match = renewline.match(self.buffer) self.keyfile = keyfile
while not match: self.certfile = certfile
self._fillBuffer() self.buffer = ""
msg = "getaddrinfo returns an empty list"
self.sock = None
for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
try:
self.sock = socket.socket(af, socktype, proto)
self.sock.connect(sa)
except socket.error, msg:
if self.sock:
self.sock.close()
self.sock = None
continue
break
if not self.sock:
raise socket.error, msg
self.file = self.sock.makefile('rb')
self.sslobj = ssl.sslsocket(self.sock, self.keyfile, self.certfile)
self._debugging = 0
self.welcome = self._getresp()
def _fillBuffer(self):
localbuf = self.sslobj.read()
if len(localbuf) == 0:
raise error_proto('-ERR EOF')
self.buffer += localbuf
def _getline(self):
line = ""
renewline = re.compile(r'.*?\n')
match = renewline.match(self.buffer) match = renewline.match(self.buffer)
line = match.group(0) while not match:
self.buffer = renewline.sub('' ,self.buffer, 1) self._fillBuffer()
if self._debugging > 1: print '*get*', repr(line) match = renewline.match(self.buffer)
line = match.group(0)
octets = len(line) self.buffer = renewline.sub('' ,self.buffer, 1)
if line[-2:] == CRLF: if self._debugging > 1: print '*get*', repr(line)
return line[:-2], octets
if line[0] == CR: octets = len(line)
return line[1:-1], octets if line[-2:] == CRLF:
return line[:-1], octets return line[:-2], octets
if line[0] == CR:
def _putline(self, line): return line[1:-1], octets
if self._debugging > 1: print '*put*', repr(line) return line[:-1], octets
line += CRLF
bytes = len(line) def _putline(self, line):
while bytes > 0: if self._debugging > 1: print '*put*', repr(line)
sent = self.sslobj.write(line) line += CRLF
if sent == bytes: bytes = len(line)
break # avoid copy while bytes > 0:
line = line[sent:] sent = self.sslobj.write(line)
bytes = bytes - sent if sent == bytes:
break # avoid copy
def quit(self): line = line[sent:]
"""Signoff: commit changes on server, unlock mailbox, close connection.""" bytes = bytes - sent
try:
resp = self._shortcmd('QUIT') def quit(self):
except error_proto, val: """Signoff: commit changes on server, unlock mailbox, close connection."""
resp = val try:
self.sock.close() resp = self._shortcmd('QUIT')
del self.sslobj, self.sock except error_proto, val:
return resp resp = val
self.sock.close()
del self.sslobj, self.sock
return resp
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -128,43 +128,6 @@ class SMTPAuthenticationError(SMTPResponseException): ...@@ -128,43 +128,6 @@ class SMTPAuthenticationError(SMTPResponseException):
combination provided. combination provided.
""" """
class SSLFakeSocket:
"""A fake socket object that really wraps a SSLObject.
It only supports what is needed in smtplib.
"""
def __init__(self, realsock, sslobj):
self.realsock = realsock
self.sslobj = sslobj
def send(self, str):
self.sslobj.write(str)
return len(str)
sendall = send
def close(self):
self.realsock.close()
class SSLFakeFile:
"""A fake file like object that really wraps a SSLObject.
It only supports what is needed in smtplib.
"""
def __init__(self, sslobj):
self.sslobj = sslobj
def readline(self):
str = ""
chr = None
while chr != "\n":
chr = self.sslobj.read(1)
str += chr
return str
def close(self):
pass
def quoteaddr(addr): def quoteaddr(addr):
"""Quote a subset of the email addresses defined by RFC 821. """Quote a subset of the email addresses defined by RFC 821.
...@@ -194,6 +157,33 @@ def quotedata(data): ...@@ -194,6 +157,33 @@ def quotedata(data):
re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data)) re.sub(r'(?:\r\n|\n|\r(?!\n))', CRLF, data))
try:
import ssl
except ImportError:
_have_ssl = False
else:
class SSLFakeFile:
"""A fake file like object that really wraps a SSLObject.
It only supports what is needed in smtplib.
"""
def __init__(self, sslobj):
self.sslobj = sslobj
def readline(self):
str = ""
chr = None
while chr != "\n":
chr = self.sslobj.read(1)
str += chr
return str
def close(self):
pass
_have_ssl = True
class SMTP: class SMTP:
"""This class manages a connection to an SMTP or ESMTP server. """This class manages a connection to an SMTP or ESMTP server.
SMTP Objects: SMTP Objects:
...@@ -596,9 +586,10 @@ class SMTP: ...@@ -596,9 +586,10 @@ class SMTP:
""" """
(resp, reply) = self.docmd("STARTTLS") (resp, reply) = self.docmd("STARTTLS")
if resp == 220: if resp == 220:
sslobj = socket.ssl(self.sock, keyfile, certfile) if not _have_ssl:
self.sock = SSLFakeSocket(self.sock, sslobj) raise RuntimeError("No SSL support included in this Python")
self.file = SSLFakeFile(sslobj) self.sock = ssl.sslsocket(self.sock, keyfile, certfile)
self.file = SSLFakeFile(self.sock)
return (resp, reply) return (resp, reply)
def sendmail(self, from_addr, to_addrs, msg, mail_options=[], def sendmail(self, from_addr, to_addrs, msg, mail_options=[],
...@@ -710,27 +701,29 @@ class SMTP: ...@@ -710,27 +701,29 @@ class SMTP:
self.docmd("quit") self.docmd("quit")
self.close() self.close()
class SMTP_SSL(SMTP): if _have_ssl:
""" This is a subclass derived from SMTP that connects over an SSL encrypted
socket (to use this class you need a socket module that was compiled with SSL class SMTP_SSL(SMTP):
support). If host is not specified, '' (the local host) is used. If port is """ This is a subclass derived from SMTP that connects over an SSL encrypted
omitted, the standard SMTP-over-SSL port (465) is used. keyfile and certfile socket (to use this class you need a socket module that was compiled with SSL
are also optional - they can contain a PEM formatted private key and support). If host is not specified, '' (the local host) is used. If port is
certificate chain file for the SSL connection. omitted, the standard SMTP-over-SSL port (465) is used. keyfile and certfile
""" are also optional - they can contain a PEM formatted private key and
def __init__(self, host='', port=0, local_hostname=None, certificate chain file for the SSL connection.
keyfile=None, certfile=None, timeout=None): """
self.keyfile = keyfile def __init__(self, host='', port=0, local_hostname=None,
self.certfile = certfile keyfile=None, certfile=None, timeout=None):
SMTP.__init__(self, host, port, local_hostname, timeout) self.keyfile = keyfile
self.default_port = SMTP_SSL_PORT self.certfile = certfile
SMTP.__init__(self, host, port, local_hostname, timeout)
def _get_socket(self, host, port, timeout): self.default_port = SMTP_SSL_PORT
if self.debuglevel > 0: print>>stderr, 'connect:', (host, port)
self.sock = socket.create_connection((host, port), timeout) def _get_socket(self, host, port, timeout):
sslobj = socket.ssl(self.sock, self.keyfile, self.certfile) if self.debuglevel > 0: print>>stderr, 'connect:', (host, port)
self.sock = SSLFakeSocket(self.sock, sslobj) self.sock = socket.create_connection((host, port), timeout)
self.file = SSLFakeFile(sslobj) sslobj = socket.ssl(self.sock, self.keyfile, self.certfile)
self.sock = SSLFakeSocket(self.sock, sslobj)
self.file = SSLFakeFile(sslobj)
# #
# LMTP extension # LMTP extension
......
...@@ -46,15 +46,37 @@ the setsockopt() and getsockopt() methods. ...@@ -46,15 +46,37 @@ the setsockopt() and getsockopt() methods.
import _socket import _socket
from _socket import * from _socket import *
_have_ssl = False
try: try:
import _ssl import _ssl
from _ssl import *
_have_ssl = True
except ImportError: except ImportError:
# no SSL support
pass pass
else:
import os, sys def ssl(sock, keyfile=None, certfile=None):
# we do an internal import here because the ssl
# module imports the socket module
import ssl as _realssl
warnings.warn("socket.ssl() is deprecated. Use ssl.sslsocket() instead.",
DeprecationWarning, stacklevel=2)
return _realssl.sslwrap_simple(sock, keyfile, certfile)
# we need to import the same constants we used to...
from _ssl import \
sslerror, \
RAND_add, \
RAND_egd, \
RAND_status, \
SSL_ERROR_ZERO_RETURN, \
SSL_ERROR_WANT_READ, \
SSL_ERROR_WANT_WRITE, \
SSL_ERROR_WANT_X509_LOOKUP, \
SSL_ERROR_SYSCALL, \
SSL_ERROR_SSL, \
SSL_ERROR_WANT_CONNECT, \
SSL_ERROR_EOF, \
SSL_ERROR_INVALID_ERROR_CODE
import os, sys, warnings
try: try:
from errno import EBADF from errno import EBADF
...@@ -63,15 +85,9 @@ except ImportError: ...@@ -63,15 +85,9 @@ except ImportError:
__all__ = ["getfqdn"] __all__ = ["getfqdn"]
__all__.extend(os._get_exports_list(_socket)) __all__.extend(os._get_exports_list(_socket))
if _have_ssl:
__all__.extend(os._get_exports_list(_ssl))
_realsocket = socket _realsocket = socket
if _have_ssl:
def ssl(sock, keyfile=None, certfile=None):
import ssl as realssl
return realssl.sslwrap_simple(sock, keyfile, certfile)
__all__.append("ssl")
# WSA error codes # WSA error codes
if sys.platform.lower().startswith("win"): if sys.platform.lower().startswith("win"):
......
...@@ -58,55 +58,47 @@ PROTOCOL_TLSv1 ...@@ -58,55 +58,47 @@ PROTOCOL_TLSv1
import os, sys import os, sys
import _ssl # if we can't import it, let the error propagate import _ssl # if we can't import it, let the error propagate
from socket import socket
from _ssl import sslerror from _ssl import sslerror
from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
from _ssl import \
SSL_ERROR_ZERO_RETURN, \
SSL_ERROR_WANT_READ, \
SSL_ERROR_WANT_WRITE, \
SSL_ERROR_WANT_X509_LOOKUP, \
SSL_ERROR_SYSCALL, \
SSL_ERROR_SSL, \
SSL_ERROR_WANT_CONNECT, \
SSL_ERROR_EOF, \
SSL_ERROR_INVALID_ERROR_CODE
from socket import socket
from socket import getnameinfo as _getnameinfo
# Root certs:
#
# The "ca_certs" argument to sslsocket() expects a file containing one or more
# certificates that are roots of various certificate signing chains. This file
# contains the certificates in PEM format (RFC ) where each certificate is
# encoded in base64 encoding and surrounded with a header and footer:
# -----BEGIN CERTIFICATE-----
# ... (CA certificate in base64 encoding) ...
# -----END CERTIFICATE-----
# The various certificates in the file are just concatenated together:
# -----BEGIN CERTIFICATE-----
# ... (CA certificate in base64 encoding) ...
# -----END CERTIFICATE-----
# -----BEGIN CERTIFICATE-----
# ... (a second CA certificate in base64 encoding) ...
# -----END CERTIFICATE-----
#
# Some "standard" root certificates are available at
#
# http://www.thawte.com/roots/ (for Thawte roots)
# http://www.verisign.com/support/roots.html (for Verisign)
class sslsocket (socket): class sslsocket (socket):
"""This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and
provides read and write methods over that channel."""
def __init__(self, sock, keyfile=None, certfile=None, def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE, server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None): ssl_version=PROTOCOL_SSLv23, ca_certs=None):
socket.__init__(self, _sock=sock._sock) socket.__init__(self, _sock=sock._sock)
if certfile and not keyfile: if certfile and not keyfile:
keyfile = certfile keyfile = certfile
if server_side: # see if it's connected
self._sslobj = _ssl.sslwrap(self._sock, 1, keyfile, certfile, try:
cert_reqs, ssl_version, ca_certs) socket.getpeername(self)
except:
# no, no connection yet
self._sslobj = None
else: else:
# see if it's connected # yes, create the SSL object
try: self._sslobj = _ssl.sslwrap(self._sock, server_side,
socket.getpeername(self) keyfile, certfile,
except: cert_reqs, ssl_version, ca_certs)
# no, no connection yet
self._sslobj = None
else:
# yes, create the SSL object
self._sslobj = _ssl.sslwrap(self._sock, 0, keyfile, certfile,
cert_reqs, ssl_version, ca_certs)
self.keyfile = keyfile self.keyfile = keyfile
self.certfile = certfile self.certfile = certfile
self.cert_reqs = cert_reqs self.cert_reqs = cert_reqs
...@@ -123,59 +115,77 @@ class sslsocket (socket): ...@@ -123,59 +115,77 @@ class sslsocket (socket):
return self._sslobj.peer_certificate() return self._sslobj.peer_certificate()
def send (self, data, flags=0): def send (self, data, flags=0):
if flags != 0: if self._sslobj:
raise ValueError( if flags != 0:
"non-zero flags not allowed in calls to send() on %s" % raise ValueError(
self.__class__) "non-zero flags not allowed in calls to send() on %s" %
return self._sslobj.write(data) self.__class__)
return self._sslobj.write(data)
else:
return socket.send(self, data, flags)
def send_to (self, data, addr, flags=0): def send_to (self, data, addr, flags=0):
raise ValueError("send_to not allowed on instances of %s" % if self._sslobj:
self.__class__) raise ValueError("send_to not allowed on instances of %s" %
self.__class__)
else:
return socket.send_to(self, data, addr, flags)
def sendall (self, data, flags=0): def sendall (self, data, flags=0):
if flags != 0: if self._sslobj:
raise ValueError( if flags != 0:
"non-zero flags not allowed in calls to sendall() on %s" % raise ValueError(
self.__class__) "non-zero flags not allowed in calls to sendall() on %s" %
return self._sslobj.write(data) self.__class__)
return self._sslobj.write(data)
else:
return socket.sendall(self, data, flags)
def recv (self, buflen=1024, flags=0): def recv (self, buflen=1024, flags=0):
if flags != 0: if self._sslobj:
raise ValueError( if flags != 0:
"non-zero flags not allowed in calls to sendall() on %s" % raise ValueError(
self.__class__) "non-zero flags not allowed in calls to sendall() on %s" %
return self._sslobj.read(data, buflen) self.__class__)
return self._sslobj.read(data, buflen)
else:
return socket.recv(self, buflen, flags)
def recv_from (self, addr, buflen=1024, flags=0): def recv_from (self, addr, buflen=1024, flags=0):
raise ValueError("recv_from not allowed on instances of %s" % if self._sslobj:
self.__class__) raise ValueError("recv_from not allowed on instances of %s" %
self.__class__)
else:
return socket.recv_from(self, addr, buflen, flags)
def shutdown(self): def ssl_shutdown(self):
if self._sslobj: if self._sslobj:
self._sslobj.shutdown() self._sslobj.shutdown()
self._sslobj = None self._sslobj = None
else:
socket.shutdown(self) def shutdown(self, how):
self.ssl_shutdown()
socket.shutdown(self, how)
def close(self): def close(self):
if self._sslobj: self.ssl_shutdown()
self.shutdown() socket.close(self)
else:
socket.close(self)
def connect(self, addr): def connect(self, addr):
# Here we assume that the socket is client-side, and not # Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it. # connected at the time of the call. We connect it, then wrap it.
if self._sslobj or (self.getsockname()[1] != 0): if self._sslobj:
raise ValueError("attempt to connect already-connected sslsocket!") raise ValueError("attempt to connect already-connected sslsocket!")
socket.connect(self, addr) socket.connect(self, addr)
self._sslobj = _ssl.sslwrap(self._sock, 0, self.keyfile, self.certfile, self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version, self.cert_reqs, self.ssl_version,
self.ca_certs) self.ca_certs)
def accept(self): def accept(self):
raise ValueError("accept() not supported on an sslsocket") newsock, addr = socket.accept(self)
return (sslsocket(newsock, True, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs), addr)
# some utility functions # some utility functions
...@@ -190,64 +200,3 @@ def sslwrap_simple (sock, keyfile=None, certfile=None): ...@@ -190,64 +200,3 @@ def sslwrap_simple (sock, keyfile=None, certfile=None):
return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE, return _ssl.sslwrap(sock._sock, 0, keyfile, certfile, CERT_NONE,
PROTOCOL_SSLv23, None) PROTOCOL_SSLv23, None)
# fetch the certificate that the server is providing in PEM form
def fetch_server_certificate (host, port):
import re, tempfile, os
def subproc(cmd):
from subprocess import Popen, PIPE, STDOUT
proc = Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True)
status = proc.wait()
output = proc.stdout.read()
return status, output
def strip_to_x509_cert(certfile_contents, outfile=None):
m = re.search(r"^([-]+BEGIN CERTIFICATE[-]+[\r]*\n"
r".*[\r]*^[-]+END CERTIFICATE[-]+)$",
certfile_contents, re.MULTILINE | re.DOTALL)
if not m:
return None
else:
tn = tempfile.mktemp()
fp = open(tn, "w")
fp.write(m.group(1) + "\n")
fp.close()
try:
tn2 = (outfile or tempfile.mktemp())
status, output = subproc(r'openssl x509 -in "%s" -out "%s"' %
(tn, tn2))
if status != 0:
raise OperationError(status, tsig, output)
fp = open(tn2, 'rb')
data = fp.read()
fp.close()
os.unlink(tn2)
return data
finally:
os.unlink(tn)
if sys.platform.startswith("win"):
tfile = tempfile.mktemp()
fp = open(tfile, "w")
fp.write("quit\n")
fp.close()
try:
status, output = subproc(
'openssl s_client -connect "%s:%s" -showcerts < "%s"' %
(host, port, tfile))
finally:
os.unlink(tfile)
else:
status, output = subproc(
'openssl s_client -connect "%s:%s" -showcerts < /dev/null' %
(host, port))
if status != 0:
raise OSError(status)
certtext = strip_to_x509_cert(output)
if not certtext:
raise ValueError("Invalid response received from server at %s:%s" %
(host, port))
return certtext
...@@ -110,12 +110,12 @@ class BasicTests(unittest.TestCase): ...@@ -110,12 +110,12 @@ class BasicTests(unittest.TestCase):
if test_support.verbose: if test_support.verbose:
print "test_978833 ..." print "test_978833 ..."
import os, httplib import os, httplib, ssl
with test_support.transient_internet(): with test_support.transient_internet():
s = socket.socket(socket.AF_INET) s = socket.socket(socket.AF_INET)
s.connect(("www.sf.net", 443)) s.connect(("www.sf.net", 443))
fd = s._sock.fileno() fd = s._sock.fileno()
sock = httplib.FakeSocket(s, socket.ssl(s)) sock = ssl.sslsocket(s)
s = None s = None
sock.close() sock.close()
try: try:
......
...@@ -91,6 +91,14 @@ def urlcleanup(): ...@@ -91,6 +91,14 @@ def urlcleanup():
if _urlopener: if _urlopener:
_urlopener.cleanup() _urlopener.cleanup()
# check for SSL
try:
import ssl
except:
_have_ssl = False
else:
_have_ssl = True
# exception raised when downloaded size does not match content-length # exception raised when downloaded size does not match content-length
class ContentTooShortError(IOError): class ContentTooShortError(IOError):
def __init__(self, message, content): def __init__(self, message, content):
...@@ -361,9 +369,10 @@ class URLopener: ...@@ -361,9 +369,10 @@ class URLopener:
fp.close() fp.close()
raise IOError, ('http error', errcode, errmsg, headers) raise IOError, ('http error', errcode, errmsg, headers)
if hasattr(socket, "ssl"): if _have_ssl:
def open_https(self, url, data=None): def open_https(self, url, data=None):
"""Use HTTPS protocol.""" """Use HTTPS protocol."""
import httplib import httplib
user_passwd = None user_passwd = None
proxy_passwd = None proxy_passwd = None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment