__init__.py 20.6 KB
Newer Older
1
#
Grégory Wisniewski's avatar
Grégory Wisniewski committed
2
# Copyright (C) 2009-2010  Nexedi SA
3
#
4 5 6 7
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
8
#
9 10 11 12 13 14 15
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
16
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17

18
import __builtin__
19
import os
20
import random
21 22
import socket
import sys
23
import tempfile
24
import unittest
25
import MySQLdb
26
import neo
27
import transaction
Olivier Cros's avatar
Olivier Cros committed
28

29
from mock import Mock
30
from neo.lib import debug, logger, protocol, setupLog
31
from neo.lib.protocol import Packets
Olivier Cros's avatar
Olivier Cros committed
32
from neo.lib.util import getAddressType
33
from time import time, gmtime
34
from struct import pack, unpack
35

36
DB_PREFIX = os.getenv('NEO_DB_PREFIX', 'test_neo')
37
DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root')
38
DB_PASSWD = os.getenv('NEO_DB_PASSWD', '')
39
DB_USER = os.getenv('NEO_DB_USER', 'test')
40

Olivier Cros's avatar
Olivier Cros committed
41 42
IP_VERSION_FORMAT_DICT = {
    socket.AF_INET:  '127.0.0.1',
43
    socket.AF_INET6: '::1',
Olivier Cros's avatar
Olivier Cros committed
44 45 46 47
}

ADDRESS_TYPE = socket.AF_INET

48 49
debug.ENABLED = True
debug.register()
50
# prevent "signal only works in main thread" errors in subprocesses
51
debug.ENABLED = False
52

53 54 55 56 57 58 59 60
def mockDefaultValue(name, function):
    def method(self, *args, **kw):
        if name in self.mockReturnValues:
            return self.__getattr__(name)(*args, **kw)
        return function(self, *args, **kw)
    method.__name__ = name
    setattr(Mock, name, method)

61
mockDefaultValue('__nonzero__', lambda self: self.__len__() != 0)
62 63 64 65
mockDefaultValue('__repr__', lambda self:
    '<%s object at 0x%x>' % (self.__class__.__name__, id(self)))
mockDefaultValue('__str__', repr)

Olivier Cros's avatar
Olivier Cros committed
66 67 68 69
def buildUrlFromString(address):
    try:
        socket.inet_pton(socket.AF_INET6, address)
        address = '[%s]' % address
70
    except Exception:
Olivier Cros's avatar
Olivier Cros committed
71 72 73
        pass
    return address

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
def getTempDirectory():
    """get the current temp directory or a new one"""
    try:
        temp_dir = os.environ['TEMP']
    except KeyError:
        neo_dir = os.path.join(tempfile.gettempdir(), 'neo_tests')
        while True:
            temp_dir = os.path.join(neo_dir, repr(time()))
            try:
                os.makedirs(temp_dir)
                break
            except OSError, e:
                if e.errno != errno.EEXIST:
                    raise
        os.environ['TEMP'] = temp_dir
        print 'Using temp directory %r.' % temp_dir
    return temp_dir

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True):
    from MySQLdb.constants.ER import BAD_DB_ERROR
    conn = MySQLdb.Connect(user=DB_ADMIN, passwd=DB_PASSWD)
    cursor = conn.cursor()
    for database in db_list:
        try:
            conn.select_db(database)
            if not clear_databases:
                continue
            cursor.execute('DROP DATABASE `%s`' % database)
        except MySQLdb.OperationalError, (code, _):
            if code != BAD_DB_ERROR:
                raise
            cursor.execute('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED'
                           ' BY "%s"' % (database, user, password))
        cursor.execute('CREATE DATABASE `%s`' % database)
    cursor.close()
    conn.commit()
    conn.close()

112
class NeoTestBase(unittest.TestCase):
113
    def setUp(self):
114
        logger.PACKET_LOGGER.enable(True)
115 116
        sys.stdout.write(' * %s ' % (self.id(), ))
        sys.stdout.flush()
117
        self.setupLog()
118 119
        unittest.TestCase.setUp(self)

120 121 122 123 124
    def setupLog(self):
        test_case, test_method = self.id().rsplit('.', 1)
        log_file = os.path.join(getTempDirectory(), test_case + '.log')
        setupLog(test_method, log_file, True)

125
    def tearDown(self):
126 127 128 129
        # Kill all unfinished transactions for next test.
        # Note we don't even abort them because it may require a valid
        # connection to a master node (see Storage.sync()).
        transaction.manager.__init__()
130 131 132 133
        unittest.TestCase.tearDown(self)
        sys.stdout.write('\n')
        sys.stdout.flush()

134 135 136
    failIfEqual = failUnlessEqual = assertEquals = assertNotEquals = None

    def assertNotEqual(self, first, second, msg=None):
137 138
        assert not (isinstance(first, Mock) or isinstance(second, Mock)), \
          "Mock objects can't be compared with '==' or '!='"
139
        return super(NeoTestBase, self).assertNotEqual(first, second, msg=msg)
140

141
    def assertEqual(self, first, second, msg=None):
142 143
        assert not (isinstance(first, Mock) or isinstance(second, Mock)), \
          "Mock objects can't be compared with '==' or '!='"
144
        return super(NeoTestBase, self).assertEqual(first, second, msg=msg)
145

146
class NeoUnitTestBase(NeoTestBase):
147 148
    """ Base class for neo tests, implements common checks """

Olivier Cros's avatar
Olivier Cros committed
149 150
    local_ip = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]

151
    def prepareDatabase(self, number, prefix='test_neo'):
152
        """ create empties databases """
153
        setupMySQLdb(['%s%u' % (prefix, i) for i in xrange(number)])
154

155
    def getMasterConfiguration(self, cluster='main', master_number=2,
156 157
            replicas=2, partitions=1009, uuid=None):
        assert master_number >= 1 and master_number <= 10
Olivier Cros's avatar
Olivier Cros committed
158 159
        masters = ([(self.local_ip, 10010 + i)
                    for i in xrange(master_number)])
160 161 162
        return Mock({
                'getCluster': cluster,
                'getBind': masters[0],
Olivier Cros's avatar
Olivier Cros committed
163 164
                'getMasters': (masters, getAddressType((
                        self.local_ip, 0))),
165 166 167 168
                'getReplicas': replicas,
                'getPartitions': partitions,
                'getUUID': uuid,
        })
169

170
    def getStorageConfiguration(self, cluster='main', master_number=2,
171 172 173
            index=0, prefix=DB_PREFIX, uuid=None):
        assert master_number >= 1 and master_number <= 10
        assert index >= 0 and index <= 9
Olivier Cros's avatar
Olivier Cros committed
174 175
        masters = [(buildUrlFromString(self.local_ip),
                     10010 + i) for i in xrange(master_number)]
176
        database = '%s@%s%s' % (DB_USER, prefix, index)
177 178 179
        return Mock({
                'getCluster': cluster,
                'getName': 'storage',
Olivier Cros's avatar
Olivier Cros committed
180 181 182
                'getBind': (masters[0], 10020 + index),
                'getMasters': (masters, getAddressType((
                        self.local_ip, 0))),
183 184 185
                'getDatabase': database,
                'getUUID': uuid,
                'getReset': False,
186
                'getAdapter': 'MySQL',
187
        })
188

189 190 191 192 193 194 195 196 197 198
    def _makeUUID(self, prefix):
        """
            Retuns a 16-bytes UUID according to namespace 'prefix'
        """
        assert len(prefix) == 1
        uuid = protocol.INVALID_UUID
        while uuid[1:] == protocol.INVALID_UUID[1:]:
            uuid = prefix + os.urandom(15)
        return uuid

199
    def getNewUUID(self):
200 201 202 203 204 205 206 207 208 209 210 211 212
        return self._makeUUID('\0')

    def getClientUUID(self):
        return self._makeUUID('C')

    def getMasterUUID(self):
        return self._makeUUID('M')

    def getStorageUUID(self):
        return self._makeUUID('S')

    def getAdminUUID(self):
        return self._makeUUID('A')
213

214
    def getNextTID(self, ltid=None):
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        tm = time()
        gmt = gmtime(tm)
        upper = ((((gmt.tm_year - 1900) * 12 + gmt.tm_mon - 1) * 31 \
                  + gmt.tm_mday - 1) * 24 + gmt.tm_hour) * 60 + gmt.tm_min
        lower = int((gmt.tm_sec % 60 + (tm - int(tm))) / (60.0 / 65536.0 / 65536.0))
        tid = pack('!LL', upper, lower)
        if ltid is not None and tid <= ltid:
            upper, lower = unpack('!LL', self._last_tid)
            if lower == 0xffffffff:
                # This should not happen usually.
                from datetime import timedelta, datetime
                d = datetime(gmt.tm_year, gmt.tm_mon, gmt.tm_mday,
                             gmt.tm_hour, gmt.tm_min) \
                        + timedelta(0, 60)
                upper = ((((d.year - 1900) * 12 + d.month - 1) * 31 \
                          + d.day - 1) * 24 + d.hour) * 60 + d.minute
                lower = 0
            else:
                lower += 1
            tid = pack('!LL', upper, lower)
        return tid

237
    def getPTID(self, i=None):
238
        """ Return an integer PTID """
239
        if i is None:
240
            return random.randint(1, 2**64)
241
        return i
242

243 244 245
    def getOID(self, i=None):
        """ Return a 8-bytes OID """
        if i is None:
246
            return os.urandom(8)
247 248
        return pack('!Q', i)

249 250 251
    def getTwoIDs(self):
        """ Return a tuple of two sorted UUIDs """
        # generate two ptid, first is lower
252 253
        uuids = self.getNewUUID(), self.getNewUUID()
        return min(uuids), max(uuids)
254

255 256 257 258
    def getFakeConnector(self, descriptor=None):
        return Mock({
            '__repr__': 'FakeConnector',
            'getDescriptor': descriptor,
259
            'getAddress': ('', 0),
260 261
        })

262
    def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000),
263
            is_server=False, connector=None, peer_id=None):
264 265
        if connector is None:
            connector = self.getFakeConnector()
266
        conn = Mock({
267 268
            'getUUID': uuid,
            'getAddress': address,
269 270
            'isServer': is_server,
            '__repr__': 'FakeConnection',
271
            '__nonzero__': 0,
272
            'getConnector': connector,
273
            'getPeerId': peer_id,
274
        })
275 276
        conn.connecting = False
        return conn
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299

    def checkProtocolErrorRaised(self, method, *args, **kwargs):
        """ Check if the ProtocolError exception was raised """
        self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)

    def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
        """ Check if the UnexpectedPacketError exception wxas raised """
        self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)

    def checkIdenficationRequired(self, method, *args, **kwargs):
        """ Check is the identification_required decorator is applied """
        self.checkUnexpectedPacketRaised(method, *args, **kwargs)

    def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
        """ Check if the BrokenNodeDisallowedError exception wxas raised """
        self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)

    def checkNotReadyErrorRaised(self, method, *args, **kwargs):
        """ Check if the NotReadyError exception wxas raised """
        self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)

    def checkAborted(self, conn):
        """ Ensure the connection was aborted """
300
        self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
301 302 303

    def checkNotAborted(self, conn):
        """ Ensure the connection was not aborted """
304
        self.assertEqual(len(conn.mockGetNamedCalls('abort')), 0)
305 306 307

    def checkClosed(self, conn):
        """ Ensure the connection was closed """
308
        self.assertEqual(len(conn.mockGetNamedCalls('close')), 1)
309 310 311

    def checkNotClosed(self, conn):
        """ Ensure the connection was not closed """
312
        self.assertEqual(len(conn.mockGetNamedCalls('close')), 0)
313

314 315
    def _checkNoPacketSend(self, conn, method_id):
        call_list = conn.mockGetNamedCalls(method_id)
316
        self.assertEqual(len(call_list), 0, call_list)
317 318 319

    def checkNoPacketSent(self, conn, check_notify=True, check_answer=True,
            check_ask=True):
320
        """ check if no packet were sent """
321 322 323 324 325 326
        if check_notify:
            self._checkNoPacketSend(conn, 'notify')
        if check_answer:
            self._checkNoPacketSend(conn, 'answer')
        if check_ask:
            self._checkNoPacketSend(conn, 'ask')
327 328 329

    def checkNoUUIDSet(self, conn):
        """ ensure no UUID was set on the connection """
330
        self.assertEqual(len(conn.mockGetNamedCalls('setUUID')), 0)
331 332 333 334

    def checkUUIDSet(self, conn, uuid=None):
        """ ensure no UUID was set on the connection """
        calls = conn.mockGetNamedCalls('setUUID')
335
        self.assertEqual(len(calls), 1)
336
        call = calls.pop()
337
        if uuid is not None:
338
            self.assertEqual(call.getParam(0), uuid)
339 340 341 342 343 344 345

    # in check(Ask|Answer|Notify)Packet we return the packet so it can be used
    # in tests if more accurates checks are required

    def checkErrorPacket(self, conn, decode=False):
        """ Check if an error packet was answered """
        calls = conn.mockGetNamedCalls("answer")
346
        self.assertEqual(len(calls), 1)
347
        packet = calls.pop().getParam(0)
348
        self.assertTrue(isinstance(packet, protocol.Packet))
349
        self.assertEqual(type(packet), Packets.Error)
350
        if decode:
351
            return packet.decode()
352
            return protocol.decode_table[type(packet)](packet._body)
353 354 355 356 357
        return packet

    def checkAskPacket(self, conn, packet_type, decode=False):
        """ Check if an ask-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('ask')
358
        self.assertEqual(len(calls), 1)
359
        packet = calls.pop().getParam(0)
360
        self.assertTrue(isinstance(packet, protocol.Packet))
361
        self.assertEqual(type(packet), packet_type)
362
        if decode:
363
            return packet.decode()
364 365
        return packet

366
    def checkAnswerPacket(self, conn, packet_type, decode=False):
367 368
        """ Check if an answer-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('answer')
369
        self.assertEqual(len(calls), 1)
370
        packet = calls.pop().getParam(0)
371
        self.assertTrue(isinstance(packet, protocol.Packet))
372
        self.assertEqual(type(packet), packet_type)
373
        if decode:
374
            return packet.decode()
375 376 377 378 379
        return packet

    def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False):
        """ Check if a notify-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('notify')
380
        packet = calls.pop(packet_number).getParam(0)
381
        self.assertTrue(isinstance(packet, protocol.Packet))
382
        self.assertEqual(type(packet), packet_type)
383
        if decode:
384
            return packet.decode()
385 386
        return packet

387 388 389
    def checkNotify(self, conn, **kw):
        return self.checkNotifyPacket(conn, Packets.Notify, **kw)

390
    def checkNotifyNodeInformation(self, conn, **kw):
391
        return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation, **kw)
392 393

    def checkSendPartitionTable(self, conn, **kw):
394
        return self.checkNotifyPacket(conn, Packets.SendPartitionTable, **kw)
395 396

    def checkStartOperation(self, conn, **kw):
397
        return self.checkNotifyPacket(conn, Packets.StartOperation, **kw)
398

399 400 401
    def checkInvalidateObjects(self, conn, **kw):
        return self.checkNotifyPacket(conn, Packets.InvalidateObjects, **kw)

402 403 404
    def checkAbortTransaction(self, conn, **kw):
        return self.checkNotifyPacket(conn, Packets.AbortTransaction, **kw)

405 406 407
    def checkNotifyLastOID(self, conn, **kw):
        return self.checkNotifyPacket(conn, Packets.NotifyLastOID, **kw)

408 409
    def checkAnswerTransactionFinished(self, conn, **kw):
        return self.checkAnswerPacket(conn, Packets.AnswerTransactionFinished, **kw)
410

411 412
    def checkAnswerInformationLocked(self, conn, **kw):
        return self.checkAnswerPacket(conn, Packets.AnswerInformationLocked, **kw)
413

414 415
    def checkAskLockInformation(self, conn, **kw):
        return self.checkAskPacket(conn, Packets.AskLockInformation, **kw)
416

417 418
    def checkNotifyUnlockInformation(self, conn, **kw):
        return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw)
419

420 421 422
    def checkNotifyTransactionFinished(self, conn, **kw):
        return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw)

423 424
    def checkRequestIdentification(self, conn, **kw):
        return self.checkAskPacket(conn, Packets.RequestIdentification, **kw)
425

426 427
    def checkAskPrimary(self, conn, **kw):
        return self.checkAskPacket(conn, Packets.AskPrimary)
428 429

    def checkAskUnfinishedTransactions(self, conn, **kw):
430
        return self.checkAskPacket(conn, Packets.AskUnfinishedTransactions)
431 432

    def checkAskTransactionInformation(self, conn, **kw):
433
        return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
434 435

    def checkAskObjectPresent(self, conn, **kw):
436
        return self.checkAskPacket(conn, Packets.AskObjectPresent, **kw)
437 438

    def checkAskObject(self, conn, **kw):
439
        return self.checkAskPacket(conn, Packets.AskObject, **kw)
440 441

    def checkAskStoreObject(self, conn, **kw):
442
        return self.checkAskPacket(conn, Packets.AskStoreObject, **kw)
443 444

    def checkAskStoreTransaction(self, conn, **kw):
445
        return self.checkAskPacket(conn, Packets.AskStoreTransaction, **kw)
446

447 448
    def checkAskFinishTransaction(self, conn, **kw):
        return self.checkAskPacket(conn, Packets.AskFinishTransaction, **kw)
449 450

    def checkAskNewTid(self, conn, **kw):
451
        return self.checkAskPacket(conn, Packets.AskBeginTransaction, **kw)
452 453

    def checkAskLastIDs(self, conn, **kw):
454
        return self.checkAskPacket(conn, Packets.AskLastIDs, **kw)
455

456 457
    def checkAcceptIdentification(self, conn, **kw):
        return self.checkAnswerPacket(conn, Packets.AcceptIdentification, **kw)
458

459 460
    def checkAnswerPrimary(self, conn, **kw):
        return self.checkAnswerPacket(conn, Packets.AnswerPrimary, **kw)
461 462

    def checkAnswerLastIDs(self, conn, **kw):
463
        return self.checkAnswerPacket(conn, Packets.AnswerLastIDs, **kw)
464 465

    def checkAnswerUnfinishedTransactions(self, conn, **kw):
466
        return self.checkAnswerPacket(conn, Packets.AnswerUnfinishedTransactions, **kw)
467 468

    def checkAnswerObject(self, conn, **kw):
469
        return self.checkAnswerPacket(conn, Packets.AnswerObject, **kw)
470 471

    def checkAnswerTransactionInformation(self, conn, **kw):
472
        return self.checkAnswerPacket(conn, Packets.AnswerTransactionInformation, **kw)
473

474 475 476
    def checkAnswerBeginTransaction(self, conn, **kw):
        return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction, **kw)

477
    def checkAnswerTids(self, conn, **kw):
478
        return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
479

480 481 482
    def checkAnswerTidsFrom(self, conn, **kw):
        return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw)

483
    def checkAnswerObjectHistory(self, conn, **kw):
484
        return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
485

486 487 488
    def checkAnswerObjectHistoryFrom(self, conn, **kw):
        return self.checkAnswerPacket(conn, Packets.AnswerObjectHistoryFrom, **kw)

489
    def checkAnswerStoreTransaction(self, conn, **kw):
490
        return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
491 492

    def checkAnswerStoreObject(self, conn, **kw):
493
        return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
494 495

    def checkAnswerOids(self, conn, **kw):
496
        return self.checkAnswerPacket(conn, Packets.AnswerOIDs, **kw)
497 498

    def checkAnswerPartitionTable(self, conn, **kw):
499
        return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
500 501

    def checkAnswerObjectPresent(self, conn, **kw):
502
        return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
503 504 505 506

connector_cpt = 0

class DoNothingConnector(Mock):
507
    def __init__(self, s=None):
508
        neo.lib.logging.info("initializing connector")
509 510 511
        global connector_cpt
        self.desc = connector_cpt
        connector_cpt += 1
512 513
        self.packet_cpt = 0
        Mock.__init__(self)
514

515 516
    def getAddress(self):
        return self.addr
517

518 519 520
    def makeClientConnection(self, addr):
        self.addr = addr

521 522 523
    def makeListeningConnection(self, addr):
        self.addr = addr

524 525
    def getDescriptor(self):
        return self.desc
526

527

528 529
__builtin__.pdb = lambda depth=0: \
    debug.getPdb().set_trace(sys._getframe(depth+1))
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554

def _fixMockForInspect():
    """
    inspect module change broke Mock, see http://bugs.python.org/issue1785
    Monkey-patch Mock class if needed by replacing predicate parameter on 2nd
    getmembers call with isroutine (was ismethod).
    """
    import inspect
    class A(object):
        def f(self):
            pass
    if not inspect.getmembers(A, inspect.ismethod):
        # _setupSubclassMethodInterceptors is under the FreeBSD license.
        # See pyMock module for the whole license.
        def _setupSubclassMethodInterceptors(self):
            methods = inspect.getmembers(self.__class__,inspect.isroutine)
            baseMethods = dict(inspect.getmembers(Mock, inspect.isroutine))
            for m in methods:
                name = m[0]
                # Don't record calls to methods of Mock base class.
                if not name in baseMethods:
                    self.__dict__[name] = MockCallable(name, self, handcrafted=True)
        from mock import Mock
        Mock._setupSubclassMethodInterceptors = _setupSubclassMethodInterceptors
_fixMockForInspect()