Commit 7b4b2846 authored by Benjamin Peterson's avatar Benjamin Peterson

allow a SSLContext to be given to ftplib.FTP_TLS

parent 9fe67cee
......@@ -55,18 +55,26 @@ The module defines the following items:
*timeout* was added.
.. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, timeout]]]]]]])
.. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, context[, timeout]]]]]]]])
A :class:`FTP` subclass which adds TLS support to FTP as described in
:rfc:`4217`.
Connect as usual to port 21 implicitly securing the FTP control connection
before authenticating. Securing the data connection requires the user to
explicitly ask for it by calling the :meth:`prot_p` method.
*keyfile* and *certfile* are optional -- they can contain a PEM formatted
private key and certificate chain file name for the SSL connection.
explicitly ask for it by calling the :meth:`prot_p` method. *context*
is a :class:`ssl.SSLContext` object which allows bundling SSL configuration
options, certificates and private keys into a single (potentially
long-lived) structure. Please read :ref:`ssl-security` for best practices.
*keyfile* and *certfile* are a legacy alternative to *context* -- they
can point to PEM-formatted private key and certificate chain files
(respectively) for the SSL connection.
.. versionadded:: 2.7
.. versionchanged:: 2.7.10
The *context* parameter was added.
Here's a sample session using the :class:`FTP_TLS` class:
>>> from ftplib import FTP_TLS
......
......@@ -641,9 +641,21 @@ else:
ssl_version = ssl.PROTOCOL_SSLv23
def __init__(self, host='', user='', passwd='', acct='', keyfile=None,
certfile=None, timeout=_GLOBAL_DEFAULT_TIMEOUT):
certfile=None, context=None,
timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None):
if context is not None and keyfile is not None:
raise ValueError("context and keyfile arguments are mutually "
"exclusive")
if context is not None and certfile is not None:
raise ValueError("context and certfile arguments are mutually "
"exclusive")
self.keyfile = keyfile
self.certfile = certfile
if context is None:
context = ssl._create_stdlib_context(self.ssl_version,
certfile=certfile,
keyfile=keyfile)
self.context = context
self._prot_p = False
FTP.__init__(self, host, user, passwd, acct, timeout)
......@@ -660,8 +672,8 @@ else:
resp = self.voidcmd('AUTH TLS')
else:
resp = self.voidcmd('AUTH SSL')
self.sock = ssl.wrap_socket(self.sock, self.keyfile, self.certfile,
ssl_version=self.ssl_version)
self.sock = self.context.wrap_socket(self.sock,
server_hostname=self.host)
self.file = self.sock.makefile(mode='rb')
return resp
......@@ -692,8 +704,8 @@ else:
def ntransfercmd(self, cmd, rest=None):
conn, size = FTP.ntransfercmd(self, cmd, rest)
if self._prot_p:
conn = ssl.wrap_socket(conn, self.keyfile, self.certfile,
ssl_version=self.ssl_version)
conn = self.context.wrap_socket(conn,
server_hostname=self.host)
return conn, size
def retrbinary(self, cmd, callback, blocksize=8192, rest=None):
......
......@@ -20,7 +20,7 @@ from test import test_support
from test.test_support import HOST, HOSTv6
threading = test_support.import_module('threading')
TIMEOUT = 3
# the dummy data returned by server over the data channel when
# RETR, LIST and NLST commands are issued
RETR_DATA = 'abcde12345\r\n' * 1000
......@@ -223,6 +223,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
self.active = False
self.active_lock = threading.Lock()
self.host, self.port = self.socket.getsockname()[:2]
self.handler_instance = None
def start(self):
assert not self.active
......@@ -246,8 +247,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
def handle_accept(self):
conn, addr = self.accept()
self.handler = self.handler(conn)
self.close()
self.handler_instance = self.handler(conn)
def handle_connect(self):
self.close()
......@@ -262,7 +262,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
if ssl is not None:
CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem")
CERTFILE = os.path.join(os.path.dirname(__file__), "keycert3.pem")
CAFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem")
class SSLConnection(object, asyncore.dispatcher):
"""An asyncore.dispatcher subclass supporting TLS/SSL."""
......@@ -271,23 +272,25 @@ if ssl is not None:
_ssl_closing = False
def secure_connection(self):
self.socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
certfile=CERTFILE, server_side=True,
do_handshake_on_connect=False,
ssl_version=ssl.PROTOCOL_SSLv23)
self.del_channel()
self.set_socket(socket)
self._ssl_accepting = True
def _do_ssl_handshake(self):
try:
self.socket.do_handshake()
except ssl.SSLError, err:
except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
elif err.args[0] == ssl.SSL_ERROR_EOF:
return self.handle_close()
raise
except socket.error, err:
except socket.error as err:
if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
else:
......@@ -297,18 +300,21 @@ if ssl is not None:
self._ssl_closing = True
try:
self.socket = self.socket.unwrap()
except ssl.SSLError, err:
except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
except socket.error, err:
except socket.error as err:
# Any "socket error" corresponds to a SSL_ERROR_SYSCALL return
# from OpenSSL's SSL_shutdown(), corresponding to a
# closed socket condition. See also:
# http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html
pass
self._ssl_closing = False
if getattr(self, '_ccc', False) is False:
super(SSLConnection, self).close()
else:
pass
def handle_read_event(self):
if self._ssl_accepting:
......@@ -329,7 +335,7 @@ if ssl is not None:
def send(self, data):
try:
return super(SSLConnection, self).send(data)
except ssl.SSLError, err:
except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN,
ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
......@@ -339,13 +345,13 @@ if ssl is not None:
def recv(self, buffer_size):
try:
return super(SSLConnection, self).recv(buffer_size)
except ssl.SSLError, err:
except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return ''
return b''
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
self.handle_close()
return ''
return b''
raise
def handle_error(self):
......@@ -355,6 +361,8 @@ if ssl is not None:
if (isinstance(self.socket, ssl.SSLSocket) and
self.socket._sslobj is not None):
self._do_ssl_shutdown()
else:
super(SSLConnection, self).close()
class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler):
......@@ -462,12 +470,12 @@ class TestFTPClass(TestCase):
def test_rename(self):
self.client.rename('a', 'b')
self.server.handler.next_response = '200'
self.server.handler_instance.next_response = '200'
self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b')
def test_delete(self):
self.client.delete('foo')
self.server.handler.next_response = '199'
self.server.handler_instance.next_response = '199'
self.assertRaises(ftplib.error_reply, self.client.delete, 'foo')
def test_size(self):
......@@ -515,7 +523,7 @@ class TestFTPClass(TestCase):
def test_storbinary(self):
f = StringIO.StringIO(RETR_DATA)
self.client.storbinary('stor', f)
self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
......@@ -527,12 +535,12 @@ class TestFTPClass(TestCase):
for r in (30, '30'):
f.seek(0)
self.client.storbinary('stor', f, rest=r)
self.assertEqual(self.server.handler.rest, str(r))
self.assertEqual(self.server.handler_instance.rest, str(r))
def test_storlines(self):
f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n'))
self.client.storlines('stor', f)
self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
......@@ -551,14 +559,14 @@ class TestFTPClass(TestCase):
def test_makeport(self):
self.client.makeport()
# IPv4 is in use, just make sure send_eprt has not been used
self.assertEqual(self.server.handler.last_received_cmd, 'port')
self.assertEqual(self.server.handler_instance.last_received_cmd, 'port')
def test_makepasv(self):
host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10)
conn.close()
# IPv4 is in use, just make sure send_epsv has not been used
self.assertEqual(self.server.handler.last_received_cmd, 'pasv')
self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv')
def test_line_too_long(self):
self.assertRaises(ftplib.Error, self.client.sendcmd,
......@@ -600,13 +608,13 @@ class TestIPv6Environment(TestCase):
def test_makeport(self):
self.client.makeport()
self.assertEqual(self.server.handler.last_received_cmd, 'eprt')
self.assertEqual(self.server.handler_instance.last_received_cmd, 'eprt')
def test_makepasv(self):
host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10)
conn.close()
self.assertEqual(self.server.handler.last_received_cmd, 'epsv')
self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv')
def test_transfer(self):
def retr():
......@@ -642,7 +650,7 @@ class TestTLS_FTPClass(TestCase):
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
self.client = ftplib.FTP_TLS(timeout=10)
self.client = ftplib.FTP_TLS(timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
......@@ -695,6 +703,59 @@ class TestTLS_FTPClass(TestCase):
finally:
self.client.ssl_version = ssl.PROTOCOL_TLSv1
def test_context(self):
self.client.quit()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE,
context=ctx)
self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
context=ctx)
self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
keyfile=CERTFILE, context=ctx)
self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
self.client.auth()
self.assertIs(self.client.sock.context, ctx)
self.assertIsInstance(self.client.sock, ssl.SSLSocket)
self.client.prot_p()
sock = self.client.transfercmd('list')
try:
self.assertIs(sock.context, ctx)
self.assertIsInstance(sock, ssl.SSLSocket)
finally:
sock.close()
def test_check_hostname(self):
self.client.quit()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = True
ctx.load_verify_locations(CAFILE)
self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
# 127.0.0.1 doesn't match SAN
self.client.connect(self.server.host, self.server.port)
with self.assertRaises(ssl.CertificateError):
self.client.auth()
# exception quits connection
self.client.connect(self.server.host, self.server.port)
self.client.prot_p()
with self.assertRaises(ssl.CertificateError):
self.client.transfercmd("list").close()
self.client.quit()
self.client.connect("localhost", self.server.port)
self.client.auth()
self.client.quit()
self.client.connect("localhost", self.server.port)
self.client.prot_p()
self.client.transfercmd("list").close()
class TestTimeouts(TestCase):
......
......@@ -15,6 +15,8 @@ Core and Builtins
Library
-------
- Backport the context argument to ftplib.FTP_TLS.
- Issue #23111: Maximize compatibility in protocol versions of ftplib.FTP_TLS.
- Issue #23112: Fix SimpleHTTPServer to correctly carry the query string and
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment