Commit 7fac1696 authored by Julien Muchembled's avatar Julien Muchembled

tests: some cleanup in threaded.__init__

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2787 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent e0aa8ef3
......@@ -20,7 +20,6 @@ import os, random, socket, sys, tempfile, threading, time, types
from collections import deque
from functools import wraps
from Queue import Queue, Empty
from weakref import ref as weak_ref
from mock import Mock
import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app
......@@ -48,66 +47,68 @@ def getVirtualIp(server_type):
class Serialized(object):
_global_lock = threading.Lock()
_global_lock.acquire()
@classmethod
def init(cls):
cls._global_lock = threading.Lock()
cls._global_lock.acquire()
# TODO: use something else than Queue, for inspection or editing
# (e.g. we'd like to suspend nodes temporarily)
_lock_list = Queue()
_pdb = False
pending = 0
cls._lock_list = Queue()
cls._pdb = False
cls.pending = 0
@staticmethod
def release(lock=None, wake_other=True, stop=False):
@classmethod
def release(cls, lock=None, wake_other=True, stop=False):
"""Suspend lock owner and resume first suspended thread"""
if lock is None:
lock = Serialized._global_lock
lock = cls._global_lock
if stop: # XXX: we should fix ClusterStates.STOPPING
Serialized.pending = None
cls.pending = None
else:
Serialized.pending = 0
cls.pending = 0
try:
sys._getframe(1).f_trace.im_self.set_continue()
Serialized._pdb = True
cls._pdb = True
except AttributeError:
pass
q = Serialized._lock_list
q = cls._lock_list
q.put(lock)
if wake_other:
q.get().release()
@staticmethod
def acquire(lock=None):
@classmethod
def acquire(cls, lock=None):
"""Suspend all threads except lock owner"""
if lock is None:
lock = Serialized._global_lock
lock = cls._global_lock
lock.acquire()
if Serialized.pending is None: # XXX
if lock is Serialized._global_lock:
Serialized.pending = 0
if cls.pending is None: # XXX
if lock is cls._global_lock:
cls.pending = 0
else:
sys.exit()
if Serialized._pdb:
Serialized._pdb = False
if cls._pdb:
cls._pdb = False
try:
sys.stdout.write(threading.currentThread().node_name)
except AttributeError:
pass
pdb(1)
@staticmethod
def tic(lock=None):
@classmethod
def tic(cls, lock=None):
# switch to another thread
# (the following calls are not supposed to be debugged into)
Serialized.release(lock); Serialized.acquire(lock)
cls.release(lock); cls.acquire(lock)
@staticmethod
def background():
@classmethod
def background(cls):
try:
Serialized._lock_list.get(0).release()
cls._lock_list.get(0).release()
except Empty:
pass
class SerializedEventManager(Serialized, EventManager):
class SerializedEventManager(EventManager):
_lock = None
_timeout = 0
......@@ -147,7 +148,7 @@ class SerializedEventManager(Serialized, EventManager):
# before the first message is sent.
# TODO: Detect where a message is sent to jump immediately to nodes
# that will do something.
self.tic(self._lock)
Serialized.tic(self._lock)
if timeout != 0:
timeout = self._timeout
if timeout != 0 and Serialized.pending:
......@@ -294,15 +295,13 @@ class NEOCluster(object):
SocketConnector_send = staticmethod(SocketConnector.send)
Storage__init__ = staticmethod(Storage.__init__)
_cluster = None
_patched = threading.Lock()
@classmethod
def patch(cls):
def _patch(cluster):
cls = cluster.__class__
if not cls._patched.acquire(0):
raise RuntimeError("Can't run several cluster at the same time")
def makeClientConnection(self, addr):
# XXX: 'threading.currentThread()._cluster'
# does not work for client. We could monkey-patch
# ClientConnection instead of using a global variable.
cluster = cls._cluster()
try:
real_addr = cluster.resolv(addr)
return cls.SocketConnector_makeClientConnection(self, real_addr)
......@@ -314,11 +313,6 @@ class NEOCluster(object):
return result
# TODO: 'sleep' should 'tic' in a smart way, so that storages can be
# safely started even if the cluster isn't.
def sleep(seconds):
l = threading.currentThread().em._lock
while Serialized.pending:
Serialized.tic(l)
Serialized.tic(l)
bootstrap.sleep = lambda seconds: None
BaseConnection.checkTimeout = lambda self, t: None
SocketConnector.makeClientConnection = makeClientConnection
......@@ -328,7 +322,7 @@ class NEOCluster(object):
Storage.setupLog = lambda *args, **kw: None
@classmethod
def unpatch(cls):
def _unpatch(cls):
bootstrap.sleep = time.sleep
BaseConnection.checkTimeout = cls.BaseConnection_checkTimeout
SocketConnector.makeClientConnection = \
......@@ -337,6 +331,7 @@ class NEOCluster(object):
cls.SocketConnector_makeListeningConnection
SocketConnector.send = cls.SocketConnector_send
Storage.setupLog = setupLog
cls._patched.release()
def __init__(self, master_count=1, partitions=1, replicas=0,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'),
......@@ -405,7 +400,8 @@ class NEOCluster(object):
self.neoctl = NeoCTL(self)
def start(self, storage_list=None, fast_startup=True):
self.__class__._cluster = weak_ref(self)
self._patch()
Serialized.init()
for node_type in 'master', 'admin':
for node in getattr(self, node_type + '_list'):
node.start()
......@@ -448,7 +444,7 @@ class NEOCluster(object):
node.join()
finally:
Serialized.acquire()
self.__class__._cluster = None
self._unpatch()
def tic(self, force=False):
if force:
......@@ -483,11 +479,3 @@ class NEOThreadedTest(NeoUnitTestBase):
def setupLog(self):
log_file = os.path.join(getTempDirectory(), self.id() + '.log')
setupLog(LoggerThreadName(), log_file, True)
def setUp(self):
NeoUnitTestBase.setUp(self)
NEOCluster.patch()
def tearDown(self):
NEOCluster.unpatch()
NeoUnitTestBase.tearDown(self)
......@@ -45,16 +45,8 @@ class MatrixImportBenchmark(BenchmarkRunner):
if storages[-1] < max_s:
storages.append(max_s)
replicas = range(min_r, max_r + 1)
if self._config.threaded:
from neo.tests.threaded import NEOCluster
NEOCluster.patch() # XXX ugly
try:
result_list = [self.runMatrix(storages, replicas)
for x in xrange(self._config.repeat)]
finally:
if self._config.threaded:
from neo.tests.threaded import NEOCluster
NEOCluster.unpatch()# XXX ugly
results = {}
for s in storages:
results[s] = z = {}
......@@ -84,7 +76,7 @@ class MatrixImportBenchmark(BenchmarkRunner):
datafs = 'PROD1'
import random, neo.tests.stat_zodb
dfs_storage = getattr(neo.tests.stat_zodb, datafs)(
random.Random(0)).as_storage(10000)
random.Random(0)).as_storage(100)
print "Import of %s with m=%s, s=%s, r=%s, p=%s" % (
datafs, masters, storages, replicas, partitions)
# cluster
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment