__init__.py 37.9 KB
Newer Older
1
#
Julien Muchembled's avatar
Julien Muchembled committed
2
# Copyright (C) 2011-2017  Nexedi SA
3 4 5 6 7 8 9 10 11 12 13 14
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

17 18
# XXX: Consider using ClusterStates.STOPPING to stop clusters

19 20
import os, random, select, socket, sys, tempfile
import thread, threading, time, traceback, weakref
21
from collections import deque
22
from ConfigParser import SafeConfigParser
23
from contextlib import contextmanager
24
from itertools import count
25
from functools import partial, wraps
26
from zlib import decompress
27
from ..mock import Mock
28 29 30 31
import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app
from neo.client import Storage
32
from neo.lib import logging
33
from neo.lib.connection import BaseConnection, \
34
    ClientConnection, Connection, ConnectionClosed, ListeningConnection
35
from neo.lib.connector import SocketConnector, ConnectorException
36
from neo.lib.handler import EventHandler
37
from neo.lib.locking import SimpleQueue
38
from neo.lib.protocol import ClusterStates, Enum, NodeStates, NodeTypes, Packets
39
from neo.lib.util import cached_property, parseMasterList, p64
40
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
41
    ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER
42 43 44

BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
45
TIC_LOOP = xrange(1000)
46 47


48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
class LockLock(object):
    """Double lock used as synchronisation point between 2 threads

    Used to wait that a slave thread has reached a specific location, and to
    keep it suspended there. It resumes on __exit__
    """

    def __init__(self):
        self._l = threading.Lock(), threading.Lock()

    def __call__(self):
        """Define synchronisation point for both threads"""
        if self._owner == thread.get_ident():
            self._l[0].acquire()
        else:
            self._l[0].release()
            self._l[1].acquire()

    def __enter__(self):
        self._owner = thread.get_ident()
        for l in self._l:
            l.acquire(0)
        return self

    def __exit__(self, t, v, tb):
        try:
            self._l[1].release()
        except thread.error:
            pass


79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
class FairLock(deque):
    """Same as a threading.Lock except that waiting threads are queued, so that
    the first one waiting for the lock is the first to get it. This is useful
    when several concurrent threads fight for the same resource in loop:
    the owner could give too little time for other to get a chance to acquire,
    blocking them for a long time with bad luck.
    """

    def __enter__(self, _allocate_lock=threading.Lock):
        me = _allocate_lock()
        me.acquire()
        self.append(me)
        other = self[0]
        while me is not other:
            with other:
                other = self[0]

    def __exit__(self, t, v, tb):
        self.popleft().release()


100
class Serialized(object):
101
    """
102 103
    "Threaded" tests run all nodes in the same process as the test itself,
    and threads are scheduled by this class, which mainly provides 2 features:
104 105 106 107 108 109 110 111 112 113 114 115 116
    - more determinism, by minimizing the number of active threads, and
      switching them in a round-robin;
    - tic() method to wait only the necessary time for the cluster to be idle.

    The basic concept is that each thread has a lock that always gets acquired
    by itself. The following pattern is used to yield the processor to the next
    thread:
        release(); acquire()
    It should be noted that this is not atomic, i.e. all other threads
    sometimes complete before a thread tries to acquire its lock: in order that
    the previous thread does not fail by releasing an un-acquired lock,
    we actually use Semaphores instead of Locks.

117
    The epoll object of each node is hooked so that thread switching happens
118 119
    before polling for network activity. An extra epoll object is used to
    detect which node has a readable epoll object.
120
    """
121 122
    check_timeout = False

123 124
    @classmethod
    def init(cls):
125 126 127 128 129
        cls._busy = set()
        cls._busy_cond = threading.Condition(threading.Lock())
        cls._epoll = select.epoll()
        cls._pdb = None
        cls._sched_lock = threading.Semaphore(0)
130 131
        cls._tic_lock = FairLock()
        cls._fd_dict = {}
132

133
    @classmethod
134 135 136 137
    def idle(cls, owner):
        with cls._busy_cond:
            cls._busy.discard(owner)
            cls._busy_cond.notify_all()
138

139
    @classmethod
140
    def stop(cls):
141 142 143
        assert not cls._fd_dict, ("file descriptor leak (%r)\nThis may happen"
            " when a test fails, in which case you can see the real exception"
            " by disabling this one." % cls._fd_dict)
144 145
        del(cls._busy, cls._busy_cond, cls._epoll, cls._fd_dict,
            cls._pdb, cls._sched_lock, cls._tic_lock)
146

147
    @classmethod
148 149
    def _sort_key(cls, fd_event):
        return -cls._fd_dict[fd_event[0]]._last
150

151
    @classmethod
152 153 154 155 156 157 158 159 160 161 162 163 164 165
    @contextmanager
    def pdb(cls):
        try:
            cls._pdb = sys._getframe(2).f_trace.im_self
            cls._pdb.set_continue()
        except AttributeError:
            pass
        yield
        p = cls._pdb
        if p is not None:
            cls._pdb = None
            t = threading.currentThread()
            p.stdout.write(getattr(t, 'node_name', t.name))
            p.set_trace(sys._getframe(3))
166 167

    @classmethod
168 169 170 171 172 173 174
    def tic(cls, step=-1, check_timeout=(), quiet=False,
            # BUG: We overuse epoll as a way to know if there are pending
            #      network messages. Sometimes, and this is more visible with
            #      a single-core CPU, other threads are still busy and haven't
            #      sent anything yet on the network. This causes tic() to
            #      return prematurely. Passing a non-zero value is a hack.
            timeout=0):
175 176
        # If you're in a pdb here, 'n' switches to another thread
        # (the following lines are not supposed to be debugged into)
177
        with cls._tic_lock, cls.pdb():
178 179 180 181 182 183 184
            if not quiet:
                f = sys._getframe(1)
                try:
                    logging.info('tic (%s:%u) ...',
                        f.f_code.co_filename, f.f_lineno)
                finally:
                    del f
185
            if cls._busy:
186 187 188
                with cls._busy_cond:
                    while cls._busy:
                        cls._busy_cond.wait()
189 190 191 192 193
            for app in check_timeout:
                app.em.epoll.check_timeout = True
                app.em.wakeup()
                del app
            while step:
194
                event_list = cls._epoll.poll(timeout)
195 196 197 198 199 200 201 202 203 204 205 206
                if not event_list:
                    break
                step -= 1
                event_list.sort(key=cls._sort_key)
                next_lock = cls._sched_lock
                for fd, event in event_list:
                    self = cls._fd_dict[fd]
                    self._release_next = next_lock.release
                    next_lock = self._lock
                del self
                next_lock.release()
                cls._sched_lock.acquire()
207 208 209 210

    def __init__(self, app, busy=True):
        self._epoll = app.em.epoll
        app.em.epoll = self
211 212 213
        # XXX: It may have been initialized before the SimpleQueue is patched.
        thread_container = getattr(app, '_thread_container', None)
        thread_container is None or thread_container.__init__()
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        if busy:
            self._busy.add(self) # block tic until app waits for polling

    def __getattr__(self, attr):
        if attr in ('close', 'modify', 'register', 'unregister'):
            return getattr(self._epoll, attr)
        return self.__getattribute__(attr)

    def poll(self, timeout):
        if self.check_timeout:
            assert timeout >= 0, (self, timeout)
            del self.check_timeout
        elif timeout:
            with self.pdb(): # same as in tic()
                release = self._release_next
                self._release_next = None
                release()
                self._lock.acquire()
                self._last = time.time()
        return self._epoll.poll(timeout)

    def _release_next(self):
        self._last = time.time()
        self._lock = threading.Semaphore(0)
        fd = self._epoll.fileno()
        cls = self.__class__
        cls._fd_dict[fd] = self
        cls._epoll.register(fd)
        cls.idle(self)
243

244 245
    def exit(self):
        fd = self._epoll.fileno()
246
        cls = self.__class__
247 248
        if cls._fd_dict.pop(fd, None) is None:
            cls.idle(self)
249
        else:
250 251 252 253 254 255 256
            cls._epoll.unregister(fd)
            self._release_next()

class TestSerialized(Serialized):

    def __init__(*args):
        Serialized.__init__(busy=False, *args)
257

258
    def poll(self, timeout):
259
        if timeout:
260
            for x in TIC_LOOP:
261 262 263
                r = self._epoll.poll(0)
                if r:
                    return r
264
                Serialized.tic(step=1, timeout=.001)
265
            raise Exception("tic is looping forever")
266
        return self._epoll.poll(timeout)
267

268

269 270
class Node(object):

271 272 273 274
    @staticmethod
    def convertInitArgs(**kw):
        return {'get' + k.capitalize(): v for k, v in kw.iteritems()}

275
    def getConnectionList(self, *peers):
276
        addr = lambda c: c and (c.addr if c.is_server else c.getAddress())
277
        addr_set = {addr(c.connector) for peer in peers
278
            for c in peer.em.connection_dict.itervalues()
279
            if isinstance(c, Connection)}
280
        addr_set.discard(None)
281
        return (c for c in self.em.connection_dict.itervalues()
282
            if isinstance(c, Connection) and addr(c.connector) in addr_set)
283 284 285

    def filterConnection(self, *peers):
        return ConnectionFilter(self.getConnectionList(*peers))
286 287

class ServerNode(Node):
288

289 290
    _server_class_dict = {}

291 292
    class __metaclass__(type):
        def __init__(cls, name, bases, d):
293
            if Node not in bases and threading.Thread not in cls.__mro__:
294
                cls.__bases__ = bases + (threading.Thread,)
295 296 297 298 299 300
                cls.node_type = getattr(NodeTypes, name[:-11].upper())
                cls._node_list = []
                cls._virtual_ip = socket.inet_ntop(ADDRESS_TYPE,
                    LOCAL_IP[:-1] + chr(2 + len(cls._server_class_dict)))
                cls._server_class_dict[cls._virtual_ip] = cls

301 302 303 304 305
    @staticmethod
    def resetPorts():
        for cls in ServerNode._server_class_dict.itervalues():
            del cls._node_list[:]

306 307 308 309 310 311 312 313 314 315 316 317 318
    @classmethod
    def newAddress(cls):
        address = cls._virtual_ip, len(cls._node_list)
        cls._node_list.append(None)
        return address

    @classmethod
    def resolv(cls, address):
        try:
            cls = cls._server_class_dict[address[0]]
        except KeyError:
            return address
        return cls._node_list[address[1]].getListeningAddress()
319

320
    def __init__(self, cluster=None, address=None, **kw):
321 322
        if not address:
            address = self.newAddress()
323
        if cluster is None:
324
            master_nodes = ()
325
            name = kw.get('name', 'test')
326
        else:
327
            master_nodes = cluster.master_nodes
328
            name = kw.get('name', cluster.name)
329
        port = address[1]
330 331
        if address is not BIND:
            self._node_list[port] = weakref.proxy(self)
332 333 334
        self._init_args = init_args = kw.copy()
        init_args['cluster'] = cluster
        init_args['address'] = address
335
        threading.Thread.__init__(self)
336
        self.daemon = True
337
        self.node_name = '%s_%u' % (self.node_type, port)
338
        kw.update(getCluster=name, getBind=address,
339
            getMasters=master_nodes and parseMasterList(master_nodes))
340 341
        super(ServerNode, self).__init__(Mock(kw))

342
    def getVirtualAddress(self):
343
        return self._init_args['address']
344

345
    def resetNode(self, **kw):
346
        assert not self.is_alive()
347 348 349 350 351
        kw = self.convertInitArgs(**kw)
        init_args = self._init_args
        init_args['getReset'] = False
        assert set(kw).issubset(init_args), (kw, init_args)
        init_args.update(kw)
352
        self.close()
353
        self.__init__(**init_args)
354 355

    def start(self):
356
        Serialized(self)
357 358 359 360 361 362 363
        threading.Thread.start(self)

    def run(self):
        try:
            super(ServerNode, self).run()
        finally:
            self._afterRun()
364
            logging.debug('stopping %r', self)
365
            self.em.epoll.exit()
366 367 368 369

    def _afterRun(self):
        try:
            self.listening_conn.close()
370
            self.listening_conn = None
371 372 373 374 375 376 377
        except AttributeError:
            pass

    def getListeningAddress(self):
        try:
            return self.listening_conn.getAddress()
        except AttributeError:
378
            raise ConnectorException
379

380
    def stop(self):
381
        self.em.wakeup(thread.exit)
382

383 384 385 386 387 388 389 390
class AdminApplication(ServerNode, neo.admin.app.Application):
    pass

class MasterApplication(ServerNode, neo.master.app.Application):
    pass

class StorageApplication(ServerNode, neo.storage.app.Application):

391 392
    dm = type('', (), {'close': lambda self: None})()

393 394 395 396
    def _afterRun(self):
        super(StorageApplication, self)._afterRun()
        try:
            self.dm.close()
397
            del self.dm
398 399
        except StandardError: # AttributeError & ProgrammingError
            pass
400 401
        if self.master_conn:
            self.master_conn.close()
402

403 404 405
    def getAdapter(self):
        return self._init_args['getAdapter']

406 407
    def getDataLockInfo(self):
        dm = self.dm
408 409
        index = tuple(dm.query("SELECT id, hash, compression FROM data"))
        assert set(dm._uncommitted_data).issubset(x[0] for x in index)
410
        get = dm._uncommitted_data.get
411 412 413 414 415
        return {(str(h), c & 0x7f): get(i, 0) for i, h, c in index}

    def sqlCount(self, table):
        (r,), = self.dm.query("SELECT COUNT(*) FROM " + table)
        return r
416

417
class ClientApplication(Node, neo.client.app.Application):
418

419 420
    max_reconnection_to_master = 10

421 422
    def __init__(self, master_nodes, name, **kw):
        super(ClientApplication, self).__init__(master_nodes, name, **kw)
423
        self.poll_thread.node_name = name
424 425 426
        # Smaller cache to speed up tests that checks behaviour when it's too
        # small. See also NEOCluster.cache_size
        self._cache._max_size //= 1024
427

428
    def _run(self):
429
        try:
430
            super(ClientApplication, self)._run()
431 432 433 434 435 436
        finally:
            self.em.epoll.exit()

    def start(self):
        isinstance(self.em.epoll, Serialized) or Serialized(self)
        super(ClientApplication, self).start()
437

438
    def getConnectionList(self, *peers):
439 440 441 442 443 444
        for peer in peers:
            if isinstance(peer, MasterApplication):
                conn = self._getMasterConnection()
            else:
                assert isinstance(peer, StorageApplication)
                conn = self.cp.getConnForNode(self.nm.getByUUID(peer.uuid))
445
            yield conn
446

447 448
    def extraCellSortKey(self, key):
        return Patch(self.cp, getCellSortKey=lambda orig, cell:
449
            (orig(cell, lambda: key(cell)), random.random()))
450

451 452
class NeoCTL(neo.neoctl.app.NeoCTL):

453 454
    def __init__(self, *args, **kw):
        super(NeoCTL, self).__init__(*args, **kw)
455
        TestSerialized(self)
456 457


458
class LoggerThreadName(str):
459

460 461
    def __new__(cls, default='TEST'):
        return str.__new__(cls, default)
462

463
    def __getattribute__(self, attr):
464 465
        return getattr(str(self), attr)

466 467 468
    def __hash__(self):
        return id(self)

469 470
    def __str__(self):
        try:
471
            return threading.currentThread().node_name
472
        except AttributeError:
473
            return str.__str__(self)
474

475 476 477

class ConnectionFilter(object):

478
    filtered_count = 0
479
    filter_list = []
480
    filter_queue = weakref.WeakKeyDictionary() # XXX: see the end of __new__
481
    lock = threading.RLock()
482 483 484 485 486
    _addPacket = Connection._addPacket

    @contextmanager
    def __new__(cls, conn_list=()):
        self = object.__new__(cls)
487
        self.filter_dict = {}
488 489 490 491 492 493 494 495
        self.conn_list = frozenset(conn_list)
        if not cls.filter_list:
            def _addPacket(conn, packet):
                with cls.lock:
                    try:
                        queue = cls.filter_queue[conn]
                    except KeyError:
                        for self in cls.filter_list:
496
                            if self._test(conn, packet):
497 498 499 500 501
                                self.filtered_count += 1
                                break
                        else:
                            return cls._addPacket(conn, packet)
                        cls.filter_queue[conn] = queue = deque()
502 503 504 505
                    p = packet.__class__
                    logging.debug("queued %s#0x%04x for %s",
                                  p.__name__, packet.getId(), conn)
                    p = packet.__new__(p)
506 507 508 509 510 511 512 513 514 515
                    p.__dict__.update(packet.__dict__)
                    queue.append(p)
            Connection._addPacket = _addPacket
        try:
            cls.filter_list.append(self)
            yield self
        finally:
            del cls.filter_list[-1:]
            if not cls.filter_list:
                Connection._addPacket = cls._addPacket.im_func
516 517 518 519 520
            # Retry even in case of exception, at least to avoid leaks in
            # filter_queue. Sometimes, WeakKeyDictionary only does the job
            # only an explicit call to gc.collect.
            with cls.lock:
                cls._retry()
521

522
    def _test(self, conn, packet):
523 524 525 526 527
        if not self.conn_list or conn in self.conn_list:
            for filter in self.filter_dict:
                if filter(conn, packet):
                    return True
        return False
528

529 530 531 532 533
    @classmethod
    def retry(cls):
        with cls.lock:
            cls._retry()

534 535 536
    @classmethod
    def _retry(cls):
        for conn, queue in cls.filter_queue.items():
537 538
            while queue:
                packet = queue.popleft()
539
                for self in cls.filter_list:
540
                    if self._test(conn, packet):
541 542 543
                        queue.appendleft(packet)
                        break
                else:
544
                    if conn.isClosed():
545 546 547 548 549 550
                        queue.clear()
                    else:
                        # Use the thread that created the packet to reinject it,
                        # to avoid a race condition on Connector.queued.
                        conn.em.wakeup(lambda conn=conn, packet=packet:
                            conn.isClosed() or cls._addPacket(conn, packet))
551 552
                    continue
                break
553 554
            else:
                del cls.filter_queue[conn]
555 556

    def add(self, filter, *patches):
557
        with self.lock:
558
            self.filter_dict[filter] = patches
559 560
            for p in patches:
                p.apply()
561 562

    def remove(self, *filters):
563
        with self.lock:
564
            for filter in filters:
565 566
                for p in self.filter_dict.pop(filter):
                    p.revert()
567 568
            self._retry()

569 570 571 572 573 574
    def discard(self, *filters):
        try:
            self.remove(*filters)
        except KeyError:
            pass

575 576 577
    def __contains__(self, filter):
        return filter in self.filter_dict

578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593
    def byPacket(self, packet_type, *args):
        patches = []
        other = []
        for x in args:
            (patches if isinstance(x, Patch) else other).append(x)
        def delay(conn, packet):
            return isinstance(packet, packet_type) and False not in (
                callback(conn) for callback in other)
        self.add(delay, *patches)
        return delay

    def __getattr__(self, attr):
        if attr.startswith('delay'):
            return partial(self.byPacket, getattr(Packets, attr[5:]))
        return self.__getattribute__(attr)

594 595
class NEOCluster(object):

Julien Muchembled's avatar
Julien Muchembled committed
596 597
    SSL = None

598 599 600 601 602
    def __init__(orig, self): # temporary definition for SimpleQueue patch
        orig(self)
        lock = self._lock
        def _lock(blocking=True):
            if blocking:
603
                logging.info('<SimpleQueue>._lock.acquire()')
604
                for i in TIC_LOOP:
605 606
                    if lock(False):
                        return True
607
                    Serialized.tic(step=1, quiet=True, timeout=.001)
608
                raise Exception("tic is looping forever")
609 610 611 612 613 614 615 616 617
            return lock(False)
        self._lock = _lock
    _patches = (
        Patch(BaseConnection, getTimeout=lambda orig, self: None),
        Patch(SimpleQueue, __init__=__init__),
        Patch(SocketConnector, CONNECT_LIMIT=0),
        Patch(SocketConnector, _bind=lambda orig, self, addr: orig(self, BIND)),
        Patch(SocketConnector, _connect = lambda orig, self, addr:
            orig(self, ServerNode.resolv(addr))))
618 619
    _patch_count = 0
    _resource_dict = weakref.WeakValueDictionary()
620

621 622 623 624 625 626
    def _allocate(self, resource, new):
        result = resource, new()
        while result in self._resource_dict:
            result = resource, new()
        self._resource_dict[result] = self
        return result[1]
627

628 629 630
    @staticmethod
    def _patch():
        cls = NEOCluster
631 632 633
        cls._patch_count += 1
        if cls._patch_count > 1:
            return
634 635
        for patch in cls._patches:
            patch.apply()
636
        Serialized.init()
637

638
    @staticmethod
639
    def _unpatch():
640
        cls = NEOCluster
641 642 643 644
        assert cls._patch_count > 0
        cls._patch_count -= 1
        if cls._patch_count:
            return
645 646
        for patch in cls._patches:
            patch.revert()
647
        Serialized.stop()
648

649 650
    started = False

651 652
    def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
                       adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
653
                       storage_count=None, db_list=None, clear_databases=True,
654
                       db_user=DB_USER, db_password='', compress=True,
655
                       importer=None, autostart=None):
656 657
        self.name = 'neo_%s' % self._allocate('name',
            lambda: random.randint(0, 100))
658
        self.compress = compress
659
        self.num_partitions = partitions
660 661 662
        master_list = [MasterApplication.newAddress()
                       for _ in xrange(master_count)]
        self.master_nodes = ' '.join('%s:%s' % x for x in master_list)
663 664 665 666
        kw = Node.convertInitArgs(replicas=replicas, adapter=adapter,
            partitions=partitions, reset=clear_databases)
        kw['cluster'] = weak_self = weakref.proxy(self)
        kw['getSSL'] = self.SSL
667
        if upstream is not None:
Vincent Pelletier's avatar
Vincent Pelletier committed
668 669 670
            self.upstream = weakref.proxy(upstream)
            kw.update(getUpstreamCluster=upstream.name,
                getUpstreamMasters=parseMasterList(upstream.master_nodes))
671 672
        self.master_list = [MasterApplication(getAutostart=autostart,
                                              address=x, **kw)
673
                            for x in master_list]
674 675 676
        if db_list is None:
            if storage_count is None:
                storage_count = replicas + 1
677 678 679
            index = count().next
            db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
                       for _ in xrange(storage_count)]
680 681
        if adapter == 'MySQL':
            setupMySQLdb(db_list, db_user, db_password, clear_databases)
682
            db = '%s:%s@%%s%s' % (db_user, db_password, DB_SOCKET)
683 684
        elif adapter == 'SQLite':
            db = os.path.join(getTempDirectory(), '%s.sqlite')
685 686
        else:
            assert False, adapter
687 688 689 690 691 692 693 694 695 696 697 698 699
        if importer:
            cfg = SafeConfigParser()
            cfg.add_section("neo")
            cfg.set("neo", "adapter", adapter)
            cfg.set("neo", "database", db % tuple(db_list))
            for name, zodb in importer:
                cfg.add_section(name)
                for x in zodb.iteritems():
                    cfg.set(name, *x)
            db = os.path.join(getTempDirectory(), '%s.conf')
            with open(db % tuple(db_list), "w") as f:
                cfg.write(f)
            kw["getAdapter"] = "Importer"
700 701 702
        self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
                             for x in db_list]
        self.admin_list = [AdminApplication(**kw)]
703

704 705 706 707
    def __repr__(self):
        return "<%s(%s) at 0x%x>" % (self.__class__.__name__,
                                     self.name, id(self))

708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
    # A few shortcuts that work when there's only 1 master/storage/admin
    @property
    def master(self):
        master, = self.master_list
        return master
    @property
    def storage(self):
        storage, = self.storage_list
        return storage
    @property
    def admin(self):
        admin, = self.admin_list
        return admin
    ###

723 724 725 726 727 728 729 730 731
    # More handy shortcuts for tests
    @property
    def backup_tid(self):
        return self.neoctl.getRecovery()[1]

    @property
    def last_tid(self):
        return self.primary_master.getLastTransaction()

732 733 734 735
    @property
    def primary_master(self):
        master, = [master for master in self.master_list if master.primary]
        return master
736 737 738 739

    @property
    def cache_size(self):
        return self.client._cache._max_size
740
    ###
741

742 743 744 745 746
    def __enter__(self):
        return self

    def __exit__(self, t, v, tb):
        self.stop(None)
747

748
    def start(self, storage_list=None, master_list=None, recovering=False):
749
        self.started = True
750
        self._patch()
751
        self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
752 753 754 755
        for node in self.master_list if master_list is None else master_list:
            node.start()
        for node in self.admin_list:
            node.start()
756
        Serialized.tic()
757 758 759 760
        if storage_list is None:
            storage_list = self.storage_list
        for node in storage_list:
            node.start()
761
        Serialized.tic()
762 763 764
        if recovering:
            expected_state = ClusterStates.RECOVERING
        else:
765
            self.startCluster()
766
            Serialized.tic()
767 768 769 770 771 772
            expected_state = ClusterStates.RUNNING, ClusterStates.BACKINGUP
        self.checkStarted(expected_state, storage_list)

    def checkStarted(self, expected_state, storage_list=None):
        if isinstance(expected_state, Enum.Item):
            expected_state = expected_state,
773
        state = self.neoctl.getClusterState()
774 775 776 777 778 779 780
        assert state in expected_state, state
        expected_state = (NodeStates.PENDING
            if state == ClusterStates.RECOVERING
            else NodeStates.RUNNING)
        for node in self.storage_list if storage_list is None else storage_list:
            state = self.getNodeState(node)
            assert state == expected_state, (node, state)
781

782
    def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw):
783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809
        if self.started:
            del self.started
            logging.debug("stopping %s", self)
            client = self.__dict__.get("client")
            client is None or self.__dict__.pop("db", client).close()
            node_list = self.admin_list + self.storage_list + self.master_list
            for node in node_list:
                node.stop()
            try:
                node_list.append(client.poll_thread)
            except AttributeError: # client is None or thread is already stopped
                pass
            self.join(node_list)
            self.neoctl.close()
            del self.neoctl
            logging.debug("stopped %s", self)
            self._unpatch()
        if clear_database is None:
            try:
                for node_type in 'admin', 'storage', 'master':
                    for node in getattr(self, node_type + '_list'):
                        node.close()
            except:
                __print_exc()
                raise
        else:
            for node_type in 'master', 'storage', 'admin':
810
                reset_kw = kw.copy()
811
                if node_type == 'storage':
812
                    reset_kw['reset'] = clear_database
813
                for node in getattr(self, node_type + '_list'):
814
                    node.resetNode(**reset_kw)
815

816
    def _newClient(self):
Julien Muchembled's avatar
Julien Muchembled committed
817 818 819
        return ClientApplication(name=self.name, master_nodes=self.master_nodes,
                                 compress=self.compress, ssl=self.SSL)

820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
    @contextmanager
    def newClient(self, with_db=False):
        x = self._newClient()
        try:
            t = x.poll_thread
            closed = []
            if with_db:
                x = ZODB.DB(storage=self.getZODBStorage(client=x))
            else:
                # XXX: Do nothing if finally if the caller already closed it.
                x.close = lambda: closed.append(x.__class__.close(x))
            yield x
        finally:
            closed or x.close()
            self.join((t,))

836 837
    @cached_property
    def client(self):
838
        client = self._newClient()
839 840 841 842 843 844 845 846 847 848 849 850
        # Make sure client won't be reused after it was closed.
        def close():
            client = self.client
            del self.client, client.close
            client.close()
        client.close = close
        return client

    @cached_property
    def db(self):
        return ZODB.DB(storage=self.getZODBStorage())

851
    def startCluster(self):
852 853 854
        try:
            self.neoctl.startCluster()
        except RuntimeError:
855
            Serialized.tic()
856
            if self.neoctl.getClusterState() not in (
857
                      ClusterStates.BACKINGUP,
858 859 860 861 862
                      ClusterStates.RUNNING,
                      ClusterStates.VERIFYING,
                  ):
                raise

863 864
    def enableStorageList(self, storage_list):
        self.neoctl.enableStorageList([x.uuid for x in storage_list])
865
        Serialized.tic()
866
        for node in storage_list:
867 868
            state = self.getNodeState(node)
            assert state == NodeStates.RUNNING, state
869

870 871 872
    def join(self, thread_list, timeout=5):
        timeout += time.time()
        while thread_list:
873 874 875
            # Map with repr before that threads become unprintable.
            assert time.time() < timeout, map(repr, thread_list)
            Serialized.tic(timeout=.001)
876 877
            thread_list = [t for t in thread_list if t.is_alive()]

878 879 880 881 882 883
    def getNodeState(self, node):
        uuid = node.uuid
        for node in self.neoctl.getNodeList(node.node_type):
            if node[2] == uuid:
                return node[3]

Julien Muchembled's avatar
Julien Muchembled committed
884
    def getOutdatedCells(self):
885 886 887 888 889 890
        # Ask the admin instead of the primary master to check that it is
        # notified of every change.
        return [(i, cell.getUUID())
            for i, row in enumerate(self.admin.pt.partition_list)
            for cell in row
            if not cell.isReadable()]
891 892

    def getZODBStorage(self, **kw):
893 894
        kw['_app'] = kw.pop('client', self.client)
        return Storage.Storage(None, self.name, **kw)
895

896
    def importZODB(self, dummy_zodb=None, random=random):
897 898
        if dummy_zodb is None:
            from ..stat_zodb import PROD1
899
            dummy_zodb = PROD1(random)
900 901
        preindex = {}
        as_storage = dummy_zodb.as_storage
902 903
        return lambda count: self.getZODBStorage().importFrom(
            as_storage(count), preindex=preindex)
904

905 906 907 908 909 910 911 912
    def populate(self, transaction_list, tid=lambda i: p64(i+1),
                                         oid=lambda i: p64(i+1)):
        storage = self.getZODBStorage()
        tid_dict = {}
        for i, oid_list in enumerate(transaction_list):
            txn = transaction.Transaction()
            storage.tpc_begin(txn, tid(i))
            for o in oid_list:
913
                storage.store(oid(o), tid_dict.get(o), repr((i, o)), '', txn)
914 915 916 917 918
            storage.tpc_vote(txn)
            i = storage.tpc_finish(txn)
            for o in oid_list:
                tid_dict[o] = i

919
    def getTransaction(self, db=None):
920
        txn = transaction.TransactionManager()
921
        return txn, (self.db if db is None else db).open(txn)
922

923 924 925 926 927 928 929
    def moduloTID(self, partition):
        """Force generation of TIDs that will be stored in given partition"""
        partition = p64(partition)
        master = self.primary_master
        return Patch(master.tm, _nextTID=lambda orig, *args:
            orig(*args) if args else orig(partition, master.pt.getPartitions()))

930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948
    def sortStorageList(self):
        """Sort storages so that storage_list[i] has partition i for all i"""
        pt = [{x.getUUID() for x in x}
            for x in self.primary_master.pt.partition_list]
        r = []
        x = [iter(pt[0])]
        try:
            while 1:
                try:
                    r.append(next(x[-1]))
                except StopIteration:
                    del r[-1], x[-1]
                else:
                    x.append(iter(pt[len(r)].difference(r)))
        except IndexError:
            assert len(r) == len(self.storage_list)
        x = {x.uuid: x for x in self.storage_list}
        self.storage_list[:] = (x[r] for r in r)
        return self.storage_list
949

950
class NEOThreadedTest(NeoTestBase):
951

952 953
    __run_count = {}

954
    def setupLog(self):
955 956 957 958 959 960
        test_id = self.id()
        i = self.__run_count.get(test_id, 0)
        self.__run_count[test_id] = 1 + i
        if i:
            test_id += '-%s' % i
        logging.setup(os.path.join(getTempDirectory(), test_id + '.log'))
961
        return LoggerThreadName()
962

963 964
    def _tearDown(self, success):
        super(NEOThreadedTest, self)._tearDown(success)
965
        ServerNode.resetPorts()
966
        if success and logging._max_size is not None:
967 968 969
            with logging as db:
                db.execute("UPDATE packet SET body=NULL")
                db.execute("VACUUM")
970

971 972
    tic = Serialized.tic

973
    @contextmanager
974
    def getLoopbackConnection(self):
975 976 977 978 979 980 981 982 983
        app = MasterApplication(address=BIND,
            getSSL=NEOCluster.SSL, getReplicas=0, getPartitions=1)
        try:
            handler = EventHandler(app)
            app.listening_conn = ListeningConnection(app, handler, app.server)
            yield ClientConnection(app, handler, app.nm.createMaster(
                address=app.listening_conn.getAddress(), uuid=app.uuid))
        finally:
            app.close()
984

985 986 987 988 989 990 991 992 993 994
    def getUnpickler(self, conn):
        reader = conn._reader
        def unpickler(data, compression=False):
            if compression:
                data = decompress(data)
            obj = reader.getGhost(data)
            reader.setGhostState(obj, data)
            return obj
        return unpickler

995
    class newPausedThread(threading.Thread):
996 997 998 999

        def __init__(self, func, *args, **kw):
            threading.Thread.__init__(self)
            self.__target = func, args, kw
1000
            self.daemon = True
1001 1002 1003 1004 1005 1006 1007

        def run(self):
            try:
                apply(*self.__target)
                self.__exc_info = None
            except:
                self.__exc_info = sys.exc_info()
1008 1009
                if self.__exc_info[0] is NEOThreadedTest.failureException:
                    traceback.print_exception(*self.__exc_info)
1010 1011 1012

        def join(self, timeout=None):
            threading.Thread.join(self, timeout)
1013
            if not self.is_alive() and self.__exc_info:
1014 1015 1016
                etype, value, tb = self.__exc_info
                del self.__exc_info
                raise etype, value, tb
1017

1018 1019 1020 1021 1022 1023
    class newThread(newPausedThread):

        def __init__(self, *args, **kw):
            NEOThreadedTest.newPausedThread.__init__(self, *args, **kw)
            self.start()

1024 1025 1026 1027
    def commitWithStorageFailure(self, client, txn):
        with Patch(client, _getFinalTID=lambda *_: None):
            self.assertRaises(ConnectionClosed, txn.commit)

1028 1029
    def assertPartitionTable(self, cluster, stats, pt_node=None):
        pt  = (pt_node or cluster.admin).pt
1030 1031 1032
        index = [x.uuid for x in cluster.storage_list].index
        self.assertEqual(stats, '|'.join(pt._formatRows(sorted(
            pt.count_dict, key=lambda x: index(x.getUUID())))))
1033

1034 1035 1036 1037 1038
    @staticmethod
    def noConnection(jar, storage):
        return Patch(jar.db().storage.app.cp, getConnForNode=lambda orig, node:
            None if node.getUUID() == storage.uuid else orig(node))

1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
    @staticmethod
    def readCurrent(ob):
        ob._p_activate()
        ob._p_jar.readCurrent(ob)


class ThreadId(list):

    def __call__(self):
        try:
            return self.index(thread.get_ident())
        except ValueError:
            i = len(self)
            self.append(thread.get_ident())
            return i


@apply
class RandomConflictDict(dict):
    # One must not depend on how Python iterates over dict keys, because this
    # is implementation-defined behaviour. This patch makes sure of that when
    # resolving conflicts.

    def __new__(cls):
        from neo.client.transactions import Transaction
        def __init__(orig, self, *args):
            orig(self, *args)
            assert self.conflict_dict == {}
            self.conflict_dict = dict.__new__(cls)
        return Patch(Transaction, __init__=__init__)

    def popitem(self):
        try:
            k = random.choice(list(self))
        except IndexError:
            raise KeyError
        return k, self.pop(k)

1077 1078 1079 1080 1081

def predictable_random(seed=None):
    # Because we have 2 running threads when client works, we can't
    # patch neo.client.pool (and cluster should have 1 storage).
    from neo.master import backup_app
1082
    from neo.master.handlers import administration
1083 1084 1085 1086
    from neo.storage import replicator
    def decorator(wrapped):
        def wrapper(*args, **kw):
            s = repr(time.time()) if seed is None else seed
1087
            logging.info("using seed %r", s)
1088 1089
            r = random.Random(s)
            try:
1090 1091
                administration.random = backup_app.random = replicator.random \
                    = r
1092 1093
                return wrapped(*args, **kw)
            finally:
1094 1095
                administration.random = backup_app.random = replicator.random \
                    = random
1096 1097
        return wraps(wrapped)(wrapper)
    return decorator
1098 1099 1100 1101 1102 1103 1104 1105 1106 1107

def with_cluster(start_cluster=True, **cluster_kw):
    def decorator(wrapped):
        def wrapper(self, *args, **kw):
            with NEOCluster(**cluster_kw) as cluster:
                if start_cluster:
                    cluster.start()
                return wrapped(self, cluster, *args, **kw)
        return wraps(wrapped)(wrapper)
    return decorator