Commit 31e7bfa6 authored by Victor Stinner's avatar Victor Stinner

asyncio, tulip issue 193: Convert StreamWriter.drain() to a classic coroutine

Replace also _make_drain_waiter() function with a classic _drain_helper()
coroutine.
parent 1392df96
...@@ -141,15 +141,14 @@ class FlowControlMixin(protocols.Protocol): ...@@ -141,15 +141,14 @@ class FlowControlMixin(protocols.Protocol):
resume_reading() and connection_lost(). If the subclass overrides resume_reading() and connection_lost(). If the subclass overrides
these it must call the super methods. these it must call the super methods.
StreamWriter.drain() must check for error conditions and then call StreamWriter.drain() must wait for _drain_helper() coroutine.
_make_drain_waiter(), which will return either () or a Future
depending on the paused state.
""" """
def __init__(self, loop=None): def __init__(self, loop=None):
self._loop = loop # May be None; we may never need it. self._loop = loop # May be None; we may never need it.
self._paused = False self._paused = False
self._drain_waiter = None self._drain_waiter = None
self._connection_lost = False
def pause_writing(self): def pause_writing(self):
assert not self._paused assert not self._paused
...@@ -170,6 +169,7 @@ class FlowControlMixin(protocols.Protocol): ...@@ -170,6 +169,7 @@ class FlowControlMixin(protocols.Protocol):
waiter.set_result(None) waiter.set_result(None)
def connection_lost(self, exc): def connection_lost(self, exc):
self._connection_lost = True
# Wake up the writer if currently paused. # Wake up the writer if currently paused.
if not self._paused: if not self._paused:
return return
...@@ -184,14 +184,17 @@ class FlowControlMixin(protocols.Protocol): ...@@ -184,14 +184,17 @@ class FlowControlMixin(protocols.Protocol):
else: else:
waiter.set_exception(exc) waiter.set_exception(exc)
def _make_drain_waiter(self): @coroutine
def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused: if not self._paused:
return () return
waiter = self._drain_waiter waiter = self._drain_waiter
assert waiter is None or waiter.cancelled() assert waiter is None or waiter.cancelled()
waiter = futures.Future(loop=self._loop) waiter = futures.Future(loop=self._loop)
self._drain_waiter = waiter self._drain_waiter = waiter
return waiter yield from waiter
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
...@@ -247,6 +250,8 @@ class StreamWriter: ...@@ -247,6 +250,8 @@ class StreamWriter:
def __init__(self, transport, protocol, reader, loop): def __init__(self, transport, protocol, reader, loop):
self._transport = transport self._transport = transport
self._protocol = protocol self._protocol = protocol
# drain() expects that the reader has a exception() method
assert reader is None or isinstance(reader, StreamReader)
self._reader = reader self._reader = reader
self._loop = loop self._loop = loop
...@@ -278,26 +283,20 @@ class StreamWriter: ...@@ -278,26 +283,20 @@ class StreamWriter:
def get_extra_info(self, name, default=None): def get_extra_info(self, name, default=None):
return self._transport.get_extra_info(name, default) return self._transport.get_extra_info(name, default)
@coroutine
def drain(self): def drain(self):
"""This method has an unusual return value. """Flush the write buffer.
The intended use is to write The intended use is to write
w.write(data) w.write(data)
yield from w.drain() yield from w.drain()
When there's nothing to wait for, drain() returns (), and the
yield-from continues immediately. When the transport buffer
is full (the protocol is paused), drain() creates and returns
a Future and the yield-from will block until that Future is
completed, which will happen when the buffer is (partially)
drained and the protocol is resumed.
""" """
if self._reader is not None and self._reader._exception is not None: if self._reader is not None:
raise self._reader._exception exc = self._reader.exception()
if self._transport._conn_lost: # Uses private variable. if exc is not None:
raise ConnectionResetError('Connection lost') raise exc
return self._protocol._make_drain_waiter() yield from self._protocol._drain_helper()
class StreamReader: class StreamReader:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment