Commit f07801bb authored by Victor Stinner's avatar Victor Stinner

asyncio: SSL transports now clear their reference to the waiter

* Rephrase also the comment explaining why the waiter is not awaken immediatly.
* SSLProtocol.eof_received() doesn't instanciate ConnectionResetError exception
  directly, it will be done by Future.set_exception(). The exception is not
  used if the waiter was cancelled or if there is no waiter.
parent b507cbaa
...@@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, ...@@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
self._server._attach() self._server._attach()
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def __repr__(self): def __repr__(self):
......
...@@ -581,7 +581,7 @@ class _SelectorSocketTransport(_SelectorTransport): ...@@ -581,7 +581,7 @@ class _SelectorSocketTransport(_SelectorTransport):
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def pause_reading(self): def pause_reading(self):
...@@ -732,6 +732,16 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -732,6 +732,16 @@ class _SelectorSslTransport(_SelectorTransport):
start_time = None start_time = None
self._on_handshake(start_time) self._on_handshake(start_time)
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
return
if not self._waiter.cancelled():
if exc is not None:
self._waiter.set_exception(exc)
else:
self._waiter.set_result(None)
self._waiter = None
def _on_handshake(self, start_time): def _on_handshake(self, start_time):
try: try:
self._sock.do_handshake() self._sock.do_handshake()
...@@ -750,8 +760,7 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -750,8 +760,7 @@ class _SelectorSslTransport(_SelectorTransport):
self._loop.remove_reader(self._sock_fd) self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd) self._loop.remove_writer(self._sock_fd)
self._sock.close() self._sock.close()
if self._waiter is not None and not self._waiter.cancelled(): self._wakeup_waiter(exc)
self._waiter.set_exception(exc)
if isinstance(exc, Exception): if isinstance(exc, Exception):
return return
else: else:
...@@ -774,9 +783,7 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -774,9 +783,7 @@ class _SelectorSslTransport(_SelectorTransport):
"on matching the hostname", "on matching the hostname",
self, exc_info=True) self, exc_info=True)
self._sock.close() self._sock.close()
if (self._waiter is not None self._wakeup_waiter(exc)
and not self._waiter.cancelled()):
self._waiter.set_exception(exc)
return return
# Add extra info that becomes available after handshake. # Add extra info that becomes available after handshake.
...@@ -789,10 +796,8 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -789,10 +796,8 @@ class _SelectorSslTransport(_SelectorTransport):
self._write_wants_read = False self._write_wants_read = False
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if self._waiter is not None: # only wake up the waiter when connection_made() has been called
# wait until protocol.connection_made() has been called self._loop.call_soon(self._wakeup_waiter)
self._loop.call_soon(self._waiter._set_result_unless_cancelled,
None)
if self._loop.get_debug(): if self._loop.get_debug():
dt = self._loop.time() - start_time dt = self._loop.time() - start_time
...@@ -924,7 +929,7 @@ class _SelectorDatagramTransport(_SelectorTransport): ...@@ -924,7 +929,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def get_write_buffer_size(self): def get_write_buffer_size(self):
......
...@@ -418,6 +418,16 @@ class SSLProtocol(protocols.Protocol): ...@@ -418,6 +418,16 @@ class SSLProtocol(protocols.Protocol):
self._in_shutdown = False self._in_shutdown = False
self._transport = None self._transport = None
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
return
if not self._waiter.cancelled():
if exc is not None:
self._waiter.set_exception(exc)
else:
self._waiter.set_result(None)
self._waiter = None
def connection_made(self, transport): def connection_made(self, transport):
"""Called when the low-level connection is made. """Called when the low-level connection is made.
...@@ -490,8 +500,7 @@ class SSLProtocol(protocols.Protocol): ...@@ -490,8 +500,7 @@ class SSLProtocol(protocols.Protocol):
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r received EOF", self) logger.debug("%r received EOF", self)
if self._waiter is not None and not self._waiter.done(): self._wakeup_waiter(ConnectionResetError)
self._waiter.set_exception(ConnectionResetError())
if not self._in_handshake: if not self._in_handshake:
keep_open = self._app_protocol.eof_received() keep_open = self._app_protocol.eof_received()
...@@ -556,8 +565,7 @@ class SSLProtocol(protocols.Protocol): ...@@ -556,8 +565,7 @@ class SSLProtocol(protocols.Protocol):
self, exc_info=True) self, exc_info=True)
self._transport.close() self._transport.close()
if isinstance(exc, Exception): if isinstance(exc, Exception):
if self._waiter is not None and not self._waiter.cancelled(): self._wakeup_waiter(exc)
self._waiter.set_exception(exc)
return return
else: else:
raise raise
...@@ -572,9 +580,7 @@ class SSLProtocol(protocols.Protocol): ...@@ -572,9 +580,7 @@ class SSLProtocol(protocols.Protocol):
compression=sslobj.compression(), compression=sslobj.compression(),
) )
self._app_protocol.connection_made(self._app_transport) self._app_protocol.connection_made(self._app_transport)
if self._waiter is not None: self._wakeup_waiter()
# wait until protocol.connection_made() has been called
self._waiter._set_result_unless_cancelled(None)
self._session_established = True self._session_established = True
# In case transport.write() was already called. Don't call # In case transport.write() was already called. Don't call
# immediatly _process_write_backlog(), but schedule it: # immediatly _process_write_backlog(), but schedule it:
......
...@@ -301,7 +301,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): ...@@ -301,7 +301,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
self._loop.add_reader(self._fileno, self._read_ready) self._loop.add_reader(self._fileno, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def __repr__(self): def __repr__(self):
...@@ -409,7 +409,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -409,7 +409,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def __repr__(self): def __repr__(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment