Commit 794623bd authored by Xiang Zhang's avatar Xiang Zhang Committed by GitHub

bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-693)

an exception raised at the very first of an iterable would cause pools behave abnormally
(swallow the exception or hang)
parent ec1f5df4
......@@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
try:
result = (True, func(*args, **kwds))
except Exception as e:
if wrap_exception:
if wrap_exception and func is not _helper_reraises_exception:
e = ExceptionWithTraceback(e, e.__traceback__)
result = (False, e)
try:
......@@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
completed += 1
util.debug('worker exiting after %d tasks' % completed)
def _helper_reraises_exception(ex):
'Pickle-able helper function for use by _guarded_task_generation.'
raise ex
#
# Class representing a process pool
#
......@@ -277,6 +281,17 @@ class Pool(object):
return self._map_async(func, iterable, starmapstar, chunksize,
callback, error_callback)
def _guarded_task_generation(self, result_job, func, iterable):
'''Provides a generator of tasks for imap and imap_unordered with
appropriate handling for iterables which throw exceptions during
iteration.'''
try:
i = -1
for i, x in enumerate(iterable):
yield (result_job, i, func, (x,), {})
except Exception as e:
yield (result_job, i+1, _helper_reraises_exception, (e,), {})
def imap(self, func, iterable, chunksize=1):
'''
Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
......@@ -285,15 +300,23 @@ class Pool(object):
raise ValueError("Pool not running")
if chunksize == 1:
result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {})
for i, x in enumerate(iterable)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result
else:
assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)
def imap_unordered(self, func, iterable, chunksize=1):
......@@ -304,15 +327,23 @@ class Pool(object):
raise ValueError("Pool not running")
if chunksize == 1:
result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, func, (x,), {})
for i, x in enumerate(iterable)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
result._set_length
))
return result
else:
assert chunksize > 1
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = IMapUnorderedIterator(self._cache)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), result._set_length))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)
def apply_async(self, func, args=(), kwds={}, callback=None,
......@@ -323,7 +354,7 @@ class Pool(object):
if self._state != RUN:
raise ValueError("Pool not running")
result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result
def map_async(self, func, iterable, chunksize=None, callback=None,
......@@ -354,8 +385,14 @@ class Pool(object):
task_batches = Pool._get_tasks(func, iterable, chunksize)
result = MapResult(self._cache, chunksize, len(iterable), callback,
error_callback=error_callback)
self._taskqueue.put((((result._job, i, mapper, (x,), {})
for i, x in enumerate(task_batches)), None))
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mapper,
task_batches),
None
)
)
return result
@staticmethod
......@@ -377,33 +414,27 @@ class Pool(object):
for taskseq, set_length in iter(taskqueue.get, None):
task = None
i = -1
try:
for i, task in enumerate(taskseq):
# iterating taskseq cannot fail
for task in taskseq:
if thread._state:
util.debug('task handler found thread._state != RUN')
break
try:
put(task)
except Exception as e:
job, ind = task[:2]
job, idx = task[:2]
try:
cache[job]._set(ind, (False, e))
cache[job]._set(idx, (False, e))
except KeyError:
pass
else:
if set_length:
util.debug('doing set_length()')
set_length(i+1)
idx = task[1] if task else -1
set_length(idx + 1)
continue
break
except Exception as ex:
job, ind = task[:2] if task else (0, 0)
if job in cache:
cache[job]._set(ind + 1, (False, ex))
if set_length:
util.debug('doing set_length()')
set_length(i+1)
finally:
task = taskseq = job = None
else:
......
......@@ -1755,6 +1755,8 @@ class CountedObject(object):
class SayWhenError(ValueError): pass
def exception_throwing_generator(total, when):
if when == -1:
raise SayWhenError("Somebody said when")
for i in range(total):
if i == when:
raise SayWhenError("Somebody said when")
......@@ -1833,6 +1835,32 @@ class _TestPool(BaseTestCase):
except multiprocessing.TimeoutError:
self.fail("pool.map_async with chunksize stalled on null list")
def test_map_handle_iterable_exception(self):
if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
# again, make sure it's reentrant
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
with self.assertRaises(SayWhenError):
self.pool.map(sqr, exception_throwing_generator(10, 3), 1)
class SpecialIterable:
def __iter__(self):
return self
def __next__(self):
raise SayWhenError
def __len__(self):
return 1
with self.assertRaises(SayWhenError):
self.pool.map(sqr, SpecialIterable(), 1)
with self.assertRaises(SayWhenError):
self.pool.map(sqr, SpecialIterable(), 1)
def test_async(self):
res = self.pool.apply_async(sqr, (7, TIMEOUT1,))
get = TimingWrapper(res.get)
......@@ -1863,6 +1891,13 @@ class _TestPool(BaseTestCase):
if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
self.assertRaises(SayWhenError, it.__next__)
# again, make sure it's reentrant
it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
self.assertRaises(SayWhenError, it.__next__)
it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1)
for i in range(3):
self.assertEqual(next(it), i*i)
......@@ -1889,6 +1924,17 @@ class _TestPool(BaseTestCase):
if self.TYPE == 'manager':
self.skipTest('test not appropriate for {}'.format(self.TYPE))
# SayWhenError seen at the very first of the iterable
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(1, -1),
1)
self.assertRaises(SayWhenError, it.__next__)
# again, make sure it's reentrant
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(1, -1),
1)
self.assertRaises(SayWhenError, it.__next__)
it = self.pool.imap_unordered(sqr,
exception_throwing_generator(10, 3),
1)
......@@ -1970,7 +2016,7 @@ class _TestPool(BaseTestCase):
except Exception as e:
exc = e
else:
raise AssertionError('expected RuntimeError')
self.fail('expected RuntimeError')
self.assertIs(type(exc), RuntimeError)
self.assertEqual(exc.args, (123,))
cause = exc.__cause__
......@@ -1984,6 +2030,17 @@ class _TestPool(BaseTestCase):
sys.excepthook(*sys.exc_info())
self.assertIn('raise RuntimeError(123) # some comment',
f1.getvalue())
# _helper_reraises_exception should not make the error
# a remote exception
with self.Pool(1) as p:
try:
p.map(sqr, exception_throwing_generator(1, -1), 1)
except Exception as e:
exc = e
else:
self.fail('expected SayWhenError')
self.assertIs(type(exc), SayWhenError)
self.assertIs(exc.__cause__, None)
@classmethod
def _test_wrapped_exception(cls):
......
......@@ -291,6 +291,10 @@ Extension Modules
Library
-------
- bpo-28699: Fixed a bug in pools in multiprocessing.pool that raising an
exception at the very first of an iterable may swallow the exception or
make the program hang. Patch by Davin Potts and Xiang Zhang.
- bpo-23890: unittest.TestCase.assertRaises() now manually breaks a reference
cycle to not keep objects alive longer than expected.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment