Commit e8c368df authored by Thomas Moreau's avatar Thomas Moreau Committed by Antoine Pitrou

bpo-31540: Allow passing multiprocessing context to ProcessPoolExecutor (#3682)

parent efb560ee
...@@ -56,7 +56,7 @@ matrix: ...@@ -56,7 +56,7 @@ matrix:
./venv/bin/python -m test.pythoninfo ./venv/bin/python -m test.pythoninfo
script: script:
# Skip tests that re-run the entire test suite. # Skip tests that re-run the entire test suite.
- ./venv/bin/python -m coverage run --pylib -m test --fail-env-changed -uall,-cpu -x test_multiprocessing_fork -x test_multiprocessing_forkserver -x test_multiprocessing_spawn - ./venv/bin/python -m coverage run --pylib -m test --fail-env-changed -uall,-cpu -x test_multiprocessing_fork -x test_multiprocessing_forkserver -x test_multiprocessing_spawn -x test_concurrent_futures
after_script: # Probably should be after_success once test suite updated to run under coverage.py. after_script: # Probably should be after_success once test suite updated to run under coverage.py.
# Make the `coverage` command available to Codecov w/ a version of Python that can parse all source files. # Make the `coverage` command available to Codecov w/ a version of Python that can parse all source files.
- source ./venv/bin/activate - source ./venv/bin/activate
......
...@@ -191,13 +191,16 @@ that :class:`ProcessPoolExecutor` will not work in the interactive interpreter. ...@@ -191,13 +191,16 @@ that :class:`ProcessPoolExecutor` will not work in the interactive interpreter.
Calling :class:`Executor` or :class:`Future` methods from a callable submitted Calling :class:`Executor` or :class:`Future` methods from a callable submitted
to a :class:`ProcessPoolExecutor` will result in deadlock. to a :class:`ProcessPoolExecutor` will result in deadlock.
.. class:: ProcessPoolExecutor(max_workers=None) .. class:: ProcessPoolExecutor(max_workers=None, mp_context=None)
An :class:`Executor` subclass that executes calls asynchronously using a pool An :class:`Executor` subclass that executes calls asynchronously using a pool
of at most *max_workers* processes. If *max_workers* is ``None`` or not of at most *max_workers* processes. If *max_workers* is ``None`` or not
given, it will default to the number of processors on the machine. given, it will default to the number of processors on the machine.
If *max_workers* is lower or equal to ``0``, then a :exc:`ValueError` If *max_workers* is lower or equal to ``0``, then a :exc:`ValueError`
will be raised. will be raised.
*mp_context* can be a multiprocessing context or None. It will be used to
launch the workers. If *mp_context* is ``None`` or not given, the default
multiprocessing context is used.
.. versionchanged:: 3.3 .. versionchanged:: 3.3
When one of the worker processes terminates abruptly, a When one of the worker processes terminates abruptly, a
...@@ -205,6 +208,10 @@ to a :class:`ProcessPoolExecutor` will result in deadlock. ...@@ -205,6 +208,10 @@ to a :class:`ProcessPoolExecutor` will result in deadlock.
was undefined but operations on the executor or its futures would often was undefined but operations on the executor or its futures would often
freeze or deadlock. freeze or deadlock.
.. versionchanged:: 3.7
The *mp_context* argument was added to allow users to control the
start_method for worker processes created by the pool.
.. _processpoolexecutor-example: .. _processpoolexecutor-example:
......
...@@ -50,8 +50,7 @@ import os ...@@ -50,8 +50,7 @@ import os
from concurrent.futures import _base from concurrent.futures import _base
import queue import queue
from queue import Full from queue import Full
import multiprocessing import multiprocessing as mp
from multiprocessing import SimpleQueue
from multiprocessing.connection import wait from multiprocessing.connection import wait
import threading import threading
import weakref import weakref
...@@ -74,11 +73,11 @@ import traceback ...@@ -74,11 +73,11 @@ import traceback
# threads/processes finish. # threads/processes finish.
_threads_queues = weakref.WeakKeyDictionary() _threads_queues = weakref.WeakKeyDictionary()
_shutdown = False _global_shutdown = False
def _python_exit(): def _python_exit():
global _shutdown global _global_shutdown
_shutdown = True _global_shutdown = True
items = list(_threads_queues.items()) items = list(_threads_queues.items())
for t, q in items: for t, q in items:
q.put(None) q.put(None)
...@@ -158,12 +157,10 @@ def _process_worker(call_queue, result_queue): ...@@ -158,12 +157,10 @@ def _process_worker(call_queue, result_queue):
This worker is run in a separate process. This worker is run in a separate process.
Args: Args:
call_queue: A multiprocessing.Queue of _CallItems that will be read and call_queue: A ctx.Queue of _CallItems that will be read and
evaluated by the worker. evaluated by the worker.
result_queue: A multiprocessing.Queue of _ResultItems that will written result_queue: A ctx.Queue of _ResultItems that will written
to by the worker. to by the worker.
shutdown: A multiprocessing.Event that will be set as a signal to the
worker that it should exit when call_queue is empty.
""" """
while True: while True:
call_item = call_queue.get(block=True) call_item = call_queue.get(block=True)
...@@ -180,6 +177,11 @@ def _process_worker(call_queue, result_queue): ...@@ -180,6 +177,11 @@ def _process_worker(call_queue, result_queue):
result_queue.put(_ResultItem(call_item.work_id, result_queue.put(_ResultItem(call_item.work_id,
result=r)) result=r))
# Liberate the resource as soon as possible, to avoid holding onto
# open files or shared memory that is not needed anymore
del call_item
def _add_call_item_to_queue(pending_work_items, def _add_call_item_to_queue(pending_work_items,
work_ids, work_ids,
call_queue): call_queue):
...@@ -231,20 +233,21 @@ def _queue_management_worker(executor_reference, ...@@ -231,20 +233,21 @@ def _queue_management_worker(executor_reference,
executor_reference: A weakref.ref to the ProcessPoolExecutor that owns executor_reference: A weakref.ref to the ProcessPoolExecutor that owns
this thread. Used to determine if the ProcessPoolExecutor has been this thread. Used to determine if the ProcessPoolExecutor has been
garbage collected and that this function can exit. garbage collected and that this function can exit.
process: A list of the multiprocessing.Process instances used as process: A list of the ctx.Process instances used as
workers. workers.
pending_work_items: A dict mapping work ids to _WorkItems e.g. pending_work_items: A dict mapping work ids to _WorkItems e.g.
{5: <_WorkItem...>, 6: <_WorkItem...>, ...} {5: <_WorkItem...>, 6: <_WorkItem...>, ...}
work_ids_queue: A queue.Queue of work ids e.g. Queue([5, 6, ...]). work_ids_queue: A queue.Queue of work ids e.g. Queue([5, 6, ...]).
call_queue: A multiprocessing.Queue that will be filled with _CallItems call_queue: A ctx.Queue that will be filled with _CallItems
derived from _WorkItems for processing by the process workers. derived from _WorkItems for processing by the process workers.
result_queue: A multiprocessing.Queue of _ResultItems generated by the result_queue: A ctx.SimpleQueue of _ResultItems generated by the
process workers. process workers.
""" """
executor = None executor = None
def shutting_down(): def shutting_down():
return _shutdown or executor is None or executor._shutdown_thread return (_global_shutdown or executor is None
or executor._shutdown_thread)
def shutdown_worker(): def shutdown_worker():
# This is an upper bound # This is an upper bound
...@@ -254,7 +257,7 @@ def _queue_management_worker(executor_reference, ...@@ -254,7 +257,7 @@ def _queue_management_worker(executor_reference,
# Release the queue's resources as soon as possible. # Release the queue's resources as soon as possible.
call_queue.close() call_queue.close()
# If .join() is not called on the created processes then # If .join() is not called on the created processes then
# some multiprocessing.Queue methods may deadlock on Mac OS X. # some ctx.Queue methods may deadlock on Mac OS X.
for p in processes.values(): for p in processes.values():
p.join() p.join()
...@@ -377,13 +380,15 @@ class BrokenProcessPool(RuntimeError): ...@@ -377,13 +380,15 @@ class BrokenProcessPool(RuntimeError):
class ProcessPoolExecutor(_base.Executor): class ProcessPoolExecutor(_base.Executor):
def __init__(self, max_workers=None): def __init__(self, max_workers=None, mp_context=None):
"""Initializes a new ProcessPoolExecutor instance. """Initializes a new ProcessPoolExecutor instance.
Args: Args:
max_workers: The maximum number of processes that can be used to max_workers: The maximum number of processes that can be used to
execute the given calls. If None or not given then as many execute the given calls. If None or not given then as many
worker processes will be created as the machine has processors. worker processes will be created as the machine has processors.
mp_context: A multiprocessing context to launch the workers. This
object should provide SimpleQueue, Queue and Process.
""" """
_check_system_limits() _check_system_limits()
...@@ -394,17 +399,20 @@ class ProcessPoolExecutor(_base.Executor): ...@@ -394,17 +399,20 @@ class ProcessPoolExecutor(_base.Executor):
raise ValueError("max_workers must be greater than 0") raise ValueError("max_workers must be greater than 0")
self._max_workers = max_workers self._max_workers = max_workers
if mp_context is None:
mp_context = mp.get_context()
self._mp_context = mp_context
# Make the call queue slightly larger than the number of processes to # Make the call queue slightly larger than the number of processes to
# prevent the worker processes from idling. But don't make it too big # prevent the worker processes from idling. But don't make it too big
# because futures in the call queue cannot be cancelled. # because futures in the call queue cannot be cancelled.
self._call_queue = multiprocessing.Queue(self._max_workers + queue_size = self._max_workers + EXTRA_QUEUED_CALLS
EXTRA_QUEUED_CALLS) self._call_queue = mp_context.Queue(queue_size)
# Killed worker processes can produce spurious "broken pipe" # Killed worker processes can produce spurious "broken pipe"
# tracebacks in the queue's own worker thread. But we detect killed # tracebacks in the queue's own worker thread. But we detect killed
# processes anyway, so silence the tracebacks. # processes anyway, so silence the tracebacks.
self._call_queue._ignore_epipe = True self._call_queue._ignore_epipe = True
self._result_queue = SimpleQueue() self._result_queue = mp_context.SimpleQueue()
self._work_ids = queue.Queue() self._work_ids = queue.Queue()
self._queue_management_thread = None self._queue_management_thread = None
# Map of pids to processes # Map of pids to processes
...@@ -426,23 +434,23 @@ class ProcessPoolExecutor(_base.Executor): ...@@ -426,23 +434,23 @@ class ProcessPoolExecutor(_base.Executor):
# Start the processes so that their sentinels are known. # Start the processes so that their sentinels are known.
self._adjust_process_count() self._adjust_process_count()
self._queue_management_thread = threading.Thread( self._queue_management_thread = threading.Thread(
target=_queue_management_worker, target=_queue_management_worker,
args=(weakref.ref(self, weakref_cb), args=(weakref.ref(self, weakref_cb),
self._processes, self._processes,
self._pending_work_items, self._pending_work_items,
self._work_ids, self._work_ids,
self._call_queue, self._call_queue,
self._result_queue)) self._result_queue))
self._queue_management_thread.daemon = True self._queue_management_thread.daemon = True
self._queue_management_thread.start() self._queue_management_thread.start()
_threads_queues[self._queue_management_thread] = self._result_queue _threads_queues[self._queue_management_thread] = self._result_queue
def _adjust_process_count(self): def _adjust_process_count(self):
for _ in range(len(self._processes), self._max_workers): for _ in range(len(self._processes), self._max_workers):
p = multiprocessing.Process( p = self._mp_context.Process(
target=_process_worker, target=_process_worker,
args=(self._call_queue, args=(self._call_queue,
self._result_queue)) self._result_queue))
p.start() p.start()
self._processes[p.pid] = p self._processes[p.pid] = p
......
...@@ -19,6 +19,7 @@ from concurrent import futures ...@@ -19,6 +19,7 @@ from concurrent import futures
from concurrent.futures._base import ( from concurrent.futures._base import (
PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future) PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future)
from concurrent.futures.process import BrokenProcessPool from concurrent.futures.process import BrokenProcessPool
from multiprocessing import get_context
def create_future(state=PENDING, exception=None, result=None): def create_future(state=PENDING, exception=None, result=None):
...@@ -56,6 +57,15 @@ class MyObject(object): ...@@ -56,6 +57,15 @@ class MyObject(object):
pass pass
class EventfulGCObj():
def __init__(self, ctx):
mgr = get_context(ctx).Manager()
self.event = mgr.Event()
def __del__(self):
self.event.set()
def make_dummy_object(_): def make_dummy_object(_):
return MyObject() return MyObject()
...@@ -77,7 +87,13 @@ class ExecutorMixin: ...@@ -77,7 +87,13 @@ class ExecutorMixin:
self.t1 = time.time() self.t1 = time.time()
try: try:
self.executor = self.executor_type(max_workers=self.worker_count) if hasattr(self, "ctx"):
self.executor = self.executor_type(
max_workers=self.worker_count,
mp_context=get_context(self.ctx))
else:
self.executor = self.executor_type(
max_workers=self.worker_count)
except NotImplementedError as e: except NotImplementedError as e:
self.skipTest(str(e)) self.skipTest(str(e))
self._prime_executor() self._prime_executor()
...@@ -107,8 +123,29 @@ class ThreadPoolMixin(ExecutorMixin): ...@@ -107,8 +123,29 @@ class ThreadPoolMixin(ExecutorMixin):
executor_type = futures.ThreadPoolExecutor executor_type = futures.ThreadPoolExecutor
class ProcessPoolMixin(ExecutorMixin): class ProcessPoolForkMixin(ExecutorMixin):
executor_type = futures.ProcessPoolExecutor
ctx = "fork"
def setUp(self):
if sys.platform == "win32":
self.skipTest("require unix system")
super().setUp()
class ProcessPoolSpawnMixin(ExecutorMixin):
executor_type = futures.ProcessPoolExecutor
ctx = "spawn"
class ProcessPoolForkserverMixin(ExecutorMixin):
executor_type = futures.ProcessPoolExecutor executor_type = futures.ProcessPoolExecutor
ctx = "forkserver"
def setUp(self):
if sys.platform == "win32":
self.skipTest("require unix system")
super().setUp()
class ExecutorShutdownTest: class ExecutorShutdownTest:
...@@ -124,9 +161,17 @@ class ExecutorShutdownTest: ...@@ -124,9 +161,17 @@ class ExecutorShutdownTest:
from concurrent.futures import {executor_type} from concurrent.futures import {executor_type}
from time import sleep from time import sleep
from test.test_concurrent_futures import sleep_and_print from test.test_concurrent_futures import sleep_and_print
t = {executor_type}(5) if __name__ == "__main__":
t.submit(sleep_and_print, 1.0, "apple") context = '{context}'
""".format(executor_type=self.executor_type.__name__)) if context == "":
t = {executor_type}(5)
else:
from multiprocessing import get_context
context = get_context(context)
t = {executor_type}(5, mp_context=context)
t.submit(sleep_and_print, 1.0, "apple")
""".format(executor_type=self.executor_type.__name__,
context=getattr(self, "ctx", "")))
# Errors in atexit hooks don't change the process exit code, check # Errors in atexit hooks don't change the process exit code, check
# stderr manually. # stderr manually.
self.assertFalse(err) self.assertFalse(err)
...@@ -194,7 +239,7 @@ class ThreadPoolShutdownTest(ThreadPoolMixin, ExecutorShutdownTest, BaseTestCase ...@@ -194,7 +239,7 @@ class ThreadPoolShutdownTest(ThreadPoolMixin, ExecutorShutdownTest, BaseTestCase
t.join() t.join()
class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest, BaseTestCase): class ProcessPoolShutdownTest(ExecutorShutdownTest):
def _prime_executor(self): def _prime_executor(self):
pass pass
...@@ -233,6 +278,22 @@ class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest, BaseTestCa ...@@ -233,6 +278,22 @@ class ProcessPoolShutdownTest(ProcessPoolMixin, ExecutorShutdownTest, BaseTestCa
call_queue.join_thread() call_queue.join_thread()
class ProcessPoolForkShutdownTest(ProcessPoolForkMixin, BaseTestCase,
ProcessPoolShutdownTest):
pass
class ProcessPoolForkserverShutdownTest(ProcessPoolForkserverMixin,
BaseTestCase,
ProcessPoolShutdownTest):
pass
class ProcessPoolSpawnShutdownTest(ProcessPoolSpawnMixin, BaseTestCase,
ProcessPoolShutdownTest):
pass
class WaitTests: class WaitTests:
def test_first_completed(self): def test_first_completed(self):
...@@ -352,7 +413,17 @@ class ThreadPoolWaitTests(ThreadPoolMixin, WaitTests, BaseTestCase): ...@@ -352,7 +413,17 @@ class ThreadPoolWaitTests(ThreadPoolMixin, WaitTests, BaseTestCase):
sys.setswitchinterval(oldswitchinterval) sys.setswitchinterval(oldswitchinterval)
class ProcessPoolWaitTests(ProcessPoolMixin, WaitTests, BaseTestCase): class ProcessPoolForkWaitTests(ProcessPoolForkMixin, WaitTests, BaseTestCase):
pass
class ProcessPoolForkserverWaitTests(ProcessPoolForkserverMixin, WaitTests,
BaseTestCase):
pass
class ProcessPoolSpawnWaitTests(ProcessPoolSpawnMixin, BaseTestCase,
WaitTests):
pass pass
...@@ -440,7 +511,19 @@ class ThreadPoolAsCompletedTests(ThreadPoolMixin, AsCompletedTests, BaseTestCase ...@@ -440,7 +511,19 @@ class ThreadPoolAsCompletedTests(ThreadPoolMixin, AsCompletedTests, BaseTestCase
pass pass
class ProcessPoolAsCompletedTests(ProcessPoolMixin, AsCompletedTests, BaseTestCase): class ProcessPoolForkAsCompletedTests(ProcessPoolForkMixin, AsCompletedTests,
BaseTestCase):
pass
class ProcessPoolForkserverAsCompletedTests(ProcessPoolForkserverMixin,
AsCompletedTests,
BaseTestCase):
pass
class ProcessPoolSpawnAsCompletedTests(ProcessPoolSpawnMixin, AsCompletedTests,
BaseTestCase):
pass pass
...@@ -540,7 +623,7 @@ class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest, BaseTestCase): ...@@ -540,7 +623,7 @@ class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest, BaseTestCase):
(os.cpu_count() or 1) * 5) (os.cpu_count() or 1) * 5)
class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest, BaseTestCase): class ProcessPoolExecutorTest(ExecutorTest):
def test_killed_child(self): def test_killed_child(self):
# When a child process is abruptly terminated, the whole pool gets # When a child process is abruptly terminated, the whole pool gets
# "broken". # "broken".
...@@ -595,6 +678,34 @@ class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest, BaseTestCase): ...@@ -595,6 +678,34 @@ class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest, BaseTestCase):
self.assertIn('raise RuntimeError(123) # some comment', self.assertIn('raise RuntimeError(123) # some comment',
f1.getvalue()) f1.getvalue())
def test_ressources_gced_in_workers(self):
# Ensure that argument for a job are correctly gc-ed after the job
# is finished
obj = EventfulGCObj(self.ctx)
future = self.executor.submit(id, obj)
future.result()
self.assertTrue(obj.event.wait(timeout=1))
class ProcessPoolForkExecutorTest(ProcessPoolForkMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class ProcessPoolForkserverExecutorTest(ProcessPoolForkserverMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class ProcessPoolSpawnExecutorTest(ProcessPoolSpawnMixin,
ProcessPoolExecutorTest,
BaseTestCase):
pass
class FutureTests(BaseTestCase): class FutureTests(BaseTestCase):
def test_done_callback_with_result(self): def test_done_callback_with_result(self):
......
Allow passing a context object in
:class:`concurrent.futures.ProcessPoolExecutor` constructor.
Also, free job ressources in :class:`concurrent.futures.ProcessPoolExecutor`
earlier to improve memory usage when a worker waits for new jobs.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment