Commit 3766f18c authored by Pablo Galindo's avatar Pablo Galindo Committed by GitHub

bpo-35378: Fix multiprocessing.Pool references (GH-11627)

Changes in this commit:

1. Use a _strong_ reference between the Pool and associated iterators
2. Rework PR #8450 to eliminate a cycle in the Pool.

There is no test in this commit because any test that automatically tests this behaviour needs to eliminate the pool before joining the pool to check that the pool object is garbaged collected/does not hang. But doing this will potentially leak threads and processes (see https://bugs.python.org/issue35413).
parent 4b250fc1
...@@ -151,8 +151,9 @@ class Pool(object): ...@@ -151,8 +151,9 @@ class Pool(object):
''' '''
_wrap_exception = True _wrap_exception = True
def Process(self, *args, **kwds): @staticmethod
return self._ctx.Process(*args, **kwds) def Process(ctx, *args, **kwds):
return ctx.Process(*args, **kwds)
def __init__(self, processes=None, initializer=None, initargs=(), def __init__(self, processes=None, initializer=None, initargs=(),
maxtasksperchild=None, context=None): maxtasksperchild=None, context=None):
...@@ -190,7 +191,10 @@ class Pool(object): ...@@ -190,7 +191,10 @@ class Pool(object):
self._worker_handler = threading.Thread( self._worker_handler = threading.Thread(
target=Pool._handle_workers, target=Pool._handle_workers,
args=(self, ) args=(self._cache, self._taskqueue, self._ctx, self.Process,
self._processes, self._pool, self._inqueue, self._outqueue,
self._initializer, self._initargs, self._maxtasksperchild,
self._wrap_exception)
) )
self._worker_handler.daemon = True self._worker_handler.daemon = True
self._worker_handler._state = RUN self._worker_handler._state = RUN
...@@ -236,43 +240,61 @@ class Pool(object): ...@@ -236,43 +240,61 @@ class Pool(object):
f'state={self._state} ' f'state={self._state} '
f'pool_size={len(self._pool)}>') f'pool_size={len(self._pool)}>')
def _join_exited_workers(self): @staticmethod
def _join_exited_workers(pool):
"""Cleanup after any worker processes which have exited due to reaching """Cleanup after any worker processes which have exited due to reaching
their specified lifetime. Returns True if any workers were cleaned up. their specified lifetime. Returns True if any workers were cleaned up.
""" """
cleaned = False cleaned = False
for i in reversed(range(len(self._pool))): for i in reversed(range(len(pool))):
worker = self._pool[i] worker = pool[i]
if worker.exitcode is not None: if worker.exitcode is not None:
# worker exited # worker exited
util.debug('cleaning up worker %d' % i) util.debug('cleaning up worker %d' % i)
worker.join() worker.join()
cleaned = True cleaned = True
del self._pool[i] del pool[i]
return cleaned return cleaned
def _repopulate_pool(self): def _repopulate_pool(self):
return self._repopulate_pool_static(self._ctx, self.Process,
self._processes,
self._pool, self._inqueue,
self._outqueue, self._initializer,
self._initargs,
self._maxtasksperchild,
self._wrap_exception)
@staticmethod
def _repopulate_pool_static(ctx, Process, processes, pool, inqueue,
outqueue, initializer, initargs,
maxtasksperchild, wrap_exception):
"""Bring the number of pool processes up to the specified number, """Bring the number of pool processes up to the specified number,
for use after reaping workers which have exited. for use after reaping workers which have exited.
""" """
for i in range(self._processes - len(self._pool)): for i in range(processes - len(pool)):
w = self.Process(target=worker, w = Process(ctx, target=worker,
args=(self._inqueue, self._outqueue, args=(inqueue, outqueue,
self._initializer, initializer,
self._initargs, self._maxtasksperchild, initargs, maxtasksperchild,
self._wrap_exception) wrap_exception))
)
w.name = w.name.replace('Process', 'PoolWorker') w.name = w.name.replace('Process', 'PoolWorker')
w.daemon = True w.daemon = True
w.start() w.start()
self._pool.append(w) pool.append(w)
util.debug('added worker') util.debug('added worker')
def _maintain_pool(self): @staticmethod
def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue,
initializer, initargs, maxtasksperchild,
wrap_exception):
"""Clean up any exited workers and start replacements for them. """Clean up any exited workers and start replacements for them.
""" """
if self._join_exited_workers(): if Pool._join_exited_workers(pool):
self._repopulate_pool() Pool._repopulate_pool_static(ctx, Process, processes, pool,
inqueue, outqueue, initializer,
initargs, maxtasksperchild,
wrap_exception)
def _setup_queues(self): def _setup_queues(self):
self._inqueue = self._ctx.SimpleQueue() self._inqueue = self._ctx.SimpleQueue()
...@@ -331,7 +353,7 @@ class Pool(object): ...@@ -331,7 +353,7 @@ class Pool(object):
''' '''
self._check_running() self._check_running()
if chunksize == 1: if chunksize == 1:
result = IMapIterator(self._cache) result = IMapIterator(self)
self._taskqueue.put( self._taskqueue.put(
( (
self._guarded_task_generation(result._job, func, iterable), self._guarded_task_generation(result._job, func, iterable),
...@@ -344,7 +366,7 @@ class Pool(object): ...@@ -344,7 +366,7 @@ class Pool(object):
"Chunksize must be 1+, not {0:n}".format( "Chunksize must be 1+, not {0:n}".format(
chunksize)) chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize) task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapIterator(self._cache) result = IMapIterator(self)
self._taskqueue.put( self._taskqueue.put(
( (
self._guarded_task_generation(result._job, self._guarded_task_generation(result._job,
...@@ -360,7 +382,7 @@ class Pool(object): ...@@ -360,7 +382,7 @@ class Pool(object):
''' '''
self._check_running() self._check_running()
if chunksize == 1: if chunksize == 1:
result = IMapUnorderedIterator(self._cache) result = IMapUnorderedIterator(self)
self._taskqueue.put( self._taskqueue.put(
( (
self._guarded_task_generation(result._job, func, iterable), self._guarded_task_generation(result._job, func, iterable),
...@@ -372,7 +394,7 @@ class Pool(object): ...@@ -372,7 +394,7 @@ class Pool(object):
raise ValueError( raise ValueError(
"Chunksize must be 1+, not {0!r}".format(chunksize)) "Chunksize must be 1+, not {0!r}".format(chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize) task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapUnorderedIterator(self._cache) result = IMapUnorderedIterator(self)
self._taskqueue.put( self._taskqueue.put(
( (
self._guarded_task_generation(result._job, self._guarded_task_generation(result._job,
...@@ -388,7 +410,7 @@ class Pool(object): ...@@ -388,7 +410,7 @@ class Pool(object):
Asynchronous version of `apply()` method. Asynchronous version of `apply()` method.
''' '''
self._check_running() self._check_running()
result = ApplyResult(self._cache, callback, error_callback) result = ApplyResult(self, callback, error_callback)
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None)) self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result return result
...@@ -417,7 +439,7 @@ class Pool(object): ...@@ -417,7 +439,7 @@ class Pool(object):
chunksize = 0 chunksize = 0
task_batches = Pool._get_tasks(func, iterable, chunksize) task_batches = Pool._get_tasks(func, iterable, chunksize)
result = MapResult(self._cache, chunksize, len(iterable), callback, result = MapResult(self, chunksize, len(iterable), callback,
error_callback=error_callback) error_callback=error_callback)
self._taskqueue.put( self._taskqueue.put(
( (
...@@ -430,16 +452,20 @@ class Pool(object): ...@@ -430,16 +452,20 @@ class Pool(object):
return result return result
@staticmethod @staticmethod
def _handle_workers(pool): def _handle_workers(cache, taskqueue, ctx, Process, processes, pool,
inqueue, outqueue, initializer, initargs,
maxtasksperchild, wrap_exception):
thread = threading.current_thread() thread = threading.current_thread()
# Keep maintaining workers until the cache gets drained, unless the pool # Keep maintaining workers until the cache gets drained, unless the pool
# is terminated. # is terminated.
while thread._state == RUN or (pool._cache and thread._state != TERMINATE): while thread._state == RUN or (cache and thread._state != TERMINATE):
pool._maintain_pool() Pool._maintain_pool(ctx, Process, processes, pool, inqueue,
outqueue, initializer, initargs,
maxtasksperchild, wrap_exception)
time.sleep(0.1) time.sleep(0.1)
# send sentinel to stop workers # send sentinel to stop workers
pool._taskqueue.put(None) taskqueue.put(None)
util.debug('worker handler exiting') util.debug('worker handler exiting')
@staticmethod @staticmethod
...@@ -656,13 +682,14 @@ class Pool(object): ...@@ -656,13 +682,14 @@ class Pool(object):
class ApplyResult(object): class ApplyResult(object):
def __init__(self, cache, callback, error_callback): def __init__(self, pool, callback, error_callback):
self._pool = pool
self._event = threading.Event() self._event = threading.Event()
self._job = next(job_counter) self._job = next(job_counter)
self._cache = cache self._cache = pool._cache
self._callback = callback self._callback = callback
self._error_callback = error_callback self._error_callback = error_callback
cache[self._job] = self self._cache[self._job] = self
def ready(self): def ready(self):
return self._event.is_set() return self._event.is_set()
...@@ -692,6 +719,7 @@ class ApplyResult(object): ...@@ -692,6 +719,7 @@ class ApplyResult(object):
self._error_callback(self._value) self._error_callback(self._value)
self._event.set() self._event.set()
del self._cache[self._job] del self._cache[self._job]
self._pool = None
AsyncResult = ApplyResult # create alias -- see #17805 AsyncResult = ApplyResult # create alias -- see #17805
...@@ -701,8 +729,8 @@ AsyncResult = ApplyResult # create alias -- see #17805 ...@@ -701,8 +729,8 @@ AsyncResult = ApplyResult # create alias -- see #17805
class MapResult(ApplyResult): class MapResult(ApplyResult):
def __init__(self, cache, chunksize, length, callback, error_callback): def __init__(self, pool, chunksize, length, callback, error_callback):
ApplyResult.__init__(self, cache, callback, ApplyResult.__init__(self, pool, callback,
error_callback=error_callback) error_callback=error_callback)
self._success = True self._success = True
self._value = [None] * length self._value = [None] * length
...@@ -710,7 +738,7 @@ class MapResult(ApplyResult): ...@@ -710,7 +738,7 @@ class MapResult(ApplyResult):
if chunksize <= 0: if chunksize <= 0:
self._number_left = 0 self._number_left = 0
self._event.set() self._event.set()
del cache[self._job] del self._cache[self._job]
else: else:
self._number_left = length//chunksize + bool(length % chunksize) self._number_left = length//chunksize + bool(length % chunksize)
...@@ -724,6 +752,7 @@ class MapResult(ApplyResult): ...@@ -724,6 +752,7 @@ class MapResult(ApplyResult):
self._callback(self._value) self._callback(self._value)
del self._cache[self._job] del self._cache[self._job]
self._event.set() self._event.set()
self._pool = None
else: else:
if not success and self._success: if not success and self._success:
# only store first exception # only store first exception
...@@ -735,6 +764,7 @@ class MapResult(ApplyResult): ...@@ -735,6 +764,7 @@ class MapResult(ApplyResult):
self._error_callback(self._value) self._error_callback(self._value)
del self._cache[self._job] del self._cache[self._job]
self._event.set() self._event.set()
self._pool = None
# #
# Class whose instances are returned by `Pool.imap()` # Class whose instances are returned by `Pool.imap()`
...@@ -742,15 +772,16 @@ class MapResult(ApplyResult): ...@@ -742,15 +772,16 @@ class MapResult(ApplyResult):
class IMapIterator(object): class IMapIterator(object):
def __init__(self, cache): def __init__(self, pool):
self._pool = pool
self._cond = threading.Condition(threading.Lock()) self._cond = threading.Condition(threading.Lock())
self._job = next(job_counter) self._job = next(job_counter)
self._cache = cache self._cache = pool._cache
self._items = collections.deque() self._items = collections.deque()
self._index = 0 self._index = 0
self._length = None self._length = None
self._unsorted = {} self._unsorted = {}
cache[self._job] = self self._cache[self._job] = self
def __iter__(self): def __iter__(self):
return self return self
...@@ -761,12 +792,14 @@ class IMapIterator(object): ...@@ -761,12 +792,14 @@ class IMapIterator(object):
item = self._items.popleft() item = self._items.popleft()
except IndexError: except IndexError:
if self._index == self._length: if self._index == self._length:
self._pool = None
raise StopIteration from None raise StopIteration from None
self._cond.wait(timeout) self._cond.wait(timeout)
try: try:
item = self._items.popleft() item = self._items.popleft()
except IndexError: except IndexError:
if self._index == self._length: if self._index == self._length:
self._pool = None
raise StopIteration from None raise StopIteration from None
raise TimeoutError from None raise TimeoutError from None
...@@ -792,6 +825,7 @@ class IMapIterator(object): ...@@ -792,6 +825,7 @@ class IMapIterator(object):
if self._index == self._length: if self._index == self._length:
del self._cache[self._job] del self._cache[self._job]
self._pool = None
def _set_length(self, length): def _set_length(self, length):
with self._cond: with self._cond:
...@@ -799,6 +833,7 @@ class IMapIterator(object): ...@@ -799,6 +833,7 @@ class IMapIterator(object):
if self._index == self._length: if self._index == self._length:
self._cond.notify() self._cond.notify()
del self._cache[self._job] del self._cache[self._job]
self._pool = None
# #
# Class whose instances are returned by `Pool.imap_unordered()` # Class whose instances are returned by `Pool.imap_unordered()`
...@@ -813,6 +848,7 @@ class IMapUnorderedIterator(IMapIterator): ...@@ -813,6 +848,7 @@ class IMapUnorderedIterator(IMapIterator):
self._cond.notify() self._cond.notify()
if self._index == self._length: if self._index == self._length:
del self._cache[self._job] del self._cache[self._job]
self._pool = None
# #
# #
...@@ -822,7 +858,7 @@ class ThreadPool(Pool): ...@@ -822,7 +858,7 @@ class ThreadPool(Pool):
_wrap_exception = False _wrap_exception = False
@staticmethod @staticmethod
def Process(*args, **kwds): def Process(ctx, *args, **kwds):
from .dummy import Process from .dummy import Process
return Process(*args, **kwds) return Process(*args, **kwds)
......
...@@ -2593,7 +2593,6 @@ class _TestPool(BaseTestCase): ...@@ -2593,7 +2593,6 @@ class _TestPool(BaseTestCase):
pool = None pool = None
support.gc_collect() support.gc_collect()
def raising(): def raising():
raise KeyError("key") raise KeyError("key")
......
Fix a reference issue inside :class:`multiprocessing.Pool` that caused
the pool to remain alive if it was deleted without being closed or
terminated explicitly. A new strong reference is added to the pool
iterators to link the lifetime of the pool to the lifetime of its
iterators so the pool does not get destroyed if a pool iterator is
still alive.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment