Commit 36f28f7a authored by Senthil Kumaran's avatar Senthil Kumaran

Backport Fix for Issue #7776: Fix ``Host:'' header and reconnection when using...

Backport Fix for Issue #7776: Fix ``Host:'' header and reconnection when using http.client.HTTPConnection.set_tunnel().

Patch by Nikolaus Rath.
parent fb371afa
...@@ -700,17 +700,33 @@ class HTTPConnection: ...@@ -700,17 +700,33 @@ class HTTPConnection:
self._tunnel_host = None self._tunnel_host = None
self._tunnel_port = None self._tunnel_port = None
self._tunnel_headers = {} self._tunnel_headers = {}
self._set_hostport(host, port)
if strict is not None: if strict is not None:
self.strict = strict self.strict = strict
(self.host, self.port) = self._get_hostport(host, port)
# This is stored as an instance variable to allow unittests
# to replace with a suitable mock
self._create_connection = socket.create_connection
def set_tunnel(self, host, port=None, headers=None): def set_tunnel(self, host, port=None, headers=None):
""" Sets up the host and the port for the HTTP CONNECT Tunnelling. """ Set up host and port for HTTP CONNECT tunnelling.
In a connection that uses HTTP Connect tunneling, the host passed to the
constructor is used as proxy server that relays all communication to the
endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT
request to the proxy server when the connection is established.
This method must be called before the HTML connection has been
established.
The headers argument should be a mapping of extra HTTP headers The headers argument should be a mapping of extra HTTP headers
to send with the CONNECT request. to send with the CONNECT request.
""" """
# Verify if this is required.
if self.sock:
raise RuntimeError("Can't setup tunnel for established connection.")
self._tunnel_host = host self._tunnel_host = host
self._tunnel_port = port self._tunnel_port = port
if headers: if headers:
...@@ -718,7 +734,7 @@ class HTTPConnection: ...@@ -718,7 +734,7 @@ class HTTPConnection:
else: else:
self._tunnel_headers.clear() self._tunnel_headers.clear()
def _set_hostport(self, host, port): def _get_hostport(self, host, port):
if port is None: if port is None:
i = host.rfind(':') i = host.rfind(':')
j = host.rfind(']') # ipv6 addresses have [...] j = host.rfind(']') # ipv6 addresses have [...]
...@@ -735,15 +751,14 @@ class HTTPConnection: ...@@ -735,15 +751,14 @@ class HTTPConnection:
port = self.default_port port = self.default_port
if host and host[0] == '[' and host[-1] == ']': if host and host[0] == '[' and host[-1] == ']':
host = host[1:-1] host = host[1:-1]
self.host = host return (host, port)
self.port = port
def set_debuglevel(self, level): def set_debuglevel(self, level):
self.debuglevel = level self.debuglevel = level
def _tunnel(self): def _tunnel(self):
self._set_hostport(self._tunnel_host, self._tunnel_port) (host, port) = self._get_hostport(self._tunnel_host, self._tunnel_port)
self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port)) self.send("CONNECT %s:%d HTTP/1.0\r\n" % (host, port))
for header, value in self._tunnel_headers.iteritems(): for header, value in self._tunnel_headers.iteritems():
self.send("%s: %s\r\n" % (header, value)) self.send("%s: %s\r\n" % (header, value))
self.send("\r\n") self.send("\r\n")
...@@ -768,8 +783,8 @@ class HTTPConnection: ...@@ -768,8 +783,8 @@ class HTTPConnection:
def connect(self): def connect(self):
"""Connect to the host and port specified in __init__.""" """Connect to the host and port specified in __init__."""
self.sock = socket.create_connection((self.host,self.port), self.sock = self._create_connection((self.host,self.port),
self.timeout, self.source_address) self.timeout, self.source_address)
if self._tunnel_host: if self._tunnel_host:
self._tunnel() self._tunnel()
...@@ -907,17 +922,24 @@ class HTTPConnection: ...@@ -907,17 +922,24 @@ class HTTPConnection:
netloc_enc = netloc.encode("idna") netloc_enc = netloc.encode("idna")
self.putheader('Host', netloc_enc) self.putheader('Host', netloc_enc)
else: else:
if self._tunnel_host:
host = self._tunnel_host
port = self._tunnel_port
else:
host = self.host
port = self.port
try: try:
host_enc = self.host.encode("ascii") host_enc = host.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
host_enc = self.host.encode("idna") host_enc = host.encode("idna")
# Wrap the IPv6 Host Header with [] (RFC 2732) # Wrap the IPv6 Host Header with [] (RFC 2732)
if host_enc.find(':') >= 0: if host_enc.find(':') >= 0:
host_enc = "[" + host_enc + "]" host_enc = "[" + host_enc + "]"
if self.port == self.default_port: if port == self.default_port:
self.putheader('Host', host_enc) self.putheader('Host', host_enc)
else: else:
self.putheader('Host', "%s:%s" % (host_enc, self.port)) self.putheader('Host', "%s:%s" % (host_enc, port))
# note: we are assuming that clients will not attempt to set these # note: we are assuming that clients will not attempt to set these
# headers since *this* library must deal with the # headers since *this* library must deal with the
...@@ -1168,8 +1190,8 @@ else: ...@@ -1168,8 +1190,8 @@ else:
def connect(self): def connect(self):
"Connect to a host on a given (SSL) port." "Connect to a host on a given (SSL) port."
sock = socket.create_connection((self.host, self.port), sock = self._create_connection((self.host, self.port),
self.timeout, self.source_address) self.timeout, self.source_address)
if self._tunnel_host: if self._tunnel_host:
self.sock = sock self.sock = sock
self._tunnel() self._tunnel()
......
...@@ -13,10 +13,12 @@ from test import test_support ...@@ -13,10 +13,12 @@ from test import test_support
HOST = test_support.HOST HOST = test_support.HOST
class FakeSocket: class FakeSocket:
def __init__(self, text, fileclass=StringIO.StringIO): def __init__(self, text, fileclass=StringIO.StringIO, host=None, port=None):
self.text = text self.text = text
self.fileclass = fileclass self.fileclass = fileclass
self.data = '' self.data = ''
self.host = host
self.port = port
def sendall(self, data): def sendall(self, data):
self.data += ''.join(data) self.data += ''.join(data)
...@@ -26,6 +28,9 @@ class FakeSocket: ...@@ -26,6 +28,9 @@ class FakeSocket:
raise httplib.UnimplementedFileMode() raise httplib.UnimplementedFileMode()
return self.fileclass(self.text) return self.fileclass(self.text)
def close(self):
pass
class EPipeSocket(FakeSocket): class EPipeSocket(FakeSocket):
def __init__(self, text, pipe_trigger): def __init__(self, text, pipe_trigger):
...@@ -526,9 +531,48 @@ class HTTPSTimeoutTest(TestCase): ...@@ -526,9 +531,48 @@ class HTTPSTimeoutTest(TestCase):
self.fail("Port incorrectly parsed: %s != %s" % (p, c.host)) self.fail("Port incorrectly parsed: %s != %s" % (p, c.host))
class TunnelTests(TestCase):
def test_connect(self):
response_text = (
'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT
'HTTP/1.1 200 OK\r\n' # Reply to HEAD
'Content-Length: 42\r\n\r\n'
)
def create_connection(address, timeout=None, source_address=None):
return FakeSocket(response_text, host=address[0], port=address[1])
conn = httplib.HTTPConnection('proxy.com')
conn._create_connection = create_connection
# Once connected, we should not be able to tunnel anymore
conn.connect()
self.assertRaises(RuntimeError, conn.set_tunnel, 'destination.com')
# But if close the connection, we are good.
conn.close()
conn.set_tunnel('destination.com')
conn.request('HEAD', '/', '')
self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue('CONNECT destination.com' in conn.sock.data)
self.assertTrue('Host: destination.com' in conn.sock.data)
self.assertTrue('Host: proxy.com' not in conn.sock.data)
conn.close()
conn.request('PUT', '/', '')
self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue('CONNECT destination.com' in conn.sock.data)
self.assertTrue('Host: destination.com' in conn.sock.data)
def test_main(verbose=None): def test_main(verbose=None):
test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
HTTPSTimeoutTest, SourceAddressTest) HTTPSTimeoutTest, SourceAddressTest, TunnelTests)
if __name__ == '__main__': if __name__ == '__main__':
test_main() test_main()
...@@ -49,6 +49,10 @@ Core and Builtins ...@@ -49,6 +49,10 @@ Core and Builtins
Library Library
------- -------
- Issue #7776: Backport Fix ``Host:'' header and reconnection when using
http.client.HTTPConnection.set_tunnel() from Python 3.
Patch by Nikolaus Rath.
- Issue #21306: Backport hmac.compare_digest from Python 3. This is part of PEP - Issue #21306: Backport hmac.compare_digest from Python 3. This is part of PEP
466. 466.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment