Commit 4d62d0b3 authored by Guido van Rossum's avatar Guido van Rossum

asyncio: Refactor drain logic in streams.py to be reusable.

parent aaabc4fd
...@@ -94,8 +94,63 @@ def start_server(client_connected_cb, host=None, port=None, *, ...@@ -94,8 +94,63 @@ def start_server(client_connected_cb, host=None, port=None, *,
return (yield from loop.create_server(factory, host, port, **kwds)) return (yield from loop.create_server(factory, host, port, **kwds))
class StreamReaderProtocol(protocols.Protocol): class FlowControlMixin(protocols.Protocol):
"""Trivial helper class to adapt between Protocol and StreamReader. """Reusable flow control logic for StreamWriter.drain().
This implements the protocol methods pause_writing(),
resume_reading() and connection_lost(). If the subclass overrides
these it must call the super methods.
StreamWriter.drain() must check for error conditions and then call
_make_drain_waiter(), which will return either () or a Future
depending on the paused state.
"""
def __init__(self, loop=None):
self._loop = loop # May be None; we may never need it.
self._paused = False
self._drain_waiter = None
def pause_writing(self):
assert not self._paused
self._paused = True
def resume_writing(self):
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
def connection_lost(self, exc):
# Wake up the writer if currently paused.
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
def _make_drain_waiter(self):
if not self._paused:
return ()
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = futures.Future(loop=self._loop)
self._drain_waiter = waiter
return waiter
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
"""Helper class to adapt between Protocol and StreamReader.
(This is a helper class instead of making StreamReader itself a (This is a helper class instead of making StreamReader itself a
Protocol subclass, because the StreamReader has other potential Protocol subclass, because the StreamReader has other potential
...@@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol): ...@@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol):
""" """
def __init__(self, stream_reader, client_connected_cb=None, loop=None): def __init__(self, stream_reader, client_connected_cb=None, loop=None):
super().__init__(loop=loop)
self._stream_reader = stream_reader self._stream_reader = stream_reader
self._stream_writer = None self._stream_writer = None
self._drain_waiter = None
self._paused = False
self._client_connected_cb = client_connected_cb self._client_connected_cb = client_connected_cb
self._loop = loop # May be None; we may never need it.
def connection_made(self, transport): def connection_made(self, transport):
self._stream_reader.set_transport(transport) self._stream_reader.set_transport(transport)
...@@ -127,16 +180,7 @@ class StreamReaderProtocol(protocols.Protocol): ...@@ -127,16 +180,7 @@ class StreamReaderProtocol(protocols.Protocol):
self._stream_reader.feed_eof() self._stream_reader.feed_eof()
else: else:
self._stream_reader.set_exception(exc) self._stream_reader.set_exception(exc)
# Also wake up the writing side. super().connection_lost(exc)
if self._paused:
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
def data_received(self, data): def data_received(self, data):
self._stream_reader.feed_data(data) self._stream_reader.feed_data(data)
...@@ -144,19 +188,6 @@ class StreamReaderProtocol(protocols.Protocol): ...@@ -144,19 +188,6 @@ class StreamReaderProtocol(protocols.Protocol):
def eof_received(self): def eof_received(self):
self._stream_reader.feed_eof() self._stream_reader.feed_eof()
def pause_writing(self):
assert not self._paused
self._paused = True
def resume_writing(self):
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
class StreamWriter: class StreamWriter:
"""Wraps a Transport. """Wraps a Transport.
...@@ -211,17 +242,11 @@ class StreamWriter: ...@@ -211,17 +242,11 @@ class StreamWriter:
completed, which will happen when the buffer is (partially) completed, which will happen when the buffer is (partially)
drained and the protocol is resumed. drained and the protocol is resumed.
""" """
if self._reader._exception is not None: if self._reader is not None and self._reader._exception is not None:
raise self._reader._exception raise self._reader._exception
if self._transport._conn_lost: # Uses private variable. if self._transport._conn_lost: # Uses private variable.
raise ConnectionResetError('Connection lost') raise ConnectionResetError('Connection lost')
if not self._protocol._paused: return self._protocol._make_drain_waiter()
return ()
waiter = self._protocol._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = futures.Future(loop=self._loop)
self._protocol._drain_waiter = waiter
return waiter
class StreamReader: class StreamReader:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment