__init__.py 31.8 KB
Newer Older
1
#
Julien Muchembled's avatar
Julien Muchembled committed
2
# Copyright (C) 2011-2016  Nexedi SA
3 4 5 6 7 8 9 10 11 12 13 14
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

17 18
# XXX: Consider using ClusterStates.STOPPING to stop clusters

19 20
import os, random, select, socket, sys, tempfile
import thread, threading, time, traceback, weakref
21
from collections import deque
22
from ConfigParser import SafeConfigParser
23
from contextlib import contextmanager
24
from itertools import count
25
from functools import wraps
26
from thread import get_ident
27
from zlib import decompress
28 29 30 31 32
from mock import Mock
import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app
from neo.client import Storage
33
from neo.lib import logging
34
from neo.lib.connection import BaseConnection, \
35
    ClientConnection, Connection, ConnectionClosed, ListeningConnection
36
from neo.lib.connector import SocketConnector, ConnectorException
37
from neo.lib.handler import EventHandler
38
from neo.lib.locking import SimpleQueue
39
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes
40
from neo.lib.util import cached_property, parseMasterList, p64
41
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
42
    ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER
43 44 45 46 47

BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])


48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
class LockLock(object):
    """Double lock used as synchronisation point between 2 threads

    Used to wait that a slave thread has reached a specific location, and to
    keep it suspended there. It resumes on __exit__
    """

    def __init__(self):
        self._l = threading.Lock(), threading.Lock()

    def __call__(self):
        """Define synchronisation point for both threads"""
        if self._owner == thread.get_ident():
            self._l[0].acquire()
        else:
            self._l[0].release()
            self._l[1].acquire()

    def __enter__(self):
        self._owner = thread.get_ident()
        for l in self._l:
            l.acquire(0)
        return self

    def __exit__(self, t, v, tb):
        try:
            self._l[1].release()
        except thread.error:
            pass


79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
class FairLock(deque):
    """Same as a threading.Lock except that waiting threads are queued, so that
    the first one waiting for the lock is the first to get it. This is useful
    when several concurrent threads fight for the same resource in loop:
    the owner could give too little time for other to get a chance to acquire,
    blocking them for a long time with bad luck.
    """

    def __enter__(self, _allocate_lock=threading.Lock):
        me = _allocate_lock()
        me.acquire()
        self.append(me)
        other = self[0]
        while me is not other:
            with other:
                other = self[0]

    def __exit__(self, t, v, tb):
        self.popleft().release()


100
class Serialized(object):
101
    """
102 103
    "Threaded" tests run all nodes in the same process as the test itself,
    and threads are scheduled by this class, which mainly provides 2 features:
104 105 106 107 108 109 110 111 112 113 114 115 116
    - more determinism, by minimizing the number of active threads, and
      switching them in a round-robin;
    - tic() method to wait only the necessary time for the cluster to be idle.

    The basic concept is that each thread has a lock that always gets acquired
    by itself. The following pattern is used to yield the processor to the next
    thread:
        release(); acquire()
    It should be noted that this is not atomic, i.e. all other threads
    sometimes complete before a thread tries to acquire its lock: in order that
    the previous thread does not fail by releasing an un-acquired lock,
    we actually use Semaphores instead of Locks.

117
    The epoll object of each node is hooked so that thread switching happens
118 119
    before polling for network activity. An extra epoll object is used to
    detect which node has a readable epoll object.
120 121 122 123

    XXX: It seems wrong to rely only on epoll as way to know if there are
         pending network messages. I had rare random failures due to tic()
         returning prematurely.
124
    """
125 126
    check_timeout = False

127 128
    @classmethod
    def init(cls):
129 130 131 132 133
        cls._busy = set()
        cls._busy_cond = threading.Condition(threading.Lock())
        cls._epoll = select.epoll()
        cls._pdb = None
        cls._sched_lock = threading.Semaphore(0)
134 135
        cls._tic_lock = FairLock()
        cls._fd_dict = {}
136

137
    @classmethod
138 139 140 141
    def idle(cls, owner):
        with cls._busy_cond:
            cls._busy.discard(owner)
            cls._busy_cond.notify_all()
142

143
    @classmethod
144
    def stop(cls):
145 146 147
        assert not cls._fd_dict, ("file descriptor leak (%r)\nThis may happen"
            " when a test fails, in which case you can see the real exception"
            " by disabling this one." % cls._fd_dict)
148 149
        del(cls._busy, cls._busy_cond, cls._epoll, cls._fd_dict,
            cls._pdb, cls._sched_lock, cls._tic_lock)
150

151
    @classmethod
152 153
    def _sort_key(cls, fd_event):
        return -cls._fd_dict[fd_event[0]]._last
154

155
    @classmethod
156 157 158 159 160 161 162 163 164 165 166 167 168 169
    @contextmanager
    def pdb(cls):
        try:
            cls._pdb = sys._getframe(2).f_trace.im_self
            cls._pdb.set_continue()
        except AttributeError:
            pass
        yield
        p = cls._pdb
        if p is not None:
            cls._pdb = None
            t = threading.currentThread()
            p.stdout.write(getattr(t, 'node_name', t.name))
            p.set_trace(sys._getframe(3))
170 171

    @classmethod
172
    def tic(cls, step=-1, check_timeout=(), quiet=False):
173 174
        # If you're in a pdb here, 'n' switches to another thread
        # (the following lines are not supposed to be debugged into)
175
        with cls._tic_lock, cls.pdb():
176 177 178 179 180 181 182
            if not quiet:
                f = sys._getframe(1)
                try:
                    logging.info('tic (%s:%u) ...',
                        f.f_code.co_filename, f.f_lineno)
                finally:
                    del f
183
            if cls._busy:
184 185 186
                with cls._busy_cond:
                    while cls._busy:
                        cls._busy_cond.wait()
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
            for app in check_timeout:
                app.em.epoll.check_timeout = True
                app.em.wakeup()
                del app
            while step:
                event_list = cls._epoll.poll(0)
                if not event_list:
                    break
                step -= 1
                event_list.sort(key=cls._sort_key)
                next_lock = cls._sched_lock
                for fd, event in event_list:
                    self = cls._fd_dict[fd]
                    self._release_next = next_lock.release
                    next_lock = self._lock
                del self
                next_lock.release()
                cls._sched_lock.acquire()
205 206 207 208

    def __init__(self, app, busy=True):
        self._epoll = app.em.epoll
        app.em.epoll = self
209 210 211
        # XXX: It may have been initialized before the SimpleQueue is patched.
        thread_container = getattr(app, '_thread_container', None)
        thread_container is None or thread_container.__init__()
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        if busy:
            self._busy.add(self) # block tic until app waits for polling

    def __getattr__(self, attr):
        if attr in ('close', 'modify', 'register', 'unregister'):
            return getattr(self._epoll, attr)
        return self.__getattribute__(attr)

    def poll(self, timeout):
        if self.check_timeout:
            assert timeout >= 0, (self, timeout)
            del self.check_timeout
        elif timeout:
            with self.pdb(): # same as in tic()
                release = self._release_next
                self._release_next = None
                release()
                self._lock.acquire()
                self._last = time.time()
        return self._epoll.poll(timeout)

    def _release_next(self):
        self._last = time.time()
        self._lock = threading.Semaphore(0)
        fd = self._epoll.fileno()
        cls = self.__class__
        cls._fd_dict[fd] = self
        cls._epoll.register(fd)
        cls.idle(self)
241

242 243
    def exit(self):
        fd = self._epoll.fileno()
244
        cls = self.__class__
245 246
        if cls._fd_dict.pop(fd, None) is None:
            cls.idle(self)
247
        else:
248 249 250 251 252 253 254
            cls._epoll.unregister(fd)
            self._release_next()

class TestSerialized(Serialized):

    def __init__(*args):
        Serialized.__init__(busy=False, *args)
255

256
    def poll(self, timeout):
257
        if timeout:
258
            for x in xrange(1000):
259 260 261 262
                r = self._epoll.poll(0)
                if r:
                    return r
                Serialized.tic(step=1)
263
            raise Exception("tic is looping forever")
264
        return self._epoll.poll(timeout)
265

266

267 268
class Node(object):

269
    def getConnectionList(self, *peers):
270
        addr = lambda c: c and (c.addr if c.is_server else c.getAddress())
271
        addr_set = {addr(c.connector) for peer in peers
272
            for c in peer.em.connection_dict.itervalues()
273
            if isinstance(c, Connection)}
274
        addr_set.discard(None)
275
        return (c for c in self.em.connection_dict.itervalues()
276
            if isinstance(c, Connection) and addr(c.connector) in addr_set)
277 278 279

    def filterConnection(self, *peers):
        return ConnectionFilter(self.getConnectionList(*peers))
280 281

class ServerNode(Node):
282

283 284
    _server_class_dict = {}

285 286
    class __metaclass__(type):
        def __init__(cls, name, bases, d):
287
            if Node not in bases and threading.Thread not in cls.__mro__:
288
                cls.__bases__ = bases + (threading.Thread,)
289 290 291 292 293 294
                cls.node_type = getattr(NodeTypes, name[:-11].upper())
                cls._node_list = []
                cls._virtual_ip = socket.inet_ntop(ADDRESS_TYPE,
                    LOCAL_IP[:-1] + chr(2 + len(cls._server_class_dict)))
                cls._server_class_dict[cls._virtual_ip] = cls

295 296 297 298 299
    @staticmethod
    def resetPorts():
        for cls in ServerNode._server_class_dict.itervalues():
            del cls._node_list[:]

300 301 302 303 304 305 306 307 308 309 310 311 312
    @classmethod
    def newAddress(cls):
        address = cls._virtual_ip, len(cls._node_list)
        cls._node_list.append(None)
        return address

    @classmethod
    def resolv(cls, address):
        try:
            cls = cls._server_class_dict[address[0]]
        except KeyError:
            return address
        return cls._node_list[address[1]].getListeningAddress()
313

314
    def __init__(self, cluster=None, address=None, **kw):
315 316
        if not address:
            address = self.newAddress()
317
        if cluster is None:
318 319
            master_nodes = kw.get('master_nodes', ())
            name = kw.get('name', 'test')
320 321 322
        else:
            master_nodes = kw.get('master_nodes', cluster.master_nodes)
            name = kw.get('name', cluster.name)
323 324
        port = address[1]
        self._node_list[port] = weakref.proxy(self)
325 326 327
        self._init_args = init_args = kw.copy()
        init_args['cluster'] = cluster
        init_args['address'] = address
328
        threading.Thread.__init__(self)
329
        self.daemon = True
330
        self.node_name = '%s_%u' % (self.node_type, port)
331
        kw.update(getCluster=name, getBind=address,
332
            getMasters=master_nodes and parseMasterList(master_nodes, address))
333 334
        super(ServerNode, self).__init__(Mock(kw))

335
    def getVirtualAddress(self):
336
        return self._init_args['address']
337

338
    def resetNode(self):
339
        assert not self.is_alive()
340
        kw = self._init_args
341
        self.close()
342
        self.__init__(**kw)
343 344

    def start(self):
345
        Serialized(self)
346 347 348 349 350 351 352
        threading.Thread.start(self)

    def run(self):
        try:
            super(ServerNode, self).run()
        finally:
            self._afterRun()
353
            logging.debug('stopping %r', self)
354
            self.em.epoll.exit()
355 356 357 358

    def _afterRun(self):
        try:
            self.listening_conn.close()
359
            self.listening_conn = None
360 361 362 363 364 365 366
        except AttributeError:
            pass

    def getListeningAddress(self):
        try:
            return self.listening_conn.getAddress()
        except AttributeError:
367
            raise ConnectorException
368

369 370 371
    def stop(self):
        self.em.wakeup(True)

372 373 374 375 376 377 378 379
class AdminApplication(ServerNode, neo.admin.app.Application):
    pass

class MasterApplication(ServerNode, neo.master.app.Application):
    pass

class StorageApplication(ServerNode, neo.storage.app.Application):

380 381
    dm = type('', (), {'close': lambda self: None})()

382
    def resetNode(self, clear_database=False):
383
        self._init_args['getReset'] = clear_database
384 385 386 387 388 389
        super(StorageApplication, self).resetNode()

    def _afterRun(self):
        super(StorageApplication, self)._afterRun()
        try:
            self.dm.close()
390
            del self.dm
391 392
        except StandardError: # AttributeError & ProgrammingError
            pass
393 394
        if self.master_conn:
            self.master_conn.close()
395

396 397 398
    def getAdapter(self):
        return self._init_args['getAdapter']

399 400
    def getDataLockInfo(self):
        dm = self.dm
401 402
        index = tuple(dm.query("SELECT id, hash, compression FROM data"))
        assert set(dm._uncommitted_data).issubset(x[0] for x in index)
403
        get = dm._uncommitted_data.get
404 405 406 407 408
        return {(str(h), c & 0x7f): get(i, 0) for i, h, c in index}

    def sqlCount(self, table):
        (r,), = self.dm.query("SELECT COUNT(*) FROM " + table)
        return r
409

410
class ClientApplication(Node, neo.client.app.Application):
411

412 413
    def __init__(self, master_nodes, name, **kw):
        super(ClientApplication, self).__init__(master_nodes, name, **kw)
414 415
        self.poll_thread.node_name = name

416
    def _run(self):
417
        try:
418
            super(ClientApplication, self)._run()
419 420 421 422 423 424
        finally:
            self.em.epoll.exit()

    def start(self):
        isinstance(self.em.epoll, Serialized) or Serialized(self)
        super(ClientApplication, self).start()
425

426
    def getConnectionList(self, *peers):
427 428 429 430 431 432
        for peer in peers:
            if isinstance(peer, MasterApplication):
                conn = self._getMasterConnection()
            else:
                assert isinstance(peer, StorageApplication)
                conn = self.cp.getConnForNode(self.nm.getByUUID(peer.uuid))
433
            yield conn
434

435 436
class NeoCTL(neo.neoctl.app.NeoCTL):

437 438
    def __init__(self, *args, **kw):
        super(NeoCTL, self).__init__(*args, **kw)
439
        TestSerialized(self)
440 441


442
class LoggerThreadName(str):
443

444 445
    def __new__(cls, default='TEST'):
        return str.__new__(cls, default)
446

447
    def __getattribute__(self, attr):
448 449
        return getattr(str(self), attr)

450 451 452
    def __hash__(self):
        return id(self)

453 454
    def __str__(self):
        try:
455
            return threading.currentThread().node_name
456
        except AttributeError:
457
            return str.__str__(self)
458

459 460 461

class ConnectionFilter(object):

462
    filtered_count = 0
463 464
    filter_list = []
    filter_queue = weakref.WeakKeyDictionary()
465
    lock = threading.RLock()
466 467 468 469 470
    _addPacket = Connection._addPacket

    @contextmanager
    def __new__(cls, conn_list=()):
        self = object.__new__(cls)
471
        self.filter_dict = {}
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
        self.conn_list = frozenset(conn_list)
        if not cls.filter_list:
            def _addPacket(conn, packet):
                with cls.lock:
                    try:
                        queue = cls.filter_queue[conn]
                    except KeyError:
                        for self in cls.filter_list:
                            if self(conn, packet):
                                self.filtered_count += 1
                                break
                        else:
                            return cls._addPacket(conn, packet)
                        cls.filter_queue[conn] = queue = deque()
                    p = packet.__new__(packet.__class__)
                    p.__dict__.update(packet.__dict__)
                    queue.append(p)
            Connection._addPacket = _addPacket
        try:
            cls.filter_list.append(self)
            yield self
        finally:
            del cls.filter_list[-1:]
            if not cls.filter_list:
                Connection._addPacket = cls._addPacket.im_func
        with cls.lock:
            cls._retry()

    def __call__(self, conn, packet):
        if not self.conn_list or conn in self.conn_list:
            for filter in self.filter_dict:
                if filter(conn, packet):
                    return True
        return False
506

507 508 509
    @classmethod
    def _retry(cls):
        for conn, queue in cls.filter_queue.items():
510 511
            while queue:
                packet = queue.popleft()
512 513
                for self in cls.filter_list:
                    if self(conn, packet):
514 515 516
                        queue.appendleft(packet)
                        break
                else:
517 518
                    if conn.isClosed():
                        return
519
                    cls._addPacket(conn, packet)
520 521
                    continue
                break
522 523
            else:
                del cls.filter_queue[conn]
524 525

    def add(self, filter, *patches):
526
        with self.lock:
527
            self.filter_dict[filter] = patches
528 529
            for p in patches:
                p.apply()
530 531

    def remove(self, *filters):
532
        with self.lock:
533 534 535 536
            for filter in filters:
                del self.filter_dict[filter]
            self._retry()

537 538 539 540 541 542
    def discard(self, *filters):
        try:
            self.remove(*filters)
        except KeyError:
            pass

543 544 545
    def __contains__(self, filter):
        return filter in self.filter_dict

546 547
class NEOCluster(object):

Julien Muchembled's avatar
Julien Muchembled committed
548 549
    SSL = None

550 551 552 553 554
    def __init__(orig, self): # temporary definition for SimpleQueue patch
        orig(self)
        lock = self._lock
        def _lock(blocking=True):
            if blocking:
555
                logging.info('<SimpleQueue>._lock.acquire()')
556
                while not lock(False):
557
                    Serialized.tic(step=1, quiet=True)
558 559 560 561 562 563 564 565 566 567
                return True
            return lock(False)
        self._lock = _lock
    _patches = (
        Patch(BaseConnection, getTimeout=lambda orig, self: None),
        Patch(SimpleQueue, __init__=__init__),
        Patch(SocketConnector, CONNECT_LIMIT=0),
        Patch(SocketConnector, _bind=lambda orig, self, addr: orig(self, BIND)),
        Patch(SocketConnector, _connect = lambda orig, self, addr:
            orig(self, ServerNode.resolv(addr))))
568 569
    _patch_count = 0
    _resource_dict = weakref.WeakValueDictionary()
570

571 572 573 574 575 576
    def _allocate(self, resource, new):
        result = resource, new()
        while result in self._resource_dict:
            result = resource, new()
        self._resource_dict[result] = self
        return result[1]
577

578 579 580
    @staticmethod
    def _patch():
        cls = NEOCluster
581 582 583
        cls._patch_count += 1
        if cls._patch_count > 1:
            return
584 585
        for patch in cls._patches:
            patch.apply()
586
        Serialized.init()
587

588
    @staticmethod
589
    def _unpatch():
590
        cls = NEOCluster
591 592 593 594
        assert cls._patch_count > 0
        cls._patch_count -= 1
        if cls._patch_count:
            return
595 596
        for patch in cls._patches:
            patch.revert()
597
        Serialized.stop()
598

599 600
    def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
                       adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
601
                       storage_count=None, db_list=None, clear_databases=True,
602
                       db_user=DB_USER, db_password='', compress=True,
603
                       importer=None, autostart=None):
604 605
        self.name = 'neo_%s' % self._allocate('name',
            lambda: random.randint(0, 100))
606
        self.compress = compress
607 608 609
        master_list = [MasterApplication.newAddress()
                       for _ in xrange(master_count)]
        self.master_nodes = ' '.join('%s:%s' % x for x in master_list)
610 611
        weak_self = weakref.proxy(self)
        kw = dict(cluster=weak_self, getReplicas=replicas, getAdapter=adapter,
Julien Muchembled's avatar
Julien Muchembled committed
612 613
                  getPartitions=partitions, getReset=clear_databases,
                  getSSL=self.SSL)
614
        if upstream is not None:
Vincent Pelletier's avatar
Vincent Pelletier committed
615 616 617
            self.upstream = weakref.proxy(upstream)
            kw.update(getUpstreamCluster=upstream.name,
                getUpstreamMasters=parseMasterList(upstream.master_nodes))
618 619
        self.master_list = [MasterApplication(getAutostart=autostart,
                                              address=x, **kw)
620
                            for x in master_list]
621 622 623
        if db_list is None:
            if storage_count is None:
                storage_count = replicas + 1
624 625 626
            index = count().next
            db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
                       for _ in xrange(storage_count)]
627 628
        if adapter == 'MySQL':
            setupMySQLdb(db_list, db_user, db_password, clear_databases)
629
            db = '%s:%s@%%s%s' % (db_user, db_password, DB_SOCKET)
630 631
        elif adapter == 'SQLite':
            db = os.path.join(getTempDirectory(), '%s.sqlite')
632 633
        else:
            assert False, adapter
634 635 636 637 638 639 640 641 642 643 644 645 646
        if importer:
            cfg = SafeConfigParser()
            cfg.add_section("neo")
            cfg.set("neo", "adapter", adapter)
            cfg.set("neo", "database", db % tuple(db_list))
            for name, zodb in importer:
                cfg.add_section(name)
                for x in zodb.iteritems():
                    cfg.set(name, *x)
            db = os.path.join(getTempDirectory(), '%s.conf')
            with open(db % tuple(db_list), "w") as f:
                cfg.write(f)
            kw["getAdapter"] = "Importer"
647 648 649
        self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
                             for x in db_list]
        self.admin_list = [AdminApplication(**kw)]
Julien Muchembled's avatar
Julien Muchembled committed
650
        self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
651

652 653 654 655
    def __repr__(self):
        return "<%s(%s) at 0x%x>" % (self.__class__.__name__,
                                     self.name, id(self))

656 657 658 659 660 661 662 663 664 665 666 667 668 669 670
    # A few shortcuts that work when there's only 1 master/storage/admin
    @property
    def master(self):
        master, = self.master_list
        return master
    @property
    def storage(self):
        storage, = self.storage_list
        return storage
    @property
    def admin(self):
        admin, = self.admin_list
        return admin
    ###

671 672 673 674 675
    @property
    def primary_master(self):
        master, = [master for master in self.master_list if master.primary]
        return master

676
    def reset(self, clear_database=False):
677
        for node_type in 'master', 'storage', 'admin':
678 679 680 681 682
            kw = {}
            if node_type == 'storage':
                kw['clear_database'] = clear_database
            for node in getattr(self, node_type + '_list'):
                node.resetNode(**kw)
683
        self.neoctl.close()
Julien Muchembled's avatar
Julien Muchembled committed
684
        self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
685

686
    def start(self, storage_list=None, fast_startup=False):
687
        self._patch()
688 689 690
        for node_type in 'master', 'admin':
            for node in getattr(self, node_type + '_list'):
                node.start()
691
        Serialized.tic()
692
        if fast_startup:
693
            self.startCluster()
694 695 696 697
        if storage_list is None:
            storage_list = self.storage_list
        for node in storage_list:
            node.start()
698
        Serialized.tic()
699
        if not fast_startup:
700
            self.startCluster()
701
            Serialized.tic()
702
        state = self.neoctl.getClusterState()
703
        assert state in (ClusterStates.RUNNING, ClusterStates.BACKINGUP), state
704 705
        self.enableStorageList(storage_list)

Julien Muchembled's avatar
Julien Muchembled committed
706 707 708 709
    def newClient(self):
        return ClientApplication(name=self.name, master_nodes=self.master_nodes,
                                 compress=self.compress, ssl=self.SSL)

710 711
    @cached_property
    def client(self):
Julien Muchembled's avatar
Julien Muchembled committed
712
        client = self.newClient()
713 714 715 716 717 718 719 720 721 722 723 724
        # Make sure client won't be reused after it was closed.
        def close():
            client = self.client
            del self.client, client.close
            client.close()
        client.close = close
        return client

    @cached_property
    def db(self):
        return ZODB.DB(storage=self.getZODBStorage())

725
    def startCluster(self):
726 727 728
        try:
            self.neoctl.startCluster()
        except RuntimeError:
729
            Serialized.tic()
730
            if self.neoctl.getClusterState() not in (
731
                      ClusterStates.BACKINGUP,
732 733 734 735 736
                      ClusterStates.RUNNING,
                      ClusterStates.VERIFYING,
                  ):
                raise

737 738
    def enableStorageList(self, storage_list):
        self.neoctl.enableStorageList([x.uuid for x in storage_list])
739
        Serialized.tic()
740 741 742
        for node in storage_list:
            assert self.getNodeState(node) == NodeStates.RUNNING

743 744 745
    def join(self, thread_list, timeout=5):
        timeout += time.time()
        while thread_list:
746
            assert time.time() < timeout, thread_list
747 748 749
            Serialized.tic()
            thread_list = [t for t in thread_list if t.is_alive()]

750
    def stop(self):
751
        logging.debug("stopping %s", self)
752 753
        client = self.__dict__.get("client")
        client is None or self.__dict__.pop("db", client).close()
754 755
        node_list = self.admin_list + self.storage_list + self.master_list
        for node in node_list:
756
            node.stop()
757 758 759 760
        try:
            node_list.append(client.poll_thread)
        except AttributeError: # client is None or thread is already stopped
            pass
761
        self.join(node_list)
762
        logging.debug("stopped %s", self)
763
        self._unpatch()
764 765 766 767 768 769 770

    def getNodeState(self, node):
        uuid = node.uuid
        for node in self.neoctl.getNodeList(node.node_type):
            if node[2] == uuid:
                return node[3]

Julien Muchembled's avatar
Julien Muchembled committed
771
    def getOutdatedCells(self):
772 773 774 775 776 777
        # Ask the admin instead of the primary master to check that it is
        # notified of every change.
        return [(i, cell.getUUID())
            for i, row in enumerate(self.admin.pt.partition_list)
            for cell in row
            if not cell.isReadable()]
778 779

    def getZODBStorage(self, **kw):
780 781
        kw['_app'] = kw.pop('client', self.client)
        return Storage.Storage(None, self.name, **kw)
782

783
    def importZODB(self, dummy_zodb=None, random=random):
784 785
        if dummy_zodb is None:
            from ..stat_zodb import PROD1
786
            dummy_zodb = PROD1(random)
787 788
        preindex = {}
        as_storage = dummy_zodb.as_storage
789 790
        return lambda count: self.getZODBStorage().importFrom(
            as_storage(count), preindex=preindex)
791

792 793 794 795 796 797 798 799
    def populate(self, transaction_list, tid=lambda i: p64(i+1),
                                         oid=lambda i: p64(i+1)):
        storage = self.getZODBStorage()
        tid_dict = {}
        for i, oid_list in enumerate(transaction_list):
            txn = transaction.Transaction()
            storage.tpc_begin(txn, tid(i))
            for o in oid_list:
800
                storage.store(oid(o), tid_dict.get(o), repr((i, o)), '', txn)
801 802 803 804 805
            storage.tpc_vote(txn)
            i = storage.tpc_finish(txn)
            for o in oid_list:
                tid_dict[o] = i

806 807
    def getTransaction(self):
        txn = transaction.TransactionManager()
808
        return txn, self.db.open(transaction_manager=txn)
809

810 811 812 813 814 815 816 817 818
    def __del__(self, __print_exc=traceback.print_exc):
        try:
            self.neoctl.close()
            for node_type in 'admin', 'storage', 'master':
                for node in getattr(self, node_type + '_list'):
                    node.close()
        except:
            __print_exc()
            raise
819

820
    def extraCellSortKey(self, key):
821 822
        return Patch(self.client.cp, getCellSortKey=lambda orig, cell:
            (orig(cell), key(cell)))
823

824 825 826 827 828 829 830
    def moduloTID(self, partition):
        """Force generation of TIDs that will be stored in given partition"""
        partition = p64(partition)
        master = self.primary_master
        return Patch(master.tm, _nextTID=lambda orig, *args:
            orig(*args) if args else orig(partition, master.pt.getPartitions()))

831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
    def sortStorageList(self):
        """Sort storages so that storage_list[i] has partition i for all i"""
        pt = [{x.getUUID() for x in x}
            for x in self.primary_master.pt.partition_list]
        r = []
        x = [iter(pt[0])]
        try:
            while 1:
                try:
                    r.append(next(x[-1]))
                except StopIteration:
                    del r[-1], x[-1]
                else:
                    x.append(iter(pt[len(r)].difference(r)))
        except IndexError:
            assert len(r) == len(self.storage_list)
        x = {x.uuid: x for x in self.storage_list}
        self.storage_list[:] = (x[r] for r in r)
        return self.storage_list
850

851
class NEOThreadedTest(NeoTestBase):
852 853 854

    def setupLog(self):
        log_file = os.path.join(getTempDirectory(), self.id() + '.log')
855
        logging.setup(log_file)
856
        return LoggerThreadName()
857

858 859
    def _tearDown(self, success):
        super(NEOThreadedTest, self)._tearDown(success)
860
        ServerNode.resetPorts()
861
        if success:
862 863 864
            with logging as db:
                db.execute("UPDATE packet SET body=NULL")
                db.execute("VACUUM")
865

866 867
    tic = Serialized.tic

868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
    def getLoopbackConnection(self):
        app = MasterApplication(getSSL=NEOCluster.SSL,
            getReplicas=0, getPartitions=1)
        handler = EventHandler(app)
        app.listening_conn = ListeningConnection(app, handler, app.server)
        node = app.nm.createMaster(address=app.listening_conn.getAddress(),
                                   uuid=app.uuid)
        conn = ClientConnection.__new__(ClientConnection)
        def reset():
            conn.__dict__.clear()
            conn.__init__(app, handler, node)
            conn.reset = reset
        reset()
        return conn

883 884 885 886 887 888 889 890 891 892
    def getUnpickler(self, conn):
        reader = conn._reader
        def unpickler(data, compression=False):
            if compression:
                data = decompress(data)
            obj = reader.getGhost(data)
            reader.setGhostState(obj, data)
            return obj
        return unpickler

893 894 895 896 897
    class newThread(threading.Thread):

        def __init__(self, func, *args, **kw):
            threading.Thread.__init__(self)
            self.__target = func, args, kw
898
            self.daemon = True
899 900 901 902 903 904 905 906 907 908 909
            self.start()

        def run(self):
            try:
                apply(*self.__target)
                self.__exc_info = None
            except:
                self.__exc_info = sys.exc_info()

        def join(self, timeout=None):
            threading.Thread.join(self, timeout)
910
            if not self.is_alive() and self.__exc_info:
911 912 913
                etype, value, tb = self.__exc_info
                del self.__exc_info
                raise etype, value, tb
914

915 916 917 918
    def commitWithStorageFailure(self, client, txn):
        with Patch(client, _getFinalTID=lambda *_: None):
            self.assertRaises(ConnectionClosed, txn.commit)

919 920 921 922 923

def predictable_random(seed=None):
    # Because we have 2 running threads when client works, we can't
    # patch neo.client.pool (and cluster should have 1 storage).
    from neo.master import backup_app
924
    from neo.master.handlers import administration
925 926 927 928
    from neo.storage import replicator
    def decorator(wrapped):
        def wrapper(*args, **kw):
            s = repr(time.time()) if seed is None else seed
929
            logging.info("using seed %r", s)
930 931
            r = random.Random(s)
            try:
932 933
                administration.random = backup_app.random = replicator.random \
                    = r
934 935
                return wrapped(*args, **kw)
            finally:
936 937
                administration.random = backup_app.random = replicator.random \
                    = random
938 939
        return wraps(wrapped)(wrapper)
    return decorator