Commit 63a70d44 authored by Kirill Smelkov's avatar Kirill Smelkov

Merge remote-tracking branch 'origin/master' into t

* origin/master:
  master,client: ignore notifications before complete initialization
  Update comment that was still showing UUIDs instead of node ids
  Remove dead code found by coverage
  Remove some useless unit tests
parents 83c151e9 36b2d141
...@@ -220,6 +220,7 @@ class Application(ThreadedApplication): ...@@ -220,6 +220,7 @@ class Application(ThreadedApplication):
ask = self._ask ask = self._ask
handler = self.primary_bootstrap_handler handler = self.primary_bootstrap_handler
while 1: while 1:
self.ignore_invalidations = True
# Get network connection to primary master # Get network connection to primary master
while 1: while 1:
if self.primary_master_node is not None: if self.primary_master_node is not None:
......
...@@ -120,6 +120,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -120,6 +120,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
db = app.getDB() db = app.getDB()
db is None or db.invalidateCache() db is None or db.invalidateCache()
app.last_tid = ltid app.last_tid = ltid
app.ignore_invalidations = False
def answerTransactionFinished(self, conn, _, tid, callback, cache_dict): def answerTransactionFinished(self, conn, _, tid, callback, cache_dict):
app = self.app app = self.app
...@@ -159,6 +160,8 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -159,6 +160,8 @@ class PrimaryNotificationsHandler(MTEventHandler):
def invalidateObjects(self, conn, tid, oid_list): def invalidateObjects(self, conn, tid, oid_list):
app = self.app app = self.app
if app.ignore_invalidations:
return
app.last_tid = tid app.last_tid = tid
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
......
...@@ -281,8 +281,8 @@ class NodeManager(object): ...@@ -281,8 +281,8 @@ class NodeManager(object):
self._address_dict.pop(node.getAddress(), None) self._address_dict.pop(node.getAddress(), None)
# - a master known by address but without UUID # - a master known by address but without UUID
self._uuid_dict.pop(node.getUUID(), None) self._uuid_dict.pop(node.getUUID(), None)
self.__dropSet(self._state_dict, node.getState(), node) self._state_dict[node.getState()].remove(node)
self.__dropSet(self._type_dict, node.getType(), node) self._type_dict[node.getType()].remove(node)
uuid = node.getUUID() uuid = node.getUUID()
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
self._master_db.discard(node.getAddress()) self._master_db.discard(node.getAddress())
...@@ -305,10 +305,6 @@ class NodeManager(object): ...@@ -305,10 +305,6 @@ class NodeManager(object):
def _updateUUID(self, node, old_uuid): def _updateUUID(self, node, old_uuid):
self.__update(self._uuid_dict, old_uuid, node.getUUID(), node) self.__update(self._uuid_dict, old_uuid, node.getUUID(), node)
def __dropSet(self, set_dict, key, node):
if key in set_dict:
set_dict[key].remove(node)
def __updateSet(self, set_dict, old_key, new_key, node): def __updateSet(self, set_dict, old_key, new_key, node):
""" Update a set index from old to new key """ """ Update a set index from old to new key """
if old_key in set_dict: if old_key in set_dict:
......
...@@ -243,10 +243,10 @@ class PartitionTable(object): ...@@ -243,10 +243,10 @@ class PartitionTable(object):
"""Help debugging partition table management. """Help debugging partition table management.
Output sample: Output sample:
pt: node 0: 67ae354b4ed240a0594d042cf5c01b28, R pt: node 0: S1, R
pt: node 1: a68a01e8bf93e287bd505201c1405bc2, R pt: node 1: S2, R
pt: node 2: ad7ffe8ceef4468a0c776f3035c7a543, R pt: node 2: S3, R
pt: node 3: df57d7298678996705cd0092d84580f4, R pt: node 3: S4, R
pt: 00: .UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU. pt: 00: .UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.
pt: 11: U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U pt: 11: U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U
......
...@@ -114,6 +114,7 @@ class BackupApplication(object): ...@@ -114,6 +114,7 @@ class BackupApplication(object):
del bootstrap, node del bootstrap, node
if num_partitions != pt.getPartitions(): if num_partitions != pt.getPartitions():
raise RuntimeError("inconsistent number of partitions") raise RuntimeError("inconsistent number of partitions")
self.ignore_invalidations = True
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskPartitionTable()) conn.ask(Packets.AskPartitionTable())
......
...@@ -29,6 +29,7 @@ class BackupHandler(EventHandler): ...@@ -29,6 +29,7 @@ class BackupHandler(EventHandler):
self.app.pt.load(ptid, row_list, self.app.nm) self.app.pt.load(ptid, row_list, self.app.nm)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
if self.app.pt.filled():
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
# NOTE invalidation from M -> Mb (all partitions) # NOTE invalidation from M -> Mb (all partitions)
...@@ -38,10 +39,13 @@ class BackupHandler(EventHandler): ...@@ -38,10 +39,13 @@ class BackupHandler(EventHandler):
app.invalidatePartitions(tid, set(xrange(app.pt.getPartitions()))) app.invalidatePartitions(tid, set(xrange(app.pt.getPartitions())))
else: # upstream DB is empty else: # upstream DB is empty
assert app.app.getLastTransaction() == tid assert app.app.getLastTransaction() == tid
app.ignore_invalidations = False
# NOTE invalidation from M -> Mb # NOTE invalidation from M -> Mb
def invalidateObjects(self, conn, tid, oid_list): def invalidateObjects(self, conn, tid, oid_list):
app = self.app app = self.app
if app.ignore_invalidations:
return
getPartition = app.app.pt.getPartition getPartition = app.app.pt.getPartition
partition_set = set(map(getPartition, oid_list)) partition_set = set(map(getPartition, oid_list))
partition_set.add(getPartition(tid)) partition_set.add(getPartition(tid))
......
...@@ -56,7 +56,6 @@ UNIT_TEST_MODULES = [ ...@@ -56,7 +56,6 @@ UNIT_TEST_MODULES = [
'neo.tests.master.testTransactions', 'neo.tests.master.testTransactions',
# storage application # storage application
'neo.tests.storage.testClientHandler', 'neo.tests.storage.testClientHandler',
'neo.tests.storage.testInitializationHandler',
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp', 'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
......
This diff is collapsed.
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time, unittest import time, unittest
from mock import Mock, ReturnValues from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.client.app import ConnectionPool from neo.client.app import ConnectionPool
...@@ -24,27 +24,6 @@ from neo.client import pool ...@@ -24,27 +24,6 @@ from neo.client import pool
class ConnectionPoolTests(NeoUnitTestBase): class ConnectionPoolTests(NeoUnitTestBase):
def test_removeConnection(self):
app = None
pool = ConnectionPool(app)
test_node_uuid = self.getStorageUUID()
other_node_uuid = self.getStorageUUID()
test_node = Mock({'getUUID': test_node_uuid})
other_node = Mock({'getUUID': other_node_uuid})
# Test sanity check
self.assertEqual(getattr(pool, 'connection_dict', None), {})
# Call must not raise if node is not known
self.assertEqual(len(pool.connection_dict), 0)
pool.removeConnection(test_node)
# Test that removal with another uuid doesn't affect entry
pool.connection_dict[test_node_uuid] = None
self.assertEqual(len(pool.connection_dict), 1)
pool.removeConnection(other_node)
self.assertEqual(len(pool.connection_dict), 1)
# Test that removeConnection works
pool.removeConnection(test_node)
self.assertEqual(len(pool.connection_dict), 0)
# TODO: test getConnForNode (requires splitting complex functionalities) # TODO: test getConnForNode (requires splitting complex functionalities)
def test_CellSortKey(self): def test_CellSortKey(self):
...@@ -81,30 +60,6 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -81,30 +60,6 @@ class ConnectionPoolTests(NeoUnitTestBase):
pool = ConnectionPool(app) pool = ConnectionPool(app)
self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next) self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next)
def test_iterateForObject_connectionRefused(self):
# connection refused at the first try
oid = self.getOID(1)
node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'})
app = Mock()
app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': ReturnValues(None, conn)})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
def test_iterateForObject_connectionAccepted(self):
# connection accepted
oid = self.getOID(1)
node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'})
app = Mock()
app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': conn})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,102 +17,16 @@ ...@@ -17,102 +17,16 @@
import unittest import unittest
from mock import Mock from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.node import NodeManager from neo.client.handlers.master import PrimaryAnswersHandler
from neo.client.handlers.master import PrimaryNotificationsHandler, \
PrimaryAnswersHandler
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
class MasterHandlerTests(NeoUnitTestBase): class MasterHandlerTests(NeoUnitTestBase):
def setUp(self):
super(MasterHandlerTests, self).setUp()
self.db = Mock()
self.app = Mock({'getDB': self.db,
'txn_contexts': ()})
self.app.nm = NodeManager()
self.app.dispatcher = Mock()
self._next_port = 3000
def getKnownMaster(self):
node = self.app.nm.createMaster(address=(
self.local_ip, self._next_port),
)
self._next_port += 1
conn = self.getFakeConnection(address=node.getAddress())
node.setConnection(conn)
return node, conn
class MasterNotificationsHandlerTests(MasterHandlerTests):
def setUp(self):
super(MasterNotificationsHandlerTests, self).setUp()
self.handler = PrimaryNotificationsHandler(self.app)
def test_connectionClosed(self):
conn = self.getFakeConnection()
node = Mock()
self.app.master_conn = conn
self.app.primary_master_node = node
self.handler.connectionClosed(conn)
self.assertEqual(self.app.master_conn, None)
self.assertEqual(self.app.primary_master_node, None)
def test_invalidateObjects(self):
conn = self.getFakeConnection()
tid = self.getNextTID()
oid1, oid2, oid3 = self.getOID(1), self.getOID(2), self.getOID(3)
self.app._cache = Mock({
'invalidate': None,
})
self.handler.invalidateObjects(conn, tid, [oid1, oid3])
cache_calls = self.app._cache.mockGetNamedCalls('invalidate')
self.assertEqual(len(cache_calls), 2)
cache_calls[0].checkArgs(oid1, tid)
cache_calls[1].checkArgs(oid3, tid)
invalidation_calls = self.db.mockGetNamedCalls('invalidate')
self.assertEqual(len(invalidation_calls), 1)
invalidation_calls[0].checkArgs(tid, [oid1, oid3])
def test_notifyPartitionChanges(self):
conn = self.getFakeConnection()
self.app.pt = Mock({'filled': True})
ptid = 0
cell_list = (Mock(), Mock())
self.handler.notifyPartitionChanges(conn, ptid, cell_list)
update_calls = self.app.pt.mockGetNamedCalls('update')
self.assertEqual(len(update_calls), 1)
update_calls[0].checkArgs(ptid, cell_list, self.app.nm)
class MasterAnswersHandlerTests(MasterHandlerTests):
def setUp(self): def setUp(self):
super(MasterAnswersHandlerTests, self).setUp() super(MasterHandlerTests, self).setUp()
self.app = Mock()
self.handler = PrimaryAnswersHandler(self.app) self.handler = PrimaryAnswersHandler(self.app)
def test_answerBeginTransaction(self):
tid = self.getNextTID()
conn = self.getFakeConnection()
self.handler.answerBeginTransaction(conn, tid)
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
def test_answerNewOIDs(self):
conn = self.getFakeConnection()
oid1, oid2, oid3 = self.getOID(0), self.getOID(1), self.getOID(2)
self.handler.answerNewOIDs(conn, [oid1, oid2, oid3])
self.assertEqual(self.app.new_oid_list, [oid3, oid2, oid1])
def test_answerTransactionFinished(self):
conn = self.getFakeConnection()
ttid2 = self.getNextTID()
tid2 = self.getNextTID()
self.handler.answerTransactionFinished(conn, ttid2, tid2)
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid2)
def test_answerPack(self): def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False) self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
# Check it doesn't raise # Check it doesn't raise
......
...@@ -19,8 +19,6 @@ from mock import Mock ...@@ -19,8 +19,6 @@ from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.client.handlers.storage import StorageAnswersHandler from neo.client.handlers.storage import StorageAnswersHandler
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError
from ZODB.TimeStamp import TimeStamp
class StorageAnswerHandlerTests(NeoUnitTestBase): class StorageAnswerHandlerTests(NeoUnitTestBase):
...@@ -29,20 +27,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -29,20 +27,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
self.app = Mock() self.app = Mock()
self.handler = StorageAnswersHandler(self.app) self.handler = StorageAnswersHandler(self.app)
def _checkHandlerData(self, ref):
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ref)
def test_answerObject(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.handler.answerObject(conn, *the_object)
self._checkHandlerData(the_object[1:])
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict, def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict): conflict_serial_dict, resolved_conflict_serial_dict):
app = Mock({ app = Mock({
...@@ -119,86 +103,11 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -119,86 +103,11 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {}, self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {},
{oid: {tid}}).answerStoreObject(conn, 1, oid, tid_2) {oid: {tid}}).answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_4(self):
uuid = self.getStorageUUID()
conn = self.getFakeConnection(uuid=uuid)
oid = self.getOID(0)
tid = self.getNextTID()
# no conflict
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {}
h = self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict)
h.app.getHandlerData()['cache_dict'] = {oid: None}
h.answerStoreObject(conn, 0, oid, tid)
self.assertFalse(oid in conflict_serial_dict)
self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {tid: {uuid}})
def test_answerTransactionInformation(self):
conn = self.getFakeConnection()
tid = self.getNextTID()
user = 'USER'
desc = 'DESC'
ext = 'EXT'
packed = False
oid_list = [self.getOID(0), self.getOID(1)]
self.handler.answerTransactionInformation(conn, tid, user, desc, ext,
packed, oid_list)
self._checkHandlerData(({
'time': TimeStamp(tid).timeTime(),
'user_name': user,
'description': desc,
'id': tid,
'oids': oid_list,
'packed': packed,
}, ext))
def test_oidNotFound(self):
conn = self.getFakeConnection()
self.assertRaises(NEOStorageNotFoundError, self.handler.oidNotFound,
conn, 'message')
def test_oidDoesNotExist(self):
conn = self.getFakeConnection()
self.assertRaises(NEOStorageDoesNotExistError,
self.handler.oidDoesNotExist, conn, 'message')
def test_tidNotFound(self): def test_tidNotFound(self):
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.assertRaises(NEOStorageNotFoundError, self.handler.tidNotFound, self.assertRaises(NEOStorageNotFoundError, self.handler.tidNotFound,
conn, 'message') conn, 'message')
def test_answerTIDs(self):
uuid = self.getStorageUUID()
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
tid_list = [tid1, tid2]
conn = self.getFakeConnection(uuid=uuid)
tid_set = set()
StorageAnswersHandler(Mock()).answerTIDs(conn, tid_list, tid_set)
self.assertEqual(tid_set, set(tid_list))
def test_answerObjectUndoSerial(self):
uuid = self.getStorageUUID()
conn = self.getFakeConnection(uuid=uuid)
oid1 = self.getOID(1)
oid2 = self.getOID(2)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
undo_dict = {}
handler = StorageAnswersHandler(Mock())
handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]}, undo_dict)
self.assertEqual(undo_dict, {oid1: [tid0, tid1]})
handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]}, undo_dict)
self.assertEqual(undo_dict, {
oid1: [tid0, tid1],
oid2: [tid2, tid3],
})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -258,23 +258,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -258,23 +258,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.assertEqual(node.getUUID(), new_uuid) self.assertEqual(node.getUUID(), new_uuid)
self.assertNotEqual(node.getUUID(), uuid) self.assertNotEqual(node.getUUID(), uuid)
def _getNodeList(self):
return [x.asTuple() for x in self.app.nm.getList()]
def __getClient(self): def __getClient(self):
uuid = self.getClientUUID() uuid = self.getClientUUID()
conn = self.getFakeConnection(uuid=uuid, address=self.client_address) conn = self.getFakeConnection(uuid=uuid, address=self.client_address)
self.app.nm.createClient(uuid=uuid, address=self.client_address) self.app.nm.createClient(uuid=uuid, address=self.client_address)
return conn return conn
def __getMaster(self, port=1000, register=True):
uuid = self.getMasterUUID()
address = ('127.0.0.1', port)
conn = self.getFakeConnection(uuid=uuid, address=address)
if register:
self.app.nm.createMaster(uuid=uuid, address=address)
return conn
def testRequestIdentification1(self): def testRequestIdentification1(self):
""" Check with a non-master node, must be refused """ """ Check with a non-master node, must be refused """
conn = self.__getClient() conn = self.__getClient()
......
...@@ -92,25 +92,6 @@ class MasterAppTests(NeoUnitTestBase): ...@@ -92,25 +92,6 @@ class MasterAppTests(NeoUnitTestBase):
self.checkNoPacketSent(master_conn) self.checkNoPacketSent(master_conn)
self.checkNotifyNodeInformation(storage_conn) self.checkNotifyNodeInformation(storage_conn)
def test_storageReadinessAPI(self):
uuid_1 = self.getStorageUUID()
uuid_2 = self.getStorageUUID()
self.assertFalse(self.app.isStorageReady(uuid_1))
self.assertFalse(self.app.isStorageReady(uuid_2))
# Must not raise, nor change readiness
self.app.setStorageNotReady(uuid_1)
self.assertFalse(self.app.isStorageReady(uuid_1))
self.assertFalse(self.app.isStorageReady(uuid_2))
# Mark as ready, only one must change
self.app.setStorageReady(uuid_1)
self.assertTrue(self.app.isStorageReady(uuid_1))
self.assertFalse(self.app.isStorageReady(uuid_2))
self.app.setStorageReady(uuid_2)
# Mark not ready, only one must change
self.app.setStorageNotReady(uuid_1)
self.assertFalse(self.app.isStorageReady(uuid_1))
self.assertTrue(self.app.isStorageReady(uuid_2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -36,11 +36,8 @@ class MasterRecoveryTests(NeoUnitTestBase): ...@@ -36,11 +36,8 @@ class MasterRecoveryTests(NeoUnitTestBase):
node.setState(NodeStates.RUNNING) node.setState(NodeStates.RUNNING)
# define some variable to simulate client and storage node # define some variable to simulate client and storage node
self.client_port = 11022
self.storage_port = 10021 self.storage_port = 10021
self.master_port = 10011 self.master_port = 10011
self.master_address = ('127.0.0.1', self.master_port)
self.storage_address = ('127.0.0.1', self.storage_port)
def _tearDown(self, success): def _tearDown(self, success):
self.app.close() self.app.close()
...@@ -58,16 +55,11 @@ class MasterRecoveryTests(NeoUnitTestBase): ...@@ -58,16 +55,11 @@ class MasterRecoveryTests(NeoUnitTestBase):
return uuid return uuid
# Tests # Tests
def test_01_connectionClosed(self):
uuid = self.identifyToMasterNode(node_type=NodeTypes.MASTER, port=self.master_port)
conn = self.getFakeConnection(uuid, self.master_address)
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.RUNNING)
self.recovery.connectionClosed(conn)
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.TEMPORARILY_DOWN)
def test_10_answerPartitionTable(self): def test_10_answerPartitionTable(self):
# XXX: This test does much less that it seems, because all 'for' loops
# iterate over empty lists. Currently, only testRecovery covers
# some paths in NodeManager._createNode: apart from that, we could
# delete it entirely.
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port) uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
# not from target node, ignore # not from target node, ignore
......
...@@ -17,11 +17,9 @@ ...@@ -17,11 +17,9 @@
import unittest import unittest
from mock import Mock from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates, Packets from neo.lib.protocol import NodeTypes, Packets
from neo.master.handlers.storage import StorageServiceHandler from neo.master.handlers.storage import StorageServiceHandler
from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
from neo.lib.exception import StoppedOperation
class MasterStorageHandlerTests(NeoUnitTestBase): class MasterStorageHandlerTests(NeoUnitTestBase):
...@@ -34,23 +32,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -34,23 +32,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.app.pt.clear() self.app.pt.clear()
self.app.em = Mock() self.app.em = Mock()
self.service = StorageServiceHandler(self.app) self.service = StorageServiceHandler(self.app)
self.client_handler = ClientServiceHandler(self.app)
# define some variable to simulate client and storage node
self.client_port = 11022
self.storage_port = 10021
self.master_port = 10010
self.master_address = ('127.0.0.1', self.master_port)
self.client_address = ('127.0.0.1', self.client_port)
self.storage_address = ('127.0.0.1', self.storage_port)
def _allocatePort(self): def _allocatePort(self):
self.port = getattr(self, 'port', 1000) + 1 self.port = getattr(self, 'port', 1000) + 1
return self.port return self.port
def _getClient(self):
return self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
ip='127.0.0.1', port=self._allocatePort())
def _getStorage(self): def _getStorage(self):
return self.identifyToMasterNode(node_type=NodeTypes.STORAGE, return self.identifyToMasterNode(node_type=NodeTypes.STORAGE,
ip='127.0.0.1', port=self._allocatePort()) ip='127.0.0.1', port=self._allocatePort())
...@@ -67,98 +53,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -67,98 +53,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
return (node, conn) return (node, conn)
def test_answerInformationLocked_1(self):
"""
Master must refuse to lock if the TID is greater than the last TID
"""
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
self.app.tm.setLastTID(tid1)
self.assertTrue(tid1 < tid2)
node, conn = self.identifyToMasterNode()
self.checkProtocolErrorRaised(self.service.answerInformationLocked,
conn, tid2)
self.checkNoPacketSent(conn)
def test_answerInformationLocked_2(self):
"""
Master must:
- lock each storage
- notify the client
- invalidate other clients
- unlock storages
"""
# one client and two storages required
client_1, client_conn_1 = self._getClient()
client_2, client_conn_2 = self._getClient()
storage_1, storage_conn_1 = self._getStorage()
storage_2, storage_conn_2 = self._getStorage()
uuid_list = storage_1.getUUID(), storage_2.getUUID()
oid_list = self.getOID(), self.getOID()
msg_id = 1
# register a transaction
ttid = self.app.tm.begin(client_1)
tid = self.app.tm.prepare(ttid, 1, oid_list, uuid_list,
msg_id)
self.assertTrue(ttid in self.app.tm)
# the first storage acknowledge the lock
self.service.answerInformationLocked(storage_conn_1, ttid)
self.checkNoPacketSent(client_conn_1)
self.checkNoPacketSent(client_conn_2)
self.checkNoPacketSent(storage_conn_1)
self.checkNoPacketSent(storage_conn_2)
# then the second
self.service.answerInformationLocked(storage_conn_2, ttid)
self.checkAnswerTransactionFinished(client_conn_1)
self.checkInvalidateObjects(client_conn_2)
self.checkNotifyUnlockInformation(storage_conn_1)
self.checkNotifyUnlockInformation(storage_conn_2)
def test_13_askUnfinishedTransactions(self):
service = self.service
node, conn = self.identifyToMasterNode()
# give a uuid
service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn)
max_tid, tid_list = packet.decode()
self.assertEqual(tid_list, [])
# create some transaction
node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port=self.client_port)
ttid = self.app.tm.begin(node)
self.app.tm.prepare(ttid, 1,
[self.getOID(1)], [node.getUUID()], 1)
conn = self.getFakeConnection(node.getUUID(), self.storage_address)
service.askUnfinishedTransactions(conn)
max_tid, tid_list = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1)
def test_connectionClosed(self):
method = self.service.connectionClosed
state = NodeStates.TEMPORARILY_DOWN
# define two nodes
node1, conn1 = self.identifyToMasterNode()
node2, conn2 = self.identifyToMasterNode(port=10022)
node1.setRunning()
node2.setRunning()
self.assertEqual(node1.getState(), NodeStates.RUNNING)
self.assertEqual(node2.getState(), NodeStates.RUNNING)
# filled the pt
self.app.pt.make(self.app.nm.getStorageList())
self.assertTrue(self.app.pt.filled())
self.assertTrue(self.app.pt.operational())
# drop one node
lptid = self.app.pt.getID()
method(conn1)
self.assertEqual(node1.getState(), state)
self.assertTrue(lptid < self.app.pt.getID())
# drop the second, no storage node left
lptid = self.app.pt.getID()
self.assertEqual(node2.getState(), NodeStates.RUNNING)
self.assertRaises(StoppedOperation, method, conn2)
self.assertEqual(node2.getState(), state)
self.assertEqual(lptid, self.app.pt.getID())
def test_answerPack(self): def test_answerPack(self):
# Note: incoming status has no meaning here, so it's left to False. # Note: incoming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage() node1, conn1 = self._getStorage()
...@@ -183,13 +77,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -183,13 +77,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.assertTrue(status) self.assertTrue(status)
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
def test_notifyReady(self):
node, conn = self._getStorage()
uuid = node.getUUID()
self.assertFalse(self.app.isStorageReady(uuid))
self.service.notifyReady(conn)
self.assertTrue(self.app.isStorageReady(uuid))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,13 +20,10 @@ from struct import pack ...@@ -20,13 +20,10 @@ from struct import pack
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes
from neo.lib.util import packTID, unpackTID, addTID from neo.lib.util import packTID, unpackTID, addTID
from neo.master.transactions import Transaction, TransactionManager from neo.master.transactions import TransactionManager
class testTransactionManager(NeoUnitTestBase): class testTransactionManager(NeoUnitTestBase):
def makeTID(self, i):
return pack('!Q', i)
def makeOID(self, i): def makeOID(self, i):
return pack('!Q', i) return pack('!Q', i)
...@@ -35,58 +32,6 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -35,58 +32,6 @@ class testTransactionManager(NeoUnitTestBase):
node = Mock({'getUUID': uuid, '__hash__': uuid, '__repr__': 'FakeNode'}) node = Mock({'getUUID': uuid, '__hash__': uuid, '__repr__': 'FakeNode'})
return uuid, node return uuid, node
def testTransaction(self):
# test data
node = Mock({'__repr__': 'Node'})
tid = self.makeTID(1)
ttid = self.makeTID(2)
oid_list = (oid1, oid2) = [self.makeOID(1), self.makeOID(2)]
uuid_list = (uuid1, uuid2) = [self.getStorageUUID(),
self.getStorageUUID()]
msg_id = 1
# create transaction object
txn = Transaction(node, ttid)
txn.prepare(tid, oid_list, uuid_list, msg_id)
self.assertEqual(txn.getUUIDList(), uuid_list)
self.assertEqual(txn.getOIDList(), oid_list)
# lock nodes one by one
self.assertFalse(txn.lock(uuid1))
self.assertTrue(txn.lock(uuid2))
# check that repr() works
repr(txn)
def testManager(self):
# test data
node = Mock({'__hash__': 1})
msg_id = 1
oid_list = (oid1, oid2) = self.makeOID(1), self.makeOID(2)
uuid_list = uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
client_uuid = self.getClientUUID()
# create transaction manager
callback = Mock()
txnman = TransactionManager(on_commit=callback)
self.assertFalse(txnman.hasPending())
self.assertEqual(txnman.registerForNotification(uuid1), [])
# begin the transaction
ttid = txnman.begin(node)
self.assertTrue(ttid is not None)
self.assertEqual(len(txnman.registerForNotification(uuid1)), 1)
self.assertTrue(txnman.hasPending())
# prepare the transaction
tid = txnman.prepare(ttid, 1, oid_list, uuid_list, msg_id)
self.assertTrue(txnman.hasPending())
self.assertEqual(txnman.registerForNotification(uuid1), [ttid])
txn = txnman[ttid]
self.assertEqual(txn.getTID(), tid)
self.assertEqual(txn.getUUIDList(), list(uuid_list))
self.assertEqual(txn.getOIDList(), list(oid_list))
# lock nodes
txnman.lock(ttid, uuid1)
self.assertEqual(len(callback.getNamedCalls('__call__')), 0)
txnman.lock(ttid, uuid2)
self.assertEqual(len(callback.getNamedCalls('__call__')), 1)
self.assertEqual(txnman.registerForNotification(uuid1), [])
def test_storageLost(self): def test_storageLost(self):
client1 = Mock({'__hash__': 1}) client1 = Mock({'__hash__': 1})
client2 = Mock({'__hash__': 2}) client2 = Mock({'__hash__': 2})
...@@ -95,7 +40,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -95,7 +40,7 @@ class testTransactionManager(NeoUnitTestBase):
storage_2_uuid = self.getStorageUUID() storage_2_uuid = self.getStorageUUID()
oid_list = [self.makeOID(1), ] oid_list = [self.makeOID(1), ]
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(None)
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
...@@ -172,7 +117,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -172,7 +117,7 @@ class testTransactionManager(NeoUnitTestBase):
Note: this implementation might change later, for more parallelism. Note: this implementation might change later, for more parallelism.
""" """
client_uuid, client = self.makeNode(NodeTypes.CLIENT) client_uuid, client = self.makeNode(NodeTypes.CLIENT)
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(None)
# With a requested TID, lock spans from begin to remove # With a requested TID, lock spans from begin to remove
ttid1 = self.getNextTID() ttid1 = self.getNextTID()
ttid2 = self.getNextTID() ttid2 = self.getNextTID()
...@@ -189,7 +134,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -189,7 +134,7 @@ class testTransactionManager(NeoUnitTestBase):
def testClientDisconectsAfterBegin(self): def testClientDisconectsAfterBegin(self):
client_uuid1, node1 = self.makeNode(NodeTypes.CLIENT) client_uuid1, node1 = self.makeNode(NodeTypes.CLIENT)
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(None)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tm.begin(node1, tid1) tm.begin(node1, tid1)
......
...@@ -17,23 +17,13 @@ ...@@ -17,23 +17,13 @@
import unittest import unittest
from mock import Mock, ReturnValues from mock import Mock, ReturnValues
from collections import deque from collections import deque
from neo.lib.util import makeChecksum
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.transactions import ConflictError
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_PARTITION, INVALID_TID, INVALID_OID from neo.lib.protocol import INVALID_TID, INVALID_OID, Packets, LockState
from neo.lib.protocol import Packets, LockState, ZERO_HASH
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs):
conn = self.getFakeConnection(address=("127.0.0.1", self.master_port),
is_server=_listening)
# hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)
def setUp(self): def setUp(self):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1) self.prepareDatabase(number=1)
...@@ -53,7 +43,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -53,7 +43,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
pmn = self.app.nm.getMasterList()[0] pmn = self.app.nm.getMasterList()[0]
pmn.setUUID(self.master_uuid) pmn.setUUID(self.master_uuid)
self.app.primary_master_node = pmn self.app.primary_master_node = pmn
self.master_port = 10010
def _tearDown(self, success): def _tearDown(self, success):
self.app.close() self.app.close()
...@@ -63,17 +52,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -63,17 +52,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
def _getConnection(self, uuid=None): def _getConnection(self, uuid=None):
return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000)) return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000))
def _checkTransactionsAborted(self, uuid):
calls = self.app.tm.mockGetNamedCalls('abortFor')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(uuid)
def test_connectionLost(self):
uuid = self.getClientUUID()
self.app.nm.createClient(uuid=uuid)
conn = self._getConnection(uuid=uuid)
self.operation.connectionClosed(conn)
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = self._getConnection() conn = self._getConnection()
...@@ -81,15 +59,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -81,15 +59,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askTransactionInformation(conn, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self):
# answer
conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
dm = Mock({ "getTransaction": (oid_list, 'user', 'desc', '', False), })
self.app.dm = dm
self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkAnswerTransactionInformation(conn)
def test_24_askObject1(self): def test_24_askObject1(self):
# delayed response # delayed response
conn = self._getConnection() conn = self._getConnection()
...@@ -103,33 +72,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -103,33 +72,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0) self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
def test_24_askObject2(self):
# invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None})
conn = self._getConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID,
serial=INVALID_TID, tid=INVALID_TID)
calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEqual(len(self.app.event_queue), 0)
self.assertEqual(len(calls), 1)
calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID)
self.checkErrorPacket(conn)
def test_24_askObject3(self):
# object found => answer
serial = self.getNextTID()
next_serial = self.getNextTID()
oid = self.getOID(1)
tid = self.getNextTID()
H = "0" * 20
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self._getConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
self.assertEqual(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
def test_25_askTIDs1(self): def test_25_askTIDs1(self):
# invalid offsets => error # invalid offsets => error
app = self.app app = self.app
...@@ -151,23 +93,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -151,23 +93,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
calls[0].checkArgs(1, 1, [1, ]) calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn) self.checkAnswerTids(conn)
def test_25_askTIDs3(self):
# invalid partition => answer usable partitions
conn = self._getConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
self.operation.askTIDs(conn, 1, 2, INVALID_PARTITION)
self.assertEqual(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, [0])
self.checkAnswerTids(conn)
def test_26_askObjectHistory1(self): def test_26_askObjectHistory1(self):
# invalid offsets => error # invalid offsets => error
app = self.app app = self.app
...@@ -177,87 +102,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -177,87 +102,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
1, 1, None) 1, 1, None)
self.assertEqual(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEqual(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self):
oid1, oid2 = self.getOID(1), self.getOID(2)
# first case: empty history
conn = self._getConnection()
self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, oid1, 1, 2)
self.checkErrorPacket(conn)
# second case: not empty history
conn = self._getConnection()
serial = self.getNextTID()
self.app.dm = Mock({'getObjectHistory': [(serial, 0, ), ]})
self.operation.askObjectHistory(conn, oid2, 1, 2)
self.checkAnswerObjectHistory(conn)
def _getObject(self):
oid = self.getOID(0)
serial = self.getNextTID()
data = 'DATA'
return (oid, serial, 1, makeChecksum(data), data)
def _checkStoreObjectCalled(self, *args):
calls = self.app.tm.mockGetNamedCalls('storeObject')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args)
def test_askStoreObject1(self):
# no conflict => answer
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject()
self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, None, tid, False)
self._checkStoreObjectCalled(tid, serial, oid, comp,
checksum, data, None, False)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True)
self.assertEqual(pconflicting, 0)
self.assertEqual(poid, oid)
self.assertEqual(pserial, serial)
def test_askStoreObjectWithDataTID(self):
# same as test_askStoreObject1, but with a non-None data_tid value
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, ZERO_HASH,
'', data_tid, tid, False)
self._checkStoreObjectCalled(tid, serial, oid, comp,
None, None, data_tid, False)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True)
self.assertEqual(pconflicting, 0)
self.assertEqual(poid, oid)
self.assertEqual(pserial, serial)
def test_askStoreObject2(self):
# conflict error
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
locking_tid = self.getNextTID(tid)
def fakeStoreObject(*args):
raise ConflictError(locking_tid)
self.app.tm.storeObject = fakeStoreObject
oid, serial, comp, checksum, data = self._getObject()
self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, None, tid, False)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True)
self.assertEqual(pconflicting, 1)
self.assertEqual(poid, oid)
self.assertEqual(pserial, locking_tid)
def test_abortTransaction(self):
conn = self._getConnection()
tid = self.getNextTID()
self.operation.abortTransaction(conn, tid)
calls = self.app.tm.mockGetNamedCalls('abort')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
def test_askObjectUndoSerial(self): def test_askObjectUndoSerial(self):
conn = self._getConnection(uuid=self.getClientUUID()) conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID() tid = self.getNextTID()
......
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeStates, NodeTypes, NotReadyError, \ from neo.lib.protocol import NodeTypes, BrokenNodeDisallowedError
BrokenNodeDisallowedError
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.identification import IdentificationHandler from neo.storage.handlers.identification import IdentificationHandler
...@@ -39,31 +37,6 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -39,31 +37,6 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
del self.app del self.app
super(StorageIdentificationHandlerTests, self)._tearDown(success) super(StorageIdentificationHandlerTests, self)._tearDown(success)
def test_requestIdentification1(self):
""" nodes are rejected during election or if unknown storage """
self.app.ready = False
self.assertRaises(
NotReadyError,
self.identification.requestIdentification,
self.getFakeConnection(),
NodeTypes.CLIENT,
self.getClientUUID(),
None,
self.app.name,
None,
)
self.app.ready = True
self.assertRaises(
NotReadyError,
self.identification.requestIdentification,
self.getFakeConnection(),
NodeTypes.STORAGE,
self.getStorageUUID(),
None,
self.app.name,
None,
)
def test_requestIdentification3(self): def test_requestIdentification3(self):
""" broken nodes must be rejected """ """ broken nodes must be rejected """
uuid = self.getClientUUID() uuid = self.getClientUUID()
...@@ -80,28 +53,5 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -80,28 +53,5 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
None, None,
) )
def test_requestIdentification2(self):
""" accepted client must be connected and running """
uuid = self.getClientUUID()
conn = self.getFakeConnection(uuid=uuid)
node = self.app.nm.createClient(uuid=uuid, state=NodeStates.RUNNING)
master = (self.local_ip, 3000)
self.app.master_node = Mock({
'getAddress': master,
})
self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid,
None, self.app.name, None)
self.assertTrue(node.isRunning())
self.assertTrue(node.isConnected())
self.assertEqual(node.getUUID(), uuid)
self.assertTrue(node.getConnection() is conn)
args = self.checkAcceptIdentification(conn, decode=True)
node_type, address, _np, _nr, _uuid, _master, _master_list = args
self.assertEqual(node_type, NodeTypes.STORAGE)
self.assertEqual(address, None)
self.assertEqual(_uuid, uuid)
self.assertEqual(_master, master)
# TODO: check _master_list ?
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
#
# Copyright (C) 2009-2016 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from .. import NeoUnitTestBase
from neo.lib.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.initialization import InitializationHandler
from neo.lib.protocol import CellStates
from neo.lib.exception import PrimaryFailure
class StorageInitializationHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.verification = InitializationHandler(self.app)
# define some variable to simulate client and storage node
self.master_port = 10010
self.storage_port = 10020
self.client_port = 11011
self.num_partitions = 1009
self.num_replicas = 2
self.app.operational = False
self.app.load_lock_dict = {}
self.app.pt = PartitionTable(self.num_partitions, self.num_replicas)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageInitializationHandlerTests, self)._tearDown(success)
def getClientConnection(self):
address = ("127.0.0.1", self.client_port)
return self.getFakeConnection(uuid=self.getClientUUID(),
address=address)
def test_03_connectionClosed(self):
conn = self.getClientConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
# nothing happens
self.checkNoPacketSent(conn)
def test_09_answerPartitionTable(self):
# send a table
conn = self.getClientConnection()
self.app.pt = PartitionTable(3, 2)
node_1 = self.getStorageUUID()
node_2 = self.getStorageUUID()
node_3 = self.getStorageUUID()
self.app.uuid = node_1
# SN already know all nodes
self.app.nm.createStorage(uuid=node_1)
self.app.nm.createStorage(uuid=node_2)
self.app.nm.createStorage(uuid=node_3)
self.assertFalse(list(self.app.dm.getPartitionTable()))
row_list = [(0, ((node_1, CellStates.UP_TO_DATE), (node_2, CellStates.UP_TO_DATE))),
(1, ((node_3, CellStates.UP_TO_DATE), (node_1, CellStates.UP_TO_DATE))),
(2, ((node_2, CellStates.UP_TO_DATE), (node_3, CellStates.UP_TO_DATE)))]
self.assertFalse(self.app.pt.filled())
# send a complete new table and ack
self.verification.sendPartitionTable(conn, 2, row_list)
self.assertTrue(self.app.pt.filled())
self.assertEqual(self.app.pt.getID(), 2)
self.assertTrue(list(self.app.dm.getPartitionTable()))
if __name__ == "__main__":
unittest.main()
...@@ -20,9 +20,8 @@ from collections import deque ...@@ -20,9 +20,8 @@ from collections import deque
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler from neo.storage.handlers.master import MasterOperationHandler
from neo.lib.exception import PrimaryFailure
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.protocol import CellStates, Packets from neo.lib.protocol import CellStates
class StorageMasterHandlerTests(NeoUnitTestBase): class StorageMasterHandlerTests(NeoUnitTestBase):
...@@ -54,13 +53,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -54,13 +53,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
address = ("127.0.0.1", self.master_port) address = ("127.0.0.1", self.master_port)
return self.getFakeConnection(uuid=self.master_uuid, address=address) return self.getFakeConnection(uuid=self.master_uuid, address=address)
def test_07_connectionClosed2(self):
# primary has closed the connection
conn = self.getMasterConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.operation.connectionClosed, conn)
self.checkNoPacketSent(conn)
def test_14_notifyPartitionChanges1(self): def test_14_notifyPartitionChanges1(self):
# old partition change -> do nothing # old partition change -> do nothing
app = self.app app = self.app
...@@ -104,19 +96,5 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -104,19 +96,5 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(ptid2, cells) calls[0].checkArgs(ptid2, cells)
def _getConnection(self):
return self.getFakeConnection()
def test_askPack(self):
self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection()
tid = self.getNextTID()
self.operation.askPack(conn, tid)
calls = self.app.dm.mockGetNamedCalls('pack')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, self.app.tm.updateObjectDataForPack)
# Content has no meaning here, don't check.
self.checkAnswerPacket(conn, Packets.AnswerPack)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock, ReturnValues from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.lib.protocol import CellStates from neo.lib.protocol import CellStates
...@@ -141,25 +141,6 @@ class StorageAppTests(NeoUnitTestBase): ...@@ -141,25 +141,6 @@ class StorageAppTests(NeoUnitTestBase):
raise_on_duplicate=False) raise_on_duplicate=False)
self.assertEqual(len(self.app.event_queue), 2) self.assertEqual(len(self.app.event_queue), 2)
def test_03_executeQueuedEvents(self):
self.assertEqual(len(self.app.event_queue), 0)
msg_id = 1325136
msg_id_2 = 1325137
event = Mock({'__repr__': 'event'})
conn = Mock({'__repr__': 'conn', 'getPeerId': ReturnValues(msg_id, msg_id_2)})
self.app.queueEvent(event, conn, ("test", ))
self.app.executeQueuedEvents()
self.assertEqual(len(event.mockGetNamedCalls("__call__")), 1)
call = event.mockGetNamedCalls("__call__")[0]
params = call.getParam(1)
self.assertEqual(params, "test")
params = call.kwparams
self.assertEqual(params, {})
calls = conn.mockGetNamedCalls("setPeerId")
self.assertEqual(len(calls), 2)
calls[0].checkArgs(msg_id)
calls[1].checkArgs(msg_id_2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
This diff is collapsed.
...@@ -19,7 +19,7 @@ from time import time ...@@ -19,7 +19,7 @@ from time import time
from mock import Mock from mock import Mock
from neo.lib import connection, logging from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ClientConnection, \ from neo.lib.connection import BaseConnection, ClientConnection, \
MTClientConnection, HandlerSwitcher, CRITICAL_TIMEOUT MTClientConnection, CRITICAL_TIMEOUT
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets from neo.lib.protocol import Packets
from . import NeoUnitTestBase, Patch from . import NeoUnitTestBase, Patch
...@@ -159,186 +159,5 @@ class MTConnectionTests(ConnectionTests): ...@@ -159,186 +159,5 @@ class MTConnectionTests(ConnectionTests):
# ... except Ping # ... except Ping
ask(Packets.Ping()) ask(Packets.Ping())
class HandlerSwitcherTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self._handler = handler = Mock({
'__repr__': 'initial handler',
})
self._connection = Mock({
'__repr__': 'connection',
'getAddress': ('127.0.0.1', 10000),
})
self._handlers = HandlerSwitcher(handler)
def _makeNotification(self, msg_id):
packet = Packets.StartOperation()
packet.setId(msg_id)
return packet
def _makeRequest(self, msg_id):
packet = Packets.AskBeginTransaction()
packet.setId(msg_id)
return packet
def _makeAnswer(self, msg_id):
packet = Packets.AnswerBeginTransaction(self.getNextTID())
packet.setId(msg_id)
return packet
def _makeHandler(self):
return Mock({'__repr__': 'handler'})
def _checkPacketReceived(self, handler, packet, index=0):
calls = handler.mockGetNamedCalls('packetReceived')
self.assertEqual(len(calls), index + 1)
def _checkCurrentHandler(self, handler):
self.assertTrue(self._handlers.getHandler() is handler)
def testInit(self):
self._checkCurrentHandler(self._handler)
self.assertFalse(self._handlers.isPending())
def testEmit(self):
# First case, emit is called outside of a handler
self.assertFalse(self._handlers.isPending())
request = self._makeRequest(1)
self._handlers.emit(request, 0, None)
self.assertTrue(self._handlers.isPending())
# Second case, emit is called from inside a handler with a pending
# handler change.
new_handler = self._makeHandler()
applied = self._handlers.setHandler(new_handler)
self.assertFalse(applied)
self._checkCurrentHandler(self._handler)
call_tracker = []
def packetReceived(conn, packet, kw):
self._handlers.emit(self._makeRequest(2), 0, None)
call_tracker.append(True)
self._handler.packetReceived = packetReceived
self._handlers.handle(self._connection, self._makeAnswer(1))
self.assertEqual(call_tracker, [True])
# Effective handler must not have changed (new request is blocking
# it)
self._checkCurrentHandler(self._handler)
# Handling the next response will cause the handler to change
delattr(self._handler, 'packetReceived')
self._handlers.handle(self._connection, self._makeAnswer(2))
self._checkCurrentHandler(new_handler)
def testHandleNotification(self):
# handle with current handler
notif1 = self._makeNotification(1)
self._handlers.handle(self._connection, notif1)
self._checkPacketReceived(self._handler, notif1)
# emit a request and delay an handler
request = self._makeRequest(2)
self._handlers.emit(request, 0, None)
handler = self._makeHandler()
applied = self._handlers.setHandler(handler)
self.assertFalse(applied)
# next notification fall into the current handler
notif2 = self._makeNotification(3)
self._handlers.handle(self._connection, notif2)
self._checkPacketReceived(self._handler, notif2, index=1)
# handle with new handler
answer = self._makeAnswer(2)
self._handlers.handle(self._connection, answer)
notif3 = self._makeNotification(4)
self._handlers.handle(self._connection, notif3)
self._checkPacketReceived(handler, notif2)
def testHandleAnswer1(self):
# handle with current handler
request = self._makeRequest(1)
self._handlers.emit(request, 0, None)
answer = self._makeAnswer(1)
self._handlers.handle(self._connection, answer)
self._checkPacketReceived(self._handler, answer)
def testHandleAnswer2(self):
# handle with blocking handler
request = self._makeRequest(1)
self._handlers.emit(request, 0, None)
handler = self._makeHandler()
applied = self._handlers.setHandler(handler)
self.assertFalse(applied)
answer = self._makeAnswer(1)
self._handlers.handle(self._connection, answer)
self._checkPacketReceived(self._handler, answer)
self._checkCurrentHandler(handler)
def testHandleAnswer3(self):
# multiple setHandler
r1 = self._makeRequest(1)
r2 = self._makeRequest(2)
r3 = self._makeRequest(3)
a1 = self._makeAnswer(1)
a2 = self._makeAnswer(2)
a3 = self._makeAnswer(3)
h1 = self._makeHandler()
h2 = self._makeHandler()
h3 = self._makeHandler()
# emit all requests and setHandleres
self._handlers.emit(r1, 0, None)
applied = self._handlers.setHandler(h1)
self.assertFalse(applied)
self._handlers.emit(r2, 0, None)
applied = self._handlers.setHandler(h2)
self.assertFalse(applied)
self._handlers.emit(r3, 0, None)
applied = self._handlers.setHandler(h3)
self.assertFalse(applied)
self._checkCurrentHandler(self._handler)
self.assertTrue(self._handlers.isPending())
# process answers
self._handlers.handle(self._connection, a1)
self._checkCurrentHandler(h1)
self._handlers.handle(self._connection, a2)
self._checkCurrentHandler(h2)
self._handlers.handle(self._connection, a3)
self._checkCurrentHandler(h3)
def testHandleAnswer4(self):
# process out of order
r1 = self._makeRequest(1)
r2 = self._makeRequest(2)
r3 = self._makeRequest(3)
a1 = self._makeAnswer(1)
a2 = self._makeAnswer(2)
a3 = self._makeAnswer(3)
h = self._makeHandler()
# emit all requests
self._handlers.emit(r1, 0, None)
self._handlers.emit(r2, 0, None)
self._handlers.emit(r3, 0, None)
applied = self._handlers.setHandler(h)
self.assertFalse(applied)
# process answers
self._handlers.handle(self._connection, a1)
self._checkCurrentHandler(self._handler)
self._handlers.handle(self._connection, a2)
self._checkCurrentHandler(self._handler)
self._handlers.handle(self._connection, a3)
self._checkCurrentHandler(h)
def testHandleUnexpected(self):
# process out of order
r1 = self._makeRequest(1)
r2 = self._makeRequest(2)
a2 = self._makeAnswer(2)
h = self._makeHandler()
# emit requests around state setHandler
self._handlers.emit(r1, 0, None)
applied = self._handlers.setHandler(h)
self.assertFalse(applied)
self._handlers.emit(r2, 0, None)
# process answer for next state
self._handlers.handle(self._connection, a2)
self.checkAborted(self._connection)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from mock import Mock
from . import NeoTestBase from . import NeoTestBase
from neo.lib.dispatcher import Dispatcher, ForgottenPacket from neo.lib.dispatcher import Dispatcher, ForgottenPacket
from Queue import Queue from Queue import Queue
...@@ -26,88 +25,6 @@ class DispatcherTests(NeoTestBase): ...@@ -26,88 +25,6 @@ class DispatcherTests(NeoTestBase):
NeoTestBase.setUp(self) NeoTestBase.setUp(self)
self.dispatcher = Dispatcher() self.dispatcher = Dispatcher()
def testRegister(self):
conn = object()
queue = Queue()
MARKER = object()
self.dispatcher.register(conn, 1, queue)
self.assertTrue(queue.empty())
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER, {}))
self.assertFalse(queue.empty())
self.assertEqual(queue.get(block=False), (conn, MARKER, {}))
self.assertTrue(queue.empty())
self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
def testUnregister(self):
conn = object()
queue = Mock()
self.dispatcher.register(conn, 2, queue)
self.dispatcher.unregister(conn)
self.assertEqual(len(queue.mockGetNamedCalls('put')), 1)
self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
def testRegistered(self):
conn1 = object()
conn2 = object()
self.assertFalse(self.dispatcher.registered(conn1))
self.assertFalse(self.dispatcher.registered(conn2))
self.dispatcher.register(conn1, 1, Mock())
self.assertTrue(self.dispatcher.registered(conn1))
self.assertFalse(self.dispatcher.registered(conn2))
self.dispatcher.register(conn2, 2, Mock())
self.assertTrue(self.dispatcher.registered(conn1))
self.assertTrue(self.dispatcher.registered(conn2))
self.dispatcher.unregister(conn1)
self.assertFalse(self.dispatcher.registered(conn1))
self.assertTrue(self.dispatcher.registered(conn2))
self.dispatcher.unregister(conn2)
self.assertFalse(self.dispatcher.registered(conn1))
self.assertFalse(self.dispatcher.registered(conn2))
def testPending(self):
conn1 = object()
conn2 = object()
class Queue(object):
_empty = True
def empty(self):
return self._empty
def put(self, value):
pass
queue1 = Queue()
queue2 = Queue()
self.dispatcher.register(conn1, 1, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.dispatcher.register(conn2, 2, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.dispatcher.register(conn2, 3, queue2)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn1, 1, None, {})
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn2, 2, None, {})
self.assertFalse(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
queue1._empty = False
self.assertTrue(self.dispatcher.pending(queue1))
queue1._empty = True
self.dispatcher.register(conn1, 4, queue1)
self.dispatcher.register(conn2, 5, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.unregister(conn2)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertFalse(self.dispatcher.pending(queue2))
self.dispatcher.unregister(conn1)
self.assertFalse(self.dispatcher.pending(queue1))
self.assertFalse(self.dispatcher.pending(queue2))
def testForget(self): def testForget(self):
conn = object() conn = object()
queue = Queue() queue = Queue()
......
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import shutil
import unittest import unittest
from mock import Mock from mock import Mock
from neo.lib.protocol import NodeTypes, NodeStates from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.node import Node, MasterDB from neo.lib.node import Node, MasterDB
from . import NeoUnitTestBase, getTempDirectory from . import NeoUnitTestBase, getTempDirectory
from time import time from time import time
from os import chmod, mkdir, rmdir, unlink from os import chmod, mkdir, rmdir
from os.path import join, exists from os.path import join, exists
class NodesTests(NeoUnitTestBase): class NodesTests(NeoUnitTestBase):
...@@ -29,16 +30,6 @@ class NodesTests(NeoUnitTestBase): ...@@ -29,16 +30,6 @@ class NodesTests(NeoUnitTestBase):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
self.nm = Mock() self.nm = Mock()
def _updatedByAddress(self, node, index=0):
calls = self.nm.mockGetNamedCalls('_updateAddress')
self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node)
def _updatedByUUID(self, node, index=0):
calls = self.nm.mockGetNamedCalls('_updateUUID')
self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node)
def testInit(self): def testInit(self):
""" Check the node initialization """ """ Check the node initialization """
address = ('127.0.0.1', 10000) address = ('127.0.0.1', 10000)
...@@ -60,23 +51,6 @@ class NodesTests(NeoUnitTestBase): ...@@ -60,23 +51,6 @@ class NodesTests(NeoUnitTestBase):
self.assertTrue(previous_time < node.getLastStateChange()) self.assertTrue(previous_time < node.getLastStateChange())
self.assertTrue(time() - 1 < node.getLastStateChange() < time()) self.assertTrue(time() - 1 < node.getLastStateChange() < time())
def testAddress(self):
""" Check if the node is indexed by address """
node = Node(self.nm)
self.assertEqual(node.getAddress(), None)
address = ('127.0.0.1', 10000)
node.setAddress(address)
self._updatedByAddress(node)
def testUUID(self):
""" As for Address but UUID """
node = Node(self.nm)
self.assertEqual(node.getAddress(), None)
uuid = self.getNewUUID(None)
node.setUUID(uuid)
self._updatedByUUID(node)
class NodeManagerTests(NeoUnitTestBase): class NodeManagerTests(NeoUnitTestBase):
def _addStorage(self): def _addStorage(self):
...@@ -209,12 +183,6 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -209,12 +183,6 @@ class NodeManagerTests(NeoUnitTestBase):
self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN) self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN)
class MasterDBTests(NeoUnitTestBase): class MasterDBTests(NeoUnitTestBase):
def _checkMasterDB(self, path, expected_master_list):
db = list(MasterDB(path))
db_set = set(db)
# Generic sanity check
self.assertEqual(len(db), len(db_set))
self.assertEqual(db_set, set(expected_master_list))
def testInitialAccessRights(self): def testInitialAccessRights(self):
""" """
...@@ -254,9 +222,7 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -254,9 +222,7 @@ class MasterDBTests(NeoUnitTestBase):
db2 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertFalse(address in db2, [x for x in db2]) self.assertFalse(address in db2, [x for x in db2])
finally: finally:
if exists(db_file): shutil.rmtree(directory)
unlink(db_file)
rmdir(directory)
def testPersistence(self): def testPersistence(self):
temp_dir = getTempDirectory() temp_dir = getTempDirectory()
...@@ -280,9 +246,7 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -280,9 +246,7 @@ class MasterDBTests(NeoUnitTestBase):
self.assertFalse(address in db3, [x for x in db3]) self.assertFalse(address in db3, [x for x in db3])
self.assertTrue(address2 in db3, [x for x in db3]) self.assertTrue(address2 in db3, [x for x in db3])
finally: finally:
if exists(db_file): shutil.rmtree(directory)
unlink(db_file)
rmdir(directory)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -26,6 +26,7 @@ from collections import defaultdict ...@@ -26,6 +26,7 @@ from collections import defaultdict
from functools import wraps from functools import wraps
from neo.lib import logging from neo.lib import logging
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.master.handlers.backup import BackupHandler
from neo.storage.checker import CHECK_COUNT from neo.storage.checker import CHECK_COUNT
from neo.storage.replicator import Replicator from neo.storage.replicator import Replicator
from neo.lib.connector import SocketConnector from neo.lib.connector import SocketConnector
...@@ -368,6 +369,31 @@ class ReplicationTests(NEOThreadedTest): ...@@ -368,6 +369,31 @@ class ReplicationTests(NEOThreadedTest):
# TODO check tids # TODO check tids
self.assertEqual(1, self.checkBackup(backup)) self.assertEqual(1, self.checkBackup(backup))
def testBackupEarlyInvalidation(self):
"""
The backup master must ignore notification before being fully
initialized.
"""
upstream = NEOCluster()
try:
upstream.start()
backup = NEOCluster(upstream=upstream)
try:
backup.start()
with ConnectionFilter() as f:
f.add(lambda conn, packet:
isinstance(packet, Packets.AskPartitionTable) and
isinstance(conn.getHandler(), BackupHandler))
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
upstream.importZODB()(1)
self.tic()
self.tic()
self.assertTrue(backup.master.isAlive())
finally:
backup.stop()
finally:
upstream.stop()
def testSafeTweak(self): def testSafeTweak(self):
""" """
Check that tweak always tries to keep a minimum of (replicas + 1) Check that tweak always tries to keep a minimum of (replicas + 1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment