Commit 70e28473 authored by Benjamin Peterson's avatar Benjamin Peterson

document the requestline and close_connection attributes, use real booleans,...

document the requestline and close_connection attributes, use real booleans, and add tests (closes #23410)

Patch by Martin Panter.
parent e7a2f644
...@@ -64,6 +64,18 @@ of which this module provides three different variants: ...@@ -64,6 +64,18 @@ of which this module provides three different variants:
Contains the server instance. Contains the server instance.
.. attribute:: close_connection
Boolean that should be set before :meth:`handle_one_request` returns,
indicating if another request may be expected, or if the connection should
be shut down.
.. attribute:: requestline
Contains the string representation of the HTTP request line. The
terminating CRLF is stripped. This attribute should be set by
:meth:`handle_one_request`. If no valid request line was processed, it
should be set to the empty string.
.. attribute:: command .. attribute:: command
......
...@@ -273,7 +273,7 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler): ...@@ -273,7 +273,7 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler):
""" """
self.command = None # set in case of error on the first line self.command = None # set in case of error on the first line
self.request_version = version = self.default_request_version self.request_version = version = self.default_request_version
self.close_connection = 1 self.close_connection = True
requestline = str(self.raw_requestline, 'iso-8859-1') requestline = str(self.raw_requestline, 'iso-8859-1')
requestline = requestline.rstrip('\r\n') requestline = requestline.rstrip('\r\n')
self.requestline = requestline self.requestline = requestline
...@@ -299,14 +299,14 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler): ...@@ -299,14 +299,14 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler):
self.send_error(400, "Bad request version (%r)" % version) self.send_error(400, "Bad request version (%r)" % version)
return False return False
if version_number >= (1, 1) and self.protocol_version >= "HTTP/1.1": if version_number >= (1, 1) and self.protocol_version >= "HTTP/1.1":
self.close_connection = 0 self.close_connection = False
if version_number >= (2, 0): if version_number >= (2, 0):
self.send_error(505, self.send_error(505,
"Invalid HTTP Version (%s)" % base_version_number) "Invalid HTTP Version (%s)" % base_version_number)
return False return False
elif len(words) == 2: elif len(words) == 2:
command, path = words command, path = words
self.close_connection = 1 self.close_connection = True
if command != 'GET': if command != 'GET':
self.send_error(400, self.send_error(400,
"Bad HTTP/0.9 request type (%r)" % command) "Bad HTTP/0.9 request type (%r)" % command)
...@@ -328,10 +328,10 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler): ...@@ -328,10 +328,10 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler):
conntype = self.headers.get('Connection', "") conntype = self.headers.get('Connection', "")
if conntype.lower() == 'close': if conntype.lower() == 'close':
self.close_connection = 1 self.close_connection = True
elif (conntype.lower() == 'keep-alive' and elif (conntype.lower() == 'keep-alive' and
self.protocol_version >= "HTTP/1.1"): self.protocol_version >= "HTTP/1.1"):
self.close_connection = 0 self.close_connection = False
# Examine the headers and look for an Expect directive # Examine the headers and look for an Expect directive
expect = self.headers.get('Expect', "") expect = self.headers.get('Expect', "")
if (expect.lower() == "100-continue" and if (expect.lower() == "100-continue" and
...@@ -376,7 +376,7 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler): ...@@ -376,7 +376,7 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler):
self.send_error(414) self.send_error(414)
return return
if not self.raw_requestline: if not self.raw_requestline:
self.close_connection = 1 self.close_connection = True
return return
if not self.parse_request(): if not self.parse_request():
# An error code has been sent, just exit # An error code has been sent, just exit
...@@ -391,12 +391,12 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler): ...@@ -391,12 +391,12 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler):
except socket.timeout as e: except socket.timeout as e:
#a read or a write timed out. Discard this connection #a read or a write timed out. Discard this connection
self.log_error("Request timed out: %r", e) self.log_error("Request timed out: %r", e)
self.close_connection = 1 self.close_connection = True
return return
def handle(self): def handle(self):
"""Handle multiple requests if necessary.""" """Handle multiple requests if necessary."""
self.close_connection = 1 self.close_connection = True
self.handle_one_request() self.handle_one_request()
while not self.close_connection: while not self.close_connection:
...@@ -478,9 +478,9 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler): ...@@ -478,9 +478,9 @@ class BaseHTTPRequestHandler(socketserver.StreamRequestHandler):
if keyword.lower() == 'connection': if keyword.lower() == 'connection':
if value.lower() == 'close': if value.lower() == 'close':
self.close_connection = 1 self.close_connection = True
elif value.lower() == 'keep-alive': elif value.lower() == 'keep-alive':
self.close_connection = 0 self.close_connection = False
def end_headers(self): def end_headers(self):
"""Send the blank line ending the MIME headers.""" """Send the blank line ending the MIME headers."""
......
...@@ -616,6 +616,11 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase): ...@@ -616,6 +616,11 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
self.verify_expected_headers(result[1:-1]) self.verify_expected_headers(result[1:-1])
self.verify_get_called() self.verify_get_called()
self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n') self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n')
self.assertEqual(self.handler.requestline, 'GET / HTTP/1.1')
self.assertEqual(self.handler.command, 'GET')
self.assertEqual(self.handler.path, '/')
self.assertEqual(self.handler.request_version, 'HTTP/1.1')
self.assertSequenceEqual(self.handler.headers.items(), ())
def test_http_1_0(self): def test_http_1_0(self):
result = self.send_typical_request(b'GET / HTTP/1.0\r\n\r\n') result = self.send_typical_request(b'GET / HTTP/1.0\r\n\r\n')
...@@ -623,6 +628,11 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase): ...@@ -623,6 +628,11 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
self.verify_expected_headers(result[1:-1]) self.verify_expected_headers(result[1:-1])
self.verify_get_called() self.verify_get_called()
self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n') self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n')
self.assertEqual(self.handler.requestline, 'GET / HTTP/1.0')
self.assertEqual(self.handler.command, 'GET')
self.assertEqual(self.handler.path, '/')
self.assertEqual(self.handler.request_version, 'HTTP/1.0')
self.assertSequenceEqual(self.handler.headers.items(), ())
def test_http_0_9(self): def test_http_0_9(self):
result = self.send_typical_request(b'GET / HTTP/0.9\r\n\r\n') result = self.send_typical_request(b'GET / HTTP/0.9\r\n\r\n')
...@@ -636,6 +646,12 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase): ...@@ -636,6 +646,12 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
self.verify_expected_headers(result[1:-1]) self.verify_expected_headers(result[1:-1])
self.verify_get_called() self.verify_get_called()
self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n') self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n')
self.assertEqual(self.handler.requestline, 'GET / HTTP/1.0')
self.assertEqual(self.handler.command, 'GET')
self.assertEqual(self.handler.path, '/')
self.assertEqual(self.handler.request_version, 'HTTP/1.0')
headers = (("Expect", "100-continue"),)
self.assertSequenceEqual(self.handler.headers.items(), headers)
def test_with_continue_1_1(self): def test_with_continue_1_1(self):
result = self.send_typical_request(b'GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n') result = self.send_typical_request(b'GET / HTTP/1.1\r\nExpect: 100-continue\r\n\r\n')
...@@ -645,6 +661,12 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase): ...@@ -645,6 +661,12 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
self.verify_expected_headers(result[2:-1]) self.verify_expected_headers(result[2:-1])
self.verify_get_called() self.verify_get_called()
self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n') self.assertEqual(result[-1], b'<html><body>Data</body></html>\r\n')
self.assertEqual(self.handler.requestline, 'GET / HTTP/1.1')
self.assertEqual(self.handler.command, 'GET')
self.assertEqual(self.handler.path, '/')
self.assertEqual(self.handler.request_version, 'HTTP/1.1')
headers = (("Expect", "100-continue"),)
self.assertSequenceEqual(self.handler.headers.items(), headers)
def test_header_buffering_of_send_error(self): def test_header_buffering_of_send_error(self):
...@@ -730,6 +752,7 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase): ...@@ -730,6 +752,7 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
result = self.send_typical_request(b'GET ' + b'x' * 65537) result = self.send_typical_request(b'GET ' + b'x' * 65537)
self.assertEqual(result[0], b'HTTP/1.1 414 Request-URI Too Long\r\n') self.assertEqual(result[0], b'HTTP/1.1 414 Request-URI Too Long\r\n')
self.assertFalse(self.handler.get_called) self.assertFalse(self.handler.get_called)
self.assertIsInstance(self.handler.requestline, str)
def test_header_length(self): def test_header_length(self):
# Issue #6791: same for headers # Issue #6791: same for headers
...@@ -737,6 +760,22 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase): ...@@ -737,6 +760,22 @@ class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
b'GET / HTTP/1.1\r\nX-Foo: bar' + b'r' * 65537 + b'\r\n\r\n') b'GET / HTTP/1.1\r\nX-Foo: bar' + b'r' * 65537 + b'\r\n\r\n')
self.assertEqual(result[0], b'HTTP/1.1 400 Line too long\r\n') self.assertEqual(result[0], b'HTTP/1.1 400 Line too long\r\n')
self.assertFalse(self.handler.get_called) self.assertFalse(self.handler.get_called)
self.assertEqual(self.handler.requestline, 'GET / HTTP/1.1')
def test_close_connection(self):
# handle_one_request() should be repeatedly called until
# it sets close_connection
def handle_one_request():
self.handler.close_connection = next(close_values)
self.handler.handle_one_request = handle_one_request
close_values = iter((True,))
self.handler.handle()
self.assertRaises(StopIteration, next, close_values)
close_values = iter((False, False, True))
self.handler.handle()
self.assertRaises(StopIteration, next, close_values)
class SimpleHTTPRequestHandlerTestCase(unittest.TestCase): class SimpleHTTPRequestHandlerTestCase(unittest.TestCase):
""" Test url parsing """ """ Test url parsing """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment