Commit d455a507 authored by Guido van Rossum's avatar Guido van Rossum

Also rewrote the guts of asyncio.Semaphore (patch by manipopopo).

parent d94c1b92
......@@ -411,6 +411,13 @@ class Semaphore(_ContextManagerMixin):
extra = '{},waiters:{}'.format(extra, len(self._waiters))
return '<{} [{}]>'.format(res[1:-1], extra)
def _wake_up_next(self):
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.done():
waiter.set_result(None)
return
def locked(self):
"""Returns True if semaphore can not be acquired immediately."""
return self._value == 0
......@@ -425,18 +432,19 @@ class Semaphore(_ContextManagerMixin):
called release() to make it larger than 0, and then return
True.
"""
if not self._waiters and self._value > 0:
self._value -= 1
return True
fut = futures.Future(loop=self._loop)
self._waiters.append(fut)
try:
yield from fut
self._value -= 1
return True
finally:
self._waiters.remove(fut)
while self._value <= 0:
fut = futures.Future(loop=self._loop)
self._waiters.append(fut)
try:
yield from fut
except:
# See the similar code in Queue.get.
fut.cancel()
if self._value > 0 and not fut.cancelled():
self._wake_up_next()
raise
self._value -= 1
return True
def release(self):
"""Release a semaphore, incrementing the internal counter by one.
......@@ -444,10 +452,7 @@ class Semaphore(_ContextManagerMixin):
become larger than zero again, wake up that coroutine.
"""
self._value += 1
for waiter in self._waiters:
if not waiter.done():
waiter.set_result(True)
break
self._wake_up_next()
class BoundedSemaphore(Semaphore):
......
......@@ -7,7 +7,6 @@ import re
import asyncio
from asyncio import test_utils
STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
r'\[(?P<extras>'
......@@ -783,22 +782,20 @@ class SemaphoreTests(test_utils.TestCase):
test_utils.run_briefly(self.loop)
self.assertEqual(0, sem._value)
self.assertEqual([1, 2, 3], result)
self.assertEqual(3, len(result))
self.assertTrue(sem.locked())
self.assertEqual(1, len(sem._waiters))
self.assertEqual(0, sem._value)
self.assertTrue(t1.done())
self.assertTrue(t1.result())
self.assertTrue(t2.done())
self.assertTrue(t2.result())
self.assertTrue(t3.done())
self.assertTrue(t3.result())
self.assertFalse(t4.done())
race_tasks = [t2, t3, t4]
done_tasks = [t for t in race_tasks if t.done() and t.result()]
self.assertTrue(2, len(done_tasks))
# cleanup locked semaphore
sem.release()
self.loop.run_until_complete(t4)
self.loop.run_until_complete(asyncio.gather(*race_tasks))
def test_acquire_cancel(self):
sem = asyncio.Semaphore(loop=self.loop)
......@@ -809,7 +806,44 @@ class SemaphoreTests(test_utils.TestCase):
self.assertRaises(
asyncio.CancelledError,
self.loop.run_until_complete, acquire)
self.assertFalse(sem._waiters)
self.assertTrue((not sem._waiters) or
all(waiter.done() for waiter in sem._waiters))
def test_acquire_cancel_before_awoken(self):
sem = asyncio.Semaphore(value=0, loop=self.loop)
t1 = asyncio.Task(sem.acquire(), loop=self.loop)
t2 = asyncio.Task(sem.acquire(), loop=self.loop)
t3 = asyncio.Task(sem.acquire(), loop=self.loop)
t4 = asyncio.Task(sem.acquire(), loop=self.loop)
test_utils.run_briefly(self.loop)
sem.release()
t1.cancel()
t2.cancel()
test_utils.run_briefly(self.loop)
num_done = sum(t.done() for t in [t3, t4])
self.assertEqual(num_done, 1)
t3.cancel()
t4.cancel()
test_utils.run_briefly(self.loop)
def test_acquire_hang(self):
sem = asyncio.Semaphore(value=0, loop=self.loop)
t1 = asyncio.Task(sem.acquire(), loop=self.loop)
t2 = asyncio.Task(sem.acquire(), loop=self.loop)
test_utils.run_briefly(self.loop)
sem.release()
t1.cancel()
test_utils.run_briefly(self.loop)
self.assertTrue(sem.locked())
def test_release_not_acquired(self):
sem = asyncio.BoundedSemaphore(loop=self.loop)
......
......@@ -81,7 +81,8 @@ Library
- Issue #25034: Fix string.Formatter problem with auto-numbering and
nested format_specs. Patch by Anthon van der Neut.
- Issue #25233: Rewrite the guts of asyncio.Queue to be more understandable and correct.
- Issue #25233: Rewrite the guts of asyncio.Queue and
asyncio.Semaphore to be more understandable and correct.
- Issue #23600: Default implementation of tzinfo.fromutc() was returning
wrong results in some cases.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment