Commit 062086d2 authored by Jason Madden's avatar Jason Madden Committed by GitHub

Merge pull request #1490 from gevent/issue1331

gevent.pywsgi: Support keep-alive in HTTP/1.0
parents 9fe8ba12 6f734773
...@@ -43,6 +43,9 @@ ...@@ -43,6 +43,9 @@
calling it, so ``unlink`` can sometimes be optimized out. See calling it, so ``unlink`` can sometimes be optimized out. See
:issue:`1487`. :issue:`1487`.
- Make ``gevent.pywsgi`` support ``Connection: keep-alive`` in
HTTP/1.0. Based on :pr:`1331` by tanchuhan.
1.5a2 (2019-10-21) 1.5a2 (2019-10-21)
================== ==================
......
...@@ -555,7 +555,11 @@ class WSGIHandler(object): ...@@ -555,7 +555,11 @@ class WSGIHandler(object):
if self.request_version == "HTTP/1.1": if self.request_version == "HTTP/1.1":
conntype = self.headers.get("Connection", "").lower() conntype = self.headers.get("Connection", "").lower()
self.close_connection = (conntype == 'close') self.close_connection = (conntype == 'close')
elif self.request_version == 'HTTP/1.0':
conntype = self.headers.get("Connection", "close").lower()
self.close_connection = (conntype != 'keep-alive')
else: else:
# XXX: HTTP 0.9. We should drop support
self.close_connection = True self.close_connection = True
return True return True
...@@ -842,7 +846,7 @@ class WSGIHandler(object): ...@@ -842,7 +846,7 @@ class WSGIHandler(object):
self.response_headers = response_headers self.response_headers = response_headers
self.code = code self.code = code
provided_connection = None provided_connection = None # Did the wsgi app give us a Connection header?
self.provided_date = None self.provided_date = None
self.provided_content_length = None self.provided_content_length = None
...@@ -856,8 +860,8 @@ class WSGIHandler(object): ...@@ -856,8 +860,8 @@ class WSGIHandler(object):
self.provided_content_length = value self.provided_content_length = value
if self.request_version == 'HTTP/1.0' and provided_connection is None: if self.request_version == 'HTTP/1.0' and provided_connection is None:
response_headers.append((b'Connection', b'close')) conntype = b'close' if self.close_connection else b'keep-alive'
self.close_connection = True response_headers.append((b'Connection', conntype))
elif provided_connection == 'close': elif provided_connection == 'close':
self.close_connection = True self.close_connection = True
......
...@@ -54,13 +54,11 @@ from gevent.pywsgi import Input ...@@ -54,13 +54,11 @@ from gevent.pywsgi import Input
CONTENT_LENGTH = 'Content-Length' CONTENT_LENGTH = 'Content-Length'
CONN_ABORTED_ERRORS = greentest.CONN_ABORTED_ERRORS CONN_ABORTED_ERRORS = greentest.CONN_ABORTED_ERRORS
server_implements_chunked = True
server_implements_pipeline = True
server_implements_100continue = True
DEBUG = '-v' in sys.argv
REASONS = {200: 'OK', REASONS = {
500: 'Internal Server Error'} 200: 'OK',
500: 'Internal Server Error'
}
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
...@@ -316,25 +314,59 @@ class TestCase(greentest.TestCase): ...@@ -316,25 +314,59 @@ class TestCase(greentest.TestCase):
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n') fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n')
return read_http(fd, *args, **kwargs) return read_http(fd, *args, **kwargs)
class CommonTests(TestCase): HTTP_CLIENT_VERSION = '1.1'
DEFAULT_EXTRA_CLIENT_HEADERS = {}
def format_request(self, method='GET', path='/', **headers):
def_headers = self.DEFAULT_EXTRA_CLIENT_HEADERS.copy()
def_headers.update(headers)
headers = def_headers
headers = '\r\n'.join('%s: %s' % item for item in headers.items())
headers = headers + '\r\n' if headers else headers
result = (
'%(method)s %(path)s HTTP/%(http_ver)s\r\n'
'Host: localhost\r\n'
'%(headers)s'
'\r\n'
)
result = result % dict(
method=method,
path=path,
http_ver=self.HTTP_CLIENT_VERSION,
headers=headers
)
return result
class CommonTestMixin(object):
PIPELINE_NOT_SUPPORTED_EXS = ()
EXPECT_CLOSE = False
EXPECT_KEEPALIVE = False
def test_basic(self): def test_basic(self):
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n') fd.write(self.format_request())
response = read_http(fd, body='hello world') response = read_http(fd, body='hello world')
if response.headers.get('Connection') == 'close' and not server_implements_pipeline: if response.headers.get('Connection') == 'close':
return self.assertTrue(self.EXPECT_CLOSE, "Server closed connection, not expecting that")
fd.write('GET /notexist HTTP/1.1\r\nHost: localhost\r\n\r\n') return response, None
read_http(fd, code=404, reason='Not Found', body='not found')
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n') self.assertFalse(self.EXPECT_CLOSE)
read_http(fd, body='hello world') if self.EXPECT_KEEPALIVE:
response.assertHeader('Connection', 'keep-alive')
fd.write(self.format_request(path='/notexist'))
dne_response = read_http(fd, code=404, reason='Not Found', body='not found')
fd.write(self.format_request())
response = read_http(fd, body='hello world')
return response, dne_response
def test_pipeline(self): def test_pipeline(self):
if not server_implements_pipeline:
return
exception = AssertionError('HTTP pipelining not supported; the second request is thrown away') exception = AssertionError('HTTP pipelining not supported; the second request is thrown away')
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n' + 'GET /notexist HTTP/1.1\r\nHost: localhost\r\n\r\n') fd.write(self.format_request() + self.format_request(path='/notexist'))
read_http(fd, body='hello world') read_http(fd, body='hello world')
try: try:
...@@ -343,19 +375,26 @@ class CommonTests(TestCase): ...@@ -343,19 +375,26 @@ class CommonTests(TestCase):
read_http(fd, code=404, reason='Not Found', body='not found') read_http(fd, code=404, reason='Not Found', body='not found')
finally: finally:
timeout.close() timeout.close()
except self.PIPELINE_NOT_SUPPORTED_EXS:
pass
except AssertionError as ex: except AssertionError as ex:
if ex is not exception: if ex is not exception:
raise raise
def test_connection_close(self): def test_connection_close(self):
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n') fd.write(self.format_request())
response = read_http(fd) response = read_http(fd)
if response.headers.get('Connection') == 'close' and not server_implements_pipeline: if response.headers.get('Connection') == 'close':
self.assertTrue(self.EXPECT_CLOSE, "Server closed connection, not expecting that")
return return
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n') self.assertFalse(self.EXPECT_CLOSE)
if self.EXPECT_KEEPALIVE:
response.assertHeader('Connection', 'keep-alive')
fd.write(self.format_request(Connection='close'))
read_http(fd) read_http(fd)
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n') fd.write(self.format_request())
# This may either raise, or it may return an empty response, # This may either raise, or it may return an empty response,
# depend on timing and the Python version. # depend on timing and the Python version.
try: try:
...@@ -384,7 +423,7 @@ class CommonTests(TestCase): ...@@ -384,7 +423,7 @@ class CommonTests(TestCase):
self.assertEqual(status, '414') self.assertEqual(status, '414')
class TestNoChunks(CommonTests): class TestNoChunks(CommonTestMixin, TestCase):
# when returning a list of strings a shortcut is employed by the server: # when returning a list of strings a shortcut is employed by the server:
# it calculates the content-length and joins all the chunks before sending # it calculates the content-length and joins all the chunks before sending
validator = None validator = None
...@@ -398,22 +437,32 @@ class TestNoChunks(CommonTests): ...@@ -398,22 +437,32 @@ class TestNoChunks(CommonTests):
start_response('404 Not Found', [('Content-Type', 'text/plain')]) start_response('404 Not Found', [('Content-Type', 'text/plain')])
return [b'not ', b'found'] return [b'not ', b'found']
def test(self): def test_basic(self):
if not server_implements_pipeline: response, dne_response = super(TestNoChunks, self).test_basic()
raise unittest.SkipTest("No pipelines") self.assertFalse(response.chunks)
response.assertHeader('Content-Length', '11')
if dne_response is not None:
self.assertFalse(dne_response.chunks)
dne_response.assertHeader('Content-Length', '9')
def test_dne(self):
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\n\r\n') fd.write(self.format_request(path='/notexist'))
response = read_http(fd, body='hello world') response = read_http(fd, code=404, reason='Not Found', body='not found')
self.assertFalse(response.chunks)
self.assertFalse(response.chunks) response.assertHeader('Content-Length', '9')
response.assertHeader('Content-Length', '11')
class TestNoChunks10(TestNoChunks):
HTTP_CLIENT_VERSION = '1.0'
PIPELINE_NOT_SUPPORTED_EXS = (ConnectionClosed,)
EXPECT_CLOSE = True
fd.write('GET /not-found HTTP/1.1\r\nHost: localhost\r\n\r\n') class TestNoChunks10KeepAlive(TestNoChunks10):
response = read_http(fd, code=404, reason='Not Found', body='not found') DEFAULT_EXTRA_CLIENT_HEADERS = {
self.assertFalse(response.chunks) 'Connection': 'keep-alive',
response.assertHeader('Content-Length', '9') }
EXPECT_CLOSE = False
EXPECT_KEEPALIVE = True
class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancestors class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancestors
...@@ -431,7 +480,7 @@ class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancesto ...@@ -431,7 +480,7 @@ class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancesto
return [b'not ', b'found'] return [b'not ', b'found']
class TestYield(CommonTests): class TestYield(CommonTestMixin, TestCase):
@staticmethod @staticmethod
def application(env, start_response): def application(env, start_response):
...@@ -444,7 +493,7 @@ class TestYield(CommonTests): ...@@ -444,7 +493,7 @@ class TestYield(CommonTests):
yield b"not found" yield b"not found"
class TestBytearray(CommonTests): class TestBytearray(CommonTestMixin, TestCase):
validator = None validator = None
...@@ -458,7 +507,7 @@ class TestBytearray(CommonTests): ...@@ -458,7 +507,7 @@ class TestBytearray(CommonTests):
return [bytearray(b"not found")] return [bytearray(b"not found")]
class MultiLineHeader(TestCase): class TestMultiLineHeader(TestCase):
@staticmethod @staticmethod
def application(env, start_response): def application(env, start_response):
assert "test.submit" in env["CONTENT_TYPE"] assert "test.submit" in env["CONTENT_TYPE"]
...@@ -550,13 +599,9 @@ class TestChunkedApp(TestCase): ...@@ -550,13 +599,9 @@ class TestChunkedApp(TestCase):
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n') fd.write('GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n')
response = read_http(fd, body=self.body(), chunks=None) response = read_http(fd, body=self.body(), chunks=None)
if server_implements_chunked:
response.assertHeader('Transfer-Encoding', 'chunked') response.assertHeader('Transfer-Encoding', 'chunked')
self.assertEqual(response.chunks, self.chunks) self.assertEqual(response.chunks, self.chunks)
else:
response.assertHeader('Transfer-Encoding', False)
response.assertHeader('Content-Length', str(len(self.body())))
self.assertEqual(response.chunks, False)
def test_no_chunked_http_1_0(self): def test_no_chunked_http_1_0(self):
with self.makefile() as fd: with self.makefile() as fd:
...@@ -743,22 +788,18 @@ class TestUseWrite(TestCase): ...@@ -743,22 +788,18 @@ class TestUseWrite(TestCase):
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET /no-content-length HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n') fd.write('GET /no-content-length HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n')
response = read_http(fd, body=self.body + self.end) response = read_http(fd, body=self.body + self.end)
if server_implements_chunked:
response.assertHeader('Content-Length', False) response.assertHeader('Content-Length', False)
response.assertHeader('Transfer-Encoding', 'chunked') response.assertHeader('Transfer-Encoding', 'chunked')
else:
response.assertHeader('Content-Length', self.content_length)
def test_no_content_length_twice(self): def test_no_content_length_twice(self):
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET /no-content-length-twice HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n') fd.write('GET /no-content-length-twice HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n')
response = read_http(fd, body=self.body + self.body + self.end) response = read_http(fd, body=self.body + self.body + self.end)
if server_implements_chunked:
response.assertHeader('Content-Length', False) response.assertHeader('Content-Length', False)
response.assertHeader('Transfer-Encoding', 'chunked') response.assertHeader('Transfer-Encoding', 'chunked')
self.assertEqual(response.chunks, [self.body, self.body, self.end]) self.assertEqual(response.chunks, [self.body, self.body, self.end])
else:
response.assertHeader('Content-Length', str(5 + 5 + 3))
class HttpsTestCase(TestCase): class HttpsTestCase(TestCase):
...@@ -1005,10 +1046,7 @@ class TestEmptyYield(TestCase): ...@@ -1005,10 +1046,7 @@ class TestEmptyYield(TestCase):
yield b"" yield b""
def test_err(self): def test_err(self):
if server_implements_chunked: chunks = []
chunks = []
else:
chunks = False
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n') fd.write('GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n')
...@@ -1028,10 +1066,7 @@ class TestFirstEmptyYield(TestCase): ...@@ -1028,10 +1066,7 @@ class TestFirstEmptyYield(TestCase):
yield b"hello" yield b"hello"
def test_err(self): def test_err(self):
if server_implements_chunked: chunks = [b'hello']
chunks = [b'hello']
else:
chunks = False
with self.makefile() as fd: with self.makefile() as fd:
fd.write('GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n') fd.write('GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n')
...@@ -1194,13 +1229,7 @@ class ChunkedInputTests(TestCase): ...@@ -1194,13 +1229,7 @@ class ChunkedInputTests(TestCase):
read_http(fd, body="pong") read_http(fd, body="pong")
def ping_if_possible(self, fd): def ping_if_possible(self, fd):
try: self.ping(fd)
self.ping(fd)
except ConnectionClosed:
if server_implements_pipeline:
raise
with self.makefile() as fd2:
self.ping(fd2)
def test_short_read_with_content_length(self): def test_short_read_with_content_length(self):
body = self.body() body = self.body()
...@@ -1239,14 +1268,7 @@ class ChunkedInputTests(TestCase): ...@@ -1239,14 +1268,7 @@ class ChunkedInputTests(TestCase):
with self.makefile() as fd: with self.makefile() as fd:
fd.write(req) fd.write(req)
try: read_http(fd, body="pong")
read_http(fd, body="pong")
except AssertionError as ex:
if str(ex).startswith('Unexpected code: 400'):
if not server_implements_chunked:
print('ChunkedNotImplementedWarning')
return
raise
self.ping_if_possible(fd) self.ping_if_possible(fd)
...@@ -1324,16 +1346,8 @@ class Expect100ContinueTests(TestCase): ...@@ -1324,16 +1346,8 @@ class Expect100ContinueTests(TestCase):
def test_continue(self): def test_continue(self):
with self.makefile() as fd: with self.makefile() as fd:
fd.write('PUT / HTTP/1.1\r\nHost: localhost\r\nContent-length: 1025\r\nExpect: 100-continue\r\n\r\n') fd.write('PUT / HTTP/1.1\r\nHost: localhost\r\nContent-length: 1025\r\nExpect: 100-continue\r\n\r\n')
try: read_http(fd, code=417, body="failure")
read_http(fd, code=417, body="failure")
except AssertionError as ex:
if str(ex).startswith('Unexpected code: 400'):
if not server_implements_100continue:
print('100ContinueNotImplementedWarning')
return
raise
fd.write('PUT / HTTP/1.1\r\nHost: localhost\r\nContent-length: 7\r\nExpect: 100-continue\r\n\r\ntesting') fd.write('PUT / HTTP/1.1\r\nHost: localhost\r\nContent-length: 7\r\nExpect: 100-continue\r\n\r\ntesting')
read_http(fd, code=100) read_http(fd, code=100)
...@@ -1824,7 +1838,6 @@ class TestEnviron(TestCase): ...@@ -1824,7 +1838,6 @@ class TestEnviron(TestCase):
self.assertEqual(json.dumps(bltin), json.dumps(env)) self.assertEqual(json.dumps(bltin), json.dumps(env))
del CommonTests
if __name__ == '__main__': if __name__ == '__main__':
greentest.main() greentest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment