Commit 849ccebe authored by David Wilson's avatar David Wilson

receiver: only permit one notify callback

There is no point spamming a list for every function call, there is no
use case where multiple notify callbacks would be useful.
parent 265d9f02
......@@ -283,9 +283,10 @@ def _queue_interruptible_get(queue, timeout=None, block=True):
class Receiver(object):
notify = None
def __init__(self, router, handle=None, persist=True, respondent=None):
self.router = router
self.notify = []
self.handle = handle # Avoid __repr__ crash in add_handler()
self.handle = router.add_handler(self._on_receive, handle,
persist, respondent)
......@@ -298,8 +299,8 @@ class Receiver(object):
"""Callback from the Stream; appends data to the internal queue."""
IOLOG.debug('%r._on_receive(%r)', self, msg)
self._queue.put(msg)
for func in self.notify:
func(self)
if self.notify:
self.notify(self)
def close(self):
self._queue.put(_DEAD)
......
......@@ -237,18 +237,19 @@ class SelectError(mitogen.core.Error):
class Select(object):
notify = None
def __init__(self, receivers=(), oneshot=True):
self._receivers = []
self._oneshot = oneshot
self._queue = Queue.Queue()
self.notify = []
for recv in receivers:
self.add(recv)
def _put(self, value):
self._queue.put(value)
for func in self.notify:
func(self)
if self.notify:
self.notify(self)
def __bool__(self):
return bool(self._receivers)
......@@ -276,12 +277,17 @@ class Select(object):
if isinstance(recv_, Select):
recv_._check_no_loop(recv)
owned_msg = 'Cannot add: Receiver is already owned by another Select'
def add(self, recv):
if isinstance(recv, Select):
recv._check_no_loop(self)
self._receivers.append(recv)
recv.notify.append(self._put)
if recv.notify is not None:
raise SelectError(self.owned_msg)
recv.notify = self._put
# Avoid race by polling once after installation.
if not recv.empty():
self._put(recv)
......@@ -290,8 +296,10 @@ class Select(object):
def remove(self, recv):
try:
if recv.notify != self._put:
raise ValueError
self._receivers.remove(recv)
recv.notify.remove(self._put)
recv.notify = None
except (IndexError, ValueError):
raise SelectError(self.not_present_msg)
......
......@@ -14,8 +14,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(recv)
self.assertEquals(1, len(select._receivers))
self.assertEquals(recv, select._receivers[0])
self.assertEquals(1, len(recv.notify))
self.assertEquals(select._put, recv.notify[0])
self.assertEquals(select._put, recv.notify)
def test_channel(self):
context = self.router.local()
......@@ -24,8 +23,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(chan)
self.assertEquals(1, len(select._receivers))
self.assertEquals(chan, select._receivers[0])
self.assertEquals(1, len(chan.notify))
self.assertEquals(select._put, chan.notify[0])
self.assertEquals(select._put, chan.notify)
def test_subselect_empty(self):
select = self.klass()
......@@ -33,8 +31,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(subselect)
self.assertEquals(1, len(select._receivers))
self.assertEquals(subselect, select._receivers[0])
self.assertEquals(1, len(subselect.notify))
self.assertEquals(select._put, subselect.notify[0])
self.assertEquals(select._put, subselect.notify)
def test_subselect_nonempty(self):
recv = mitogen.core.Receiver(self.router)
......@@ -45,8 +42,7 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
select.add(subselect)
self.assertEquals(1, len(select._receivers))
self.assertEquals(subselect, select._receivers[0])
self.assertEquals(1, len(subselect.notify))
self.assertEquals(select._put, subselect.notify[0])
self.assertEquals(select._put, subselect.notify)
def test_subselect_loop_direct(self):
select = self.klass()
......@@ -65,6 +61,22 @@ class AddTest(testlib.RouterMixin, testlib.TestCase):
lambda: s2.add(s0))
self.assertEquals(str(exc), self.klass.loop_msg)
def test_double_add_receiver(self):
select = self.klass()
recv = mitogen.core.Receiver(self.router)
select.add(recv)
exc = self.assertRaises(mitogen.master.SelectError,
lambda: select.add(recv))
self.assertEquals(str(exc), self.klass.owned_msg)
def test_double_add_subselect(self):
select = self.klass()
select2 = self.klass()
select.add(select2)
exc = self.assertRaises(mitogen.master.SelectError,
lambda: select.add(select2))
self.assertEquals(str(exc), self.klass.owned_msg)
class RemoveTest(testlib.RouterMixin, testlib.TestCase):
klass = mitogen.master.Select
......@@ -91,7 +103,7 @@ class RemoveTest(testlib.RouterMixin, testlib.TestCase):
select.add(recv)
select.remove(recv)
self.assertEquals(0, len(select._receivers))
self.assertEquals(0, len(recv.notify))
self.assertEquals(None, recv.notify)
class CloseTest(testlib.RouterMixin, testlib.TestCase):
......@@ -107,12 +119,11 @@ class CloseTest(testlib.RouterMixin, testlib.TestCase):
select.add(recv)
self.assertEquals(1, len(select._receivers))
self.assertEquals(1, len(recv.notify))
self.assertEquals(select._put, recv.notify[0])
self.assertEquals(select._put, recv.notify)
select.close()
self.assertEquals(0, len(select._receivers))
self.assertEquals(0, len(recv.notify))
self.assertEquals(None, recv.notify)
def test_one_subselect(self):
select = self.klass()
......@@ -123,16 +134,15 @@ class CloseTest(testlib.RouterMixin, testlib.TestCase):
subselect.add(recv)
self.assertEquals(1, len(select._receivers))
self.assertEquals(1, len(recv.notify))
self.assertEquals(subselect._put, recv.notify[0])
self.assertEquals(subselect._put, recv.notify)
select.close()
self.assertEquals(0, len(select._receivers))
self.assertEquals(1, len(recv.notify))
self.assertEquals(subselect._put, recv.notify)
subselect.close()
self.assertEquals(0, len(recv.notify))
self.assertEquals(None, recv.notify)
class EmptyTest(testlib.RouterMixin, testlib.TestCase):
......@@ -186,7 +196,7 @@ class OneShotTest(testlib.RouterMixin, testlib.TestCase):
recv, (msg_, data) = select.get()
self.assertEquals(msg, msg_)
self.assertEquals(0, len(select._receivers))
self.assertEquals(0, len(recv.notify))
self.assertEquals(None, recv.notify)
def test_false_persists_after_get(self):
recv = mitogen.core.Receiver(self.router)
......@@ -196,8 +206,7 @@ class OneShotTest(testlib.RouterMixin, testlib.TestCase):
self.assertEquals((recv, (msg, '123')), select.get())
self.assertEquals(1, len(select._receivers))
self.assertEquals(recv, select._receivers[0])
self.assertEquals(1, len(recv.notify))
self.assertEquals(select._put, recv.notify[0])
self.assertEquals(select._put, recv.notify)
class GetTest(testlib.RouterMixin, testlib.TestCase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment