Commit 3e1ed6a4 authored by Julien Muchembled's avatar Julien Muchembled

Simplify polling thread in threaded apps

It's been a long time that the polling thread never ends and don't need to be
restarted. On the other side, there will be a need for the admin to define a
different polling loop, hence the move from threaded_poll to threaded_app.
parent f5f42522
......@@ -215,6 +215,7 @@ class Application(ThreadedApplication):
Lookup for the current primary master node
"""
logging.debug('connecting to primary master...')
self.start()
index = -1
ask = self._ask
handler = self.primary_bootstrap_handler
......
......@@ -764,7 +764,6 @@ class MTClientConnection(ClientConnection):
def __init__(self, *args, **kwargs):
self.lock = lock = RLock()
self.dispatcher = kwargs.pop('dispatcher')
self.dispatcher.needPollThread()
with lock:
super(MTClientConnection, self).__init__(*args, **kwargs)
......
......@@ -43,13 +43,12 @@ def giant_lock(func):
class Dispatcher:
"""Register a packet, connection pair as expecting a response packet."""
def __init__(self, poll_thread=None):
def __init__(self):
self.message_table = {}
self.queue_dict = {}
lock = Lock()
self.lock_acquire = lock.acquire
self.lock_release = lock.release
self.poll_thread = poll_thread
@giant_lock
def dispatch(self, conn, msg_id, packet, kw):
......@@ -81,14 +80,9 @@ class Dispatcher:
except KeyError:
queue_dict[queue_id] = 1
def needPollThread(self):
self.poll_thread.start()
@giant_lock
def register(self, conn, msg_id, queue):
"""Register an expectation for a reply."""
if self.poll_thread is not None:
self.needPollThread()
self.message_table.setdefault(id(conn), {})[msg_id] = queue
self._increfQueue(queue)
......
......@@ -23,7 +23,6 @@ from .event import EventManager
from .locking import SimpleQueue
from .node import NodeManager
from .protocol import Packets
from .threaded_poll import ThreadedPoll, psThreadedPoll
class app_set(weakref.WeakSet):
......@@ -48,11 +47,11 @@ class ThreadedApplication(object):
def __init__(self, master_nodes, name, dynamic_master_list=None):
# Start polling thread
self.em = EventManager()
self.poll_thread = ThreadedPoll(self.em, name=name)
psThreadedPoll()
self.poll_thread = threading.Thread(target=self.run, name=name)
self.poll_thread.daemon = True
# Internal Attributes common to all thread
self.name = name
self.dispatcher = Dispatcher(self.poll_thread)
self.dispatcher = Dispatcher()
self.nm = NodeManager(dynamic_master_list)
self.master_conn = None
......@@ -78,8 +77,27 @@ class ThreadedApplication(object):
conn.close()
# Stop polling thread
logging.debug('Stopping %s', self.poll_thread)
self.poll_thread.stop()
psThreadedPoll()
self.em.wakeup(True)
def start(self):
self.poll_thread.is_alive() or self.poll_thread.start()
def run(self):
logging.debug("Started %s", self.poll_thread)
try:
self._run()
finally:
logging.debug("Poll thread stopped")
def _run(self):
poll = self.em.poll
while 1:
try:
while 1:
poll(1)
except Exception:
self.log()
logging.error("poll raised, retrying", exc_info=1)
def getHandlerData(self):
return self._thread_container.answer
......
#
# Copyright (C) 2006-2015 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from threading import Thread, enumerate as thread_enum
from . import logging
from .locking import Lock
class _ThreadedPoll(Thread):
"""Polling thread."""
stopping = False
def __init__(self, em, **kw):
Thread.__init__(self, **kw)
self.em = em
self.daemon = True
def run(self):
logging.debug('Started %s', self)
try:
while 1:
try:
self.em.poll(1)
except Exception:
logging.error('poll raised, retrying', exc_info=1)
finally:
logging.debug('Threaded poll stopped')
def stop(self):
self.stopping = True
self.em.wakeup(True)
class ThreadedPoll(object):
"""
Wrapper for polloing thread, just to be able to start it again when
it stopped.
"""
_thread = None
_started = False
def __init__(self, *args, **kw):
lock = Lock()
self._status_lock_acquire = lock.acquire
self._status_lock_release = lock.release
self._args = args
self._kw = kw
self.newThread()
def newThread(self):
self._thread = _ThreadedPoll(*self._args, **self._kw)
def start(self):
"""
Start thread if not started or restart it if it's shutting down.
"""
# TODO: a refcount-based approach would be better, but more intrusive.
self._status_lock_acquire()
try:
thread = self._thread
if thread.stopping:
# XXX: ideally, we should wake thread up here, to be sure not
# to wait forever.
thread.join()
if not thread.is_alive():
if self._started:
self.newThread()
else:
self._started = True
self._thread.start()
finally:
self._status_lock_release()
def stop(self):
self._status_lock_acquire()
try:
self._thread.stop()
finally:
self._status_lock_release()
def __getattr__(self, key):
return getattr(self._thread, key)
def __repr__(self):
return repr(self._thread)
def psThreadedPoll(log=None):
"""
Logs alive ThreadedPoll threads.
"""
if log is None:
log = logging.debug
for thread in thread_enum():
if not isinstance(thread, ThreadedPoll):
continue
log('Thread %s at 0x%x, %s', thread.getName(), id(thread),
thread._stop.isSet() and 'stopping' or 'running')
......@@ -797,6 +797,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# faked environnement
app.em = Mock({'getConnectionList': []})
app.pt = Mock({ 'operational': False})
app.start = lambda: None
app.master_conn = app._connectToPrimaryNode()
self.assertEqual(len(all_passed), 1)
self.assertTrue(app.master_conn is not None)
......
......@@ -24,8 +24,7 @@ class DispatcherTests(NeoTestBase):
def setUp(self):
NeoTestBase.setUp(self)
self.fake_thread = Mock({'stopping': True})
self.dispatcher = Dispatcher(self.fake_thread)
self.dispatcher = Dispatcher()
def testRegister(self):
conn = object()
......@@ -38,7 +37,6 @@ class DispatcherTests(NeoTestBase):
self.assertEqual(queue.get(block=False), (conn, MARKER, {}))
self.assertTrue(queue.empty())
self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
self.assertEqual(len(self.fake_thread.mockGetNamedCalls('start')), 1)
def testUnregister(self):
conn = object()
......
......@@ -29,7 +29,6 @@ import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app
from neo.client import Storage
from neo.lib.threaded_poll import _ThreadedPoll
from neo.lib import logging
from neo.lib.connection import BaseConnection, Connection
from neo.lib.connector import SocketConnector, \
......@@ -167,6 +166,9 @@ class Serialized(object):
def __init__(self, app, busy=True):
self._epoll = app.em.epoll
app.em.epoll = self
# XXX: It may have been initialized before the SimpleQueue is patched.
thread_container = getattr(app, '_thread_container', None)
thread_container is None or thread_container.__init__()
if busy:
self._busy.add(self) # block tic until app waits for polling
......@@ -370,6 +372,17 @@ class ClientApplication(Node, neo.client.app.Application):
def __init__(self, master_nodes, name, **kw):
super(ClientApplication, self).__init__(master_nodes, name, **kw)
self.poll_thread.node_name = name
def run(self):
try:
super(ClientApplication, self).run()
finally:
self.em.epoll.exit()
def start(self):
isinstance(self.em.epoll, Serialized) or Serialized(self)
super(ClientApplication, self).start()
def getConnectionList(self, *peers):
for peer in peers:
......@@ -399,9 +412,8 @@ class LoggerThreadName(str):
return id(self)
def __str__(self):
t = threading.currentThread()
try:
return t.name if isinstance(t, _ThreadedPoll) else t.node_name
return threading.currentThread().node_name
except AttributeError:
return str.__str__(self)
......@@ -497,8 +509,6 @@ class NEOCluster(object):
SimpleQueue__init__ = staticmethod(SimpleQueue.__init__)
SocketConnector_bind = staticmethod(SocketConnector._bind)
SocketConnector_connect = staticmethod(SocketConnector._connect)
_ThreadedPoll_run = staticmethod(_ThreadedPoll.run)
_ThreadedPoll_start = staticmethod(_ThreadedPoll.start)
_patch_count = 0
_resource_dict = weakref.WeakValueDictionary()
......@@ -525,14 +535,6 @@ class NEOCluster(object):
return True
return lock(False)
self._lock = _lock
def start(self):
Serialized(self)
cls._ThreadedPoll_start(self)
def run(self):
try:
cls._ThreadedPoll_run(self)
finally:
self.em.epoll.exit()
BaseConnection.getTimeout = lambda self: None
SimpleQueue.__init__ = __init__
SocketConnector.CONNECT_LIMIT = 0
......@@ -540,8 +542,6 @@ class NEOCluster(object):
cls.SocketConnector_bind(self, BIND)
SocketConnector._connect = lambda self, addr: \
cls.SocketConnector_connect(self, ServerNode.resolv(addr))
_ThreadedPoll.run = run
_ThreadedPoll.start = start
Serialized.init()
@staticmethod
......@@ -556,8 +556,6 @@ class NEOCluster(object):
SocketConnector.CONNECT_LIMIT = cls.CONNECT_LIMIT
SocketConnector._bind = cls.SocketConnector_bind
SocketConnector._connect = cls.SocketConnector_connect
_ThreadedPoll.run = cls._ThreadedPoll_run
_ThreadedPoll.start = cls._ThreadedPoll_start
Serialized.stop()
def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
......@@ -650,7 +648,6 @@ class NEOCluster(object):
def start(self, storage_list=None, fast_startup=False):
self._patch()
self.client._thread_container.__init__()
for node_type in 'master', 'admin':
for node in getattr(self, node_type + '_list'):
node.start()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment