Commit 4cb814c7 authored by Victor Stinner's avatar Victor Stinner

asyncio, Tulip issue 220: Merge JoinableQueue with Queue.

Merge JoinableQueue with Queue. To more closely match the standard Queue,
asyncio.Queue has "join" and "task_done". JoinableQueue is deleted.

Docstring for Queue.join shouldn't mention threads.

Restore JoinableQueue as a deprecated alias for Queue. To more closely match
the standard Queue, asyncio.Queue has "join" and "task_done".  JoinableQueue
remains as a deprecated alias for Queue to avoid needlessly breaking too much
code that depended on it.

Patch written by A. Jesse Jiryu Davis <jesse@mongodb.com>.
parent 4e82fb99
"""Queues"""
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue',
'QueueFull', 'QueueEmpty']
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty',
'JoinableQueue']
import collections
import heapq
......@@ -49,6 +49,9 @@ class Queue:
self._getters = collections.deque()
# Pairs of (item, Future).
self._putters = collections.deque()
self._unfinished_tasks = 0
self._finished = locks.Event(loop=self._loop)
self._finished.set()
self._init(maxsize)
def _init(self, maxsize):
......@@ -59,6 +62,8 @@ class Queue:
def _put(self, item):
self._queue.append(item)
self._unfinished_tasks += 1
self._finished.clear()
def __repr__(self):
return '<{} at {:#x} {}>'.format(
......@@ -75,6 +80,8 @@ class Queue:
result += ' _getters[{}]'.format(len(self._getters))
if self._putters:
result += ' _putters[{}]'.format(len(self._putters))
if self._unfinished_tasks:
result += ' tasks={}'.format(self._unfinished_tasks)
return result
def _consume_done_getters(self):
......@@ -126,9 +133,6 @@ class Queue:
'queue non-empty, why are getters waiting?')
getter = self._getters.popleft()
# Use _put and _get instead of passing item straight to getter, in
# case a subclass has logic that must run (e.g. JoinableQueue).
self._put(item)
# getter cannot be cancelled, we just removed done getters
......@@ -154,9 +158,6 @@ class Queue:
'queue non-empty, why are getters waiting?')
getter = self._getters.popleft()
# Use _put and _get instead of passing item straight to getter, in
# case a subclass has logic that must run (e.g. JoinableQueue).
self._put(item)
# getter cannot be cancelled, we just removed done getters
......@@ -219,6 +220,38 @@ class Queue:
else:
raise QueueEmpty
def task_done(self):
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each get() used to fetch a task,
a subsequent call to task_done() tells the queue that the processing
on the task is complete.
If a join() is currently blocking, it will resume when all items have
been processed (meaning that a task_done() call was received for every
item that had been put() into the queue).
Raises ValueError if called more times than there were items placed in
the queue.
"""
if self._unfinished_tasks <= 0:
raise ValueError('task_done() called too many times')
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
@coroutine
def join(self):
"""Block until all items in the queue have been gotten and processed.
The count of unfinished tasks goes up whenever an item is added to the
queue. The count goes down whenever a consumer calls task_done() to
indicate that the item was retrieved and all work on it is complete.
When the count of unfinished tasks drops to zero, join() unblocks.
"""
if self._unfinished_tasks > 0:
yield from self._finished.wait()
class PriorityQueue(Queue):
"""A subclass of Queue; retrieves entries in priority order (lowest first).
......@@ -249,54 +282,5 @@ class LifoQueue(Queue):
return self._queue.pop()
class JoinableQueue(Queue):
"""A subclass of Queue with task_done() and join() methods."""
def __init__(self, maxsize=0, *, loop=None):
super().__init__(maxsize=maxsize, loop=loop)
self._unfinished_tasks = 0
self._finished = locks.Event(loop=self._loop)
self._finished.set()
def _format(self):
result = Queue._format(self)
if self._unfinished_tasks:
result += ' tasks={}'.format(self._unfinished_tasks)
return result
def _put(self, item):
super()._put(item)
self._unfinished_tasks += 1
self._finished.clear()
def task_done(self):
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each get() used to fetch a task,
a subsequent call to task_done() tells the queue that the processing
on the task is complete.
If a join() is currently blocking, it will resume when all items have
been processed (meaning that a task_done() call was received for every
item that had been put() into the queue).
Raises ValueError if called more times than there were items placed in
the queue.
"""
if self._unfinished_tasks <= 0:
raise ValueError('task_done() called too many times')
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
@coroutine
def join(self):
"""Block until all items in the queue have been gotten and processed.
The count of unfinished tasks goes up whenever an item is added to the
queue. The count goes down whenever a consumer thread calls task_done()
to indicate that the item was retrieved and all work on it is complete.
When the count of unfinished tasks drops to zero, join() unblocks.
"""
if self._unfinished_tasks > 0:
yield from self._finished.wait()
JoinableQueue = Queue
"""Deprecated alias for Queue."""
......@@ -408,14 +408,14 @@ class PriorityQueueTests(_QueueTestBase):
self.assertEqual([1, 2, 3], items)
class JoinableQueueTests(_QueueTestBase):
class QueueJoinTests(_QueueTestBase):
def test_task_done_underflow(self):
q = asyncio.JoinableQueue(loop=self.loop)
q = asyncio.Queue(loop=self.loop)
self.assertRaises(ValueError, q.task_done)
def test_task_done(self):
q = asyncio.JoinableQueue(loop=self.loop)
q = asyncio.Queue(loop=self.loop)
for i in range(100):
q.put_nowait(i)
......@@ -452,7 +452,7 @@ class JoinableQueueTests(_QueueTestBase):
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
def test_join_empty_queue(self):
q = asyncio.JoinableQueue(loop=self.loop)
q = asyncio.Queue(loop=self.loop)
# Test that a queue join()s successfully, and before anything else
# (done twice for insurance).
......@@ -465,7 +465,7 @@ class JoinableQueueTests(_QueueTestBase):
self.loop.run_until_complete(join())
def test_format(self):
q = asyncio.JoinableQueue(loop=self.loop)
q = asyncio.Queue(loop=self.loop)
self.assertEqual(q._format(), 'maxsize=0')
q._unfinished_tasks = 2
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment