Commit 48350412 authored by Antoine Pitrou's avatar Antoine Pitrou Committed by GitHub

bpo-29293: multiprocessing.Condition.notify() lacks parameter `n` (#2480)

* bpo-29293: multiprocessing.Condition.notify() lacks parameter `n`

* Add NEWS blurb
parent d3ed2877
...@@ -999,8 +999,8 @@ class ConditionProxy(AcquirerProxy): ...@@ -999,8 +999,8 @@ class ConditionProxy(AcquirerProxy):
_exposed_ = ('acquire', 'release', 'wait', 'notify', 'notify_all') _exposed_ = ('acquire', 'release', 'wait', 'notify', 'notify_all')
def wait(self, timeout=None): def wait(self, timeout=None):
return self._callmethod('wait', (timeout,)) return self._callmethod('wait', (timeout,))
def notify(self): def notify(self, n=1):
return self._callmethod('notify') return self._callmethod('notify', (n,))
def notify_all(self): def notify_all(self):
return self._callmethod('notify_all') return self._callmethod('notify_all')
def wait_for(self, predicate, timeout=None): def wait_for(self, predicate, timeout=None):
......
...@@ -268,24 +268,7 @@ class Condition(object): ...@@ -268,24 +268,7 @@ class Condition(object):
for i in range(count): for i in range(count):
self._lock.acquire() self._lock.acquire()
def notify(self): def notify(self, n=1):
assert self._lock._semlock._is_mine(), 'lock is not owned'
assert not self._wait_semaphore.acquire(False)
# to take account of timeouts since last notify() we subtract
# woken_count from sleeping_count and rezero woken_count
while self._woken_count.acquire(False):
res = self._sleeping_count.acquire(False)
assert res
if self._sleeping_count.acquire(False): # try grabbing a sleeper
self._wait_semaphore.release() # wake up one sleeper
self._woken_count.acquire() # wait for the sleeper to wake
# rezero _wait_semaphore in case a timeout just happened
self._wait_semaphore.acquire(False)
def notify_all(self):
assert self._lock._semlock._is_mine(), 'lock is not owned' assert self._lock._semlock._is_mine(), 'lock is not owned'
assert not self._wait_semaphore.acquire(False) assert not self._wait_semaphore.acquire(False)
...@@ -296,7 +279,7 @@ class Condition(object): ...@@ -296,7 +279,7 @@ class Condition(object):
assert res assert res
sleepers = 0 sleepers = 0
while self._sleeping_count.acquire(False): while sleepers < n and self._sleeping_count.acquire(False):
self._wait_semaphore.release() # wake up one sleeper self._wait_semaphore.release() # wake up one sleeper
sleepers += 1 sleepers += 1
...@@ -308,6 +291,9 @@ class Condition(object): ...@@ -308,6 +291,9 @@ class Condition(object):
while self._wait_semaphore.acquire(False): while self._wait_semaphore.acquire(False):
pass pass
def notify_all(self):
self.notify(n=sys.maxsize)
def wait_for(self, predicate, timeout=None): def wait_for(self, predicate, timeout=None):
result = predicate() result = predicate()
if result: if result:
......
...@@ -948,6 +948,17 @@ class _TestCondition(BaseTestCase): ...@@ -948,6 +948,17 @@ class _TestCondition(BaseTestCase):
woken.release() woken.release()
cond.release() cond.release()
def assertReachesEventually(self, func, value):
for i in range(10):
try:
if func() == value:
break
except NotImplementedError:
break
time.sleep(DELTA)
time.sleep(DELTA)
self.assertReturnsIfImplemented(value, func)
def check_invariant(self, cond): def check_invariant(self, cond):
# this is only supposed to succeed when there are no sleepers # this is only supposed to succeed when there are no sleepers
if self.TYPE == 'processes': if self.TYPE == 'processes':
...@@ -1055,13 +1066,54 @@ class _TestCondition(BaseTestCase): ...@@ -1055,13 +1066,54 @@ class _TestCondition(BaseTestCase):
cond.release() cond.release()
# check they have all woken # check they have all woken
for i in range(10): self.assertReachesEventually(lambda: get_value(woken), 6)
try:
if get_value(woken) == 6: # check state is not mucked up
break self.check_invariant(cond)
except NotImplementedError:
break def test_notify_n(self):
time.sleep(DELTA) cond = self.Condition()
sleeping = self.Semaphore(0)
woken = self.Semaphore(0)
# start some threads/processes
for i in range(3):
p = self.Process(target=self.f, args=(cond, sleeping, woken))
p.daemon = True
p.start()
t = threading.Thread(target=self.f, args=(cond, sleeping, woken))
t.daemon = True
t.start()
# wait for them to all sleep
for i in range(6):
sleeping.acquire()
# check no process/thread has woken up
time.sleep(DELTA)
self.assertReturnsIfImplemented(0, get_value, woken)
# wake some of them up
cond.acquire()
cond.notify(n=2)
cond.release()
# check 2 have woken
self.assertReachesEventually(lambda: get_value(woken), 2)
# wake the rest of them
cond.acquire()
cond.notify(n=4)
cond.release()
self.assertReachesEventually(lambda: get_value(woken), 6)
# doesn't do anything more
cond.acquire()
cond.notify(n=3)
cond.release()
self.assertReturnsIfImplemented(6, get_value, woken) self.assertReturnsIfImplemented(6, get_value, woken)
# check state is not mucked up # check state is not mucked up
......
Add missing parameter "n" on multiprocessing.Condition.notify().
The doc claims multiprocessing.Condition behaves like threading.Condition,
but its notify() method lacked the optional "n" argument (to specify the
number of sleepers to wake up) that threading.Condition.notify() accepts.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment