Commit 47cd10d7 authored by Victor Stinner's avatar Victor Stinner

asyncio: sync with Tulip

Issue #23347: send_signal(), kill() and terminate() methods of
BaseSubprocessTransport now check if the transport was closed and if the
process exited.

Issue #23347: Refactor creation of subprocess transports. Changes on
BaseSubprocessTransport:

* Add a wait() method to wait until the child process exit
* The constructor now accepts an optional waiter parameter. The _post_init()
  coroutine must not be called explicitly anymore. It makes subprocess
  transports closer to other transports, and it gives more freedom if we want
  later to change completly how subprocess transports are created.
* close() now kills the process instead of kindly terminate it: the child
  process may ignore SIGTERM and continue to run. Call explicitly terminate()
  and wait() if you want to kindly terminate the child process.
* close() now logs a warning in debug mode if the process is still running and
  needs to be killed
* _make_subprocess_transport() is now fully asynchronous again: if the creation
  of the transport failed, wait asynchronously for the process eixt. Before the
  wait was synchronous. This change requires close() to *kill*, and not
  terminate, the child process.
* Remove the _kill_wait() method, replaced with a more agressive close()
  method. It fixes _make_subprocess_transport() on error.
  BaseSubprocessTransport.close() calls the close() method of pipe transports,
  whereas _kill_wait() closed directly pipes of the subprocess.Popen object
  without unregistering file descriptors from the selector (which caused severe
  bugs).

These changes simplifies the code of subprocess.py.
parent 978a9afc
...@@ -3,6 +3,7 @@ import subprocess ...@@ -3,6 +3,7 @@ import subprocess
import sys import sys
import warnings import warnings
from . import futures
from . import protocols from . import protocols
from . import transports from . import transports
from .coroutines import coroutine from .coroutines import coroutine
...@@ -13,27 +14,32 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -13,27 +14,32 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
def __init__(self, loop, protocol, args, shell, def __init__(self, loop, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=None, **kwargs): waiter=None, extra=None, **kwargs):
super().__init__(extra) super().__init__(extra)
self._closed = False self._closed = False
self._protocol = protocol self._protocol = protocol
self._loop = loop self._loop = loop
self._proc = None
self._pid = None self._pid = None
self._returncode = None
self._exit_waiters = []
self._pending_calls = collections.deque()
self._pipes = {} self._pipes = {}
self._finished = False
if stdin == subprocess.PIPE: if stdin == subprocess.PIPE:
self._pipes[0] = None self._pipes[0] = None
if stdout == subprocess.PIPE: if stdout == subprocess.PIPE:
self._pipes[1] = None self._pipes[1] = None
if stderr == subprocess.PIPE: if stderr == subprocess.PIPE:
self._pipes[2] = None self._pipes[2] = None
self._pending_calls = collections.deque()
self._finished = False # Create the child process: set the _proc attribute
self._returncode = None
self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, self._start(args=args, shell=shell, stdin=stdin, stdout=stdout,
stderr=stderr, bufsize=bufsize, **kwargs) stderr=stderr, bufsize=bufsize, **kwargs)
self._pid = self._proc.pid self._pid = self._proc.pid
self._extra['subprocess'] = self._proc self._extra['subprocess'] = self._proc
if self._loop.get_debug(): if self._loop.get_debug():
if isinstance(args, (bytes, str)): if isinstance(args, (bytes, str)):
program = args program = args
...@@ -42,6 +48,8 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -42,6 +48,8 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
logger.debug('process %r created: pid %s', logger.debug('process %r created: pid %s',
program, self._pid) program, self._pid)
self._loop.create_task(self._connect_pipes(waiter))
def __repr__(self): def __repr__(self):
info = [self.__class__.__name__] info = [self.__class__.__name__]
if self._closed: if self._closed:
...@@ -77,12 +85,23 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -77,12 +85,23 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
def close(self): def close(self):
self._closed = True self._closed = True
for proto in self._pipes.values(): for proto in self._pipes.values():
if proto is None: if proto is None:
continue continue
proto.pipe.close() proto.pipe.close()
if self._returncode is None:
self.terminate() if self._proc is not None and self._returncode is None:
if self._loop.get_debug():
logger.warning('Close running child process: kill %r', self)
try:
self._proc.kill()
except ProcessLookupError:
pass
# Don't clear the _proc reference yet because _post_init() may
# still run
# On Python 3.3 and older, objects with a destructor part of a reference # On Python 3.3 and older, objects with a destructor part of a reference
# cycle are never destroyed. It's not more the case on Python 3.4 thanks # cycle are never destroyed. It's not more the case on Python 3.4 thanks
...@@ -105,59 +124,42 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -105,59 +124,42 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
else: else:
return None return None
def _check_proc(self):
if self._closed:
raise ValueError("operation on closed transport")
if self._proc is None:
raise ProcessLookupError()
def send_signal(self, signal): def send_signal(self, signal):
self._check_proc()
self._proc.send_signal(signal) self._proc.send_signal(signal)
def terminate(self): def terminate(self):
self._check_proc()
self._proc.terminate() self._proc.terminate()
def kill(self): def kill(self):
self._check_proc()
self._proc.kill() self._proc.kill()
def _kill_wait(self):
"""Close pipes, kill the subprocess and read its return status.
Function called when an exception is raised during the creation
of a subprocess.
"""
self._closed = True
if self._loop.get_debug():
logger.warning('Exception during subprocess creation, '
'kill the subprocess %r',
self,
exc_info=True)
proc = self._proc
if proc.stdout:
proc.stdout.close()
if proc.stderr:
proc.stderr.close()
if proc.stdin:
proc.stdin.close()
try:
proc.kill()
except ProcessLookupError:
pass
self._returncode = proc.wait()
self.close()
@coroutine @coroutine
def _post_init(self): def _connect_pipes(self, waiter):
try: try:
proc = self._proc proc = self._proc
loop = self._loop loop = self._loop
if proc.stdin is not None: if proc.stdin is not None:
_, pipe = yield from loop.connect_write_pipe( _, pipe = yield from loop.connect_write_pipe(
lambda: WriteSubprocessPipeProto(self, 0), lambda: WriteSubprocessPipeProto(self, 0),
proc.stdin) proc.stdin)
self._pipes[0] = pipe self._pipes[0] = pipe
if proc.stdout is not None: if proc.stdout is not None:
_, pipe = yield from loop.connect_read_pipe( _, pipe = yield from loop.connect_read_pipe(
lambda: ReadSubprocessPipeProto(self, 1), lambda: ReadSubprocessPipeProto(self, 1),
proc.stdout) proc.stdout)
self._pipes[1] = pipe self._pipes[1] = pipe
if proc.stderr is not None: if proc.stderr is not None:
_, pipe = yield from loop.connect_read_pipe( _, pipe = yield from loop.connect_read_pipe(
lambda: ReadSubprocessPipeProto(self, 2), lambda: ReadSubprocessPipeProto(self, 2),
...@@ -166,13 +168,16 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -166,13 +168,16 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
assert self._pending_calls is not None assert self._pending_calls is not None
self._loop.call_soon(self._protocol.connection_made, self) loop.call_soon(self._protocol.connection_made, self)
for callback, data in self._pending_calls: for callback, data in self._pending_calls:
self._loop.call_soon(callback, *data) loop.call_soon(callback, *data)
self._pending_calls = None self._pending_calls = None
except: except Exception as exc:
self._kill_wait() if waiter is not None and not waiter.cancelled():
raise waiter.set_exception(exc)
else:
if waiter is not None and not waiter.cancelled():
waiter.set_result(None)
def _call(self, cb, *data): def _call(self, cb, *data):
if self._pending_calls is not None: if self._pending_calls is not None:
...@@ -197,6 +202,23 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -197,6 +202,23 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
self._call(self._protocol.process_exited) self._call(self._protocol.process_exited)
self._try_finish() self._try_finish()
# wake up futures waiting for wait()
for waiter in self._exit_waiters:
if not waiter.cancelled():
waiter.set_result(returncode)
self._exit_waiters = None
def wait(self):
"""Wait until the process exit and return the process return code.
This method is a coroutine."""
if self._returncode is not None:
return self._returncode
waiter = futures.Future(loop=self._loop)
self._exit_waiters.append(waiter)
return (yield from waiter)
def _try_finish(self): def _try_finish(self):
assert not self._finished assert not self._finished
if self._returncode is None: if self._returncode is None:
...@@ -210,9 +232,9 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -210,9 +232,9 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
try: try:
self._protocol.connection_lost(exc) self._protocol.connection_lost(exc)
finally: finally:
self._loop = None
self._proc = None self._proc = None
self._protocol = None self._protocol = None
self._loop = None
class WriteSubprocessPipeProto(protocols.BaseProtocol): class WriteSubprocessPipeProto(protocols.BaseProtocol):
......
...@@ -25,8 +25,6 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -25,8 +25,6 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
super().__init__(loop=loop) super().__init__(loop=loop)
self._limit = limit self._limit = limit
self.stdin = self.stdout = self.stderr = None self.stdin = self.stdout = self.stderr = None
self.waiter = futures.Future(loop=loop)
self._waiters = collections.deque()
self._transport = None self._transport = None
def __repr__(self): def __repr__(self):
...@@ -61,9 +59,6 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -61,9 +59,6 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
reader=None, reader=None,
loop=self._loop) loop=self._loop)
if not self.waiter.cancelled():
self.waiter.set_result(None)
def pipe_data_received(self, fd, data): def pipe_data_received(self, fd, data):
if fd == 1: if fd == 1:
reader = self.stdout reader = self.stdout
...@@ -94,16 +89,9 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -94,16 +89,9 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
reader.set_exception(exc) reader.set_exception(exc)
def process_exited(self): def process_exited(self):
returncode = self._transport.get_returncode()
self._transport.close() self._transport.close()
self._transport = None self._transport = None
# wake up futures waiting for wait()
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.cancelled():
waiter.set_result(returncode)
class Process: class Process:
def __init__(self, transport, protocol, loop): def __init__(self, transport, protocol, loop):
...@@ -124,30 +112,18 @@ class Process: ...@@ -124,30 +112,18 @@ class Process:
@coroutine @coroutine
def wait(self): def wait(self):
"""Wait until the process exit and return the process return code.""" """Wait until the process exit and return the process return code.
returncode = self._transport.get_returncode()
if returncode is not None:
return returncode
waiter = futures.Future(loop=self._loop)
self._protocol._waiters.append(waiter)
yield from waiter
return waiter.result()
def _check_alive(self): This method is a coroutine."""
if self._transport.get_returncode() is not None: return (yield from self._transport.wait())
raise ProcessLookupError()
def send_signal(self, signal): def send_signal(self, signal):
self._check_alive()
self._transport.send_signal(signal) self._transport.send_signal(signal)
def terminate(self): def terminate(self):
self._check_alive()
self._transport.terminate() self._transport.terminate()
def kill(self): def kill(self):
self._check_alive()
self._transport.kill() self._transport.kill()
@coroutine @coroutine
...@@ -221,11 +197,6 @@ def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, ...@@ -221,11 +197,6 @@ def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None,
protocol_factory, protocol_factory,
cmd, stdin=stdin, stdout=stdout, cmd, stdin=stdin, stdout=stdout,
stderr=stderr, **kwds) stderr=stderr, **kwds)
try:
yield from protocol.waiter
except:
transport._kill_wait()
raise
return Process(transport, protocol, loop) return Process(transport, protocol, loop)
@coroutine @coroutine
...@@ -241,9 +212,4 @@ def create_subprocess_exec(program, *args, stdin=None, stdout=None, ...@@ -241,9 +212,4 @@ def create_subprocess_exec(program, *args, stdin=None, stdout=None,
program, *args, program, *args,
stdin=stdin, stdout=stdout, stdin=stdin, stdout=stdout,
stderr=stderr, **kwds) stderr=stderr, **kwds)
try:
yield from protocol.waiter
except:
transport._kill_wait()
raise
return Process(transport, protocol, loop) return Process(transport, protocol, loop)
...@@ -16,6 +16,7 @@ from . import base_subprocess ...@@ -16,6 +16,7 @@ from . import base_subprocess
from . import constants from . import constants
from . import coroutines from . import coroutines
from . import events from . import events
from . import futures
from . import selector_events from . import selector_events
from . import selectors from . import selectors
from . import transports from . import transports
...@@ -175,16 +176,20 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -175,16 +176,20 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=None, **kwargs): extra=None, **kwargs):
with events.get_child_watcher() as watcher: with events.get_child_watcher() as watcher:
waiter = futures.Future(loop=self)
transp = _UnixSubprocessTransport(self, protocol, args, shell, transp = _UnixSubprocessTransport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=extra, **kwargs) waiter=waiter, extra=extra,
**kwargs)
watcher.add_child_handler(transp.get_pid(),
self._child_watcher_callback, transp)
try: try:
yield from transp._post_init() yield from waiter
except: except:
transp.close() transp.close()
yield from transp.wait()
raise raise
watcher.add_child_handler(transp.get_pid(),
self._child_watcher_callback, transp)
return transp return transp
...@@ -774,7 +779,7 @@ class SafeChildWatcher(BaseChildWatcher): ...@@ -774,7 +779,7 @@ class SafeChildWatcher(BaseChildWatcher):
pass pass
def add_child_handler(self, pid, callback, *args): def add_child_handler(self, pid, callback, *args):
self._callbacks[pid] = callback, args self._callbacks[pid] = (callback, args)
# Prevent a race condition in case the child is already terminated. # Prevent a race condition in case the child is already terminated.
self._do_waitpid(pid) self._do_waitpid(pid)
......
...@@ -366,13 +366,16 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): ...@@ -366,13 +366,16 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
def _make_subprocess_transport(self, protocol, args, shell, def _make_subprocess_transport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=None, **kwargs): extra=None, **kwargs):
waiter = futures.Future(loop=self)
transp = _WindowsSubprocessTransport(self, protocol, args, shell, transp = _WindowsSubprocessTransport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=extra, **kwargs) waiter=waiter, extra=extra,
**kwargs)
try: try:
yield from transp._post_init() yield from waiter
except: except:
transp.close() transp.close()
yield from transp.wait()
raise raise
return transp return transp
......
...@@ -1551,9 +1551,10 @@ class SubprocessTestsMixin: ...@@ -1551,9 +1551,10 @@ class SubprocessTestsMixin:
stdin = transp.get_pipe_transport(0) stdin = transp.get_pipe_transport(0)
stdin.write(b'Python The Winner') stdin.write(b'Python The Winner')
self.loop.run_until_complete(proto.got_data[1].wait()) self.loop.run_until_complete(proto.got_data[1].wait())
transp.close() with test_utils.disable_logger():
transp.close()
self.loop.run_until_complete(proto.completed) self.loop.run_until_complete(proto.completed)
self.check_terminated(proto.returncode) self.check_killed(proto.returncode)
self.assertEqual(b'Python The Winner', proto.data[1]) self.assertEqual(b'Python The Winner', proto.data[1])
def test_subprocess_interactive(self): def test_subprocess_interactive(self):
...@@ -1567,21 +1568,20 @@ class SubprocessTestsMixin: ...@@ -1567,21 +1568,20 @@ class SubprocessTestsMixin:
self.loop.run_until_complete(proto.connected) self.loop.run_until_complete(proto.connected)
self.assertEqual('CONNECTED', proto.state) self.assertEqual('CONNECTED', proto.state)
try: stdin = transp.get_pipe_transport(0)
stdin = transp.get_pipe_transport(0) stdin.write(b'Python ')
stdin.write(b'Python ') self.loop.run_until_complete(proto.got_data[1].wait())
self.loop.run_until_complete(proto.got_data[1].wait()) proto.got_data[1].clear()
proto.got_data[1].clear() self.assertEqual(b'Python ', proto.data[1])
self.assertEqual(b'Python ', proto.data[1])
stdin.write(b'The Winner')
self.loop.run_until_complete(proto.got_data[1].wait())
self.assertEqual(b'Python The Winner', proto.data[1])
finally:
transp.close()
stdin.write(b'The Winner')
self.loop.run_until_complete(proto.got_data[1].wait())
self.assertEqual(b'Python The Winner', proto.data[1])
with test_utils.disable_logger():
transp.close()
self.loop.run_until_complete(proto.completed) self.loop.run_until_complete(proto.completed)
self.check_terminated(proto.returncode) self.check_killed(proto.returncode)
def test_subprocess_shell(self): def test_subprocess_shell(self):
connect = self.loop.subprocess_shell( connect = self.loop.subprocess_shell(
...@@ -1739,9 +1739,10 @@ class SubprocessTestsMixin: ...@@ -1739,9 +1739,10 @@ class SubprocessTestsMixin:
# GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using
# WriteFile() we get ERROR_BROKEN_PIPE as expected.) # WriteFile() we get ERROR_BROKEN_PIPE as expected.)
self.assertEqual(b'ERR:OSError', proto.data[2]) self.assertEqual(b'ERR:OSError', proto.data[2])
transp.close() with test_utils.disable_logger():
transp.close()
self.loop.run_until_complete(proto.completed) self.loop.run_until_complete(proto.completed)
self.check_terminated(proto.returncode) self.check_killed(proto.returncode)
def test_subprocess_wait_no_same_group(self): def test_subprocess_wait_no_same_group(self):
# start the new process in a new session # start the new process in a new session
......
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
from unittest import mock from unittest import mock
import asyncio import asyncio
from asyncio import base_subprocess
from asyncio import subprocess from asyncio import subprocess
from asyncio import test_utils from asyncio import test_utils
try: try:
...@@ -23,6 +24,70 @@ PROGRAM_CAT = [ ...@@ -23,6 +24,70 @@ PROGRAM_CAT = [
'data = sys.stdin.buffer.read()', 'data = sys.stdin.buffer.read()',
'sys.stdout.buffer.write(data)'))] 'sys.stdout.buffer.write(data)'))]
class TestSubprocessTransport(base_subprocess.BaseSubprocessTransport):
def _start(self, *args, **kwargs):
self._proc = mock.Mock()
self._proc.stdin = None
self._proc.stdout = None
self._proc.stderr = None
class SubprocessTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = self.new_test_loop()
self.set_event_loop(self.loop)
def create_transport(self, waiter=None):
protocol = mock.Mock()
protocol.connection_made._is_coroutine = False
protocol.process_exited._is_coroutine = False
transport = TestSubprocessTransport(
self.loop, protocol, ['test'], False,
None, None, None, 0, waiter=waiter)
return (transport, protocol)
def test_close(self):
waiter = asyncio.Future(loop=self.loop)
transport, protocol = self.create_transport(waiter)
transport._process_exited(0)
transport.close()
# The loop didn't run yet
self.assertFalse(protocol.connection_made.called)
# methods must raise ProcessLookupError if the transport was closed
self.assertRaises(ValueError, transport.send_signal, signal.SIGTERM)
self.assertRaises(ValueError, transport.terminate)
self.assertRaises(ValueError, transport.kill)
self.loop.run_until_complete(waiter)
def test_proc_exited(self):
waiter = asyncio.Future(loop=self.loop)
transport, protocol = self.create_transport(waiter)
transport._process_exited(6)
self.loop.run_until_complete(waiter)
self.assertEqual(transport.get_returncode(), 6)
self.assertTrue(protocol.connection_made.called)
self.assertTrue(protocol.process_exited.called)
self.assertTrue(protocol.connection_lost.called)
self.assertEqual(protocol.connection_lost.call_args[0], (None,))
self.assertFalse(transport._closed)
self.assertIsNone(transport._loop)
self.assertIsNone(transport._proc)
self.assertIsNone(transport._protocol)
# methods must raise ProcessLookupError if the process exited
self.assertRaises(ProcessLookupError,
transport.send_signal, signal.SIGTERM)
self.assertRaises(ProcessLookupError, transport.terminate)
self.assertRaises(ProcessLookupError, transport.kill)
class SubprocessMixin: class SubprocessMixin:
def test_stdin_stdout(self): def test_stdin_stdout(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment