Commit d68cfb1c authored by Vincent Pelletier's avatar Vincent Pelletier

Add a method on dispatcher to know if a given queue is registered to it.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1775 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 49781cbb
...@@ -32,6 +32,7 @@ class Dispatcher: ...@@ -32,6 +32,7 @@ class Dispatcher:
def __init__(self): def __init__(self):
self.message_table = {} self.message_table = {}
self.queue_dict = {}
lock = Lock() lock = Lock()
self.lock_acquire = lock.acquire self.lock_acquire = lock.acquire
self.lock_release = lock.release self.lock_release = lock.release
...@@ -42,6 +43,7 @@ class Dispatcher: ...@@ -42,6 +43,7 @@ class Dispatcher:
queue = self.message_table.get(id(conn), EMPTY).pop(msg_id, None) queue = self.message_table.get(id(conn), EMPTY).pop(msg_id, None)
if queue is None: if queue is None:
return False return False
self.queue_dict[id(queue)] -= 1
queue.put(data) queue.put(data)
return True return True
...@@ -49,6 +51,12 @@ class Dispatcher: ...@@ -49,6 +51,12 @@ class Dispatcher:
def register(self, conn, msg_id, queue): def register(self, conn, msg_id, queue):
"""Register an expectation for a reply.""" """Register an expectation for a reply."""
self.message_table.setdefault(id(conn), {})[msg_id] = queue self.message_table.setdefault(id(conn), {})[msg_id] = queue
queue_dict = self.queue_dict
key = id(queue)
try:
queue_dict[key] += 1
except KeyError:
queue_dict[key] = 1
def unregister(self, conn): def unregister(self, conn):
""" Unregister a connection and put fake packet in queues to unlock """ Unregister a connection and put fake packet in queues to unlock
...@@ -59,13 +67,19 @@ class Dispatcher: ...@@ -59,13 +67,19 @@ class Dispatcher:
finally: finally:
self.lock_release() self.lock_release()
notified_set = set() notified_set = set()
queue_dict = self.queue_dict
for queue in message_table.itervalues(): for queue in message_table.itervalues():
queue_id = id(queue) queue_id = id(queue)
if queue_id not in notified_set: if queue_id not in notified_set:
queue.put((conn, None)) queue.put((conn, None))
notified_set.add(queue_id) notified_set.add(queue_id)
queue_dict[queue_id] -= 1
def registered(self, conn): def registered(self, conn):
"""Check if a connection is registered into message table.""" """Check if a connection is registered into message table."""
return len(self.message_table.get(id(conn), EMPTY)) != 0 return len(self.message_table.get(id(conn), EMPTY)) != 0
@giant_lock
def pending(self, queue):
return not queue.empty() or self.queue_dict[id(queue)] > 0
...@@ -64,6 +64,49 @@ class DispatcherTests(unittest.TestCase): ...@@ -64,6 +64,49 @@ class DispatcherTests(unittest.TestCase):
self.assertFalse(self.dispatcher.registered(conn1)) self.assertFalse(self.dispatcher.registered(conn1))
self.assertFalse(self.dispatcher.registered(conn2)) self.assertFalse(self.dispatcher.registered(conn2))
def testPending(self):
conn1 = object()
conn2 = object()
class Queue(object):
_empty = True
def empty(self):
return self._empty
def put(self, value):
pass
queue1 = Queue()
queue2 = Queue()
self.dispatcher.register(conn1, 1, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.dispatcher.register(conn2, 2, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.dispatcher.register(conn2, 3, queue2)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn1, 1, None)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn2, 2, None)
self.assertFalse(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
queue1._empty = False
self.assertTrue(self.dispatcher.pending(queue1))
queue1._empty = True
self.dispatcher.register(conn1, 4, queue1)
self.dispatcher.register(conn2, 5, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.unregister(conn2)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertFalse(self.dispatcher.pending(queue2))
self.dispatcher.unregister(conn1)
self.assertFalse(self.dispatcher.pending(queue1))
self.assertFalse(self.dispatcher.pending(queue2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment