Commit 3730a17a authored by Richard Oudkerk's avatar Richard Oudkerk

Issue #14059: Implement multiprocessing.Barrier

parent a6bc4b4c
...@@ -226,11 +226,11 @@ However, if you really do need to use some shared data then ...@@ -226,11 +226,11 @@ However, if you really do need to use some shared data then
holds Python objects and allows other processes to manipulate them using holds Python objects and allows other processes to manipulate them using
proxies. proxies.
A manager returned by :func:`Manager` will support types :class:`list`, A manager returned by :func:`Manager` will support types
:class:`dict`, :class:`Namespace`, :class:`Lock`, :class:`RLock`, :class:`list`, :class:`dict`, :class:`Namespace`, :class:`Lock`,
:class:`Semaphore`, :class:`BoundedSemaphore`, :class:`Condition`, :class:`RLock`, :class:`Semaphore`, :class:`BoundedSemaphore`,
:class:`Event`, :class:`Queue`, :class:`Value` and :class:`Array`. For :class:`Condition`, :class:`Event`, :class:`Barrier`,
example, :: :class:`Queue`, :class:`Value` and :class:`Array`. For example, ::
from multiprocessing import Process, Manager from multiprocessing import Process, Manager
...@@ -885,6 +885,12 @@ program as they are in a multithreaded program. See the documentation for ...@@ -885,6 +885,12 @@ program as they are in a multithreaded program. See the documentation for
Note that one can also create synchronization primitives by using a manager Note that one can also create synchronization primitives by using a manager
object -- see :ref:`multiprocessing-managers`. object -- see :ref:`multiprocessing-managers`.
.. class:: Barrier(parties[, action[, timeout]])
A barrier object: a clone of :class:`threading.Barrier`.
.. versionadded:: 3.3
.. class:: BoundedSemaphore([value]) .. class:: BoundedSemaphore([value])
A bounded semaphore object: a clone of :class:`threading.BoundedSemaphore`. A bounded semaphore object: a clone of :class:`threading.BoundedSemaphore`.
...@@ -1280,6 +1286,13 @@ their parent process exits. The manager classes are defined in the ...@@ -1280,6 +1286,13 @@ their parent process exits. The manager classes are defined in the
It also supports creation of shared lists and dictionaries. It also supports creation of shared lists and dictionaries.
.. method:: Barrier(parties[, action[, timeout]])
Create a shared :class:`threading.Barrier` object and return a
proxy for it.
.. versionadded:: 3.3
.. method:: BoundedSemaphore([value]) .. method:: BoundedSemaphore([value])
Create a shared :class:`threading.BoundedSemaphore` object and return a Create a shared :class:`threading.BoundedSemaphore` object and return a
......
...@@ -23,8 +23,8 @@ __all__ = [ ...@@ -23,8 +23,8 @@ __all__ = [
'Manager', 'Pipe', 'cpu_count', 'log_to_stderr', 'get_logger', 'Manager', 'Pipe', 'cpu_count', 'log_to_stderr', 'get_logger',
'allow_connection_pickling', 'BufferTooShort', 'TimeoutError', 'allow_connection_pickling', 'BufferTooShort', 'TimeoutError',
'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition',
'Event', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool', 'Value', 'Array', 'Event', 'Barrier', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool',
'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING', 'Value', 'Array', 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING',
] ]
__author__ = 'R. Oudkerk (r.m.oudkerk@gmail.com)' __author__ = 'R. Oudkerk (r.m.oudkerk@gmail.com)'
...@@ -186,6 +186,13 @@ def Event(): ...@@ -186,6 +186,13 @@ def Event():
from multiprocessing.synchronize import Event from multiprocessing.synchronize import Event
return Event() return Event()
def Barrier(parties, action=None, timeout=None):
'''
Returns a barrier object
'''
from multiprocessing.synchronize import Barrier
return Barrier(parties, action, timeout)
def Queue(maxsize=0): def Queue(maxsize=0):
''' '''
Returns a queue object Returns a queue object
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
__all__ = [ __all__ = [
'Process', 'current_process', 'active_children', 'freeze_support', 'Process', 'current_process', 'active_children', 'freeze_support',
'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition',
'Event', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' 'Event', 'Barrier', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue'
] ]
# #
...@@ -49,7 +49,7 @@ import array ...@@ -49,7 +49,7 @@ import array
from multiprocessing.dummy.connection import Pipe from multiprocessing.dummy.connection import Pipe
from threading import Lock, RLock, Semaphore, BoundedSemaphore from threading import Lock, RLock, Semaphore, BoundedSemaphore
from threading import Event, Condition from threading import Event, Condition, Barrier
from queue import Queue from queue import Queue
# #
......
...@@ -993,6 +993,26 @@ class EventProxy(BaseProxy): ...@@ -993,6 +993,26 @@ class EventProxy(BaseProxy):
def wait(self, timeout=None): def wait(self, timeout=None):
return self._callmethod('wait', (timeout,)) return self._callmethod('wait', (timeout,))
class BarrierProxy(BaseProxy):
_exposed_ = ('__getattribute__', 'wait', 'abort', 'reset')
def wait(self, timeout=None):
return self._callmethod('wait', (timeout,))
def abort(self):
return self._callmethod('abort')
def reset(self):
return self._callmethod('reset')
@property
def parties(self):
return self._callmethod('__getattribute__', ('parties',))
@property
def n_waiting(self):
return self._callmethod('__getattribute__', ('n_waiting',))
@property
def broken(self):
return self._callmethod('__getattribute__', ('broken',))
class NamespaceProxy(BaseProxy): class NamespaceProxy(BaseProxy):
_exposed_ = ('__getattribute__', '__setattr__', '__delattr__') _exposed_ = ('__getattribute__', '__setattr__', '__delattr__')
def __getattr__(self, key): def __getattr__(self, key):
...@@ -1084,6 +1104,7 @@ SyncManager.register('Semaphore', threading.Semaphore, AcquirerProxy) ...@@ -1084,6 +1104,7 @@ SyncManager.register('Semaphore', threading.Semaphore, AcquirerProxy)
SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore, SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore,
AcquirerProxy) AcquirerProxy)
SyncManager.register('Condition', threading.Condition, ConditionProxy) SyncManager.register('Condition', threading.Condition, ConditionProxy)
SyncManager.register('Barrier', threading.Barrier, BarrierProxy)
SyncManager.register('Pool', Pool, PoolProxy) SyncManager.register('Pool', Pool, PoolProxy)
SyncManager.register('list', list, ListProxy) SyncManager.register('list', list, ListProxy)
SyncManager.register('dict', dict, DictProxy) SyncManager.register('dict', dict, DictProxy)
......
...@@ -333,3 +333,43 @@ class Event(object): ...@@ -333,3 +333,43 @@ class Event(object):
return False return False
finally: finally:
self._cond.release() self._cond.release()
#
# Barrier
#
class Barrier(threading.Barrier):
def __init__(self, parties, action=None, timeout=None):
import struct
from multiprocessing.heap import BufferWrapper
wrapper = BufferWrapper(struct.calcsize('i') * 2)
cond = Condition()
self.__setstate__((parties, action, timeout, cond, wrapper))
self._state = 0
self._count = 0
def __setstate__(self, state):
(self._parties, self._action, self._timeout,
self._cond, self._wrapper) = state
self._array = self._wrapper.create_memoryview().cast('i')
def __getstate__(self):
return (self._parties, self._action, self._timeout,
self._cond, self._wrapper)
@property
def _state(self):
return self._array[0]
@_state.setter
def _state(self, value):
self._array[0] = value
@property
def _count(self):
return self._array[1]
@_count.setter
def _count(self, value):
self._array[1] = value
...@@ -18,6 +18,7 @@ import array ...@@ -18,6 +18,7 @@ import array
import socket import socket
import random import random
import logging import logging
import struct
import test.support import test.support
...@@ -1056,6 +1057,336 @@ class _TestEvent(BaseTestCase): ...@@ -1056,6 +1057,336 @@ class _TestEvent(BaseTestCase):
p.start() p.start()
self.assertEqual(wait(), True) self.assertEqual(wait(), True)
#
# Tests for Barrier - adapted from tests in test/lock_tests.py
#
# Many of the tests for threading.Barrier use a list as an atomic
# counter: a value is appended to increment the counter, and the
# length of the list gives the value. We use the class DummyList
# for the same purpose.
class _DummyList(object):
def __init__(self):
wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i'))
lock = multiprocessing.Lock()
self.__setstate__((wrapper, lock))
self._lengthbuf[0] = 0
def __setstate__(self, state):
(self._wrapper, self._lock) = state
self._lengthbuf = self._wrapper.create_memoryview().cast('i')
def __getstate__(self):
return (self._wrapper, self._lock)
def append(self, _):
with self._lock:
self._lengthbuf[0] += 1
def __len__(self):
with self._lock:
return self._lengthbuf[0]
def _wait():
# A crude wait/yield function not relying on synchronization primitives.
time.sleep(0.01)
class Bunch(object):
"""
A bunch of threads.
"""
def __init__(self, namespace, f, args, n, wait_before_exit=False):
"""
Construct a bunch of `n` threads running the same function `f`.
If `wait_before_exit` is True, the threads won't terminate until
do_finish() is called.
"""
self.f = f
self.args = args
self.n = n
self.started = namespace.DummyList()
self.finished = namespace.DummyList()
self._can_exit = namespace.Value('i', not wait_before_exit)
for i in range(n):
namespace.Process(target=self.task).start()
def task(self):
pid = os.getpid()
self.started.append(pid)
try:
self.f(*self.args)
finally:
self.finished.append(pid)
while not self._can_exit.value:
_wait()
def wait_for_started(self):
while len(self.started) < self.n:
_wait()
def wait_for_finished(self):
while len(self.finished) < self.n:
_wait()
def do_finish(self):
self._can_exit.value = True
class AppendTrue(object):
def __init__(self, obj):
self.obj = obj
def __call__(self):
self.obj.append(True)
class _TestBarrier(BaseTestCase):
"""
Tests for Barrier objects.
"""
N = 5
defaultTimeout = 10.0 # XXX Slow Windows buildbots need generous timeout
def setUp(self):
self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout)
def tearDown(self):
self.barrier.abort()
self.barrier = None
def DummyList(self):
if self.TYPE == 'threads':
return []
elif self.TYPE == 'manager':
return self.manager.list()
else:
return _DummyList()
def run_threads(self, f, args):
b = Bunch(self, f, args, self.N-1)
f(*args)
b.wait_for_finished()
@classmethod
def multipass(cls, barrier, results, n):
m = barrier.parties
assert m == cls.N
for i in range(n):
results[0].append(True)
assert len(results[1]) == i * m
barrier.wait()
results[1].append(True)
assert len(results[0]) == (i + 1) * m
barrier.wait()
try:
assert barrier.n_waiting == 0
except NotImplementedError:
pass
assert not barrier.broken
def test_barrier(self, passes=1):
"""
Test that a barrier is passed in lockstep
"""
results = [self.DummyList(), self.DummyList()]
self.run_threads(self.multipass, (self.barrier, results, passes))
def test_barrier_10(self):
"""
Test that a barrier works for 10 consecutive runs
"""
return self.test_barrier(10)
@classmethod
def _test_wait_return_f(cls, barrier, queue):
res = barrier.wait()
queue.put(res)
def test_wait_return(self):
"""
test the return value from barrier.wait
"""
queue = self.Queue()
self.run_threads(self._test_wait_return_f, (self.barrier, queue))
results = [queue.get() for i in range(self.N)]
self.assertEqual(results.count(0), 1)
@classmethod
def _test_action_f(cls, barrier, results):
barrier.wait()
if len(results) != 1:
raise RuntimeError
def test_action(self):
"""
Test the 'action' callback
"""
results = self.DummyList()
barrier = self.Barrier(self.N, action=AppendTrue(results))
self.run_threads(self._test_action_f, (barrier, results))
self.assertEqual(len(results), 1)
@classmethod
def _test_abort_f(cls, barrier, results1, results2):
try:
i = barrier.wait()
if i == cls.N//2:
raise RuntimeError
barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
except RuntimeError:
barrier.abort()
def test_abort(self):
"""
Test that an abort will put the barrier in a broken state
"""
results1 = self.DummyList()
results2 = self.DummyList()
self.run_threads(self._test_abort_f,
(self.barrier, results1, results2))
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertTrue(self.barrier.broken)
@classmethod
def _test_reset_f(cls, barrier, results1, results2, results3):
i = barrier.wait()
if i == cls.N//2:
# Wait until the other threads are all in the barrier.
while barrier.n_waiting < cls.N-1:
time.sleep(0.001)
barrier.reset()
else:
try:
barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
# Now, pass the barrier again
barrier.wait()
results3.append(True)
def test_reset(self):
"""
Test that a 'reset' on a barrier frees the waiting threads
"""
results1 = self.DummyList()
results2 = self.DummyList()
results3 = self.DummyList()
self.run_threads(self._test_reset_f,
(self.barrier, results1, results2, results3))
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertEqual(len(results3), self.N)
@classmethod
def _test_abort_and_reset_f(cls, barrier, barrier2,
results1, results2, results3):
try:
i = barrier.wait()
if i == cls.N//2:
raise RuntimeError
barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
except RuntimeError:
barrier.abort()
# Synchronize and reset the barrier. Must synchronize first so
# that everyone has left it when we reset, and after so that no
# one enters it before the reset.
if barrier2.wait() == cls.N//2:
barrier.reset()
barrier2.wait()
barrier.wait()
results3.append(True)
def test_abort_and_reset(self):
"""
Test that a barrier can be reset after being broken.
"""
results1 = self.DummyList()
results2 = self.DummyList()
results3 = self.DummyList()
barrier2 = self.Barrier(self.N)
self.run_threads(self._test_abort_and_reset_f,
(self.barrier, barrier2, results1, results2, results3))
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertEqual(len(results3), self.N)
@classmethod
def _test_timeout_f(cls, barrier, results):
i = barrier.wait(20)
if i == cls.N//2:
# One thread is late!
time.sleep(4.0)
try:
barrier.wait(0.5)
except threading.BrokenBarrierError:
results.append(True)
def test_timeout(self):
"""
Test wait(timeout)
"""
results = self.DummyList()
self.run_threads(self._test_timeout_f, (self.barrier, results))
self.assertEqual(len(results), self.barrier.parties)
@classmethod
def _test_default_timeout_f(cls, barrier, results):
i = barrier.wait(20)
if i == cls.N//2:
# One thread is later than the default timeout
time.sleep(4.0)
try:
barrier.wait()
except threading.BrokenBarrierError:
results.append(True)
def test_default_timeout(self):
"""
Test the barrier's default timeout
"""
barrier = self.Barrier(self.N, timeout=1.0)
results = self.DummyList()
self.run_threads(self._test_default_timeout_f, (barrier, results))
self.assertEqual(len(results), barrier.parties)
def test_single_thread(self):
b = self.Barrier(1)
b.wait()
b.wait()
@classmethod
def _test_thousand_f(cls, barrier, passes, conn, lock):
for i in range(passes):
barrier.wait()
with lock:
conn.send(i)
def test_thousand(self):
if self.TYPE == 'manager':
return
passes = 1000
lock = self.Lock()
conn, child_conn = self.Pipe(False)
for j in range(self.N):
p = self.Process(target=self._test_thousand_f,
args=(self.barrier, passes, child_conn, lock))
p.start()
for i in range(passes):
for j in range(self.N):
self.assertEqual(conn.recv(), i)
# #
# #
# #
...@@ -2532,7 +2863,7 @@ class ProcessesMixin(object): ...@@ -2532,7 +2863,7 @@ class ProcessesMixin(object):
Process = multiprocessing.Process Process = multiprocessing.Process
locals().update(get_attributes(multiprocessing, ( locals().update(get_attributes(multiprocessing, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
'Condition', 'Event', 'Value', 'Array', 'RawValue', 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'RawValue',
'RawArray', 'current_process', 'active_children', 'Pipe', 'RawArray', 'current_process', 'active_children', 'Pipe',
'connection', 'JoinableQueue', 'Pool' 'connection', 'JoinableQueue', 'Pool'
))) )))
...@@ -2547,7 +2878,7 @@ class ManagerMixin(object): ...@@ -2547,7 +2878,7 @@ class ManagerMixin(object):
manager = object.__new__(multiprocessing.managers.SyncManager) manager = object.__new__(multiprocessing.managers.SyncManager)
locals().update(get_attributes(manager, ( locals().update(get_attributes(manager, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
'Condition', 'Event', 'Value', 'Array', 'list', 'dict', 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'list', 'dict',
'Namespace', 'JoinableQueue', 'Pool' 'Namespace', 'JoinableQueue', 'Pool'
))) )))
...@@ -2560,7 +2891,7 @@ class ThreadsMixin(object): ...@@ -2560,7 +2891,7 @@ class ThreadsMixin(object):
Process = multiprocessing.dummy.Process Process = multiprocessing.dummy.Process
locals().update(get_attributes(multiprocessing.dummy, ( locals().update(get_attributes(multiprocessing.dummy, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
'Condition', 'Event', 'Value', 'Array', 'current_process', 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'current_process',
'active_children', 'Pipe', 'connection', 'dict', 'list', 'active_children', 'Pipe', 'connection', 'dict', 'list',
'Namespace', 'JoinableQueue', 'Pool' 'Namespace', 'JoinableQueue', 'Pool'
))) )))
......
...@@ -21,6 +21,8 @@ Core and Builtins ...@@ -21,6 +21,8 @@ Core and Builtins
Library Library
------- -------
- Issue #14059: Implement multiprocessing.Barrier.
- Issue #15061: The inappropriately named hmac.secure_compare has been - Issue #15061: The inappropriately named hmac.secure_compare has been
renamed to hmac.compare_digest, restricted to operating on bytes inputs renamed to hmac.compare_digest, restricted to operating on bytes inputs
only and had its documentation updated to more accurately reflect both its only and had its documentation updated to more accurately reflect both its
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment