Commit a675c8b1 authored by Benjamin Peterson's avatar Benjamin Peterson

allow a SSLContext to be given to ftplib.FTP_TLS

parent 1eef4307
...@@ -55,18 +55,26 @@ The module defines the following items: ...@@ -55,18 +55,26 @@ The module defines the following items:
*timeout* was added. *timeout* was added.
.. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, timeout]]]]]]]) .. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, context[, timeout]]]]]]]])
A :class:`FTP` subclass which adds TLS support to FTP as described in A :class:`FTP` subclass which adds TLS support to FTP as described in
:rfc:`4217`. :rfc:`4217`.
Connect as usual to port 21 implicitly securing the FTP control connection Connect as usual to port 21 implicitly securing the FTP control connection
before authenticating. Securing the data connection requires the user to before authenticating. Securing the data connection requires the user to
explicitly ask for it by calling the :meth:`prot_p` method. explicitly ask for it by calling the :meth:`prot_p` method. *context*
*keyfile* and *certfile* are optional -- they can contain a PEM formatted is a :class:`ssl.SSLContext` object which allows bundling SSL configuration
private key and certificate chain file name for the SSL connection. options, certificates and private keys into a single (potentially
long-lived) structure. Please read :ref:`ssl-security` for best practices.
*keyfile* and *certfile* are a legacy alternative to *context* -- they
can point to PEM-formatted private key and certificate chain files
(respectively) for the SSL connection.
.. versionadded:: 2.7 .. versionadded:: 2.7
.. versionchanged:: 2.7.10
The *context* parameter was added.
Here's a sample session using the :class:`FTP_TLS` class: Here's a sample session using the :class:`FTP_TLS` class:
>>> from ftplib import FTP_TLS >>> from ftplib import FTP_TLS
......
...@@ -641,9 +641,21 @@ else: ...@@ -641,9 +641,21 @@ else:
ssl_version = ssl.PROTOCOL_SSLv23 ssl_version = ssl.PROTOCOL_SSLv23
def __init__(self, host='', user='', passwd='', acct='', keyfile=None, def __init__(self, host='', user='', passwd='', acct='', keyfile=None,
certfile=None, timeout=_GLOBAL_DEFAULT_TIMEOUT): certfile=None, context=None,
timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None):
if context is not None and keyfile is not None:
raise ValueError("context and keyfile arguments are mutually "
"exclusive")
if context is not None and certfile is not None:
raise ValueError("context and certfile arguments are mutually "
"exclusive")
self.keyfile = keyfile self.keyfile = keyfile
self.certfile = certfile self.certfile = certfile
if context is None:
context = ssl._create_stdlib_context(self.ssl_version,
certfile=certfile,
keyfile=keyfile)
self.context = context
self._prot_p = False self._prot_p = False
FTP.__init__(self, host, user, passwd, acct, timeout) FTP.__init__(self, host, user, passwd, acct, timeout)
...@@ -660,8 +672,8 @@ else: ...@@ -660,8 +672,8 @@ else:
resp = self.voidcmd('AUTH TLS') resp = self.voidcmd('AUTH TLS')
else: else:
resp = self.voidcmd('AUTH SSL') resp = self.voidcmd('AUTH SSL')
self.sock = ssl.wrap_socket(self.sock, self.keyfile, self.certfile, self.sock = self.context.wrap_socket(self.sock,
ssl_version=self.ssl_version) server_hostname=self.host)
self.file = self.sock.makefile(mode='rb') self.file = self.sock.makefile(mode='rb')
return resp return resp
...@@ -692,8 +704,8 @@ else: ...@@ -692,8 +704,8 @@ else:
def ntransfercmd(self, cmd, rest=None): def ntransfercmd(self, cmd, rest=None):
conn, size = FTP.ntransfercmd(self, cmd, rest) conn, size = FTP.ntransfercmd(self, cmd, rest)
if self._prot_p: if self._prot_p:
conn = ssl.wrap_socket(conn, self.keyfile, self.certfile, conn = self.context.wrap_socket(conn,
ssl_version=self.ssl_version) server_hostname=self.host)
return conn, size return conn, size
def retrbinary(self, cmd, callback, blocksize=8192, rest=None): def retrbinary(self, cmd, callback, blocksize=8192, rest=None):
......
...@@ -20,7 +20,7 @@ from test import test_support ...@@ -20,7 +20,7 @@ from test import test_support
from test.test_support import HOST, HOSTv6 from test.test_support import HOST, HOSTv6
threading = test_support.import_module('threading') threading = test_support.import_module('threading')
TIMEOUT = 3
# the dummy data returned by server over the data channel when # the dummy data returned by server over the data channel when
# RETR, LIST and NLST commands are issued # RETR, LIST and NLST commands are issued
RETR_DATA = 'abcde12345\r\n' * 1000 RETR_DATA = 'abcde12345\r\n' * 1000
...@@ -223,6 +223,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): ...@@ -223,6 +223,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
self.active = False self.active = False
self.active_lock = threading.Lock() self.active_lock = threading.Lock()
self.host, self.port = self.socket.getsockname()[:2] self.host, self.port = self.socket.getsockname()[:2]
self.handler_instance = None
def start(self): def start(self):
assert not self.active assert not self.active
...@@ -246,8 +247,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): ...@@ -246,8 +247,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
def handle_accept(self): def handle_accept(self):
conn, addr = self.accept() conn, addr = self.accept()
self.handler = self.handler(conn) self.handler_instance = self.handler(conn)
self.close()
def handle_connect(self): def handle_connect(self):
self.close() self.close()
...@@ -262,7 +262,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): ...@@ -262,7 +262,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
if ssl is not None: if ssl is not None:
CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem") CERTFILE = os.path.join(os.path.dirname(__file__), "keycert3.pem")
CAFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem")
class SSLConnection(object, asyncore.dispatcher): class SSLConnection(object, asyncore.dispatcher):
"""An asyncore.dispatcher subclass supporting TLS/SSL.""" """An asyncore.dispatcher subclass supporting TLS/SSL."""
...@@ -271,23 +272,25 @@ if ssl is not None: ...@@ -271,23 +272,25 @@ if ssl is not None:
_ssl_closing = False _ssl_closing = False
def secure_connection(self): def secure_connection(self):
self.socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False, socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
certfile=CERTFILE, server_side=True, certfile=CERTFILE, server_side=True,
do_handshake_on_connect=False, do_handshake_on_connect=False,
ssl_version=ssl.PROTOCOL_SSLv23) ssl_version=ssl.PROTOCOL_SSLv23)
self.del_channel()
self.set_socket(socket)
self._ssl_accepting = True self._ssl_accepting = True
def _do_ssl_handshake(self): def _do_ssl_handshake(self):
try: try:
self.socket.do_handshake() self.socket.do_handshake()
except ssl.SSLError, err: except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ, if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE): ssl.SSL_ERROR_WANT_WRITE):
return return
elif err.args[0] == ssl.SSL_ERROR_EOF: elif err.args[0] == ssl.SSL_ERROR_EOF:
return self.handle_close() return self.handle_close()
raise raise
except socket.error, err: except socket.error as err:
if err.args[0] == errno.ECONNABORTED: if err.args[0] == errno.ECONNABORTED:
return self.handle_close() return self.handle_close()
else: else:
...@@ -297,18 +300,21 @@ if ssl is not None: ...@@ -297,18 +300,21 @@ if ssl is not None:
self._ssl_closing = True self._ssl_closing = True
try: try:
self.socket = self.socket.unwrap() self.socket = self.socket.unwrap()
except ssl.SSLError, err: except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ, if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE): ssl.SSL_ERROR_WANT_WRITE):
return return
except socket.error, err: except socket.error as err:
# Any "socket error" corresponds to a SSL_ERROR_SYSCALL return # Any "socket error" corresponds to a SSL_ERROR_SYSCALL return
# from OpenSSL's SSL_shutdown(), corresponding to a # from OpenSSL's SSL_shutdown(), corresponding to a
# closed socket condition. See also: # closed socket condition. See also:
# http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html # http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html
pass pass
self._ssl_closing = False self._ssl_closing = False
super(SSLConnection, self).close() if getattr(self, '_ccc', False) is False:
super(SSLConnection, self).close()
else:
pass
def handle_read_event(self): def handle_read_event(self):
if self._ssl_accepting: if self._ssl_accepting:
...@@ -329,7 +335,7 @@ if ssl is not None: ...@@ -329,7 +335,7 @@ if ssl is not None:
def send(self, data): def send(self, data):
try: try:
return super(SSLConnection, self).send(data) return super(SSLConnection, self).send(data)
except ssl.SSLError, err: except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN, if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN,
ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE): ssl.SSL_ERROR_WANT_WRITE):
...@@ -339,13 +345,13 @@ if ssl is not None: ...@@ -339,13 +345,13 @@ if ssl is not None:
def recv(self, buffer_size): def recv(self, buffer_size):
try: try:
return super(SSLConnection, self).recv(buffer_size) return super(SSLConnection, self).recv(buffer_size)
except ssl.SSLError, err: except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ, if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE): ssl.SSL_ERROR_WANT_WRITE):
return '' return b''
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
self.handle_close() self.handle_close()
return '' return b''
raise raise
def handle_error(self): def handle_error(self):
...@@ -355,6 +361,8 @@ if ssl is not None: ...@@ -355,6 +361,8 @@ if ssl is not None:
if (isinstance(self.socket, ssl.SSLSocket) and if (isinstance(self.socket, ssl.SSLSocket) and
self.socket._sslobj is not None): self.socket._sslobj is not None):
self._do_ssl_shutdown() self._do_ssl_shutdown()
else:
super(SSLConnection, self).close()
class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler): class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler):
...@@ -462,12 +470,12 @@ class TestFTPClass(TestCase): ...@@ -462,12 +470,12 @@ class TestFTPClass(TestCase):
def test_rename(self): def test_rename(self):
self.client.rename('a', 'b') self.client.rename('a', 'b')
self.server.handler.next_response = '200' self.server.handler_instance.next_response = '200'
self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b') self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b')
def test_delete(self): def test_delete(self):
self.client.delete('foo') self.client.delete('foo')
self.server.handler.next_response = '199' self.server.handler_instance.next_response = '199'
self.assertRaises(ftplib.error_reply, self.client.delete, 'foo') self.assertRaises(ftplib.error_reply, self.client.delete, 'foo')
def test_size(self): def test_size(self):
...@@ -515,7 +523,7 @@ class TestFTPClass(TestCase): ...@@ -515,7 +523,7 @@ class TestFTPClass(TestCase):
def test_storbinary(self): def test_storbinary(self):
f = StringIO.StringIO(RETR_DATA) f = StringIO.StringIO(RETR_DATA)
self.client.storbinary('stor', f) self.client.storbinary('stor', f)
self.assertEqual(self.server.handler.last_received_data, RETR_DATA) self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg # test new callback arg
flag = [] flag = []
f.seek(0) f.seek(0)
...@@ -527,12 +535,12 @@ class TestFTPClass(TestCase): ...@@ -527,12 +535,12 @@ class TestFTPClass(TestCase):
for r in (30, '30'): for r in (30, '30'):
f.seek(0) f.seek(0)
self.client.storbinary('stor', f, rest=r) self.client.storbinary('stor', f, rest=r)
self.assertEqual(self.server.handler.rest, str(r)) self.assertEqual(self.server.handler_instance.rest, str(r))
def test_storlines(self): def test_storlines(self):
f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n')) f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n'))
self.client.storlines('stor', f) self.client.storlines('stor', f)
self.assertEqual(self.server.handler.last_received_data, RETR_DATA) self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg # test new callback arg
flag = [] flag = []
f.seek(0) f.seek(0)
...@@ -551,14 +559,14 @@ class TestFTPClass(TestCase): ...@@ -551,14 +559,14 @@ class TestFTPClass(TestCase):
def test_makeport(self): def test_makeport(self):
self.client.makeport() self.client.makeport()
# IPv4 is in use, just make sure send_eprt has not been used # IPv4 is in use, just make sure send_eprt has not been used
self.assertEqual(self.server.handler.last_received_cmd, 'port') self.assertEqual(self.server.handler_instance.last_received_cmd, 'port')
def test_makepasv(self): def test_makepasv(self):
host, port = self.client.makepasv() host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10) conn = socket.create_connection((host, port), 10)
conn.close() conn.close()
# IPv4 is in use, just make sure send_epsv has not been used # IPv4 is in use, just make sure send_epsv has not been used
self.assertEqual(self.server.handler.last_received_cmd, 'pasv') self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv')
def test_line_too_long(self): def test_line_too_long(self):
self.assertRaises(ftplib.Error, self.client.sendcmd, self.assertRaises(ftplib.Error, self.client.sendcmd,
...@@ -600,13 +608,13 @@ class TestIPv6Environment(TestCase): ...@@ -600,13 +608,13 @@ class TestIPv6Environment(TestCase):
def test_makeport(self): def test_makeport(self):
self.client.makeport() self.client.makeport()
self.assertEqual(self.server.handler.last_received_cmd, 'eprt') self.assertEqual(self.server.handler_instance.last_received_cmd, 'eprt')
def test_makepasv(self): def test_makepasv(self):
host, port = self.client.makepasv() host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10) conn = socket.create_connection((host, port), 10)
conn.close() conn.close()
self.assertEqual(self.server.handler.last_received_cmd, 'epsv') self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv')
def test_transfer(self): def test_transfer(self):
def retr(): def retr():
...@@ -642,7 +650,7 @@ class TestTLS_FTPClass(TestCase): ...@@ -642,7 +650,7 @@ class TestTLS_FTPClass(TestCase):
def setUp(self): def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0)) self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start() self.server.start()
self.client = ftplib.FTP_TLS(timeout=10) self.client = ftplib.FTP_TLS(timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port) self.client.connect(self.server.host, self.server.port)
def tearDown(self): def tearDown(self):
...@@ -695,6 +703,59 @@ class TestTLS_FTPClass(TestCase): ...@@ -695,6 +703,59 @@ class TestTLS_FTPClass(TestCase):
finally: finally:
self.client.ssl_version = ssl.PROTOCOL_TLSv1 self.client.ssl_version = ssl.PROTOCOL_TLSv1
def test_context(self):
self.client.quit()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE,
context=ctx)
self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
context=ctx)
self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
keyfile=CERTFILE, context=ctx)
self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
self.client.auth()
self.assertIs(self.client.sock.context, ctx)
self.assertIsInstance(self.client.sock, ssl.SSLSocket)
self.client.prot_p()
sock = self.client.transfercmd('list')
try:
self.assertIs(sock.context, ctx)
self.assertIsInstance(sock, ssl.SSLSocket)
finally:
sock.close()
def test_check_hostname(self):
self.client.quit()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = True
ctx.load_verify_locations(CAFILE)
self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
# 127.0.0.1 doesn't match SAN
self.client.connect(self.server.host, self.server.port)
with self.assertRaises(ssl.CertificateError):
self.client.auth()
# exception quits connection
self.client.connect(self.server.host, self.server.port)
self.client.prot_p()
with self.assertRaises(ssl.CertificateError):
self.client.transfercmd("list").close()
self.client.quit()
self.client.connect("localhost", self.server.port)
self.client.auth()
self.client.quit()
self.client.connect("localhost", self.server.port)
self.client.prot_p()
self.client.transfercmd("list").close()
class TestTimeouts(TestCase): class TestTimeouts(TestCase):
......
...@@ -15,6 +15,8 @@ Core and Builtins ...@@ -15,6 +15,8 @@ Core and Builtins
Library Library
------- -------
- Backport the context argument to ftplib.FTP_TLS.
- Issue #23111: Maximize compatibility in protocol versions of ftplib.FTP_TLS. - Issue #23111: Maximize compatibility in protocol versions of ftplib.FTP_TLS.
- Issue #23112: Fix SimpleHTTPServer to correctly carry the query string and - Issue #23112: Fix SimpleHTTPServer to correctly carry the query string and
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment