Commit bdc2f9b5 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Master and storage tests now inherit from neo.tests.base.NeoTestBase.

Duplicate checks removed and unified. 


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@520 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent e443219f
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
from tempfile import mkstemp from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.tests.base import NeoTestBase
from neo import protocol from neo import protocol
from neo.protocol import Packet, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.election import ElectionEventHandler from neo.master.election import ElectionEventHandler
...@@ -64,7 +65,7 @@ ClientConnection._addPacket = _addPacket ...@@ -64,7 +65,7 @@ ClientConnection._addPacket = _addPacket
ClientConnection.expectMessage = expectMessage ClientConnection.expectMessage = expectMessage
class MasterElectionTests(unittest.TestCase): class MasterElectionTests(NeoTestBase):
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.WARNING) logging.basicConfig(level = logging.WARNING)
...@@ -134,35 +135,6 @@ server: 127.0.0.1:10023 ...@@ -134,35 +135,6 @@ server: 127.0.0.1:10023
# Delete tmp file # Delete tmp file
os.remove(self.tmp_path) os.remove(self.tmp_path)
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
# Common methods # Common methods
def getNewUUID(self): def getNewUUID(self):
uuid = INVALID_UUID uuid = INVALID_UUID
...@@ -191,7 +163,7 @@ server: 127.0.0.1:10023 ...@@ -191,7 +163,7 @@ server: 127.0.0.1:10023
ip_address=ip, ip_address=ip,
port=port, port=port,
name=self.app.name,) name=self.app.name,)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
return uuid return uuid
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
...@@ -549,7 +521,7 @@ server: 127.0.0.1:10023 ...@@ -549,7 +521,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 1) self.assertEqual(len(self.app.nm.getMasterNodeList()), 1)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
# unknown node # unknown node
conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None, conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None,
"isServerConnection" : True}) "isServerConnection" : True})
...@@ -565,7 +537,7 @@ server: 127.0.0.1:10023 ...@@ -565,7 +537,7 @@ server: 127.0.0.1:10023
port=self.master_port+1, port=self.master_port+1,
name=self.app.name,) name=self.app.name,)
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2) self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
self.assertEqual(len(self.app.unconnected_master_node_set), 2) self.assertEqual(len(self.app.unconnected_master_node_set), 2)
self.assertEqual(len(self.app.negotiating_master_node_set), 0) self.assertEqual(len(self.app.negotiating_master_node_set), 0)
# broken node # broken node
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
from tempfile import mkstemp from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.tests.base import NeoTestBase
from neo import protocol from neo import protocol
from neo.protocol import Packet, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.recovery import RecoveryEventHandler from neo.master.recovery import RecoveryEventHandler
...@@ -51,7 +52,7 @@ from neo.node import MasterNode, StorageNode ...@@ -51,7 +52,7 @@ from neo.node import MasterNode, StorageNode
from neo.master.tests.connector import DoNothingConnector from neo.master.tests.connector import DoNothingConnector
from neo.connection import ClientConnection from neo.connection import ClientConnection
class MasterRecoveryTests(unittest.TestCase): class MasterRecoveryTests(NeoTestBase):
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.WARNING) logging.basicConfig(level = logging.WARNING)
...@@ -121,13 +122,6 @@ server: 127.0.0.1:10023 ...@@ -121,13 +122,6 @@ server: 127.0.0.1:10023
os.remove(self.tmp_path) os.remove(self.tmp_path)
# Common methods # Common methods
def getNewUUID(self):
uuid = INVALID_UUID
while uuid == INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def getLastUUID(self): def getLastUUID(self):
return self.uuid return self.uuid
...@@ -141,38 +135,9 @@ server: 127.0.0.1:10023 ...@@ -141,38 +135,9 @@ server: 127.0.0.1:10023
# test alien cluster # test alien cluster
conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None}) conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None})
self.recovery.handleRequestNodeIdentification(conn, packet, *args) self.recovery.handleRequestNodeIdentification(conn, packet, *args)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
return uuid return uuid
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
def checkCalledRequestNodeIdentification(self, conn, packet_number=0): def checkCalledRequestNodeIdentification(self, conn, packet_number=0):
""" Check Request Node Identification has been send""" """ Check Request Node Identification has been send"""
...@@ -322,7 +287,7 @@ server: 127.0.0.1:10023 ...@@ -322,7 +287,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 1) self.assertEqual(len(self.app.nm.getMasterNodeList()), 1)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
# 3. unknown master node with known address but different uuid, will be replaced # 3. unknown master node with known address but different uuid, will be replaced
old_uuid = uuid old_uuid = uuid
...@@ -375,7 +340,7 @@ server: 127.0.0.1:10023 ...@@ -375,7 +340,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 1) self.assertEqual(len(self.app.nm.getMasterNodeList()), 1)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
known_uuid = uuid known_uuid = uuid
# 5. known by uuid, but different address -> conflict / new master # 5. known by uuid, but different address -> conflict / new master
...@@ -401,7 +366,7 @@ server: 127.0.0.1:10023 ...@@ -401,7 +366,7 @@ server: 127.0.0.1:10023
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2) self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
# a new uuid is sent # a new uuid is sent
call = conn.mockGetNamedCalls('answer')[0] call = conn.mockGetNamedCalls('answer')[0]
body = call.getParam(0)._body body = call.getParam(0)._body
...@@ -482,7 +447,7 @@ server: 127.0.0.1:10023 ...@@ -482,7 +447,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2) self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
# 9. New node # 9. New node
uuid = self.getNewUUID() uuid = self.getNewUUID()
...@@ -506,7 +471,7 @@ server: 127.0.0.1:10023 ...@@ -506,7 +471,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 3) self.assertEqual(len(self.app.nm.getMasterNodeList()), 3)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
def test_05_handleAskPrimaryMaster(self): def test_05_handleAskPrimaryMaster(self):
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
from tempfile import mkstemp from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.tests.base import NeoTestBase
from neo import protocol from neo import protocol
from neo.protocol import Packet, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.service import ServiceEventHandler from neo.master.service import ServiceEventHandler
...@@ -49,7 +50,7 @@ from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIF ...@@ -49,7 +50,7 @@ from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIF
from neo.exception import OperationFailure, ElectionFailure from neo.exception import OperationFailure, ElectionFailure
from neo.node import MasterNode, StorageNode from neo.node import MasterNode, StorageNode
class MasterServiceTests(unittest.TestCase): class MasterServiceTests(NeoTestBase):
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.WARNING) logging.basicConfig(level = logging.WARNING)
...@@ -116,105 +117,6 @@ server: 127.0.0.1:10023 ...@@ -116,105 +117,6 @@ server: 127.0.0.1:10023
# Delete tmp file # Delete tmp file
os.remove(self.tmp_path) os.remove(self.tmp_path)
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
# Method to test the kind of packet returned in answer
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
def checkCalledNotifyNodeInformation(self, conn, packet_number=0):
""" Check Notify Node Information message has been send"""
call = conn.mockGetNamedCalls("notify")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), NOTIFY_NODE_INFORMATION)
def checkCalledAnswerPrimaryMaster(self, conn, packet_number=0):
""" Check Answer primaty master message has been send"""
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ANSWER_PRIMARY_MASTER)
def checkCalledSendPartitionTable(self, conn, packet_number=0):
""" Check partition table has been send"""
call = conn.mockGetNamedCalls("notify")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), SEND_PARTITION_TABLE)
def checkCalledStartOperation(self, conn, packet_number=0):
""" Check start operation message has been send"""
call = conn.mockGetNamedCalls("notify")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), START_OPERATION)
def checkCalledLockInformation(self, conn, packet_number=0):
""" Check lockInformation message has been send"""
call = conn.mockGetNamedCalls("ask")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), LOCK_INFORMATION)
def checkCalledUnlockInformation(self, conn, packet_number=0):
""" Check unlockInformation message has been send"""
call = conn.mockGetNamedCalls("ask")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), UNLOCK_INFORMATION)
def checkCalledNotifyTransactionFinished(self, conn, packet_number=0):
""" Check notifyTransactionFinished message has been send"""
call = conn.mockGetNamedCalls("notify")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), NOTIFY_TRANSACTION_FINISHED)
def checkCalledAnswerLastIDs(self, conn, packet_number=0):
""" Check answerLastIDs message has been send"""
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ANSWER_LAST_IDS)
return protocol._decodeAnswerLastIDs(packet._body)
def checkCalledAnswerUnfinishedTransactions(self, conn, packet_number=0):
""" Check answerUnfinishedTransactions message has been send"""
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ANSWER_UNFINISHED_TRANSACTIONS)
return protocol._decodeAnswerUnfinishedTransactions(packet._body)
# Common methods
def getNewUUID(self):
uuid = INVALID_UUID
while uuid == INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def getLastUUID(self): def getLastUUID(self):
return self.uuid return self.uuid
...@@ -228,7 +130,7 @@ server: 127.0.0.1:10023 ...@@ -228,7 +130,7 @@ server: 127.0.0.1:10023
# test alien cluster # test alien cluster
conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None}) conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None})
self.service.handleRequestNodeIdentification(conn, packet, *args) self.service.handleRequestNodeIdentification(conn, packet, *args)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
return uuid return uuid
# Tests # Tests
...@@ -254,7 +156,7 @@ server: 127.0.0.1:10023 ...@@ -254,7 +156,7 @@ server: 127.0.0.1:10023
ip_address='127.0.0.1', ip_address='127.0.0.1',
port=self.storage_port, port=self.storage_port,
name=self.app.name,) name=self.app.name,)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
self.assertEquals(len(self.app.nm.getStorageNodeList()), 1) self.assertEquals(len(self.app.nm.getStorageNodeList()), 1)
sn = self.app.nm.getStorageNodeList()[0] sn = self.app.nm.getStorageNodeList()[0]
self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port)) self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port))
...@@ -272,7 +174,7 @@ server: 127.0.0.1:10023 ...@@ -272,7 +174,7 @@ server: 127.0.0.1:10023
ip_address='127.0.0.1', ip_address='127.0.0.1',
port=self.storage_port, port=self.storage_port,
name=self.app.name,) name=self.app.name,)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
sn = self.app.nm.getStorageNodeList()[0] sn = self.app.nm.getStorageNodeList()[0]
self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port)) self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port))
self.assertEquals(sn.getUUID(), uuid) self.assertEquals(sn.getUUID(), uuid)
...@@ -318,7 +220,7 @@ server: 127.0.0.1:10023 ...@@ -318,7 +220,7 @@ server: 127.0.0.1:10023
ip_address='127.0.0.1', ip_address='127.0.0.1',
port=self.storage_port, port=self.storage_port,
name=self.app.name,) name=self.app.name,)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
self.assertEquals(len(self.app.nm.getStorageNodeList()), 1) self.assertEquals(len(self.app.nm.getStorageNodeList()), 1)
sn = self.app.nm.getStorageNodeList()[0] sn = self.app.nm.getStorageNodeList()[0]
self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port)) self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port))
...@@ -339,7 +241,7 @@ server: 127.0.0.1:10023 ...@@ -339,7 +241,7 @@ server: 127.0.0.1:10023
ip_address='127.0.0.2', ip_address='127.0.0.2',
port=10022, port=10022,
name=self.app.name,) name=self.app.name,)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn, packet)
call = conn.mockGetNamedCalls('answer')[0] call = conn.mockGetNamedCalls('answer')[0]
new_uuid = call.getParam(0)._body[-16:] new_uuid = call.getParam(0)._body[-16:]
self.assertNotEquals(uuid, new_uuid) self.assertNotEquals(uuid, new_uuid)
...@@ -394,11 +296,11 @@ server: 127.0.0.1:10023 ...@@ -394,11 +296,11 @@ server: 127.0.0.1:10023
"getAddress" : ("127.0.0.1", self.storage_port)}) "getAddress" : ("127.0.0.1", self.storage_port)})
service.handleAskPrimaryMaster(conn, packet) service.handleAskPrimaryMaster(conn, packet)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0) self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
self.checkCalledAnswerPrimaryMaster(conn, 0) self.checkAnswerPrimaryMaster(conn, packet)
self.checkCalledNotifyNodeInformation(conn, 0) self.checkNotifyNodeInformation(conn, 0)
self.checkCalledSendPartitionTable(conn, 1) self.checkSendPartitionTable(conn, 1)
self.checkCalledSendPartitionTable(conn, 2) self.checkSendPartitionTable(conn, 2)
self.checkCalledStartOperation(conn, 3) self.checkStartOperation(conn, 3)
# Same but identify as a client node, must not get start operation message # Same but identify as a client node, must not get start operation message
uuid = self.identifyToMasterNode(node_type=CLIENT_NODE_TYPE, port=11021) uuid = self.identifyToMasterNode(node_type=CLIENT_NODE_TYPE, port=11021)
...@@ -408,10 +310,10 @@ server: 127.0.0.1:10023 ...@@ -408,10 +310,10 @@ server: 127.0.0.1:10023
"getUUID" : uuid, "getAddress" : ("127.0.0.1", 11021)}) "getUUID" : uuid, "getAddress" : ("127.0.0.1", 11021)})
service.handleAskPrimaryMaster(conn, packet) service.handleAskPrimaryMaster(conn, packet)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0) self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
self.checkCalledAnswerPrimaryMaster(conn, 0) self.checkAnswerPrimaryMaster(conn, packet)
self.checkCalledNotifyNodeInformation(conn, 0) self.checkNotifyNodeInformation(conn, 0)
self.checkCalledSendPartitionTable(conn, 1) self.checkSendPartitionTable(conn, 1)
self.checkCalledSendPartitionTable(conn, 2) self.checkSendPartitionTable(conn, 2)
def test_03_handleAnnouncePrimaryMaster(self): def test_03_handleAnnouncePrimaryMaster(self):
service = self.service service = self.service
...@@ -644,7 +546,7 @@ server: 127.0.0.1:10023 ...@@ -644,7 +546,7 @@ server: 127.0.0.1:10023
"getAddress" : ("127.0.0.1", self.client_port)}) "getAddress" : ("127.0.0.1", self.client_port)})
self.app.em = Mock({"getConnectionList" : [conn, storage_conn]}) self.app.em = Mock({"getConnectionList" : [conn, storage_conn]})
service.handleFinishTransaction(conn, packet, oid_list, tid) service.handleFinishTransaction(conn, packet, oid_list, tid)
self.checkCalledLockInformation(storage_conn) self.checkLockInformation(storage_conn)
self.assertEquals(len(storage_conn.mockGetNamedCalls("ask")), 1) self.assertEquals(len(storage_conn.mockGetNamedCalls("ask")), 1)
self.assertEquals(len(self.app.finishing_transaction_dict), 1) self.assertEquals(len(self.app.finishing_transaction_dict), 1)
apptid = self.app.finishing_transaction_dict.keys()[0] apptid = self.app.finishing_transaction_dict.keys()[0]
...@@ -705,8 +607,8 @@ server: 127.0.0.1:10023 ...@@ -705,8 +607,8 @@ server: 127.0.0.1:10023
oid_list = [] oid_list = []
tid = self.app.ltid tid = self.app.ltid
service.handleFinishTransaction(conn, packet, oid_list, tid) service.handleFinishTransaction(conn, packet, oid_list, tid)
self.checkCalledLockInformation(storage_conn_1) self.checkLockInformation(storage_conn_1)
self.checkCalledLockInformation(storage_conn_2) self.checkLockInformation(storage_conn_2)
self.assertFalse(self.app.finishing_transaction_dict.values()[0].allLocked()) self.assertFalse(self.app.finishing_transaction_dict.values()[0].allLocked())
self.assertEquals(len(storage_conn_1.mockGetNamedCalls("ask")), 1) self.assertEquals(len(storage_conn_1.mockGetNamedCalls("ask")), 1)
self.assertEquals(len(storage_conn_2.mockGetNamedCalls("ask")), 1) self.assertEquals(len(storage_conn_2.mockGetNamedCalls("ask")), 1)
...@@ -717,13 +619,13 @@ server: 127.0.0.1:10023 ...@@ -717,13 +619,13 @@ server: 127.0.0.1:10023
self.assertFalse(self.app.finishing_transaction_dict.values()[0].allLocked()) self.assertFalse(self.app.finishing_transaction_dict.values()[0].allLocked())
service.handleNotifyInformationLocked(storage_conn_2, packet, tid) service.handleNotifyInformationLocked(storage_conn_2, packet, tid)
self.checkCalledNotifyTransactionFinished(conn) self.checkNotifyTransactionFinished(conn)
self.assertEquals(len(storage_conn_1.mockGetNamedCalls("ask")), 1) self.assertEquals(len(storage_conn_1.mockGetNamedCalls("ask")), 1)
self.assertEquals(len(storage_conn_1.mockGetNamedCalls("notify")), 1) self.assertEquals(len(storage_conn_1.mockGetNamedCalls("notify")), 1)
self.assertEquals(len(storage_conn_2.mockGetNamedCalls("ask")), 1) self.assertEquals(len(storage_conn_2.mockGetNamedCalls("ask")), 1)
self.assertEquals(len(storage_conn_2.mockGetNamedCalls("notify")), 1) self.assertEquals(len(storage_conn_2.mockGetNamedCalls("notify")), 1)
self.checkCalledLockInformation(storage_conn_1) self.checkLockInformation(storage_conn_1)
self.checkCalledLockInformation(storage_conn_2) self.checkLockInformation(storage_conn_2)
def test_11_handleAbortTransaction(self): def test_11_handleAbortTransaction(self):
...@@ -773,7 +675,8 @@ server: 127.0.0.1:10023 ...@@ -773,7 +675,8 @@ server: 127.0.0.1:10023
tid = self.app.ltid tid = self.app.ltid
oid = self.app.loid oid = self.app.loid
service.handleAskLastIDs(conn, packet) service.handleAskLastIDs(conn, packet)
loid, ltid, lptid = self.checkCalledAnswerLastIDs(conn) packet = self.checkAnswerLastIDs(conn, packet)
loid, ltid, lptid = protocol._decodeAnswerLastIDs(packet._body)
self.assertEqual(loid, oid) self.assertEqual(loid, oid)
self.assertEqual(ltid, tid) self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid) self.assertEqual(lptid, ptid)
...@@ -792,7 +695,8 @@ server: 127.0.0.1:10023 ...@@ -792,7 +695,8 @@ server: 127.0.0.1:10023
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.storage_port)}) "getAddress" : ("127.0.0.1", self.storage_port)})
service.handleAskUnfinishedTransactions(conn, packet) service.handleAskUnfinishedTransactions(conn, packet)
tid_list = self.checkCalledAnswerUnfinishedTransactions(conn)[0] packet = self.checkAnswerUnfinishedTransactions(conn, packet)
tid_list = protocol._decodeAnswerUnfinishedTransactions(packet._body)[0]
self.assertEqual(len(tid_list), 0) self.assertEqual(len(tid_list), 0)
# create some transaction # create some transaction
client_uuid = self.identifyToMasterNode(node_type=CLIENT_NODE_TYPE, client_uuid = self.identifyToMasterNode(node_type=CLIENT_NODE_TYPE,
...@@ -805,7 +709,8 @@ server: 127.0.0.1:10023 ...@@ -805,7 +709,8 @@ server: 127.0.0.1:10023
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.storage_port)}) "getAddress" : ("127.0.0.1", self.storage_port)})
service.handleAskUnfinishedTransactions(conn, packet) service.handleAskUnfinishedTransactions(conn, packet)
tid_list = self.checkCalledAnswerUnfinishedTransactions(conn)[0] packet = self.checkAnswerUnfinishedTransactions(conn, packet)
tid_list = protocol._decodeAnswerUnfinishedTransactions(packet._body)[0]
self.assertEqual(len(tid_list), 3) self.assertEqual(len(tid_list), 3)
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
from tempfile import mkstemp from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.tests.base import NeoTestBase
from neo.protocol import Packet, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.verification import VerificationEventHandler from neo.master.verification import VerificationEventHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -52,7 +53,7 @@ from neo.master.tests.connector import DoNothingConnector ...@@ -52,7 +53,7 @@ from neo.master.tests.connector import DoNothingConnector
from neo.connection import ClientConnection from neo.connection import ClientConnection
class MasterVerificationeTests(unittest.TestCase): class MasterVerificationeTests(NeoTestBase):
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.WARNING) logging.basicConfig(level = logging.WARNING)
...@@ -123,43 +124,7 @@ server: 127.0.0.1:10023 ...@@ -123,43 +124,7 @@ server: 127.0.0.1:10023
# Delete tmp file # Delete tmp file
os.remove(self.tmp_path) os.remove(self.tmp_path)
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
# Common methods # Common methods
def getNewUUID(self):
uuid = INVALID_UUID
while uuid == INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def getLastUUID(self): def getLastUUID(self):
return self.uuid return self.uuid
...@@ -173,7 +138,7 @@ server: 127.0.0.1:10023 ...@@ -173,7 +138,7 @@ server: 127.0.0.1:10023
# test alien cluster # test alien cluster
conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None}) conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None})
self.verification.handleRequestNodeIdentification(conn, packet, *args) self.verification.handleRequestNodeIdentification(conn, packet, *args)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
return uuid return uuid
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
...@@ -344,7 +309,7 @@ server: 127.0.0.1:10023 ...@@ -344,7 +309,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 1) self.assertEqual(len(self.app.nm.getMasterNodeList()), 1)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
# 3. unknown master node with known address but different uuid, will be replaced # 3. unknown master node with known address but different uuid, will be replaced
old_uuid = uuid old_uuid = uuid
...@@ -397,7 +362,7 @@ server: 127.0.0.1:10023 ...@@ -397,7 +362,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 1) self.assertEqual(len(self.app.nm.getMasterNodeList()), 1)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
# 5. known by uuid, but different address # 5. known by uuid, but different address
conn = Mock({"_addPacket" : None, conn = Mock({"_addPacket" : None,
...@@ -422,7 +387,7 @@ server: 127.0.0.1:10023 ...@@ -422,7 +387,7 @@ server: 127.0.0.1:10023
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2) self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
# a new uuid is sent # a new uuid is sent
call = conn.mockGetNamedCalls('answer')[0] call = conn.mockGetNamedCalls('answer')[0]
body = call.getParam(0)._body body = call.getParam(0)._body
...@@ -503,7 +468,7 @@ server: 127.0.0.1:10023 ...@@ -503,7 +468,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2) self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
# 9. New node # 9. New node
uuid = self.getNewUUID() uuid = self.getNewUUID()
...@@ -527,7 +492,7 @@ server: 127.0.0.1:10023 ...@@ -527,7 +492,7 @@ server: 127.0.0.1:10023
self.assertEqual(len(self.app.nm.getMasterNodeList()), 3) self.assertEqual(len(self.app.nm.getMasterNodeList()), 3)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
self.checkCalledAcceptNodeIdentification(conn) self.checkAcceptNodeIdentification(conn)
def test_05_handleAskPrimaryMaster(self): def test_05_handleAskPrimaryMaster(self):
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import MySQLdb import MySQLdb
from tempfile import mkstemp from tempfile import mkstemp
from mock import Mock from mock import Mock
from neo.tests.base import NeoTestBase
from neo.master.app import MasterNode from neo.master.app import MasterNode
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo.storage.app import Application, StorageNode from neo.storage.app import Application, StorageNode
...@@ -38,7 +39,7 @@ SQL_ADMIN_PASSWORD = None ...@@ -38,7 +39,7 @@ SQL_ADMIN_PASSWORD = None
NEO_SQL_USER = 'test' NEO_SQL_USER = 'test'
NEO_SQL_DATABASE = 'test_storage_neo1' NEO_SQL_DATABASE = 'test_storage_neo1'
class StorageBootstrapTests(unittest.TestCase): class StorageBootstrapTests(NeoTestBase):
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.ERROR) logging.basicConfig(level = logging.ERROR)
...@@ -100,56 +101,15 @@ server: 127.0.0.1:10020 ...@@ -100,56 +101,15 @@ server: 127.0.0.1:10020
os.remove(self.tmp_path) os.remove(self.tmp_path)
# Common methods # Common methods
def getNewUUID(self):
uuid = INVALID_UUID
while uuid == INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def getLastUUID(self): def getLastUUID(self):
return self.uuid return self.uuid
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkAskPacket(self, conn, packet_type):
""" Check if an ask-packet with the right type is send """
calls = conn.mockGetNamedCalls('ask')
self.assertEquals(len(calls), 1)
packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), packet_type)
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
def checkCalledRequestNodeIdentification(self, conn, packet_number=0): def checkCalledRequestNodeIdentification(self, conn, packet_number=0):
""" Check Request Node Identification has been send""" """ Check Request Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0) self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
self.checkAskPacket(conn, protocol.REQUEST_NODE_IDENTIFICATION) self.checkAskPacket(conn, protocol.REQUEST_NODE_IDENTIFICATION)
def checkNoPacketSent(self, conn):
# no packet should be sent
self.assertEquals(len(conn.mockGetNamedCalls('notify')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('answer')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('ask')), 0)
# Tests # Tests
def test_01_connectionCompleted(self): def test_01_connectionCompleted(self):
# trying mn is None -> RuntimeError # trying mn is None -> RuntimeError
......
...@@ -23,6 +23,7 @@ from tempfile import mkstemp ...@@ -23,6 +23,7 @@ from tempfile import mkstemp
from struct import pack, unpack from struct import pack, unpack
from mock import Mock from mock import Mock
from collections import deque from collections import deque
from neo.tests.base import NeoTestBase
from neo.master.app import MasterNode from neo.master.app import MasterNode
from neo.storage.app import Application, StorageNode from neo.storage.app import Application, StorageNode
from neo.storage.operation import TransactionInformation, OperationEventHandler from neo.storage.operation import TransactionInformation, OperationEventHandler
...@@ -36,41 +37,9 @@ SQL_ADMIN_PASSWORD = None ...@@ -36,41 +37,9 @@ SQL_ADMIN_PASSWORD = None
NEO_SQL_USER = 'test' NEO_SQL_USER = 'test'
NEO_SQL_DATABASE = 'test_storage_neo1' NEO_SQL_DATABASE = 'test_storage_neo1'
class StorageOperationTests(unittest.TestCase): class StorageOperationTests(NeoTestBase):
def getNewUUID(self):
uuid = INVALID_UUID
while uuid == INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def getTwoIDs(self):
# generate two ptid, first is lower
ptids = self.getNewUUID(), self.getNewUUID()
return min(ptids), max(ptids)
ptid = min(ptids)
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
# TODO: move this check to base class and rename as checkAnswerPacket
def checkPacket(self, conn, packet_type=ERROR): def checkPacket(self, conn, packet_type=ERROR):
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1) self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
call = conn.mockGetNamedCalls("answer")[0] call = conn.mockGetNamedCalls("answer")[0]
...@@ -88,12 +57,6 @@ class StorageOperationTests(unittest.TestCase): ...@@ -88,12 +57,6 @@ class StorageOperationTests(unittest.TestCase):
self.operation.peerBroken = lambda c: c.peerBrokendCalled() self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, packet=packet, **kwargs) self.checkUnexpectedPacketRaised(_call, conn=conn, packet=packet, **kwargs)
def checkNoPacketSent(self, conn):
# no packet should be sent
self.assertEquals(len(conn.mockGetNamedCalls('notify')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('answer')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('ask')), 0)
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.ERROR) logging.basicConfig(level = logging.ERROR)
# create an application object # create an application object
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import MySQLdb import MySQLdb
from tempfile import mkstemp from tempfile import mkstemp
from mock import Mock from mock import Mock
from neo.tests.base import NeoTestBase
from neo import protocol from neo import protocol
from neo.node import MasterNode from neo.node import MasterNode
from neo.pt import PartitionTable from neo.pt import PartitionTable
...@@ -45,7 +46,7 @@ SQL_ADMIN_PASSWORD = None ...@@ -45,7 +46,7 @@ SQL_ADMIN_PASSWORD = None
NEO_SQL_USER = 'test' NEO_SQL_USER = 'test'
NEO_SQL_DATABASE = 'test_storage_neo1' NEO_SQL_DATABASE = 'test_storage_neo1'
class StorageVerificationTests(unittest.TestCase): class StorageVerificationTests(NeoTestBase):
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.ERROR) logging.basicConfig(level = logging.ERROR)
...@@ -111,48 +112,9 @@ server: 127.0.0.1:10020 ...@@ -111,48 +112,9 @@ server: 127.0.0.1:10020
os.remove(self.tmp_path) os.remove(self.tmp_path)
# Common methods # Common methods
def getNewUUID(self):
uuid = INVALID_UUID
while uuid == INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def getLastUUID(self): def getLastUUID(self):
return self.uuid return self.uuid
def getTwoIDs(self):
# generate two ptid, first is lower
ptids = self.getNewUUID(), self.getNewUUID()
return min(ptids), max(ptids)
ptid = min(ptids)
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkNoPacketSent(self, conn):
# no packet should be sent
self.assertEquals(len(conn.mockGetNamedCalls('notify')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('answer')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('ask')), 0)
# Tests # Tests
def test_01_connectionAccepted(self): def test_01_connectionAccepted(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
......
...@@ -17,9 +17,132 @@ ...@@ -17,9 +17,132 @@
import unittest, os import unittest, os
from mock import Mock from mock import Mock
from neo import protocol
class NeoTestBase(unittest.TestCase): class NeoTestBase(unittest.TestCase):
""" Base class for neo tests, implements common checks """ """ Base class for neo tests, implements common checks """
pass # XXX: according to changes with namespaced UUIDs, it whould be better to
# implement get<NodeType>UUID() methods
def getNewUUID(self):
""" Return a valid UUID """
uuid = protocol.INVALID_UUID
while uuid == protocol.INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def getTwoIDs(self):
""" Return a tuple of two sorted UUIDs """
# generate two ptid, first is lower
ptids = self.getNewUUID(), self.getNewUUID()
return min(ptids), max(ptids)
ptid = min(ptids)
def checkProtocolErrorRaised(self, method, *args, **kwargs):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkNoPacketSent(self, conn):
# no packet should be sent
self.assertEquals(len(conn.mockGetNamedCalls('notify')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('answer')), 0)
self.assertEquals(len(conn.mockGetNamedCalls('ask')), 0)
# in check(Ask|Answer|Notify)Packet we return the packet so it can be used
# in tests if more accurates checks are required
def checkAskPacket(self, conn, packet_type):
""" Check if an ask-packet with the right type is sent """
calls = conn.mockGetNamedCalls('ask')
self.assertEquals(len(calls), 1)
packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type)
return packet
def checkAnswerPacket(self, conn, packet_type, answered_packet=None):
""" Check if an answer-packet with the right type is sent """
calls = conn.mockGetNamedCalls('answer')
self.assertEquals(len(calls), 1)
packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type)
if answered_packet is not None:
a_packet = calls[0].getParam(1)
self.assertEquals(a_packet, answered_packet)
self.assertEquals(a_packet.getId(), answered_packet.getId())
return packet
def checkNotifyPacket(self, conn, packet_type, packet_number=0):
""" Check if a notify-packet with the right type is sent """
calls = conn.mockGetNamedCalls('notify')
self.assertTrue(len(calls) > packet_number)
packet = calls[packet_number].getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type)
return packet
def checkNotifyNodeInformation(self, conn, packet_number=0):
""" Check Notify Node Information message has been send"""
return self.checkNotifyPacket(conn, protocol.NOTIFY_NODE_INFORMATION,
packet_number)
def checkSendPartitionTable(self, conn, packet_number=0):
""" Check partition table has been send"""
return self.checkNotifyPacket(conn, protocol.SEND_PARTITION_TABLE,
packet_number)
def checkStartOperation(self, conn, packet_number=0):
""" Check start operation message has been send"""
return self.checkNotifyPacket(conn, protocol.START_OPERATION,
packet_number)
def checkNotifyTransactionFinished(self, conn, packet_number=0):
""" Check notifyTransactionFinished message has been send"""
return self.checkNotifyPacket(conn, protocol.NOTIFY_TRANSACTION_FINISHED,
packet_number)
def checkLockInformation(self, conn):
""" Check lockInformation message has been send"""
return self.checkAskPacket(conn, protocol.LOCK_INFORMATION)
def checkUnlockInformation(self, conn):
""" Check unlockInformation message has been send"""
return self.checkAskPacket(conn, protocol.UNLOCK_INFORMATION)
def checkAcceptNodeIdentification(self, conn, answered_packet=None):
""" Check Accept Node Identification has been answered """
return self.checkAnswerPacket(conn, protocol.ACCEPT_NODE_IDENTIFICATION,
answered_packet)
def checkAnswerPrimaryMaster(self, conn, answered_packet=None):
""" Check Answer primaty master message has been send"""
return self.checkAnswerPacket(conn, protocol.ANSWER_PRIMARY_MASTER,
answered_packet)
def checkAnswerLastIDs(self, conn, packet_number=0):
""" Check answerLastIDs message has been send"""
return self.checkAnswerPacket(conn, protocol.ANSWER_LAST_IDS)
def checkAnswerUnfinishedTransactions(self, conn, packet_number=0):
""" Check answerUnfinishedTransactions message has been send"""
return self.checkAnswerPacket(conn,
protocol.ANSWER_UNFINISHED_TRANSACTIONS)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment