__init__.py 20.9 KB
Newer Older
1
#
2
# Copyright (C) 2009-2019  Nexedi SA
3
#
4 5 6 7
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
8
#
9 10 11 12 13 14
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

17
import __builtin__
18
import errno
19
import functools
20
import gc
21
import os
22
import random
23
import socket
24
import subprocess
25
import sys
26
import tempfile
27
import unittest
28
import weakref
29
import transaction
Olivier Cros's avatar
Olivier Cros committed
30

31
from contextlib import contextmanager
32
from ConfigParser import SafeConfigParser
33
from cStringIO import StringIO
34 35 36 37
try:
    from ZODB._compat import Unpickler
except ImportError:
    from cPickle import Unpickler
38
from functools import wraps
39
from inspect import isclass
40
from .mock import Mock
41
from neo.lib import debug, logging, protocol
42
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
43
from neo.lib.util import cached_property
44
from time import time, sleep
45
from struct import pack, unpack
46
from unittest.case import _ExpectedFailure, _UnexpectedSuccess
47
try:
48
    from transaction.interfaces import IDataManager
49
    from ZODB.utils import newTid
50
    from ZODB.ConflictResolution import PersistentReferenceFactory
51 52
except ImportError:
    pass
53

54 55 56 57 58 59 60 61 62
def expectedFailure(exception=AssertionError):
    def decorator(func):
        def wrapper(*args, **kw):
            try:
                func(*args, **kw)
            except exception, e:
                # XXX: passing sys.exc_info() causes deadlocks
                raise _ExpectedFailure((type(e), None, None))
            raise _UnexpectedSuccess
63
        return wraps(func)(wrapper)
64 65 66 67 68 69
    if callable(exception) and not isinstance(exception, type):
        func = exception
        exception = Exception
        return decorator(func)
    return decorator

70
DB_PREFIX = os.getenv('NEO_DB_PREFIX', 'test_neo')
71
DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root')
72
DB_PASSWD = os.getenv('NEO_DB_PASSWD', '')
73
DB_USER = os.getenv('NEO_DB_USER', 'test')
74
DB_SOCKET = os.getenv('NEO_DB_SOCKET', '')
75 76 77
DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db')
DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld')
DB_MYCNF = os.getenv('NEO_DB_MYCNF')
78

Olivier Cros's avatar
Olivier Cros committed
79 80
IP_VERSION_FORMAT_DICT = {
    socket.AF_INET:  '127.0.0.1',
81
    socket.AF_INET6: '::1',
Olivier Cros's avatar
Olivier Cros committed
82 83 84 85
}

ADDRESS_TYPE = socket.AF_INET

Julien Muchembled's avatar
Julien Muchembled committed
86 87 88
SSL = os.path.dirname(__file__) + os.sep
SSL = SSL + "ca.crt", SSL + "node.crt", SSL + "node.key"

89 90
logging.default_root_handler.handle = lambda record: None

91
debug.register()
92

93 94 95 96 97 98 99 100
def mockDefaultValue(name, function):
    def method(self, *args, **kw):
        if name in self.mockReturnValues:
            return self.__getattr__(name)(*args, **kw)
        return function(self, *args, **kw)
    method.__name__ = name
    setattr(Mock, name, method)

101
mockDefaultValue('__nonzero__', lambda self: self.__len__() != 0)
102 103 104 105
mockDefaultValue('__repr__', lambda self:
    '<%s object at 0x%x>' % (self.__class__.__name__, id(self)))
mockDefaultValue('__str__', repr)

Olivier Cros's avatar
Olivier Cros committed
106 107 108 109
def buildUrlFromString(address):
    try:
        socket.inet_pton(socket.AF_INET6, address)
        address = '[%s]' % address
110
    except Exception:
Olivier Cros's avatar
Olivier Cros committed
111 112 113
        pass
    return address

114 115 116 117 118 119 120
def getTempDirectory():
    """get the current temp directory or a new one"""
    try:
        temp_dir = os.environ['TEMP']
    except KeyError:
        neo_dir = os.path.join(tempfile.gettempdir(), 'neo_tests')
        while True:
121 122
            temp_name = repr(time())
            temp_dir = os.path.join(neo_dir, temp_name)
123 124 125 126 127 128
            try:
                os.makedirs(temp_dir)
                break
            except OSError, e:
                if e.errno != errno.EEXIST:
                    raise
129 130 131 132 133 134 135
        last = os.path.join(neo_dir, "last")
        try:
            os.remove(last)
        except OSError, e:
            if e.errno != errno.ENOENT:
                raise
        os.symlink(temp_name, last)
136 137 138 139
        os.environ['TEMP'] = temp_dir
        print 'Using temp directory %r.' % temp_dir
    return temp_dir

140 141 142
def setupMySQLdb(db_list, clear_databases=True):
    if mysql_pool:
        return mysql_pool.setup(db_list, clear_databases)
143
    import MySQLdb
144
    from MySQLdb.constants.ER import BAD_DB_ERROR
145 146
    user = DB_USER
    password = ''
147 148
    kw = {'unix_socket': os.path.expanduser(DB_SOCKET)} if DB_SOCKET else {}
    conn = MySQLdb.connect(user=DB_ADMIN, passwd=DB_PASSWD, **kw)
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    cursor = conn.cursor()
    for database in db_list:
        try:
            conn.select_db(database)
            if not clear_databases:
                continue
            cursor.execute('DROP DATABASE `%s`' % database)
        except MySQLdb.OperationalError, (code, _):
            if code != BAD_DB_ERROR:
                raise
            cursor.execute('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED'
                           ' BY "%s"' % (database, user, password))
        cursor.execute('CREATE DATABASE `%s`' % database)
    cursor.close()
    conn.commit()
    conn.close()
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    return '{}:{}@%s{}'.format(user, password, DB_SOCKET).__mod__

class MySQLPool(object):

    def __init__(self, pool_dir=None):
        self._args = {}
        self._mysqld_dict = {}
        if not pool_dir:
            pool_dir = getTempDirectory()
        self._base = pool_dir + os.sep
        self._sock_template = os.path.join(pool_dir, '%s', 'mysql.sock')

    def __del__(self):
        self.kill(*self._mysqld_dict)

    def setup(self, db_list, clear_databases):
181
        import MySQLdb
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
        start_list = set(db_list).difference(self._mysqld_dict)
        if start_list:
            start_list = sorted(start_list)
            x = []
            with open(os.devnull, 'wb') as f:
                for db in start_list:
                    base = self._base + db
                    datadir = os.path.join(base, 'datadir')
                    sock = self._sock_template % db
                    tmpdir = os.path.join(base, 'tmp')
                    args = [DB_INSTALL,
                        '--defaults-file=' + DB_MYCNF,
                        '--datadir=' + datadir,
                        '--socket=' + sock,
                        '--tmpdir=' + tmpdir,
                        '--log_error=' + os.path.join(base, 'error.log')]
                    if os.path.exists(datadir):
                        try:
                            os.remove(sock)
                        except OSError, e:
                            if e.errno != errno.ENOENT:
                                raise
                    else:
                        os.makedirs(tmpdir)
                        x.append(subprocess.Popen(args,
                            stdout=f, stderr=subprocess.STDOUT))
                    args[0] = DB_MYSQLD
                    self._args[db] = args
            for x in x:
                x = x.wait()
                if x:
                    raise subprocess.CalledProcessError(x, DB_INSTALL)
            self.start(*start_list)
            for db in start_list:
                sock = self._sock_template % db
                p = self._mysqld_dict[db]
                while not os.path.exists(sock):
                    sleep(1)
                    x = p.poll()
                    if x is not None:
                        raise subprocess.CalledProcessError(x, DB_MYSQLD)
        for db in db_list:
            db = MySQLdb.connect(unix_socket=self._sock_template % db,
                                 user='root')
            if clear_databases:
                db.query('DROP DATABASE IF EXISTS neo')
            db.query('CREATE DATABASE IF NOT EXISTS neo')
            db.close()
        return ('root@neo' + self._sock_template).__mod__

    def start(self, *db, **kw):
        assert set(db).isdisjoint(self._mysqld_dict)
        for db in db:
            self._mysqld_dict[db] = subprocess.Popen(self._args[db], **kw)

    def kill(self, *db):
        processes = []
        for db in db:
            p = self._mysqld_dict.pop(db)
            processes.append(p)
            p.kill()
        for p in processes:
            p.wait()

mysql_pool = MySQLPool() if DB_MYCNF else None

248

249 250 251 252 253 254 255 256 257 258 259 260
def ImporterConfigParser(adapter, zodb, **kw):
    cfg = SafeConfigParser()
    cfg.add_section("neo")
    cfg.set("neo", "adapter", adapter)
    for x in kw.iteritems():
        cfg.set("neo", *x)
    for name, zodb in zodb:
        cfg.add_section(name)
        for x in zodb.iteritems():
            cfg.set(name, *x)
    return cfg

261
class NeoTestBase(unittest.TestCase):
262

263 264
    maxDiff = None

265
    def setUp(self):
266
        logging.name = self.setupLog()
267 268
        unittest.TestCase.setUp(self)

269
    def setupLog(self):
270 271
        test_case, logging.name = self.id().rsplit('.', 1)
        logging.setup(os.path.join(getTempDirectory(), test_case + '.log'))
272

273
    def tearDown(self):
274
        assert self.tearDown.im_func is NeoTestBase.tearDown.im_func
275
        self._tearDown(sys._getframe(1).f_locals['success'])
276
        assert not gc.garbage, gc.garbage
277 278

    def _tearDown(self, success):
279 280 281 282
        # Kill all unfinished transactions for next test.
        # Note we don't even abort them because it may require a valid
        # connection to a master node (see Storage.sync()).
        transaction.manager.__init__()
283 284
        if logging._max_size is not None:
            logging.flush()
285

286 287
    class failureException(AssertionError):
        def __init__(self, msg=None):
288
            logging.error(msg)
289 290
            AssertionError.__init__(self, msg)

291 292 293
    failIfEqual = failUnlessEqual = assertEquals = assertNotEquals = None

    def assertNotEqual(self, first, second, msg=None):
294 295
        assert not (isinstance(first, Mock) or isinstance(second, Mock)), \
          "Mock objects can't be compared with '==' or '!='"
296
        return super(NeoTestBase, self).assertNotEqual(first, second, msg=msg)
297

298
    def assertEqual(self, first, second, msg=None):
299 300
        assert not (isinstance(first, Mock) or isinstance(second, Mock)), \
          "Mock objects can't be compared with '==' or '!='"
301
        return super(NeoTestBase, self).assertEqual(first, second, msg=msg)
302

303 304 305 306 307
    def assertPartitionTable(self, pt, expected, key=None):
        self.assertEqual(
            expected if isinstance(expected, str) else '|'.join(expected),
            '|'.join(pt._formatRows(sorted(pt.count_dict, key=key))))

308 309 310 311 312 313 314 315
    @contextmanager
    def expectedFailure(self, exception=AssertionError, regex=None):
        with self.assertRaisesRegexp(exception, regex) as cm:
            yield
            raise _UnexpectedSuccess
        # XXX: passing sys.exc_info() causes deadlocks
        raise _ExpectedFailure((type(cm.exception), None, None))

316
class NeoUnitTestBase(NeoTestBase):
317 318
    """ Base class for neo tests, implements common checks """

Olivier Cros's avatar
Olivier Cros committed
319 320
    local_ip = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]

321 322 323 324
    def setUp(self):
        self.uuid_dict = {}
        NeoTestBase.setUp(self)

325 326 327 328 329 330 331 332 333
    @cached_property
    def nm(self):
        from neo.lib import node
        return node.NodeManager()

    def createStorage(self, *args):
        return self.nm.createStorage(**dict(zip(
            ('address', 'uuid', 'state'), args)))

334 335 336 337
    def prepareDatabase(self, number, prefix=DB_PREFIX):
        """ create empty databases """
        adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
        if adapter == 'MySQL':
338 339 340
            db_template = setupMySQLdb(
                [prefix + str(i) for i in xrange(number)])
            self.db_template = lambda i: db_template(prefix + str(i))
341
        elif adapter == 'SQLite':
342 343
            self.db_template = os.path.join(getTempDirectory(),
                                       prefix + '%s.sqlite').__mod__
344 345
            for i in xrange(number):
                try:
346
                    os.remove(self.db_template(i))
347 348 349 350 351
                except OSError, e:
                    if e.errno != errno.ENOENT:
                        raise
        else:
            assert False, adapter
352

353
    def getMasterConfiguration(self, cluster='main', master_number=2,
354 355
            replicas=2, partitions=1009, uuid=None):
        assert master_number >= 1 and master_number <= 10
Olivier Cros's avatar
Olivier Cros committed
356 357
        masters = ([(self.local_ip, 10010 + i)
                    for i in xrange(master_number)])
358 359 360 361 362 363 364 365
        return {
                'cluster': cluster,
                'bind': masters[0],
                'masters': masters,
                'replicas': replicas,
                'partitions': partitions,
                'uuid': uuid,
        }
366

367
    def getStorageConfiguration(self, cluster='main', master_number=2,
368 369
            index=0, prefix=DB_PREFIX, uuid=None):
        assert master_number >= 1 and master_number <= 10
Olivier Cros's avatar
Olivier Cros committed
370 371
        masters = [(buildUrlFromString(self.local_ip),
                     10010 + i) for i in xrange(master_number)]
372
        adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
373 374 375 376
        return {
                'cluster': cluster,
                'bind': (masters[0], 10020 + index),
                'masters': masters,
377
                'database': self.db_template(index),
378 379 380 381
                'uuid': uuid,
                'adapter': adapter,
                'wait': 0,
        }
382

383
    def getNewUUID(self, node_type):
384 385 386
        """
            Retuns a 16-bytes UUID according to namespace 'prefix'
        """
387 388 389 390
        if node_type is None:
            node_type = random.choice(NodeTypes)
        self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0)
        return uuid + (UUID_NAMESPACES[node_type] << 24)
391 392

    def getClientUUID(self):
393
        return self.getNewUUID(NodeTypes.CLIENT)
394 395

    def getMasterUUID(self):
396
        return self.getNewUUID(NodeTypes.MASTER)
397 398

    def getStorageUUID(self):
399
        return self.getNewUUID(NodeTypes.STORAGE)
400 401

    def getAdminUUID(self):
402
        return self.getNewUUID(NodeTypes.ADMIN)
403

404
    def getNextTID(self, ltid=None):
405
        return newTid(ltid)
406

407 408 409 410
    def getFakeConnector(self, descriptor=None):
        return Mock({
            '__repr__': 'FakeConnector',
            'getDescriptor': descriptor,
411
            'getAddress': ('', 0),
412 413
        })

414
    def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000),
415
            is_server=False, connector=None, peer_id=None):
416 417
        if connector is None:
            connector = self.getFakeConnector()
418
        conn = Mock({
419 420
            'getUUID': uuid,
            'getAddress': address,
421 422
            'isServer': is_server,
            '__repr__': 'FakeConnection',
423
            '__nonzero__': 0,
424
            'getConnector': connector,
425
            'getPeerId': peer_id,
426
        })
427
        conn.mockAddReturnValues(__hash__ = id(conn))
428 429
        conn.connecting = False
        return conn
430 431 432 433 434 435 436

    def checkProtocolErrorRaised(self, method, *args, **kwargs):
        """ Check if the ProtocolError exception was raised """
        self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)

    def checkAborted(self, conn):
        """ Ensure the connection was aborted """
437
        self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
438 439 440

    def checkClosed(self, conn):
        """ Ensure the connection was closed """
441
        self.assertEqual(len(conn.mockGetNamedCalls('close')), 1)
442

443
    def _checkNoPacketSend(self, conn, method_id):
444
        self.assertEqual([], conn.mockGetNamedCalls(method_id))
445

446
    def checkNoPacketSent(self, conn):
447
        """ check if no packet were sent """
448
        self._checkNoPacketSend(conn, 'send')
449 450
        self._checkNoPacketSend(conn, 'answer')
        self._checkNoPacketSend(conn, 'ask')
451 452

    # in check(Ask|Answer|Notify)Packet we return the packet so it can be used
453
    # in tests if more accurate checks are required
454

455
    def checkErrorPacket(self, conn):
456 457
        """ Check if an error packet was answered """
        calls = conn.mockGetNamedCalls("answer")
458
        self.assertEqual(len(calls), 1)
459
        packet = calls.pop().getParam(0)
460
        self.assertTrue(isinstance(packet, protocol.Packet))
461
        self.assertEqual(type(packet), Packets.Error)
462 463
        return packet

464
    def checkAskPacket(self, conn, packet_type):
465 466
        """ Check if an ask-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('ask')
467
        self.assertEqual(len(calls), 1)
468
        packet = calls.pop().getParam(0)
469
        self.assertTrue(isinstance(packet, protocol.Packet))
470
        self.assertEqual(type(packet), packet_type)
471 472
        return packet

473
    def checkAnswerPacket(self, conn, packet_type):
474 475
        """ Check if an answer-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('answer')
476
        self.assertEqual(len(calls), 1)
477
        packet = calls.pop().getParam(0)
478
        self.assertTrue(isinstance(packet, protocol.Packet))
479
        self.assertEqual(type(packet), packet_type)
480 481
        return packet

482
    def checkNotifyPacket(self, conn, packet_type, packet_number=0):
483
        """ Check if a notify-packet with the right type is sent """
484
        calls = conn.mockGetNamedCalls('send')
485
        packet = calls.pop(packet_number).getParam(0)
486
        self.assertTrue(isinstance(packet, protocol.Packet))
487
        self.assertEqual(type(packet), packet_type)
488 489
        return packet

490

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
class TransactionalResource(object):

    class _sortKey(object):

        def __init__(self, last):
            self._last = last

        def __cmp__(self, other):
            assert type(self) is not type(other), other
            return 1 if self._last else -1

    def __init__(self, txn, last, **kw):
        self.sortKey = lambda: self._sortKey(last)
        for k in kw:
            assert callable(IDataManager.get(k)), k
        self.__dict__.update(kw)
        txn.get().join(self)

509 510 511 512 513 514
    def __call__(self, func):
        name = func.__name__
        assert callable(IDataManager.get(name)), name
        setattr(self, name, func)
        return func

515 516 517 518 519
    def __getattr__(self, attr):
        if callable(IDataManager.get(attr)):
            return lambda *_: None
        return self.__getattribute__(attr)

520 521 522 523 524 525 526 527 528
try:
    from ZODB.Connection import TransactionMetaData
except ImportError: # BBB: ZODB < 5
    def getTransactionMetaData(txn, conn):
        return txn
else:
    def getTransactionMetaData(txn, conn):
        return txn.data(conn)

529

530
class Patch(object):
531 532 533 534 535
    """
    Patch attributes and revert later automatically.

    Usage:

536
      with Patch(someObject, [new,] attrToPatch=newValue) as patch:
537 538 539
        [... code that runs with patches ...]
      [... code that runs without patch ...]

540 541 542 543
      The 'new' positional parameter defaults to False and it must be equal to
         not hasattr(someObject, 'attrToPatch')
      It is an assertion to detect when a Patch is obsolete.

544 545 546 547 548 549 550 551 552 553 554
      ' as patch' is optional: 'patch.revert()' can be used to revert patches
      in the middle of the 'with' clause.

    Or:

      patch = Patch(...)
      patch.apply()

      In this case, patches are automatically reverted when 'patch' is deleted.

    For patched callables, the new one receives the original value as first
555
    argument if 'new' is True.
556 557 558 559 560 561 562 563 564 565

    Alternative usage:

      @Patch(someObject)
      def funcToPatch(orig, ...):
        ...
      ...
      funcToPatch.revert()

      The decorator applies the patch immediately.
566
    """
567 568 569

    applied = False

570
    def __new__(cls, patched, *args, **patch):
571 572 573
        if patch:
            return object.__new__(cls)
        def patch(func):
574
            self = cls(patched, *args, **{func.__name__: func})
575 576 577 578
            self.apply()
            return self
        return patch

579 580
    def __init__(self, patched, *args, **patch):
        new, = args or (0,)
581 582 583
        (name, patch), = patch.iteritems()
        self._patched = patched
        self._name = name
584 585 586 587 588 589 590 591 592 593 594
        try:
            wrapped = getattr(patched, name)
        except AttributeError:
            assert new, (patched, name)
        else:
            assert not new, (patched, name)
            if callable(patch):
                  func = patch
                  patch = lambda *args, **kw: func(wrapped, *args, **kw)
                  if callable(wrapped):
                      patch = wraps(wrapped)(patch)
595
        self._patch = patch
596 597 598
        try:
            orig = patched.__dict__[name]
        except KeyError:
599 600 601 602 603
            if new or isclass(patched):
                self._revert = lambda: delattr(patched, name)
                return
            orig = getattr(patched, name)
        self._revert = lambda: setattr(patched, name, orig)
604 605 606

    def apply(self):
        assert not self.applied
607
        setattr(self._patched, self._name, self._patch)
608 609 610 611 612 613 614 615 616 617 618 619
        self.applied = True

    def revert(self):
        del self.applied
        self._revert()

    def __del__(self):
        if self.applied:
            self.revert()

    def __enter__(self):
        self.apply()
620
        return weakref.proxy(self)
621 622 623 624 625

    def __exit__(self, t, v, tb):
        self.__del__()


626 627 628 629 630 631
def unpickle_state(data):
    unpickler = Unpickler(StringIO(data))
    unpickler.persistent_load = PersistentReferenceFactory().persistent_load
    unpickler.load() # skip the class tuple
    return unpickler.load()

632 633
__builtin__.pdb = lambda depth=0: \
    debug.getPdb().set_trace(sys._getframe(depth+1))