Commit 59004b8c authored by Julien Muchembled's avatar Julien Muchembled

qa: code cleanup in non-threaded -u tests

parent bcf4afa0
Pipeline #4623 skipped
......@@ -281,18 +281,6 @@ class NeoUnitTestBase(NeoTestBase):
def getNextTID(self, ltid=None):
return newTid(ltid)
def getPTID(self, i=None):
""" Return an integer PTID """
if i is None:
return random.randint(1, 2**64)
return i
def getOID(self, i=None):
""" Return a 8-bytes OID """
if i is None:
return os.urandom(8)
return pack('!Q', i)
def getFakeConnector(self, descriptor=None):
return Mock({
'__repr__': 'FakeConnector',
......@@ -321,18 +309,6 @@ class NeoUnitTestBase(NeoTestBase):
""" Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception was raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception was raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception was raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
......@@ -341,35 +317,18 @@ class NeoUnitTestBase(NeoTestBase):
""" Ensure the connection was aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
def checkNotAborted(self, conn):
""" Ensure the connection was not aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 0)
def checkClosed(self, conn):
""" Ensure the connection was closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 1)
def checkNotClosed(self, conn):
""" Ensure the connection was not closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 0)
def _checkNoPacketSend(self, conn, method_id):
call_list = conn.mockGetNamedCalls(method_id)
self.assertEqual(len(call_list), 0, call_list)
self.assertEqual([], conn.mockGetNamedCalls(method_id))
def checkNoPacketSent(self, conn, check_notify=True, check_answer=True,
check_ask=True):
def checkNoPacketSent(self, conn):
""" check if no packet were sent """
if check_notify:
self._checkNoPacketSend(conn, 'notify')
if check_answer:
self._checkNoPacketSend(conn, 'answer')
if check_ask:
self._checkNoPacketSend(conn, 'ask')
def checkNoUUIDSet(self, conn):
""" ensure no UUID was set on the connection """
self.assertEqual(len(conn.mockGetNamedCalls('setUUID')), 0)
self._checkNoPacketSend(conn, 'notify')
self._checkNoPacketSend(conn, 'answer')
self._checkNoPacketSend(conn, 'ask')
def checkUUIDSet(self, conn, uuid=None, check_intermediate=True):
""" ensure UUID was set on the connection """
......@@ -384,151 +343,41 @@ class NeoUnitTestBase(NeoTestBase):
# in check(Ask|Answer|Notify)Packet we return the packet so it can be used
# in tests if more accurate checks are required
def checkErrorPacket(self, conn, decode=False):
def checkErrorPacket(self, conn):
""" Check if an error packet was answered """
calls = conn.mockGetNamedCalls("answer")
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), Packets.Error)
if decode:
return packet.decode()
return packet
def checkAskPacket(self, conn, packet_type, decode=False):
def checkAskPacket(self, conn, packet_type):
""" Check if an ask-packet with the right type is sent """
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
def checkAnswerPacket(self, conn, packet_type, decode=False):
def checkAnswerPacket(self, conn, packet_type):
""" Check if an answer-packet with the right type is sent """
calls = conn.mockGetNamedCalls('answer')
self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False):
def checkNotifyPacket(self, conn, packet_type, packet_number=0):
""" Check if a notify-packet with the right type is sent """
calls = conn.mockGetNamedCalls('notify')
packet = calls.pop(packet_number).getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet
def checkNotify(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.Notify, **kw)
def checkNotifyNodeInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation, **kw)
def checkSendPartitionTable(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.SendPartitionTable, **kw)
def checkStartOperation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.StartOperation, **kw)
def checkInvalidateObjects(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.InvalidateObjects, **kw)
def checkAbortTransaction(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.AbortTransaction, **kw)
def checkNotifyLastOID(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyLastOID, **kw)
def checkAnswerTransactionFinished(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionFinished, **kw)
def checkAnswerInformationLocked(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerInformationLocked, **kw)
def checkAskLockInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLockInformation, **kw)
def checkNotifyUnlockInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw)
def checkNotifyTransactionFinished(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw)
def checkRequestIdentification(self, conn, **kw):
return self.checkAskPacket(conn, Packets.RequestIdentification, **kw)
def checkAskPrimary(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskPrimary)
def checkAskUnfinishedTransactions(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskUnfinishedTransactions)
def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw)
def checkAskStoreObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreObject, **kw)
def checkAskStoreTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreTransaction, **kw)
def checkAskFinishTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskFinishTransaction, **kw)
def checkAskNewTid(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskBeginTransaction, **kw)
def checkAskLastIDs(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLastIDs, **kw)
def checkAcceptIdentification(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification, **kw)
def checkAnswerPrimary(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPrimary, **kw)
def checkAnswerLastIDs(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerLastIDs, **kw)
def checkAnswerUnfinishedTransactions(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerUnfinishedTransactions, **kw)
def checkAnswerObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObject, **kw)
def checkAnswerTransactionInformation(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionInformation, **kw)
def checkAnswerBeginTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction, **kw)
def checkAnswerTids(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
def checkAnswerTidsFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw)
def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
class Patch(object):
"""
......
......@@ -68,6 +68,9 @@ class ClientApplicationTests(NeoUnitTestBase):
# some helpers
def checkAskObject(self, conn):
return self.checkAskPacket(conn, Packets.AskObject)
def _begin(self, app, txn, tid):
txn_context = app._txn_container.new(txn)
txn_context['ttid'] = tid
......
......@@ -21,6 +21,7 @@ from .. import NeoUnitTestBase
from neo.client.app import ConnectionPool
from neo.client.exception import NEOStorageError
from neo.client import pool
from neo.lib.util import p64
class ConnectionPoolTests(NeoUnitTestBase):
......@@ -54,7 +55,7 @@ class ConnectionPoolTests(NeoUnitTestBase):
def test_iterateForObject_noStorageAvailable(self):
# no node available
oid = self.getOID(1)
oid = p64(1)
app = Mock()
app.pt = Mock({'getCellList': []})
pool = ConnectionPool(app)
......
......@@ -17,6 +17,7 @@
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application
......@@ -62,6 +63,9 @@ class MasterClientHandlerTests(NeoUnitTestBase):
)
return uuid
def checkAnswerBeginTransaction(self, conn):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction)
# Tests
def test_07_askBeginTransaction(self):
tid1 = self.getNextTID()
......@@ -87,12 +91,12 @@ class MasterClientHandlerTests(NeoUnitTestBase):
calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_node, None)
args = self.checkAnswerBeginTransaction(conn, decode=True)
self.assertEqual(args, (tid1, ))
packet = self.checkAnswerBeginTransaction(conn)
self.assertEqual(packet.decode(), (tid1, ))
def test_08_askNewOIDs(self):
service = self.service
oid1, oid2 = self.getOID(1), self.getOID(2)
oid1, oid2 = p64(1), p64(2)
self.app.tm.setLastOID(oid1)
# client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
......@@ -136,7 +140,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.setStorageReady(storage_uuid)
self.assertTrue(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, (), ())
self.checkAskLockInformation(storage_conn)
self.checkAskPacket(storage_conn, Packets.AskLockInformation)
self.assertEqual(len(self.app.tm.registerForNotification(storage_uuid)), 1)
txn = self.app.tm[ttid]
pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0]
......@@ -170,8 +174,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack,
decode=True)[0]
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0]
self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id)
......@@ -183,8 +186,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack,
decode=True)[0]
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0]
self.assertFalse(status)
if __name__ == '__main__':
......
......@@ -18,7 +18,7 @@ import unittest
from mock import Mock
from neo.lib import protocol
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.election import ClientElectionHandler, \
ServerElectionHandler
from neo.master.app import Application
......@@ -48,6 +48,9 @@ class MasterClientElectionTestBase(NeoUnitTestBase):
node.setConnection(conn)
return (node, conn)
def checkAcceptIdentification(self, conn):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification)
class MasterClientElectionTests(MasterClientElectionTestBase):
def setUp(self):
......@@ -91,7 +94,7 @@ class MasterClientElectionTests(MasterClientElectionTestBase):
self.election.connectionCompleted(conn)
self._checkUnconnected(node)
self.assertTrue(node.isUnknown())
self.checkRequestIdentification(conn)
self.checkAskPacket(conn, Packets.RequestIdentification)
def _setNegociating(self, node):
self._checkUnconnected(node)
......@@ -252,9 +255,8 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.election.requestIdentification(conn,
NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID())
args = self.checkAcceptIdentification(conn, decode=True)
(node_type, uuid, partitions, replicas, new_uuid, primary_uuid,
master_list) = args
master_list) = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node.getUUID(), new_uuid)
self.assertNotEqual(node.getUUID(), uuid)
......@@ -290,7 +292,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
None,
)
node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn, decode=True)
master_list = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node_type, NodeTypes.MASTER)
self.assertEqual(uuid, self.app.uuid)
self.assertEqual(partitions, self.app.pt.getPartitions())
......
......@@ -16,6 +16,7 @@
import unittest
from .. import NeoUnitTestBase
from neo.lib.protocol import Packets
from neo.master.app import Application
class MasterAppTests(NeoUnitTestBase):
......@@ -31,6 +32,9 @@ class MasterAppTests(NeoUnitTestBase):
self.app.close()
NeoUnitTestBase._tearDown(self, success)
def checkNotifyNodeInformation(self, conn):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation)
def test_06_broadcastNodeInformation(self):
# defined some nodes to which data will be send
master_uuid = self.getMasterUUID()
......
......@@ -71,10 +71,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.checkNoPacketSent(client_conn)
self.assertEqual(self.app.packing[2], {conn2.getUUID()})
self.service.answerPack(conn2, False)
status = self.checkAnswerPacket(client_conn, Packets.AnswerPack,
decode=True)[0]
packet = self.checkAnswerPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id
self.assertTrue(status)
self.assertTrue(packet.decode()[0])
self.assertEqual(self.app.packing, None)
if __name__ == '__main__':
......
......@@ -20,6 +20,7 @@ from collections import deque
from .. import NeoUnitTestBase
from neo.storage.app import Application
from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.util import p64
from neo.lib.protocol import INVALID_TID, INVALID_OID, Packets, LockState
class StorageClientHandlerTests(NeoUnitTestBase):
......@@ -91,7 +92,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn)
self.checkAnswerPacket(conn, Packets.AnswerTIDs)
def test_26_askObjectHistory1(self):
# invalid offsets => error
......@@ -108,7 +109,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
ltid = self.getNextTID()
undone_tid = self.getNextTID()
# Keep 2 entries here, so we check findUndoTID is called only once.
oid_list = [self.getOID(1), self.getOID(2)]
oid_list = map(p64, (1, 2))
obj2_data = [] # Marker
self.app.tm = Mock({
'getObjectFromTransaction': None,
......@@ -134,7 +135,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
conn = self._getConnection()
self.operation.askHasLock(conn, tid_1, oid)
p_oid, p_status = self.checkAnswerPacket(conn,
Packets.AnswerHasLock, decode=True)
Packets.AnswerHasLock).decode()
self.assertEqual(oid, p_oid)
self.assertEqual(status, p_status)
......
......@@ -103,20 +103,19 @@ class StorageDBTests(NeoUnitTestBase):
def test_15_PTID(self):
db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1))
self.checkConfigEntry(db.getPTID, db.setPTID, 1)
def test_getPartitionTable(self):
db = self.getDB()
ptid = self.getPTID(1)
uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
cell1 = (0, uuid1, CellStates.OUT_OF_DATE)
cell2 = (1, uuid1, CellStates.UP_TO_DATE)
db.changePartitionTable(ptid, [cell1, cell2], 1)
db.changePartitionTable(1, [cell1, cell2], 1)
result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2})
def getOIDs(self, count):
return map(self.getOID, xrange(count))
return map(p64, xrange(count))
def getTIDs(self, count):
tid_list = [self.getNextTID()]
......@@ -198,7 +197,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_setPartitionTable(self):
db = self.getDB()
ptid = self.getPTID(1)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
......@@ -220,7 +219,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_changePartitionTable(self):
db = self.getDB()
ptid = self.getPTID(1)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
......@@ -301,7 +300,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_deleteRange(self):
np = 4
self.setNumPartitions(np)
t1, t2, t3 = map(self.getOID, (1, 2, 3))
t1, t2, t3 = map(p64, (1, 2, 3))
oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list)
......@@ -339,7 +338,7 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getObjectHistory(self):
oid = self.getOID(1)
oid = p64(1)
tid1, tid2, tid3 = self.getTIDs(3)
txn1, objs1 = self.getTransaction([oid])
txn2, objs2 = self.getTransaction([oid])
......@@ -362,7 +361,7 @@ class StorageDBTests(NeoUnitTestBase):
def _storeTransactions(self, count):
# use OID generator to know result of tid % N
tid_list = self.getOIDs(count)
oid = self.getOID(1)
oid = p64(1)
for tid in tid_list:
txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn, False)
......@@ -446,7 +445,7 @@ class StorageDBTests(NeoUnitTestBase):
tid3 = self.getNextTID()
tid4 = self.getNextTID()
tid5 = self.getNextTID()
oid1 = self.getOID(1)
oid1 = p64(1)
foo = db.holdData("3" * 20, 'foo', 0)
bar = db.holdData("4" * 20, 'bar', 0)
db.releaseData((foo, bar))
......
......@@ -17,6 +17,7 @@
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.storage.transactions import TransactionManager
......@@ -36,7 +37,7 @@ class TransactionManagerTests(NeoUnitTestBase):
def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID()
oid = self.getOID(1)
oid = p64(1)
orig_serial = self.getNextTID()
uuid = self.getClientUUID()
locking_serial = self.getNextTID()
......
......@@ -18,7 +18,7 @@ import unittest
from . import NeoUnitTestBase
from neo.storage.app import Application
from neo.lib.bootstrap import BootstrapManager
from neo.lib.protocol import NodeTypes
from neo.lib.protocol import NodeTypes, Packets
class BootstrapManagerTests(NeoUnitTestBase):
......@@ -46,7 +46,7 @@ class BootstrapManagerTests(NeoUnitTestBase):
conn = self.getFakeConnection(address=address)
self.bootstrap.current = self.app.nm.createMaster(address=address)
self.bootstrap.connectionCompleted(conn)
self.checkRequestIdentification(conn)
self.checkAskPacket(conn, Packets.RequestIdentification)
def testHandleNotReady(self):
# the primary is not ready
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment