Commit 794623bd authored by Xiang Zhang's avatar Xiang Zhang Committed by GitHub

bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-693)

an exception raised at the very first of an iterable would cause pools behave abnormally
(swallow the exception or hang)
parent ec1f5df4
...@@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, ...@@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
try: try:
result = (True, func(*args, **kwds)) result = (True, func(*args, **kwds))
except Exception as e: except Exception as e:
if wrap_exception: if wrap_exception and func is not _helper_reraises_exception:
e = ExceptionWithTraceback(e, e.__traceback__) e = ExceptionWithTraceback(e, e.__traceback__)
result = (False, e) result = (False, e)
try: try:
...@@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, ...@@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
completed += 1 completed += 1
util.debug('worker exiting after %d tasks' % completed) util.debug('worker exiting after %d tasks' % completed)
def _helper_reraises_exception(ex):
'Pickle-able helper function for use by _guarded_task_generation.'
raise ex
# #
# Class representing a process pool # Class representing a process pool
# #
...@@ -277,6 +281,17 @@ class Pool(object): ...@@ -277,6 +281,17 @@ class Pool(object):
return self._map_async(func, iterable, starmapstar, chunksize, return self._map_async(func, iterable, starmapstar, chunksize,
callback, error_callback) callback, error_callback)
def _guarded_task_generation(self, result_job, func, iterable):
'''Provides a generator of tasks for imap and imap_unordered with
appropriate handling for iterables which throw exceptions during
iteration.'''
try:
i = -1
for i, x in enumerate(iterable):
yield (result_job, i, func, (x,), {})
except Exception as e:
yield (result_job, i+1, _helper_reraises_exception, (e,), {})
def imap(self, func, iterable, chunksize=1): def imap(self, func, iterable, chunksize=1):
''' '''
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`. Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
...@@ -285,15 +300,23 @@ class Pool(object): ...@@ -285,15 +300,23 @@ class Pool(object):
raise ValueError("Pool not running") raise ValueError("Pool not running")
if chunksize == 1: if chunksize == 1:
result = IMapIterator(self._cache) result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {}) self._taskqueue.put(
for i, x in enumerate(iterable)), result._set_length)) (
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result return result
else: else:
assert chunksize > 1 assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize) task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapIterator(self._cache) result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {}) self._taskqueue.put(
for i, x in enumerate(task_batches)), result._set_length)) (
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk) return (item for chunk in result for item in chunk)
def imap_unordered(self, func, iterable, chunksize=1): def imap_unordered(self, func, iterable, chunksize=1):
...@@ -304,15 +327,23 @@ class Pool(object): ...@@ -304,15 +327,23 @@ class Pool(object):
raise ValueError("Pool not running") raise ValueError("Pool not running")
if chunksize == 1: if chunksize == 1:
result = IMapUnorderedIterator(self._cache) result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {}) self._taskqueue.put(
for i, x in enumerate(iterable)), result._set_length)) (
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result return result
else: else:
assert chunksize > 1 assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize) task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapUnorderedIterator(self._cache) result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {}) self._taskqueue.put(
for i, x in enumerate(task_batches)), result._set_length)) (
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk) return (item for chunk in result for item in chunk)
def apply_async(self, func, args=(), kwds={}, callback=None, def apply_async(self, func, args=(), kwds={}, callback=None,
...@@ -323,7 +354,7 @@ class Pool(object): ...@@ -323,7 +354,7 @@ class Pool(object):
if self._state != RUN: if self._state != RUN:
raise ValueError("Pool not running") raise ValueError("Pool not running")
result = ApplyResult(self._cache, callback, error_callback) result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result return result
def map_async(self, func, iterable, chunksize=None, callback=None, def map_async(self, func, iterable, chunksize=None, callback=None,
...@@ -354,8 +385,14 @@ class Pool(object): ...@@ -354,8 +385,14 @@ class Pool(object):
task_batches = Pool._get_tasks(func, iterable, chunksize) task_batches = Pool._get_tasks(func, iterable, chunksize)
result = MapResult(self._cache, chunksize, len(iterable), callback, result = MapResult(self._cache, chunksize, len(iterable), callback,
error_callback=error_callback) error_callback=error_callback)
self._taskqueue.put((((result._job, i, mapper, (x,), {}) self._taskqueue.put(
for i, x in enumerate(task_batches)), None)) (
self._guarded_task_generation(result._job,
mapper,
task_batches),
None
)
)
return result return result
@staticmethod @staticmethod
...@@ -377,33 +414,27 @@ class Pool(object): ...@@ -377,33 +414,27 @@ class Pool(object):
for taskseq, set_length in iter(taskqueue.get, None): for taskseq, set_length in iter(taskqueue.get, None):
task = None task = None
i = -1
try: try:
for i, task in enumerate(taskseq): # iterating taskseq cannot fail
for task in taskseq:
if thread._state: if thread._state:
util.debug('task handler found thread._state != RUN') util.debug('task handler found thread._state != RUN')
break break
try: try:
put(task) put(task)
except Exception as e: except Exception as e:
job, ind = task[:2] job, idx = task[:2]
try: try:
cache[job]._set(ind, (False, e)) cache[job]._set(idx, (False, e))
except KeyError: except KeyError:
pass pass
else: else:
if set_length: if set_length:
util.debug('doing set_length()') util.debug('doing set_length()')
set_length(i+1) idx = task[1] if task else -1
set_length(idx + 1)
continue continue
break break
except Exception as ex:
job, ind = task[:2] if task else (0, 0)
if job in cache:
cache[job]._set(ind + 1, (False, ex))
if set_length:
util.debug('doing set_length()')
set_length(i+1)
finally: finally:
task = taskseq = job = None task = taskseq = job = None
else: else:
......
...@@ -1755,6 +1755,8 @@ class CountedObject(object): ...@@ -1755,6 +1755,8 @@ class CountedObject(object):
class SayWhenError(ValueError): pass class SayWhenError(ValueError): pass
def exception_throwing_generator(total, when): def exception_throwing_generator(total, when):
if when == -1:
raise SayWhenError("Somebody said when")
for i in range(total): for i in range(total):
if i == when: if i == when:
raise SayWhenError("Somebody said when") raise SayWhenError("Somebody said when")
...@@ -1833,6 +1835,32 @@ class _TestPool(BaseTestCase): ...@@ -1833,6 +1835,32 @@ class _TestPool(BaseTestCase):
except multiprocessing.TimeoutError: except multiprocessing.TimeoutError:
self.fail("pool.map_async with chunksize stalled on null list") self.fail("pool.map_async with chunksize stalled on null list")
def test_map_handle_iterable_exception(self):
if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
# again, make sure it's reentrant
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(10, 3), 1)
class SpecialIterable:
def __iter__(self):
return self
def __next__(self):
raise SayWhenError
def __len__(self):
return 1
with self.assertRaises(SayWhenError):
self.pool.map(sqr, SpecialIterable(), 1)
with self.assertRaises(SayWhenError):
self.pool.map(sqr, SpecialIterable(), 1)
def test_async(self): def test_async(self):
res = self.pool.apply_async(sqr, (7, TIMEOUT1,)) res = self.pool.apply_async(sqr, (7, TIMEOUT1,))
get = TimingWrapper(res.get) get = TimingWrapper(res.get)
...@@ -1863,6 +1891,13 @@ class _TestPool(BaseTestCase): ...@@ -1863,6 +1891,13 @@ class _TestPool(BaseTestCase):
if self.TYPE == 'manager': if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE)) self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
self.assertRaises(SayWhenError, it.__next__)
# again, make sure it's reentrant
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
self.assertRaises(SayWhenError, it.__next__)
it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1) it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1)
for i in range(3): for i in range(3):
self.assertEqual(next(it), i*i) self.assertEqual(next(it), i*i)
...@@ -1889,6 +1924,17 @@ class _TestPool(BaseTestCase): ...@@ -1889,6 +1924,17 @@ class _TestPool(BaseTestCase):
if self.TYPE == 'manager': if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE)) self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(1, -1),
1)
self.assertRaises(SayWhenError, it.__next__)
# again, make sure it's reentrant
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(1, -1),
1)
self.assertRaises(SayWhenError, it.__next__)
it = self.pool.imap_unordered(sqr, it = self.pool.imap_unordered(sqr,
exception_throwing_generator(10, 3), exception_throwing_generator(10, 3),
1) 1)
...@@ -1970,7 +2016,7 @@ class _TestPool(BaseTestCase): ...@@ -1970,7 +2016,7 @@ class _TestPool(BaseTestCase):
except Exception as e: except Exception as e:
exc = e exc = e
else: else:
raise AssertionError('expected RuntimeError') self.fail('expected RuntimeError')
self.assertIs(type(exc), RuntimeError) self.assertIs(type(exc), RuntimeError)
self.assertEqual(exc.args, (123,)) self.assertEqual(exc.args, (123,))
cause = exc.__cause__ cause = exc.__cause__
...@@ -1984,6 +2030,17 @@ class _TestPool(BaseTestCase): ...@@ -1984,6 +2030,17 @@ class _TestPool(BaseTestCase):
sys.excepthook(*sys.exc_info()) sys.excepthook(*sys.exc_info())
self.assertIn('raise RuntimeError(123) # some comment', self.assertIn('raise RuntimeError(123) # some comment',
f1.getvalue()) f1.getvalue())
# _helper_reraises_exception should not make the error
# a remote exception
with self.Pool(1) as p:
try:
p.map(sqr, exception_throwing_generator(1, -1), 1)
except Exception as e:
exc = e
else:
self.fail('expected SayWhenError')
self.assertIs(type(exc), SayWhenError)
self.assertIs(exc.__cause__, None)
@classmethod @classmethod
def _test_wrapped_exception(cls): def _test_wrapped_exception(cls):
......
...@@ -291,6 +291,10 @@ Extension Modules ...@@ -291,6 +291,10 @@ Extension Modules
Library Library
------- -------
- bpo-28699: Fixed a bug in pools in multiprocessing.pool that raising an
exception at the very first of an iterable may swallow the exception or
make the program hang. Patch by Davin Potts and Xiang Zhang.
- bpo-23890: unittest.TestCase.assertRaises() now manually breaks a reference - bpo-23890: unittest.TestCase.assertRaises() now manually breaks a reference
cycle to not keep objects alive longer than expected. cycle to not keep objects alive longer than expected.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment