Commit 5a603ade authored by Jason Madden's avatar Jason Madden Committed by GitHub

Merge pull request #1723 from gevent/issue1712

Make gevent.pywsgi stop dealing with chunks when the connection is being upgraded
parents f54fa619 8c497751
Make `gevent.pywsgi` trying to enforce the rules for reading chunked input or
``Content-Length`` terminated input when the connection is being
upgraded, for example to a websocket connection. Likewise, if the
protocol was switched by returning a ``101`` status, stop trying to
automatically chunk the responses.
Reported by Kavindu Santhusa.
...@@ -415,6 +415,9 @@ class WSGIHandler(object): ...@@ -415,6 +415,9 @@ class WSGIHandler(object):
time_finish = 0 # time.time() when done handling request time_finish = 0 # time.time() when done handling request
headers_sent = False # Have we already sent headers? headers_sent = False # Have we already sent headers?
response_use_chunked = False # Write with transfer-encoding chunked response_use_chunked = False # Write with transfer-encoding chunked
# Was the connection upgraded? We shouldn't try to chunk writes in that
# case.
connection_upgraded = False
environ = None # Dict from self.get_environ environ = None # Dict from self.get_environ
application = None # application callable from self.server.application application = None # application callable from self.server.application
requestline = None # native str 'GET / HTTP/1.1' requestline = None # native str 'GET / HTTP/1.1'
...@@ -486,6 +489,7 @@ class WSGIHandler(object): ...@@ -486,6 +489,7 @@ class WSGIHandler(object):
pass pass
self.__dict__.pop('socket', None) self.__dict__.pop('socket', None)
self.__dict__.pop('rfile', None) self.__dict__.pop('rfile', None)
self.__dict__.pop('wsgi_input', None)
def _check_http_version(self): def _check_http_version(self):
version_str = self.request_version version_str = self.request_version
...@@ -697,10 +701,19 @@ class WSGIHandler(object): ...@@ -697,10 +701,19 @@ class WSGIHandler(object):
return True # read more requests return True # read more requests
def _connection_upgrade_requested(self):
if self.headers.get('Connection', '').lower() == 'upgrade':
return True
if self.headers.get('Upgrade', '').lower() == 'websocket':
return True
return False
def finalize_headers(self): def finalize_headers(self):
if self.provided_date is None: if self.provided_date is None:
self.response_headers.append((b'Date', format_date_time(time.time()))) self.response_headers.append((b'Date', format_date_time(time.time())))
self.connection_upgraded = self.code == 101
if self.code not in (304, 204): if self.code not in (304, 204):
# the reply will include message-body; make sure we have either Content-Length or chunked # the reply will include message-body; make sure we have either Content-Length or chunked
if self.provided_content_length is None: if self.provided_content_length is None:
...@@ -711,8 +724,11 @@ class WSGIHandler(object): ...@@ -711,8 +724,11 @@ class WSGIHandler(object):
total_len_str = total_len_str.encode("latin-1") total_len_str = total_len_str.encode("latin-1")
self.response_headers.append((b'Content-Length', total_len_str)) self.response_headers.append((b'Content-Length', total_len_str))
else: else:
if self.request_version != 'HTTP/1.0': self.response_use_chunked = (
self.response_use_chunked = True not self.connection_upgraded
and self.request_version != 'HTTP/1.0'
)
if self.response_use_chunked:
self.response_headers.append((b'Transfer-Encoding', b'chunked')) self.response_headers.append((b'Transfer-Encoding', b'chunked'))
def _sendall(self, data): def _sendall(self, data):
...@@ -975,6 +991,7 @@ class WSGIHandler(object): ...@@ -975,6 +991,7 @@ class WSGIHandler(object):
self.result = None self.result = None
self.response_use_chunked = False self.response_use_chunked = False
self.connection_upgraded = False
self.response_length = 0 self.response_length = 0
try: try:
...@@ -1103,10 +1120,7 @@ class WSGIHandler(object): ...@@ -1103,10 +1120,7 @@ class WSGIHandler(object):
# See https://github.com/gevent/gevent/issues/1667 for discussion. # See https://github.com/gevent/gevent/issues/1667 for discussion.
env['SCRIPT_NAME'] = '' env['SCRIPT_NAME'] = ''
if '?' in self.path: path, query = self.path.split('?', 1) if '?' in self.path else (self.path, '')
path, query = self.path.split('?', 1)
else:
path, query = self.path, ''
# Note that self.path contains the original str object; if it contains # Note that self.path contains the original str object; if it contains
# encoded escapes, it will NOT match PATH_INFO. # encoded escapes, it will NOT match PATH_INFO.
env['PATH_INFO'] = unquote_latin1(path) env['PATH_INFO'] = unquote_latin1(path)
...@@ -1134,18 +1148,20 @@ class WSGIHandler(object): ...@@ -1134,18 +1148,20 @@ class WSGIHandler(object):
else: else:
env[key] = value env[key] = value
if env.get('HTTP_EXPECT') == '100-continue': sock = self.socket if env.get('HTTP_EXPECT') == '100-continue' else None
sock = self.socket
else:
sock = None
chunked = env.get('HTTP_TRANSFER_ENCODING', '').lower() == 'chunked' chunked = env.get('HTTP_TRANSFER_ENCODING', '').lower() == 'chunked'
# Input refuses to read if the data isn't chunked, and there is no content_length
# provided. For 'Upgrade: Websocket' requests, neither of those things is true.
handling_reads = not self._connection_upgrade_requested()
self.wsgi_input = Input(self.rfile, self.content_length, socket=sock, chunked_input=chunked) self.wsgi_input = Input(self.rfile, self.content_length, socket=sock, chunked_input=chunked)
env['wsgi.input'] = self.wsgi_input
env['wsgi.input'] = self.wsgi_input if handling_reads else self.rfile
# This is a non-standard flag indicating that our input stream is # This is a non-standard flag indicating that our input stream is
# self-terminated (returns EOF when consumed). # self-terminated (returns EOF when consumed).
# See https://github.com/gevent/gevent/issues/1308 # See https://github.com/gevent/gevent/issues/1308
env['wsgi.input_terminated'] = True env['wsgi.input_terminated'] = handling_reads
return env return env
......
...@@ -432,18 +432,35 @@ class TestNoChunks(CommonTestMixin, TestCase): ...@@ -432,18 +432,35 @@ class TestNoChunks(CommonTestMixin, TestCase):
# when returning a list of strings a shortcut is employed by the server: # when returning a list of strings a shortcut is employed by the server:
# it calculates the content-length and joins all the chunks before sending # it calculates the content-length and joins all the chunks before sending
validator = None validator = None
last_environ = None
def _check_environ(self, input_terminated=True):
if input_terminated:
self.assertTrue(self.last_environ.get('wsgi.input_terminated'))
else:
self.assertFalse(self.last_environ['wsgi.input_terminated'])
def application(self, env, start_response): def application(self, env, start_response):
self.assertTrue(env.get('wsgi.input_terminated')) self.last_environ = env
path = env['PATH_INFO'] path = env['PATH_INFO']
if path == '/': if path == '/':
start_response('200 OK', [('Content-Type', 'text/plain')]) start_response('200 OK', [('Content-Type', 'text/plain')])
return [b'hello ', b'world'] return [b'hello ', b'world']
if path == '/websocket':
write = start_response('101 Switching Protocols',
[('Content-Type', 'text/plain'),
# Con:close is to make our simple client
# happy; otherwise it wants to read data from the
# body thot's being kept open.
('Connection', 'close')])
write(b'') # Trigger finalizing the headers now.
return [b'upgrading to', b'websocket']
start_response('404 Not Found', [('Content-Type', 'text/plain')]) start_response('404 Not Found', [('Content-Type', 'text/plain')])
return [b'not ', b'found'] return [b'not ', b'found']
def test_basic(self): def test_basic(self):
response, dne_response = super(TestNoChunks, self).test_basic() response, dne_response = super(TestNoChunks, self).test_basic()
self._check_environ()
self.assertFalse(response.chunks) self.assertFalse(response.chunks)
response.assertHeader('Content-Length', '11') response.assertHeader('Content-Length', '11')
if dne_response is not None: if dne_response is not None:
...@@ -455,8 +472,28 @@ class TestNoChunks(CommonTestMixin, TestCase): ...@@ -455,8 +472,28 @@ class TestNoChunks(CommonTestMixin, TestCase):
fd.write(self.format_request(path='/notexist')) fd.write(self.format_request(path='/notexist'))
response = read_http(fd, code=404, reason='Not Found', body='not found') response = read_http(fd, code=404, reason='Not Found', body='not found')
self.assertFalse(response.chunks) self.assertFalse(response.chunks)
self._check_environ()
response.assertHeader('Content-Length', '9') response.assertHeader('Content-Length', '9')
class TestConnectionUpgrades(TestNoChunks):
def test_connection_upgrade(self):
with self.makefile() as fd:
fd.write(self.format_request(path='/websocket', Connection='upgrade'))
response = read_http(fd, code=101)
self._check_environ(input_terminated=False)
self.assertFalse(response.chunks)
def test_upgrade_websocket(self):
with self.makefile() as fd:
fd.write(self.format_request(path='/websocket', Upgrade='websocket'))
response = read_http(fd, code=101)
self._check_environ(input_terminated=False)
self.assertFalse(response.chunks)
class TestNoChunks10(TestNoChunks): class TestNoChunks10(TestNoChunks):
HTTP_CLIENT_VERSION = '1.0' HTTP_CLIENT_VERSION = '1.0'
PIPELINE_NOT_SUPPORTED_EXS = (ConnectionClosed,) PIPELINE_NOT_SUPPORTED_EXS = (ConnectionClosed,)
...@@ -475,6 +512,7 @@ class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancesto ...@@ -475,6 +512,7 @@ class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancesto
# server - it caculates the content-length # server - it caculates the content-length
def application(self, env, start_response): def application(self, env, start_response):
self.last_environ = env
self.assertTrue(env.get('wsgi.input_terminated')) self.assertTrue(env.get('wsgi.input_terminated'))
path = env['PATH_INFO'] path = env['PATH_INFO']
if path == '/': if path == '/':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment