Commit b3b5175f authored by Julien Muchembled's avatar Julien Muchembled

tests: make it possible to run several threaded clusters at the same time

parent dcbf0b02
......@@ -18,6 +18,7 @@
import os, random, socket, sys, tempfile, threading, time, types, weakref
from collections import deque
from itertools import count
from functools import wraps
from zlib import decompress
from mock import Mock
......@@ -37,12 +38,6 @@ from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
SERVER_TYPE = ['master', 'storage', 'admin']
VIRTUAL_IP = [socket.inet_ntop(ADDRESS_TYPE, LOCAL_IP[:-1] + chr(2 + i))
for i in xrange(len(SERVER_TYPE))]
def getVirtualIp(server_type):
return VIRTUAL_IP[SERVER_TYPE.index(server_type)]
class Serialized(object):
......@@ -57,12 +52,12 @@ class Serialized(object):
cls.pending = 0
@classmethod
def release(cls, lock=None, wake_other=True, stop=False):
def release(cls, lock=None, wake_other=True, stop=None):
"""Suspend lock owner and resume first suspended thread"""
if lock is None:
lock = cls._global_lock
if stop: # XXX: we should fix ClusterStates.STOPPING
cls.pending = None
cls.pending = frozenset(stop)
else:
cls.pending = 0
try:
......@@ -86,10 +81,10 @@ class Serialized(object):
if lock is None:
lock = cls._global_lock
lock.acquire()
if cls.pending is None: # XXX
if type(cls.pending) is frozenset: # XXX
if lock is cls._global_lock:
cls.pending = 0
else:
elif threading.currentThread() in cls.pending:
sys.exit()
if cls._pdb:
cls._pdb = False
......@@ -143,7 +138,7 @@ class SerializedEventManager(EventManager):
self.writer_set):
return
else:
if self.writer_set and Serialized.pending is not None:
if self.writer_set and Serialized.pending == 0:
Serialized.pending = 1
# Jump to another thread before polling, so that when a message is
# sent on the network, one can debug immediately the receiving part.
......@@ -154,7 +149,7 @@ class SerializedEventManager(EventManager):
Serialized.tic(self._lock)
if timeout != 0:
timeout = self._timeout
if timeout != 0 and Serialized.pending:
if timeout != 0 and Serialized.pending == 1:
Serialized.pending = timeout = 0
EventManager._poll(self, timeout)
......@@ -173,25 +168,50 @@ class Node(object):
class ServerNode(Node):
_server_class_dict = {}
class __metaclass__(type):
def __init__(cls, name, bases, d):
type.__init__(cls, name, bases, d)
if Node not in bases and threading.Thread not in cls.__mro__:
cls.__bases__ = bases + (threading.Thread,)
cls.node_type = getattr(NodeTypes, name[:-11].upper())
cls._node_list = []
cls._virtual_ip = socket.inet_ntop(ADDRESS_TYPE,
LOCAL_IP[:-1] + chr(2 + len(cls._server_class_dict)))
cls._server_class_dict[cls._virtual_ip] = cls
@classmethod
def newAddress(cls):
address = cls._virtual_ip, len(cls._node_list)
cls._node_list.append(None)
return address
@classmethod
def resolv(cls, address):
try:
cls = cls._server_class_dict[address[0]]
except KeyError:
return address
return cls._node_list[address[1]].getListeningAddress()
@SerializedEventManager.decorate
def __init__(self, cluster, address, **kw):
self._init_args = (cluster, address), dict(kw)
def __init__(self, cluster, address=None, **kw):
if not address:
address = self.newAddress()
port = address[1]
self._node_list[port] = weakref.proxy(self)
self._init_args = (cluster, address), kw.copy()
threading.Thread.__init__(self)
self.daemon = True
h, p = address
self.node_type = getattr(NodeTypes,
SERVER_TYPE[VIRTUAL_IP.index(h)].upper())
self.node_name = '%s_%u' % (self.node_type, p)
self.node_name = '%s_%u' % (self.node_type, port)
kw.update(getCluster=cluster.name, getBind=address,
getMasters=parseMasterList(cluster.master_nodes, address))
super(ServerNode, self).__init__(Mock(kw))
def getVirtualAddress(self):
return self._init_args[0][1]
def resetNode(self):
assert not self.isAlive()
args, kw = self._init_args
......@@ -321,12 +341,12 @@ class ClientApplication(Node, neo.client.app.Application):
class NeoCTL(neo.neoctl.app.NeoCTL):
@SerializedEventManager.decorate
def __init__(self, cluster, address=(getVirtualIp('admin'), 0)):
def __init__(self, cluster):
self._cluster = cluster
super(NeoCTL, self).__init__(address)
super(NeoCTL, self).__init__(cluster.admin.getVirtualAddress())
self.em._timeout = -1
server = property(lambda self: self._cluster.resolv(self._server),
server = property(lambda self: ServerNode.resolv(self._server),
lambda self, address: setattr(self, '_server', address))
......@@ -441,16 +461,24 @@ class NEOCluster(object):
SocketConnector.makeListeningConnection)
SocketConnector_send = staticmethod(SocketConnector.send)
Storage__init__ = staticmethod(Storage.__init__)
_patch_count = 0
_resource_dict = weakref.WeakValueDictionary()
_patched = threading.Lock()
def _allocate(self, resource, new):
result = resource, new()
while result in self._resource_dict:
result = resource, new()
self._resource_dict[result] = self
return result[1]
def _patch(cluster):
cls = cluster.__class__
if not cls._patched.acquire(0):
raise RuntimeError("Can't run several cluster at the same time")
cls._patch_count += 1
if cls._patch_count > 1:
return
def makeClientConnection(self, addr):
real_addr = ServerNode.resolv(addr)
try:
real_addr = cluster.resolv(addr)
return cls.SocketConnector_makeClientConnection(self, real_addr)
finally:
self.remote_addr = addr
......@@ -468,9 +496,14 @@ class NEOCluster(object):
cls.SocketConnector_makeListeningConnection(self, BIND)
SocketConnector.send = send
Storage.setupLog = lambda *args, **kw: None
Serialized.init()
@classmethod
def _unpatch(cls):
assert cls._patch_count > 0
cls._patch_count -= 1
if cls._patch_count:
return
bootstrap.sleep = time.sleep
BaseConnection.checkTimeout = cls.BaseConnection_checkTimeout
SocketConnector.makeClientConnection = \
......@@ -479,7 +512,6 @@ class NEOCluster(object):
cls.SocketConnector_makeListeningConnection
SocketConnector.send = cls.SocketConnector_send
Storage.setupLog = setupLog
cls._patched.release()
def __init__(self, master_count=1, partitions=1, replicas=0,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'),
......@@ -492,27 +524,27 @@ class NEOCluster(object):
log_file = tempfile.mkstemp('.log', '', temp_dir)[1]
print 'Logging to %r' % log_file
setupLog(LoggerThreadName(), log_file, verbose)
self.name = 'neo_%s' % random.randint(0, 100)
ip = getVirtualIp('master')
self.master_nodes = ' '.join('%s:%s' % (ip, i)
for i in xrange(master_count))
self.name = 'neo_%s' % self._allocate('name',
lambda: random.randint(0, 100))
master_list = [MasterApplication.newAddress()
for _ in xrange(master_count)]
self.master_nodes = ' '.join('%s:%s' % x for x in master_list)
weak_self = weakref.proxy(self)
kw = dict(cluster=weak_self, getReplicas=replicas, getAdapter=adapter,
getPartitions=partitions, getReset=clear_databases)
self.master_list = [MasterApplication(address=(ip, i), **kw)
for i in xrange(master_count)]
ip = getVirtualIp('storage')
self.master_list = [MasterApplication(address=x, **kw)
for x in master_list]
if db_list is None:
if storage_count is None:
storage_count = replicas + 1
db_list = ['%s%u' % (DB_PREFIX, i) for i in xrange(storage_count)]
index = count().next
db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
for _ in xrange(storage_count)]
setupMySQLdb(db_list, db_user, db_password, clear_databases)
db = '%s:%s@%%s' % (db_user, db_password)
self.storage_list = [StorageApplication(address=(ip, i),
getDatabase=db % x, **kw)
for i, x in enumerate(db_list)]
ip = getVirtualIp('admin')
self.admin_list = [AdminApplication(address=(ip, 0), **kw)]
self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
for x in db_list]
self.admin_list = [AdminApplication(**kw)]
self.client = ClientApplication(weak_self)
self.neoctl = NeoCTL(weak_self)
......@@ -531,16 +563,8 @@ class NEOCluster(object):
return admin
###
def resolv(self, addr):
host, port = addr
try:
attr = SERVER_TYPE[VIRTUAL_IP.index(host)] + '_list'
except ValueError:
return addr
return getattr(self, attr)[port].getListeningAddress()
def reset(self, clear_database=False):
for node_type in SERVER_TYPE:
for node_type in 'master', 'storage', 'admin':
kw = {}
if node_type == 'storage':
kw['clear_database'] = clear_database
......@@ -551,7 +575,6 @@ class NEOCluster(object):
def start(self, storage_list=None, fast_startup=False):
self._patch()
Serialized.init()
for node_type in 'master', 'admin':
for node in getattr(self, node_type + '_list'):
node.start()
......@@ -566,7 +589,8 @@ class NEOCluster(object):
if not fast_startup:
self._startCluster()
self.tic()
assert self.neoctl.getClusterState() == ClusterStates.RUNNING
state = self.neoctl.getClusterState()
assert state == ClusterStates.RUNNING, state
self.enableStorageList(storage_list)
def _startCluster(self):
......@@ -598,8 +622,9 @@ class NEOCluster(object):
self.__dict__.pop('_db', self.client).close()
#self.neoctl.setClusterState(ClusterStates.STOPPING) # TODO
try:
Serialized.release(stop=1)
for node_type in SERVER_TYPE[::-1]:
Serialized.release(stop=
self.admin_list + self.storage_list + self.master_list)
for node_type in 'admin', 'storage', 'master':
for node in getattr(self, node_type + '_list'):
if node.isAlive():
node.join()
......
......@@ -411,6 +411,21 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def test2Clusters(self):
cluster1 = NEOCluster()
cluster2 = NEOCluster()
try:
cluster1.start()
cluster2.start()
t1, c1 = cluster1.getTransaction()
t2, c2 = cluster2.getTransaction()
c1.root()['1'] = c2.root()['2'] = ''
t1.commit()
t2.commit()
finally:
cluster1.stop()
cluster2.stop()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment