Commit 64213d90 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Update tests according to remove of 'packet' parameter from handler methods.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1571 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 8030b635
...@@ -228,16 +228,13 @@ class NeoTestBase(unittest.TestCase): ...@@ -228,16 +228,13 @@ class NeoTestBase(unittest.TestCase):
return packet.decode() return packet.decode()
return packet return packet
def checkAnswerPacket(self, conn, packet_type, answered_packet=None, decode=False): def checkAnswerPacket(self, conn, packet_type, decode=False):
""" Check if an answer-packet with the right type is sent """ """ Check if an answer-packet with the right type is sent """
calls = conn.mockGetNamedCalls('answer') calls = conn.mockGetNamedCalls('answer')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
packet = calls[0].getParam(0) packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type) self.assertEquals(packet.getType(), packet_type)
if answered_packet is not None:
msg_id = calls[0].getParam(1)
self.assertEqual(msg_id, answered_packet.getId())
if decode: if decode:
return packet.decode() return packet.decode()
return packet return packet
......
...@@ -43,6 +43,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -43,6 +43,7 @@ class ClientHandlerTests(NeoTestBase):
'getUUID': uuid, 'getUUID': uuid,
'getAddress': (ip, port), 'getAddress': (ip, port),
'getNextId': next_id, 'getNextId': next_id,
'getPeerId': 0,
'lock': None, 'lock': None,
'unlock': None}) 'unlock': None})
...@@ -64,7 +65,8 @@ class ClientHandlerTests(NeoTestBase): ...@@ -64,7 +65,8 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = BaseHandler(None, dispatcher) client_handler = BaseHandler(None, dispatcher)
conn = self.getConnection() conn = self.getConnection()
client_handler.packetReceived(conn, Packets.Ping()) packet = protocol.Ping()
client_handler.packetReceived(conn, packet)
self.checkAnswerPacket(conn, protocol.PONG) self.checkAnswerPacket(conn, protocol.PONG)
def _testInitialMasterWithMethod(self, method): def _testInitialMasterWithMethod(self, method):
...@@ -200,7 +202,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -200,7 +202,7 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
conn = self.getConnection() conn = self.getConnection()
client_handler = StorageBootstrapHandler(app) client_handler = StorageBootstrapHandler(app)
client_handler.notReady(conn, None, None) client_handler.notReady(conn, None)
self.assertEquals(len(app.mockGetNamedCalls('setNodeNotReady')), 1) self.assertEquals(len(app.mockGetNamedCalls('setNodeNotReady')), 1)
def test_clientAcceptIdentification(self): def test_clientAcceptIdentification(self):
...@@ -214,12 +216,8 @@ class ClientHandlerTests(NeoTestBase): ...@@ -214,12 +216,8 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
uuid = self.getNewUUID() uuid = self.getNewUUID()
app.uuid = 'C' * 16 app.uuid = 'C' * 16
client_handler.acceptIdentification( client_handler.acceptIdentification(conn, NodeTypes.CLIENT,
conn, None, uuid, 0, 0, INVALID_UUID)
NodeTypes.CLIENT,
uuid, ('127.0.0.1', 10010),
0, 0, INVALID_UUID
)
self.checkClosed(conn) self.checkClosed(conn)
self.assertEquals(app.storage_node, None) self.assertEquals(app.storage_node, None)
self.assertEquals(app.pt, None) self.assertEquals(app.pt, None)
...@@ -242,8 +240,8 @@ class ClientHandlerTests(NeoTestBase): ...@@ -242,8 +240,8 @@ class ClientHandlerTests(NeoTestBase):
uuid = self.getNewUUID() uuid = self.getNewUUID()
your_uuid = 'C' * 16 your_uuid = 'C' * 16
app.uuid = INVALID_UUID app.uuid = INVALID_UUID
client_handler.acceptIdentification(conn, None, client_handler.acceptIdentification(conn, NodeTypes.MASTER,
NodeTypes.MASTER, uuid, ('127.0.0.1', 10010), 10, 2, your_uuid) uuid, 10, 2, your_uuid)
self.checkNotClosed(conn) self.checkNotClosed(conn)
self.checkUUIDSet(conn, uuid) self.checkUUIDSet(conn, uuid)
self.assertEquals(app.storage_node, None) self.assertEquals(app.storage_node, None)
...@@ -262,8 +260,8 @@ class ClientHandlerTests(NeoTestBase): ...@@ -262,8 +260,8 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
uuid = self.getNewUUID() uuid = self.getNewUUID()
app.uuid = 'C' * 16 app.uuid = 'C' * 16
client_handler.acceptIdentification(conn, None, client_handler.acceptIdentification(conn, NodeTypes.STORAGE,
NodeTypes.STORAGE, uuid, ('127.0.0.1', 10010), 0, 0, INVALID_UUID) uuid, 0, 0, INVALID_UUID)
self.checkNotClosed(conn) self.checkNotClosed(conn)
self.checkUUIDSet(conn, uuid) self.checkUUIDSet(conn, uuid)
self.assertEquals(app.pt, None) self.assertEquals(app.pt, None)
...@@ -282,7 +280,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -282,7 +280,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerPrimary(conn, None, 0, []) client_handler.answerPrimary(conn, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertEqual(len(app.nm.mockGetNamedCalls('getByAddress')), 0) self.assertEqual(len(app.nm.mockGetNamedCalls('getByAddress')), 0)
self.assertEqual(len(app.nm.mockGetNamedCalls('add')), 0) self.assertEqual(len(app.nm.mockGetNamedCalls('add')), 0)
...@@ -296,7 +294,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -296,7 +294,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_master_list = [(('127.0.0.1', 10010), self.getNewUUID())] test_master_list = [(('127.0.0.1', 10010), self.getNewUUID())]
client_handler.answerPrimary(conn, None, INVALID_UUID, test_master_list) client_handler.answerPrimary(conn, INVALID_UUID, test_master_list)
# Check that yet-unknown master node got added # Check that yet-unknown master node got added
getByAddress_call_list = app.nm.mockGetNamedCalls('getByAddress') getByAddress_call_list = app.nm.mockGetNamedCalls('getByAddress')
add_call_list = app.nm.mockGetNamedCalls('add') add_call_list = app.nm.mockGetNamedCalls('add')
...@@ -322,7 +320,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -322,7 +320,7 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
test_node_uuid = self.getNewUUID() test_node_uuid = self.getNewUUID()
test_master_list = [(('127.0.0.1', 10010), test_node_uuid)] test_master_list = [(('127.0.0.1', 10010), test_node_uuid)]
client_handler.answerPrimary(conn, None, INVALID_UUID, test_master_list) client_handler.answerPrimary(conn, INVALID_UUID, test_master_list)
# Test sanity checks # Test sanity checks
getByAddress_call_list = app.nm.mockGetNamedCalls('getByAddress') getByAddress_call_list = app.nm.mockGetNamedCalls('getByAddress')
self.assertEqual(len(getByAddress_call_list), 1) self.assertEqual(len(getByAddress_call_list), 1)
...@@ -348,7 +346,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -348,7 +346,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_master_list = [(('127.0.0.1', 10010), test_node_uuid)] test_master_list = [(('127.0.0.1', 10010), test_node_uuid)]
client_handler.answerPrimary(conn, None, INVALID_UUID, test_master_list) client_handler.answerPrimary(conn, INVALID_UUID, test_master_list)
# Test sanity checks # Test sanity checks
getByAddress_call_list = app.nm.mockGetNamedCalls('getByAddress') getByAddress_call_list = app.nm.mockGetNamedCalls('getByAddress')
self.assertEqual(len(getByAddress_call_list), 1) self.assertEqual(len(getByAddress_call_list), 1)
...@@ -384,7 +382,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -384,7 +382,7 @@ class ClientHandlerTests(NeoTestBase):
# If primary master is already set *and* is not given primary master # If primary master is already set *and* is not given primary master
# handle call raises. # handle call raises.
# Check that the call doesn't raise # Check that the call doesn't raise
client_handler.answerPrimary(conn, None, test_node_uuid, []) client_handler.answerPrimary(conn, test_node_uuid, [])
# Check that the primary master changed # Check that the primary master changed
self.assertTrue(app.primary_master_node is node) self.assertTrue(app.primary_master_node is node)
# Test sanity checks # Test sanity checks
...@@ -404,7 +402,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -404,7 +402,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerPrimary(conn, None, test_node_uuid, []) client_handler.answerPrimary(conn, test_node_uuid, [])
# Check that primary node is (still) node. # Check that primary node is (still) node.
self.assertTrue(app.primary_master_node is node) self.assertTrue(app.primary_master_node is node)
...@@ -421,7 +419,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -421,7 +419,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerPrimary(conn, None, test_primary_node_uuid, []) client_handler.answerPrimary(conn, test_primary_node_uuid, [])
# Test sanity checks # Test sanity checks
getByUUID_call_list = app.nm.mockGetNamedCalls('getByUUID') getByUUID_call_list = app.nm.mockGetNamedCalls('getByUUID')
self.assertEqual(len(getByUUID_call_list), 1) self.assertEqual(len(getByUUID_call_list), 1)
...@@ -440,7 +438,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -440,7 +438,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_master_list = [(('127.0.0.1', 10010), test_node_uuid)] test_master_list = [(('127.0.0.1', 10010), test_node_uuid)]
client_handler.answerPrimary(conn, None, test_node_uuid, test_master_list) client_handler.answerPrimary(conn, test_node_uuid, test_master_list)
# Test sanity checks # Test sanity checks
getByUUID_call_list = app.nm.mockGetNamedCalls('getByUUID') getByUUID_call_list = app.nm.mockGetNamedCalls('getByUUID')
self.assertEqual(len(getByUUID_call_list), 1) self.assertEqual(len(getByUUID_call_list), 1)
...@@ -460,7 +458,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -460,7 +458,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryNotificationsHandler(app, Mock()) client_handler = PrimaryNotificationsHandler(app, Mock())
conn = self.getConnection() conn = self.getConnection()
client_handler.sendPartitionTable(conn, None, test_ptid + 1, []) client_handler.sendPartitionTable(conn, test_ptid + 1, [])
# Check that partition table got cleared and ptid got updated # Check that partition table got cleared and ptid got updated
self.assertEquals(app.pt.getID(), 1) self.assertEquals(app.pt.getID(), 1)
...@@ -474,7 +472,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -474,7 +472,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryNotificationsHandler(app, self.getDispatcher()) client_handler = PrimaryNotificationsHandler(app, self.getDispatcher())
conn = self.getConnection(uuid=test_master_uuid) conn = self.getConnection(uuid=test_master_uuid)
client_handler.notifyNodeInformation(conn, None, ()) client_handler.notifyNodeInformation(conn, ())
def test_nonIterableParameterRaisesNotifyNodeInformation(self): def test_nonIterableParameterRaisesNotifyNodeInformation(self):
# XXX: this test is here for sanity self-check: it verifies the # XXX: this test is here for sanity self-check: it verifies the
...@@ -489,7 +487,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -489,7 +487,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryNotificationsHandler(app, self.getDispatcher()) client_handler = PrimaryNotificationsHandler(app, self.getDispatcher())
conn = self.getConnection(uuid=test_master_uuid) conn = self.getConnection(uuid=test_master_uuid)
self.assertRaises(TypeError, client_handler.notifyNodeInformation, self.assertRaises(TypeError, client_handler.notifyNodeInformation,
conn, None, None) conn, None)
def _testNotifyNodeInformation(self, test_node, getByAddress=None, getByUUID=MARKER): def _testNotifyNodeInformation(self, test_node, getByAddress=None, getByUUID=MARKER):
invalid_uid_test_node = (test_node[0], (test_node[1][0], invalid_uid_test_node = (test_node[0], (test_node[1][0],
...@@ -508,7 +506,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -508,7 +506,7 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = PrimaryNotificationsHandler(app, dispatcher) client_handler = PrimaryNotificationsHandler(app, dispatcher)
conn = self.getConnection(uuid=test_master_uuid) conn = self.getConnection(uuid=test_master_uuid)
client_handler.notifyNodeInformation(conn, None, test_node_list) client_handler.notifyNodeInformation(conn, test_node_list)
# Return nm so caller can check handler actions. # Return nm so caller can check handler actions.
return app.nm return app.nm
...@@ -582,7 +580,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -582,7 +580,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryNotificationsHandler(app, self.getDispatcher()) client_handler = PrimaryNotificationsHandler(app, self.getDispatcher())
conn = self.getConnection(uuid=test_master_uuid) conn = self.getConnection(uuid=test_master_uuid)
client_handler.notifyPartitionChanges(conn, None, 0, []) client_handler.notifyPartitionChanges(conn, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
...@@ -597,7 +595,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -597,7 +595,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryNotificationsHandler(app, self.getDispatcher()) client_handler = PrimaryNotificationsHandler(app, self.getDispatcher())
conn = self.getConnection() conn = self.getConnection()
client_handler.notifyPartitionChanges(conn, None, 0, []) client_handler.notifyPartitionChanges(conn, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
...@@ -617,7 +615,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -617,7 +615,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryNotificationsHandler(app, self.getDispatcher()) client_handler = PrimaryNotificationsHandler(app, self.getDispatcher())
conn = self.getConnection(uuid=test_sender_uuid) conn = self.getConnection(uuid=test_sender_uuid)
client_handler.notifyPartitionChanges(conn, None, 0, []) client_handler.notifyPartitionChanges(conn, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
...@@ -634,7 +632,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -634,7 +632,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryNotificationsHandler(app, self.getDispatcher()) client_handler = PrimaryNotificationsHandler(app, self.getDispatcher())
conn = self.getConnection(uuid=test_master_uuid) conn = self.getConnection(uuid=test_master_uuid)
client_handler.notifyPartitionChanges(conn, None, test_ptid, []) client_handler.notifyPartitionChanges(conn, test_ptid, [])
# Check that nothing happened # Check that nothing happened
self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
...@@ -686,7 +684,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -686,7 +684,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryAnswersHandler(app) client_handler = PrimaryAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_tid = 1 test_tid = 1
client_handler.answerBeginTransaction(conn, None, test_tid) client_handler.answerBeginTransaction(conn, test_tid)
setTID_call_list = app.mockGetNamedCalls('setTID') setTID_call_list = app.mockGetNamedCalls('setTID')
self.assertEquals(len(setTID_call_list), 1) self.assertEquals(len(setTID_call_list), 1)
self.assertEquals(setTID_call_list[0].getParam(0), test_tid) self.assertEquals(setTID_call_list[0].getParam(0), test_tid)
...@@ -697,7 +695,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -697,7 +695,7 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = PrimaryAnswersHandler(app) client_handler = PrimaryAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerTransactionFinished(conn, None, test_tid) client_handler.answerTransactionFinished(conn, test_tid)
self.assertEquals(len(app.mockGetNamedCalls('setTransactionFinished')), 1) self.assertEquals(len(app.mockGetNamedCalls('setTransactionFinished')), 1)
# TODO: decide what to do when non-current transaction is notified as finished, and test that behaviour # TODO: decide what to do when non-current transaction is notified as finished, and test that behaviour
...@@ -724,7 +722,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -724,7 +722,7 @@ class ClientHandlerTests(NeoTestBase):
test_oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02'] test_oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02']
test_db = Mock({'invalidate': None}) test_db = Mock({'invalidate': None})
app.registerDB(test_db, None) app.registerDB(test_db, None)
client_handler.invalidateObjects(conn, None, test_oid_list[:], test_tid) client_handler.invalidateObjects(conn, test_oid_list[:], test_tid)
# 'invalidate' is called just once # 'invalidate' is called just once
db = app.getDB() db = app.getDB()
self.assertTrue(db is test_db) self.assertTrue(db is test_db)
...@@ -751,7 +749,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -751,7 +749,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryAnswersHandler(app) client_handler = PrimaryAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02'] test_oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02']
client_handler.answerNewOIDs(conn, None, test_oid_list[:]) client_handler.answerNewOIDs(conn, test_oid_list[:])
self.assertEquals(set(app.new_oid_list), set(test_oid_list)) self.assertEquals(set(app.new_oid_list), set(test_oid_list))
def test_StopOperation(self): def test_StopOperation(self):
...@@ -770,14 +768,14 @@ class ClientHandlerTests(NeoTestBase): ...@@ -770,14 +768,14 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
# XXX: use realistic values # XXX: use realistic values
test_object_data = ('\x00\x00\x00\x00\x00\x00\x00\x01', 0, 0, 0, 0, 'test') test_object_data = ('\x00\x00\x00\x00\x00\x00\x00\x01', 0, 0, 0, 0, 'test')
client_handler.answerObject(conn, None, *test_object_data) client_handler.answerObject(conn, *test_object_data)
self.assertEquals(app.local_var.asked_object, test_object_data) self.assertEquals(app.local_var.asked_object, test_object_data)
def _testAnswerStoreObject(self, app, conflicting, oid, serial): def _testAnswerStoreObject(self, app, conflicting, oid, serial):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = StorageAnswersHandler(app) client_handler = StorageAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerStoreObject(conn, None, conflicting, oid, serial) client_handler.answerStoreObject(conn, conflicting, oid, serial)
def test_conflictingAnswerStoreObject(self): def test_conflictingAnswerStoreObject(self):
class App: class App:
...@@ -805,7 +803,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -805,7 +803,7 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = StorageAnswersHandler(app) client_handler = StorageAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerStoreTransaction(conn, None, test_tid) client_handler.answerStoreTransaction(conn, test_tid)
self.assertEquals(len(app.mockGetNamedCalls('setTransactionVoted')), 1) self.assertEquals(len(app.mockGetNamedCalls('setTransactionVoted')), 1)
# XXX: test answerObject with test_tid not matching app.tid (not handled in program) # XXX: test answerObject with test_tid not matching app.tid (not handled in program)
...@@ -823,7 +821,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -823,7 +821,7 @@ class ClientHandlerTests(NeoTestBase):
desc = 'foo' desc = 'foo'
ext = 0 # XXX: unused in implementation ext = 0 # XXX: unused in implementation
oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02'] oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02']
client_handler.answerTransactionInformation(conn, None, tid, user, desc, ext, oid_list[:]) client_handler.answerTransactionInformation(conn, tid, user, desc, ext, oid_list[:])
stored_dict = app.local_var.txn_info stored_dict = app.local_var.txn_info
# XXX: test 'time' value ? # XXX: test 'time' value ?
self.assertEquals(stored_dict['user_name'], user) self.assertEquals(stored_dict['user_name'], user)
...@@ -843,7 +841,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -843,7 +841,7 @@ class ClientHandlerTests(NeoTestBase):
test_oid = '\x00\x00\x00\x00\x00\x00\x00\x01' test_oid = '\x00\x00\x00\x00\x00\x00\x00\x01'
# XXX: use realistic values # XXX: use realistic values
test_history_list = [(1, 2), (3, 4)] test_history_list = [(1, 2), (3, 4)]
client_handler.answerObjectHistory(conn, None, test_oid, test_history_list[:]) client_handler.answerObjectHistory(conn, test_oid, test_history_list[:])
oid, history = app.local_var.history oid, history = app.local_var.history
self.assertEquals(oid, test_oid) self.assertEquals(oid, test_oid)
self.assertEquals(len(history), len(test_history_list)) self.assertEquals(len(history), len(test_history_list))
...@@ -859,7 +857,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -859,7 +857,7 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = StorageAnswersHandler(app) client_handler = StorageAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.oidNotFound(conn, None, None) client_handler.oidNotFound(conn, None)
self.assertEquals(app.local_var.asked_object, -1) self.assertEquals(app.local_var.asked_object, -1)
self.assertEquals(app.local_var.history, -1) self.assertEquals(app.local_var.history, -1)
...@@ -872,7 +870,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -872,7 +870,7 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = StorageAnswersHandler(app) client_handler = StorageAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.tidNotFound(conn, None, None) client_handler.tidNotFound(conn, None)
self.assertEquals(app.local_var.txn_info, -1) self.assertEquals(app.local_var.txn_info, -1)
def test_AnswerTIDs(self): def test_AnswerTIDs(self):
...@@ -885,7 +883,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -885,7 +883,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = StorageAnswersHandler(app) client_handler = StorageAnswersHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_tid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02'] test_tid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02']
client_handler.answerTIDs(conn, None, test_tid_list[:]) client_handler.answerTIDs(conn, test_tid_list[:])
stored_tid_list = [] stored_tid_list = []
for tid_list in app.local_var.node_tids.itervalues(): for tid_list in app.local_var.node_tids.itervalues():
stored_tid_list.extend(tid_list) stored_tid_list.extend(tid_list)
......
...@@ -69,13 +69,11 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -69,13 +69,11 @@ class MasterClientHandlerTests(NeoTestBase):
def test_07_askBeginTransaction(self): def test_07_askBeginTransaction(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AskBeginTransaction()
packet.setId(0)
ltid = self.app.tm.getLastTID() ltid = self.app.tm.getLastTID()
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
service.askBeginTransaction(conn, packet, None) service.askBeginTransaction(conn, None)
self.assertTrue(ltid < self.app.tm.getLastTID()) self.assertTrue(ltid < self.app.tm.getLastTID())
self.assertEqual(len(self.app.tm.getPendingList()), 1) self.assertEqual(len(self.app.tm.getPendingList()), 1)
tid = self.app.tm.getPendingList()[0] tid = self.app.tm.getPendingList()[0]
...@@ -84,27 +82,23 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -84,27 +82,23 @@ class MasterClientHandlerTests(NeoTestBase):
def test_08_askNewOIDs(self): def test_08_askNewOIDs(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AskNewOIDs()
packet.setId(0)
loid = self.app.loid loid = self.app.loid
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
service.askNewOIDs(conn, packet, 1) service.askNewOIDs(conn, 1)
self.assertTrue(loid < self.app.loid) self.assertTrue(loid < self.app.loid)
def test_09_finishTransaction(self): def test_09_finishTransaction(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.FinishTransaction()
packet.setId(9)
# give an older tid than the PMN known, must abort # give an older tid than the PMN known, must abort
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
oid_list = [] oid_list = []
upper, lower = unpack('!LL', self.app.tm.getLastTID()) upper, lower = unpack('!LL', self.app.tm.getLastTID())
new_tid = pack('!LL', upper, lower + 10) new_tid = pack('!LL', upper, lower + 10)
self.checkProtocolErrorRaised(service.finishTransaction, conn, packet, oid_list, new_tid) self.checkProtocolErrorRaised(service.finishTransaction, conn, oid_list, new_tid)
old_node = self.app.nm.getByUUID(uuid) old_node = self.app.nm.getByUUID(uuid)
self.app.nm.remove(old_node) self.app.nm.remove(old_node)
self.app.pt.dropNode(old_node) self.app.pt.dropNode(old_node)
...@@ -119,12 +113,12 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -119,12 +113,12 @@ class MasterClientHandlerTests(NeoTestBase):
'getPartition': 0, 'getPartition': 0,
'getCellList': [Mock({'getUUID': storage_uuid})], 'getCellList': [Mock({'getUUID': storage_uuid})],
}) })
service.askBeginTransaction(conn, packet, None) service.askBeginTransaction(conn, None)
oid_list = [] oid_list = []
tid = self.app.tm.getLastTID() tid = self.app.tm.getLastTID()
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
self.app.em = Mock({"getConnectionList" : [conn, storage_conn]}) self.app.em = Mock({"getConnectionList" : [conn, storage_conn]})
service.finishTransaction(conn, packet, oid_list, tid) service.finishTransaction(conn, oid_list, tid)
self.checkLockInformation(storage_conn) self.checkLockInformation(storage_conn)
self.assertEquals(len(self.app.tm.getPendingList()), 1) self.assertEquals(len(self.app.tm.getPendingList()), 1)
apptid = self.app.tm.getPendingList()[0] apptid = self.app.tm.getPendingList()[0]
...@@ -132,18 +126,16 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -132,18 +126,16 @@ class MasterClientHandlerTests(NeoTestBase):
txn = self.app.tm[tid] txn = self.app.tm[tid]
self.assertEquals(len(txn.getOIDList()), 0) self.assertEquals(len(txn.getOIDList()), 0)
self.assertEquals(len(txn.getUUIDList()), 1) self.assertEquals(len(txn.getUUIDList()), 1)
self.assertEquals(txn.getMessageId(), 9)
def test_11_abortTransaction(self): def test_11_abortTransaction(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AbortTransaction()
# give a bad tid, must not failed, just ignored it # give a bad tid, must not failed, just ignored it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
self.assertFalse(self.app.tm.hasPending()) self.assertFalse(self.app.tm.hasPending())
service.abortTransaction(conn, packet, None) service.abortTransaction(conn, None)
self.assertFalse(self.app.tm.hasPending()) self.assertFalse(self.app.tm.hasPending())
# give a known tid # give a known tid
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
...@@ -151,7 +143,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -151,7 +143,7 @@ class MasterClientHandlerTests(NeoTestBase):
self.app.tm.remove(tid) self.app.tm.remove(tid)
self.app.tm.begin(Mock({'__hash__': 1}), tid) self.app.tm.begin(Mock({'__hash__': 1}), tid)
self.assertTrue(self.app.tm.hasPending()) self.assertTrue(self.app.tm.hasPending())
service.abortTransaction(conn, packet, tid) service.abortTransaction(conn, tid)
self.assertFalse(self.app.tm.hasPending()) self.assertFalse(self.app.tm.hasPending())
def __testWithMethod(self, method, state): def __testWithMethod(self, method, state):
...@@ -160,11 +152,9 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -160,11 +152,9 @@ class MasterClientHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
packet = Packets.AskBeginTransaction() self.service.askBeginTransaction(conn, None)
packet.setId(0) self.service.askBeginTransaction(conn, None)
self.service.askBeginTransaction(conn, packet, None) self.service.askBeginTransaction(conn, None)
self.service.askBeginTransaction(conn, packet, None)
self.service.askBeginTransaction(conn, packet, None)
self.assertEquals(self.app.nm.getByUUID(client_uuid).getState(), self.assertEquals(self.app.nm.getByUUID(client_uuid).getState(),
NodeStates.RUNNING) NodeStates.RUNNING)
self.assertEquals(len(self.app.tm.getPendingList()), 3) self.assertEquals(len(self.app.tm.getPendingList()), 3)
......
...@@ -141,10 +141,8 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -141,10 +141,8 @@ class MasterClientElectionTests(NeoTestBase):
def test_acceptIdentification1(self): def test_acceptIdentification1(self):
""" A non-master node accept identification """ """ A non-master node accept identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AcceptIdentification()
packet.setId(0)
args = (node.getUUID(), 0, 10, self.app.uuid) args = (node.getUUID(), 0, 10, self.app.uuid)
self.election.acceptIdentification(conn, packet, self.election.acceptIdentification(conn,
NodeTypes.CLIENT, *args) NodeTypes.CLIENT, *args)
self.assertFalse(node in self.app.unconnected_master_node_set) self.assertFalse(node in self.app.unconnected_master_node_set)
self.assertFalse(node in self.app.negotiating_master_node_set) self.assertFalse(node in self.app.negotiating_master_node_set)
...@@ -153,22 +151,17 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -153,22 +151,17 @@ class MasterClientElectionTests(NeoTestBase):
def test_acceptIdentification2(self): def test_acceptIdentification2(self):
""" UUID conflict """ """ UUID conflict """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AcceptIdentification()
packet.setId(0)
new_uuid = self._makeUUID('M') new_uuid = self._makeUUID('M')
args = (node.getUUID(), 0, 10, new_uuid) args = (node.getUUID(), 0, 10, new_uuid)
self.assertRaises(ElectionFailure, self.election.acceptIdentification, self.assertRaises(ElectionFailure, self.election.acceptIdentification,
conn, packet, NodeTypes.MASTER, *args) conn, NodeTypes.MASTER, *args)
self.assertEqual(self.app.uuid, new_uuid) self.assertEqual(self.app.uuid, new_uuid)
def test_acceptIdentification3(self): def test_acceptIdentification3(self):
""" Identification accepted """ """ Identification accepted """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AcceptIdentification()
packet.setId(0)
args = (node.getUUID(), 0, 10, self.app.uuid) args = (node.getUUID(), 0, 10, self.app.uuid)
self.election.acceptIdentification(conn, packet, self.election.acceptIdentification(conn, NodeTypes.MASTER, *args)
NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID()) self.checkUUIDSet(conn, node.getUUID())
self.assertTrue(self.app.primary or node.getUUID() < self.app.uuid) self.assertTrue(self.app.primary or node.getUUID() < self.app.uuid)
self.assertFalse(node in self.app.negotiating_master_node_set) self.assertFalse(node in self.app.negotiating_master_node_set)
...@@ -180,21 +173,17 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -180,21 +173,17 @@ class MasterClientElectionTests(NeoTestBase):
def test_answerPrimary1(self): def test_answerPrimary1(self):
""" Multiple primary masters -> election failure raised """ """ Multiple primary masters -> election failure raised """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AnswerPrimary()
packet.setId(0)
self.app.primary = True self.app.primary = True
self.app.primary_master_node = node self.app.primary_master_node = node
master_list = self._getMasterList() master_list = self._getMasterList()
self.assertRaises(ElectionFailure, self.election.answerPrimary, self.assertRaises(ElectionFailure, self.election.answerPrimary,
conn, packet, self.app.uuid, master_list) conn, self.app.uuid, master_list)
def test_answerPrimary2(self): def test_answerPrimary2(self):
""" Don't known who's the primary """ """ Don't known who's the primary """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AnswerPrimary()
packet.setId(0)
master_list = self._getMasterList() master_list = self._getMasterList()
self.election.answerPrimary(conn, packet, None, master_list) self.election.answerPrimary(conn, None, master_list)
self.assertFalse(self.app.primary) self.assertFalse(self.app.primary)
self.assertEqual(self.app.primary_master_node, None) self.assertEqual(self.app.primary_master_node, None)
self.checkRequestIdentification(conn) self.checkRequestIdentification(conn)
...@@ -202,10 +191,8 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -202,10 +191,8 @@ class MasterClientElectionTests(NeoTestBase):
def test_answerPrimary3(self): def test_answerPrimary3(self):
""" Answer who's the primary """ """ Answer who's the primary """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AnswerPrimary()
packet.setId(0)
master_list = self._getMasterList() master_list = self._getMasterList()
self.election.answerPrimary(conn, packet, node.getUUID(), master_list) self.election.answerPrimary(conn, node.getUUID(), master_list)
addr = conn.getAddress() addr = conn.getAddress()
self.assertTrue(addr in self.app.unconnected_master_node_set) self.assertTrue(addr in self.app.unconnected_master_node_set)
self.assertTrue(addr in self.app.negotiating_master_node_set) self.assertTrue(addr in self.app.negotiating_master_node_set)
...@@ -271,40 +258,32 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -271,40 +258,32 @@ class MasterServerElectionTests(NeoTestBase):
def test_requestIdentification1(self): def test_requestIdentification1(self):
""" A non-master node request identification """ """ A non-master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.RequestIdentification()
packet.setId(0)
args = (node.getUUID(), node.getAddress(), self.app.name) args = (node.getUUID(), node.getAddress(), self.app.name)
self.assertRaises(protocol.NotReadyError, self.assertRaises(protocol.NotReadyError,
self.election.requestIdentification, self.election.requestIdentification,
conn, packet, NodeTypes.CLIENT, *args) conn, NodeTypes.CLIENT, *args)
def test_requestIdentification2(self): def test_requestIdentification2(self):
""" A unknown master node request identification """ """ A unknown master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.RequestIdentification()
packet.setId(0)
args = (node.getUUID(), ('127.0.0.1', 1000), self.app.name) args = (node.getUUID(), ('127.0.0.1', 1000), self.app.name)
self.checkProtocolErrorRaised(self.election.requestIdentification, self.checkProtocolErrorRaised(self.election.requestIdentification,
conn, packet, NodeTypes.MASTER, *args) conn, NodeTypes.MASTER, *args)
def test_requestIdentification3(self): def test_requestIdentification3(self):
""" A broken master node request identification """ """ A broken master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
node.setBroken() node.setBroken()
packet = protocol.RequestIdentification()
packet.setId(0)
args = (node.getUUID(), node.getAddress(), self.app.name) args = (node.getUUID(), node.getAddress(), self.app.name)
self.assertRaises(protocol.BrokenNodeDisallowedError, self.assertRaises(protocol.BrokenNodeDisallowedError,
self.election.requestIdentification, self.election.requestIdentification,
conn, packet, NodeTypes.MASTER, *args) conn, NodeTypes.MASTER, *args)
def test_requestIdentification4(self): def test_requestIdentification4(self):
""" No conflict """ """ No conflict """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.RequestIdentification()
packet.setId(0)
args = (node.getUUID(), node.getAddress(), self.app.name) args = (node.getUUID(), node.getAddress(), self.app.name)
self.election.requestIdentification(conn, packet, self.election.requestIdentification(conn,
NodeTypes.MASTER, *args) NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID()) self.checkUUIDSet(conn, node.getUUID())
args = self.checkAcceptIdentification(conn, decode=True) args = self.checkAcceptIdentification(conn, decode=True)
...@@ -315,10 +294,8 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -315,10 +294,8 @@ class MasterServerElectionTests(NeoTestBase):
def test_requestIdentification5(self): def test_requestIdentification5(self):
""" UUID conflict """ """ UUID conflict """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.RequestIdentification()
packet.setId(0)
args = (self.app.uuid, node.getAddress(), self.app.name) args = (self.app.uuid, node.getAddress(), self.app.name)
self.election.requestIdentification(conn, packet, self.election.requestIdentification(conn,
NodeTypes.MASTER, *args) NodeTypes.MASTER, *args)
self.checkUUIDSet(conn) self.checkUUIDSet(conn)
args = self.checkAcceptIdentification(conn, decode=True) args = self.checkAcceptIdentification(conn, decode=True)
...@@ -334,11 +311,9 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -334,11 +311,9 @@ class MasterServerElectionTests(NeoTestBase):
def test_notifyNodeInformation1(self): def test_notifyNodeInformation1(self):
""" Not identified """ """ Not identified """
node, conn = self.identifyToMasterNode(uuid=None) node, conn = self.identifyToMasterNode(uuid=None)
packet = protocol.NotifyNodeInformation()
packet.setId(0)
node_list = self._getNodeList() node_list = self._getNodeList()
self.assertRaises(protocol.ProtocolError, self.assertRaises(protocol.ProtocolError,
self.election.notifyNodeInformation, conn, packet, node_list) self.election.notifyNodeInformation, conn, node_list)
# TODO: build a full notifyNodeInformation test # TODO: build a full notifyNodeInformation test
...@@ -367,16 +342,9 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -367,16 +342,9 @@ class MasterServerElectionTests(NeoTestBase):
def testRequestIdentification1(self): def testRequestIdentification1(self):
""" Check with a non-master node, must be refused """ """ Check with a non-master node, must be refused """
conn = self.__getClient() conn = self.__getClient()
packet = protocol.RequestIdentification(
NodeTypes.CLIENT,
conn.getUUID(),
conn.getAddress(),
name=self.app.name,
)
self.checkNotReadyErrorRaised( self.checkNotReadyErrorRaised(
self.election.requestIdentification, self.election.requestIdentification,
conn=conn, conn=conn,
packet=packet,
node_type=NodeTypes.CLIENT, node_type=NodeTypes.CLIENT,
uuid=conn.getUUID(), uuid=conn.getUUID(),
address=conn.getAddress(), address=conn.getAddress(),
...@@ -386,16 +354,9 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -386,16 +354,9 @@ class MasterServerElectionTests(NeoTestBase):
def testRequestIdentification2(self): def testRequestIdentification2(self):
""" Check with an unknown master node """ """ Check with an unknown master node """
conn = self.__getMaster(register=False) conn = self.__getMaster(register=False)
packet = protocol.RequestIdentification(
NodeTypes.MASTER,
conn.getUUID(),
conn.getAddress(),
name=self.app.name,
)
self.checkProtocolErrorRaised( self.checkProtocolErrorRaised(
self.election.requestIdentification, self.election.requestIdentification,
conn=conn, conn=conn,
packet=packet,
node_type=NodeTypes.MASTER, node_type=NodeTypes.MASTER,
uuid=conn.getUUID(), uuid=conn.getUUID(),
address=conn.getAddress(), address=conn.getAddress(),
...@@ -405,37 +366,33 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -405,37 +366,33 @@ class MasterServerElectionTests(NeoTestBase):
def testAnnouncePrimary1(self): def testAnnouncePrimary1(self):
""" check the wrong cases """ """ check the wrong cases """
announce = self.election.announcePrimary announce = self.election.announcePrimary
packet = Packets.AnnouncePrimary()
# No uuid # No uuid
node, conn = self.identifyToMasterNode(uuid=None) node, conn = self.identifyToMasterNode(uuid=None)
self.checkProtocolErrorRaised(announce, conn, packet) self.checkProtocolErrorRaised(announce, conn)
# Announce to a primary, raise # Announce to a primary, raise
self.app.primary = True self.app.primary = True
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
self.assertTrue(self.app.primary) self.assertTrue(self.app.primary)
self.assertEqual(self.app.primary_master_node, None) self.assertEqual(self.app.primary_master_node, None)
self.assertRaises(ElectionFailure, announce, conn, packet) self.assertRaises(ElectionFailure, announce, conn)
def testAnnouncePrimary2(self): def testAnnouncePrimary2(self):
""" Check the good case """ """ Check the good case """
announce = self.election.announcePrimary announce = self.election.announcePrimary
packet = Packets.AnnouncePrimary()
# Announce, must set the primary # Announce, must set the primary
self.app.primary = False self.app.primary = False
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
self.assertFalse(self.app.primary) self.assertFalse(self.app.primary)
self.assertFalse(self.app.primary_master_node) self.assertFalse(self.app.primary_master_node)
announce(conn, packet) announce(conn)
self.assertFalse(self.app.primary) self.assertFalse(self.app.primary)
self.assertEqual(self.app.primary_master_node, node) self.assertEqual(self.app.primary_master_node, node)
def test_askPrimary1(self): def test_askPrimary1(self):
""" Ask the primary to the primary """ """ Ask the primary to the primary """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AskPrimary()
packet.setId(0)
self.app.primary = True self.app.primary = True
self.election.askPrimary(conn, packet) self.election.askPrimary(conn)
uuid, master_list = self.checkAnswerPrimary(conn, decode=True) uuid, master_list = self.checkAnswerPrimary(conn, decode=True)
self.assertEqual(uuid, self.app.uuid) self.assertEqual(uuid, self.app.uuid)
self.assertEqual(len(master_list), 2) self.assertEqual(len(master_list), 2)
...@@ -447,12 +404,10 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -447,12 +404,10 @@ class MasterServerElectionTests(NeoTestBase):
def test_askPrimary2(self): def test_askPrimary2(self):
""" Ask the primary to a secondary that known who's te primary """ """ Ask the primary to a secondary that known who's te primary """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AskPrimary()
packet.setId(0)
self.app.primary = False self.app.primary = False
# it will answer ourself as primary # it will answer ourself as primary
self.app.primary_master_node = node self.app.primary_master_node = node
self.election.askPrimary(conn, packet) self.election.askPrimary(conn)
uuid, master_list = self.checkAnswerPrimary(conn, decode=True) uuid, master_list = self.checkAnswerPrimary(conn, decode=True)
self.assertEqual(uuid, node.getUUID()) self.assertEqual(uuid, node.getUUID())
self.assertEqual(len(master_list), 2) self.assertEqual(len(master_list), 2)
...@@ -463,11 +418,9 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -463,11 +418,9 @@ class MasterServerElectionTests(NeoTestBase):
def test_askPrimary3(self): def test_askPrimary3(self):
""" Ask the primary to a master that don't known who's the primary """ """ Ask the primary to a master that don't known who's the primary """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = protocol.AskPrimary()
packet.setId(0)
self.app.primary = False self.app.primary = False
self.app.primary_master_node = None self.app.primary_master_node = None
self.election.askPrimary(conn, packet) self.election.askPrimary(conn)
uuid, master_list = self.checkAnswerPrimary(conn, decode=True) uuid, master_list = self.checkAnswerPrimary(conn, decode=True)
self.assertEqual(uuid, None) self.assertEqual(uuid, None)
self.assertEqual(len(master_list), 2) self.assertEqual(len(master_list), 2)
...@@ -478,10 +431,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -478,10 +431,7 @@ class MasterServerElectionTests(NeoTestBase):
def test_reelectPrimary(self): def test_reelectPrimary(self):
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = Packets.AskPrimary() self.assertRaises(ElectionFailure, self.election.reelectPrimary, conn)
packet.setId(0)
self.assertRaises(ElectionFailure, self.election.reelectPrimary,
conn, packet)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -94,7 +94,6 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -94,7 +94,6 @@ class MasterRecoveryTests(NeoTestBase):
def test_09_answerLastIDs(self): def test_09_answerLastIDs(self):
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AnswerLastIDs()
loid = self.app.loid = '\1' * 8 loid = self.app.loid = '\1' * 8
self.app.tm.setLastTID('\1' * 8) self.app.tm.setLastTID('\1' * 8)
ltid = self.app.tm.getLastTID() ltid = self.app.tm.getLastTID()
...@@ -113,7 +112,7 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -113,7 +112,7 @@ class MasterRecoveryTests(NeoTestBase):
self.assertTrue(new_oid > self.app.loid) self.assertTrue(new_oid > self.app.loid)
self.assertTrue(new_tid > self.app.tm.getLastTID()) self.assertTrue(new_tid > self.app.tm.getLastTID())
self.assertEquals(self.app.target_uuid, None) self.assertEquals(self.app.target_uuid, None)
recovery.answerLastIDs(conn, packet, new_oid, new_tid, new_ptid) recovery.answerLastIDs(conn, new_oid, new_tid, new_ptid)
self.assertEquals(new_oid, self.app.loid) self.assertEquals(new_oid, self.app.loid)
self.assertEquals(new_tid, self.app.tm.getLastTID()) self.assertEquals(new_tid, self.app.tm.getLastTID())
self.assertEquals(new_ptid, self.app.pt.getID()) self.assertEquals(new_ptid, self.app.pt.getID())
...@@ -123,7 +122,6 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -123,7 +122,6 @@ class MasterRecoveryTests(NeoTestBase):
def test_10_answerPartitionTable(self): def test_10_answerPartitionTable(self):
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port) uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
packet = Packets.AnswerPartitionTable()
# not from target node, ignore # not from target node, ignore
uuid = self.identifyToMasterNode(NodeTypes.STORAGE, port=self.storage_port) uuid = self.identifyToMasterNode(NodeTypes.STORAGE, port=self.storage_port)
conn = self.getFakeConnection(uuid, self.storage_port) conn = self.getFakeConnection(uuid, self.storage_port)
...@@ -133,7 +131,7 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -133,7 +131,7 @@ class MasterRecoveryTests(NeoTestBase):
cells = self.app.pt.getRow(offset) cells = self.app.pt.getRow(offset)
for cell, state in cells: for cell, state in cells:
self.assertEquals(state, CellStates.OUT_OF_DATE) self.assertEquals(state, CellStates.OUT_OF_DATE)
recovery.answerPartitionTable(conn, packet, None, cell_list) recovery.answerPartitionTable(conn, None, cell_list)
cells = self.app.pt.getRow(offset) cells = self.app.pt.getRow(offset)
for cell, state in cells: for cell, state in cells:
self.assertEquals(state, CellStates.OUT_OF_DATE) self.assertEquals(state, CellStates.OUT_OF_DATE)
...@@ -147,7 +145,7 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -147,7 +145,7 @@ class MasterRecoveryTests(NeoTestBase):
cells = self.app.pt.getRow(offset) cells = self.app.pt.getRow(offset)
for cell, state in cells: for cell, state in cells:
self.assertEquals(state, CellStates.OUT_OF_DATE) self.assertEquals(state, CellStates.OUT_OF_DATE)
recovery.answerPartitionTable(conn, packet, None, cell_list) recovery.answerPartitionTable(conn, None, cell_list)
cells = self.app.pt.getRow(offset) cells = self.app.pt.getRow(offset)
for cell, state in cells: for cell, state in cells:
self.assertEquals(state, CellStates.UP_TO_DATE) self.assertEquals(state, CellStates.UP_TO_DATE)
...@@ -158,7 +156,7 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -158,7 +156,7 @@ class MasterRecoveryTests(NeoTestBase):
self.assertFalse(self.app.pt.hasOffset(offset)) self.assertFalse(self.app.pt.hasOffset(offset))
cell_list = [(offset, ((uuid, NodeStates.DOWN,),),)] cell_list = [(offset, ((uuid, NodeStates.DOWN,),),)]
self.checkProtocolErrorRaised(recovery.answerPartitionTable, conn, self.checkProtocolErrorRaised(recovery.answerPartitionTable, conn,
packet, None, cell_list) None, cell_list)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -83,9 +83,8 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -83,9 +83,8 @@ class MasterStorageHandlerTests(NeoTestBase):
self.app.tm.setLastTID(tid1) self.app.tm.setLastTID(tid1)
self.assertTrue(tid1 < tid2) self.assertTrue(tid1 < tid2)
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = Packets.NotifyInformationLocked(tid2)
self.checkProtocolErrorRaised(self.service.notifyInformationLocked, self.checkProtocolErrorRaised(self.service.notifyInformationLocked,
conn, packet, tid2) conn, tid2)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
def test_notifyInformationLocked_2(self): def test_notifyInformationLocked_2(self):
...@@ -112,15 +111,14 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -112,15 +111,14 @@ class MasterStorageHandlerTests(NeoTestBase):
tid = self.app.tm.begin(client_1, None) tid = self.app.tm.begin(client_1, None)
self.app.tm.prepare(tid, oid_list, uuid_list, msg_id) self.app.tm.prepare(tid, oid_list, uuid_list, msg_id)
self.assertTrue(tid in self.app.tm) self.assertTrue(tid in self.app.tm)
packet = Packets.NotifyInformationLocked(tid)
# the first storage acknowledge the lock # the first storage acknowledge the lock
self.service.notifyInformationLocked(storage_conn_1, packet, tid) self.service.notifyInformationLocked(storage_conn_1, tid)
self.checkNoPacketSent(client_conn_1) self.checkNoPacketSent(client_conn_1)
self.checkNoPacketSent(client_conn_2) self.checkNoPacketSent(client_conn_2)
self.checkNoPacketSent(storage_conn_1) self.checkNoPacketSent(storage_conn_1)
self.checkNoPacketSent(storage_conn_2) self.checkNoPacketSent(storage_conn_2)
# then the second # then the second
self.service.notifyInformationLocked(storage_conn_2, packet, tid) self.service.notifyInformationLocked(storage_conn_2, tid)
self.checkAnswerTransactionFinished(client_conn_1) self.checkAnswerTransactionFinished(client_conn_1)
self.checkInvalidateObjects(client_conn_2) self.checkInvalidateObjects(client_conn_2)
self.checkNotifyUnlockInformation(storage_conn_1) self.checkNotifyUnlockInformation(storage_conn_1)
...@@ -129,16 +127,14 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -129,16 +127,14 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_12_askLastIDs(self): def test_12_askLastIDs(self):
service = self.service service = self.service
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = Packets.AskLastIDs()
packet.setId(0)
# give a uuid # give a uuid
conn = self.getFakeConnection(node.getUUID(), self.storage_address) conn = self.getFakeConnection(node.getUUID(), self.storage_address)
ptid = self.app.pt.getID() ptid = self.app.pt.getID()
oid = self.app.loid = '\1' * 8 oid = self.app.loid = '\1' * 8
tid = '\1' * 8 tid = '\1' * 8
self.app.tm.setLastTID(tid) self.app.tm.setLastTID(tid)
service.askLastIDs(conn, packet) service.askLastIDs(conn)
packet = self.checkAnswerLastIDs(conn, answered_packet=packet) packet = self.checkAnswerLastIDs(conn)
loid, ltid, lptid = packet.decode() loid, ltid, lptid = packet.decode()
self.assertEqual(loid, oid) self.assertEqual(loid, oid)
self.assertEqual(ltid, tid) self.assertEqual(ltid, tid)
...@@ -148,24 +144,21 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -148,24 +144,21 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_13_askUnfinishedTransactions(self): def test_13_askUnfinishedTransactions(self):
service = self.service service = self.service
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
packet = Packets.AskUnfinishedTransactions()
packet.setId(0)
# give a uuid # give a uuid
service.askUnfinishedTransactions(conn, packet) service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn, answered_packet=packet) packet = self.checkAnswerUnfinishedTransactions(conn)
packet.setId(0)
tid_list, = packet.decode() tid_list, = packet.decode()
self.assertEqual(tid_list, []) self.assertEqual(tid_list, [])
# create some transaction # create some transaction
node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port=self.client_port) port=self.client_port)
client_uuid = node.getUUID() client_uuid = node.getUUID()
self.client_handler.askBeginTransaction(conn, packet, None) self.client_handler.askBeginTransaction(conn, None)
self.client_handler.askBeginTransaction(conn, packet, None) self.client_handler.askBeginTransaction(conn, None)
self.client_handler.askBeginTransaction(conn, packet, None) self.client_handler.askBeginTransaction(conn, None)
conn = self.getFakeConnection(node.getUUID(), self.storage_address) conn = self.getFakeConnection(node.getUUID(), self.storage_address)
service.askUnfinishedTransactions(conn, packet) service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn, answered_packet=packet) packet = self.checkAnswerUnfinishedTransactions(conn)
(tid_list, ) = packet.decode() (tid_list, ) = packet.decode()
self.assertEqual(len(tid_list), 3) self.assertEqual(len(tid_list), 3)
......
...@@ -104,7 +104,6 @@ class MasterVerificationTests(NeoTestBase): ...@@ -104,7 +104,6 @@ class MasterVerificationTests(NeoTestBase):
def test_09_answerLastIDs(self): def test_09_answerLastIDs(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AnswerLastIDs()
loid = self.app.loid loid = self.app.loid
ltid = self.app.tm.getLastTID() ltid = self.app.tm.getLastTID()
lptid = '\0' * 8 lptid = '\0' * 8
...@@ -120,7 +119,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -120,7 +119,7 @@ class MasterVerificationTests(NeoTestBase):
self.assertTrue(new_ptid > self.app.pt.getID()) self.assertTrue(new_ptid > self.app.pt.getID())
self.assertTrue(new_oid > self.app.loid) self.assertTrue(new_oid > self.app.loid)
self.assertTrue(new_tid > self.app.tm.getLastTID()) self.assertTrue(new_tid > self.app.tm.getLastTID())
self.assertRaises(VerificationFailure, verification.answerLastIDs, conn, packet, new_oid, new_tid, new_ptid) self.assertRaises(VerificationFailure, verification.answerLastIDs, conn, new_oid, new_tid, new_ptid)
self.assertNotEquals(new_oid, self.app.loid) self.assertNotEquals(new_oid, self.app.loid)
self.assertNotEquals(new_tid, self.app.tm.getLastTID()) self.assertNotEquals(new_tid, self.app.tm.getLastTID())
self.assertNotEquals(new_ptid, self.app.pt.getID()) self.assertNotEquals(new_ptid, self.app.pt.getID())
...@@ -128,7 +127,6 @@ class MasterVerificationTests(NeoTestBase): ...@@ -128,7 +127,6 @@ class MasterVerificationTests(NeoTestBase):
def test_11_answerUnfinishedTransactions(self): def test_11_answerUnfinishedTransactions(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AnswerUnfinishedTransactions()
# do nothing # do nothing
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
...@@ -137,7 +135,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -137,7 +135,7 @@ class MasterVerificationTests(NeoTestBase):
self.assertEquals(len(self.app.unfinished_tid_set), 0) self.assertEquals(len(self.app.unfinished_tid_set), 0)
upper, lower = unpack('!LL', self.app.tm.getLastTID()) upper, lower = unpack('!LL', self.app.tm.getLastTID())
new_tid = pack('!LL', upper, lower + 10) new_tid = pack('!LL', upper, lower + 10)
verification.answerUnfinishedTransactions(conn, packet, [new_tid]) verification.answerUnfinishedTransactions(conn, [new_tid])
self.assertEquals(len(self.app.unfinished_tid_set), 0) self.assertEquals(len(self.app.unfinished_tid_set), 0)
# update dict # update dict
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
...@@ -146,7 +144,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -146,7 +144,7 @@ class MasterVerificationTests(NeoTestBase):
self.assertEquals(len(self.app.unfinished_tid_set), 0) self.assertEquals(len(self.app.unfinished_tid_set), 0)
upper, lower = unpack('!LL', self.app.tm.getLastTID()) upper, lower = unpack('!LL', self.app.tm.getLastTID())
new_tid = pack('!LL', upper, lower + 10) new_tid = pack('!LL', upper, lower + 10)
verification.answerUnfinishedTransactions(conn, packet, [new_tid,]) verification.answerUnfinishedTransactions(conn, [new_tid,])
self.assertTrue(self.app.asking_uuid_dict[uuid]) self.assertTrue(self.app.asking_uuid_dict[uuid])
self.assertEquals(len(self.app.unfinished_tid_set), 1) self.assertEquals(len(self.app.unfinished_tid_set), 1)
self.assertTrue(new_tid in self.app.unfinished_tid_set) self.assertTrue(new_tid in self.app.unfinished_tid_set)
...@@ -155,7 +153,6 @@ class MasterVerificationTests(NeoTestBase): ...@@ -155,7 +153,6 @@ class MasterVerificationTests(NeoTestBase):
def test_12_answerTransactionInformation(self): def test_12_answerTransactionInformation(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AnswerTransactionInformation()
# do nothing, as unfinished_oid_set is None # do nothing, as unfinished_oid_set is None
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
...@@ -166,7 +163,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -166,7 +163,7 @@ class MasterVerificationTests(NeoTestBase):
new_tid = pack('!LL', upper, lower + 10) new_tid = pack('!LL', upper, lower + 10)
oid = unpack('!Q', self.app.loid)[0] oid = unpack('!Q', self.app.loid)[0]
new_oid = pack('!Q', oid + 1) new_oid = pack('!Q', oid + 1)
verification.answerTransactionInformation(conn, packet, new_tid, verification.answerTransactionInformation(conn, new_tid,
"user", "desc", "ext", [new_oid,]) "user", "desc", "ext", [new_oid,])
self.assertEquals(self.app.unfinished_oid_set, None) self.assertEquals(self.app.unfinished_oid_set, None)
# do nothing as asking_uuid_dict is True # do nothing as asking_uuid_dict is True
...@@ -176,7 +173,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -176,7 +173,7 @@ class MasterVerificationTests(NeoTestBase):
self.app.unfinished_oid_set = set() self.app.unfinished_oid_set = set()
self.assertTrue(self.app.asking_uuid_dict.has_key(uuid)) self.assertTrue(self.app.asking_uuid_dict.has_key(uuid))
self.assertEquals(len(self.app.unfinished_oid_set), 0) self.assertEquals(len(self.app.unfinished_oid_set), 0)
verification.answerTransactionInformation(conn, packet, new_tid, verification.answerTransactionInformation(conn, new_tid,
"user", "desc", "ext", [new_oid,]) "user", "desc", "ext", [new_oid,])
self.assertEquals(len(self.app.unfinished_oid_set), 0) self.assertEquals(len(self.app.unfinished_oid_set), 0)
# do work # do work
...@@ -185,7 +182,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -185,7 +182,7 @@ class MasterVerificationTests(NeoTestBase):
self.app.asking_uuid_dict[uuid] = False self.app.asking_uuid_dict[uuid] = False
self.assertTrue(self.app.asking_uuid_dict.has_key(uuid)) self.assertTrue(self.app.asking_uuid_dict.has_key(uuid))
self.assertEquals(len(self.app.unfinished_oid_set), 0) self.assertEquals(len(self.app.unfinished_oid_set), 0)
verification.answerTransactionInformation(conn, packet, new_tid, verification.answerTransactionInformation(conn, new_tid,
"user", "desc", "ext", [new_oid,]) "user", "desc", "ext", [new_oid,])
self.assertEquals(len(self.app.unfinished_oid_set), 1) self.assertEquals(len(self.app.unfinished_oid_set), 1)
self.assertTrue(new_oid in self.app.unfinished_oid_set) self.assertTrue(new_oid in self.app.unfinished_oid_set)
...@@ -199,21 +196,20 @@ class MasterVerificationTests(NeoTestBase): ...@@ -199,21 +196,20 @@ class MasterVerificationTests(NeoTestBase):
oid = unpack('!Q', old_oid)[0] oid = unpack('!Q', old_oid)[0]
new_oid = pack('!Q', oid + 1) new_oid = pack('!Q', oid + 1)
self.assertNotEqual(new_oid, old_oid) self.assertNotEqual(new_oid, old_oid)
verification.answerTransactionInformation(conn, packet, new_tid, verification.answerTransactionInformation(conn, new_tid,
"user", "desc", "ext", [new_oid,]) "user", "desc", "ext", [new_oid,])
self.assertEquals(self.app.unfinished_oid_set, None) self.assertEquals(self.app.unfinished_oid_set, None)
def test_13_tidNotFound(self): def test_13_tidNotFound(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = protocol.tidNotFound('')
# do nothing as asking_uuid_dict is True # do nothing as asking_uuid_dict is True
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
self.app.asking_uuid_dict[uuid] = True self.app.asking_uuid_dict[uuid] = True
self.app.unfinished_oid_set = [] self.app.unfinished_oid_set = []
self.assertTrue(self.app.asking_uuid_dict.has_key(uuid)) self.assertTrue(self.app.asking_uuid_dict.has_key(uuid))
verification.tidNotFound(conn, packet, "msg") verification.tidNotFound(conn, "msg")
self.assertNotEqual(self.app.unfinished_oid_set, None) self.assertNotEqual(self.app.unfinished_oid_set, None)
# do work as asking_uuid_dict is False # do work as asking_uuid_dict is False
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
...@@ -221,13 +217,12 @@ class MasterVerificationTests(NeoTestBase): ...@@ -221,13 +217,12 @@ class MasterVerificationTests(NeoTestBase):
self.app.asking_uuid_dict[uuid] = False self.app.asking_uuid_dict[uuid] = False
self.app.unfinished_oid_set = [] self.app.unfinished_oid_set = []
self.assertTrue(self.app.asking_uuid_dict.has_key(uuid)) self.assertTrue(self.app.asking_uuid_dict.has_key(uuid))
verification.tidNotFound(conn, packet, "msg") verification.tidNotFound(conn, "msg")
self.assertEqual(self.app.unfinished_oid_set, None) self.assertEqual(self.app.unfinished_oid_set, None)
def test_14_answerObjectPresent(self): def test_14_answerObjectPresent(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packets.AnswerObjectPresent()
# do nothing as asking_uuid_dict is True # do nothing as asking_uuid_dict is True
upper, lower = unpack('!LL', self.app.tm.getLastTID()) upper, lower = unpack('!LL', self.app.tm.getLastTID())
new_tid = pack('!LL', upper, lower + 10) new_tid = pack('!LL', upper, lower + 10)
...@@ -237,26 +232,25 @@ class MasterVerificationTests(NeoTestBase): ...@@ -237,26 +232,25 @@ class MasterVerificationTests(NeoTestBase):
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
self.app.asking_uuid_dict[uuid] = True self.app.asking_uuid_dict[uuid] = True
self.assertTrue(self.app.asking_uuid_dict.has_key(uuid)) self.assertTrue(self.app.asking_uuid_dict.has_key(uuid))
verification.answerObjectPresent(conn, packet, new_oid, new_tid) verification.answerObjectPresent(conn, new_oid, new_tid)
# do work # do work
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 1) self.assertEquals(len(self.app.asking_uuid_dict), 1)
self.app.asking_uuid_dict[uuid] = False self.app.asking_uuid_dict[uuid] = False
self.assertFalse(self.app.asking_uuid_dict[uuid]) self.assertFalse(self.app.asking_uuid_dict[uuid])
verification.answerObjectPresent(conn, packet, new_oid, new_tid) verification.answerObjectPresent(conn, new_oid, new_tid)
self.assertTrue(self.app.asking_uuid_dict[uuid]) self.assertTrue(self.app.asking_uuid_dict[uuid])
def test_15_oidNotFound(self): def test_15_oidNotFound(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = protocol.oidNotFound('')
# do nothing as asking_uuid_dict is True # do nothing as asking_uuid_dict is True
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
self.app.asking_uuid_dict[uuid] = True self.app.asking_uuid_dict[uuid] = True
self.app.object_present = True self.app.object_present = True
self.assertTrue(self.app.object_present) self.assertTrue(self.app.object_present)
verification.oidNotFound(conn, packet, "msg") verification.oidNotFound(conn, "msg")
self.assertTrue(self.app.object_present) self.assertTrue(self.app.object_present)
# do work as asking_uuid_dict is False # do work as asking_uuid_dict is False
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
...@@ -264,7 +258,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -264,7 +258,7 @@ class MasterVerificationTests(NeoTestBase):
self.app.asking_uuid_dict[uuid] = False self.app.asking_uuid_dict[uuid] = False
self.assertFalse(self.app.asking_uuid_dict[uuid ]) self.assertFalse(self.app.asking_uuid_dict[uuid ])
self.assertTrue(self.app.object_present) self.assertTrue(self.app.object_present)
verification.oidNotFound(conn, packet, "msg") verification.oidNotFound(conn, "msg")
self.assertFalse(self.app.object_present) self.assertFalse(self.app.object_present)
self.assertTrue(self.app.asking_uuid_dict[uuid ]) self.assertTrue(self.app.asking_uuid_dict[uuid ])
......
...@@ -33,10 +33,9 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -33,10 +33,9 @@ class StorageClientHandlerTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
"isServer": _listening, "isServer": _listening,
}) })
packet = Packet(msg_type=_msg_type)
# hook # hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled() self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, packet=packet, **kwargs) self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)
def setUp(self): def setUp(self):
self.prepareDatabase(number=1) self.prepareDatabase(number=1)
...@@ -101,33 +100,25 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -101,33 +100,25 @@ class StorageClientHandlerTests(NeoTestBase):
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = Mock({ }) conn = Mock({ })
packet = Packets.AskTransactionInformation() self.operation.askTransactionInformation(conn, INVALID_TID)
packet.setId(0)
self.operation.askTransactionInformation(conn, packet, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self): def test_18_askTransactionInformation2(self):
# answer # answer
conn = Mock({ }) conn = Mock({ })
packet = Packets.AskTransactionInformation()
packet.setId(0)
dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), }) dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), })
self.app.dm = dm self.app.dm = dm
self.operation.askTransactionInformation(conn, packet, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkAnswerTransactionInformation(conn) self.checkAnswerTransactionInformation(conn)
def test_24_askObject1(self): def test_24_askObject1(self):
# delayed response # delayed response
conn = Mock({}) conn = Mock({})
self.app.dm = Mock() self.app.dm = Mock()
packet = Packets.AskObject()
packet.setId(0)
self.app.load_lock_dict[INVALID_OID] = object() self.app.load_lock_dict[INVALID_OID] = object()
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, oid=INVALID_OID,
oid=INVALID_OID, serial=INVALID_SERIAL, tid=INVALID_TID)
serial=INVALID_SERIAL,
tid=INVALID_TID)
self.assertEquals(len(self.app.event_queue), 1) self.assertEquals(len(self.app.event_queue), 1)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.assertEquals(len(self.app.dm.mockGetNamedCalls('getObject')), 0) self.assertEquals(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
...@@ -136,13 +127,9 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -136,13 +127,9 @@ class StorageClientHandlerTests(NeoTestBase):
# invalid serial / tid / packet not found # invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None}) self.app.dm = Mock({'getObject': None})
conn = Mock({}) conn = Mock({})
packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, oid=INVALID_OID,
oid=INVALID_OID, serial=INVALID_SERIAL, tid=INVALID_TID)
serial=INVALID_SERIAL,
tid=INVALID_TID)
calls = self.app.dm.mockGetNamedCalls('getObject') calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
...@@ -153,13 +140,9 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -153,13 +140,9 @@ class StorageClientHandlerTests(NeoTestBase):
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )}) self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )})
conn = Mock({}) conn = Mock({})
packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, oid=INVALID_OID,
oid=INVALID_OID, serial=INVALID_SERIAL, tid=INVALID_TID)
serial=INVALID_SERIAL,
tid=INVALID_TID)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn) self.checkAnswerObject(conn)
...@@ -169,20 +152,16 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -169,20 +152,16 @@ class StorageClientHandlerTests(NeoTestBase):
app.pt = Mock() app.pt = Mock()
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packets.AskTIDs() self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None)
packet.setId(0)
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, packet, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0)
def test_25_askTIDs2(self): def test_25_askTIDs2(self):
# well case => answer # well case => answer
conn = Mock({}) conn = Mock({})
packet = Packets.AskTIDs()
packet.setId(0)
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.operation.askTIDs(conn, packet, 1, 2, 1) self.operation.askTIDs(conn, 1, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ]) calls[0].checkArgs(1, 1, 1, [1, ])
...@@ -191,8 +170,6 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -191,8 +170,6 @@ class StorageClientHandlerTests(NeoTestBase):
def test_25_askTIDs3(self): def test_25_askTIDs3(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = Mock({}) conn = Mock({})
packet = Packets.AskTIDs()
packet.setId(0)
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({ self.app.pt = Mock({
...@@ -200,7 +177,7 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -200,7 +177,7 @@ class StorageClientHandlerTests(NeoTestBase):
'getPartitions': 1, 'getPartitions': 1,
'getAssignedPartitionList': [0], 'getAssignedPartitionList': [0],
}) })
self.operation.askTIDs(conn, packet, 1, 2, INVALID_PARTITION) self.operation.askTIDs(conn, 1, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
...@@ -212,32 +189,26 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -212,32 +189,26 @@ class StorageClientHandlerTests(NeoTestBase):
app = self.app app = self.app
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packets.AskObjectHistory() self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn,
packet.setId(0) 1, 1, None)
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, packet, 1, 1, None)
self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self): def test_26_askObjectHistory2(self):
# first case: empty history # first case: empty history
packet = Packets.AskObjectHistory()
packet.setId(0)
conn = Mock({}) conn = Mock({})
self.app.dm = Mock({'getObjectHistory': None}) self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2) self.operation.askObjectHistory(conn, INVALID_OID, 1, 2)
self.checkAnswerObjectHistory(conn) self.checkAnswerObjectHistory(conn)
# second case: not empty history # second case: not empty history
conn = Mock({}) conn = Mock({})
self.app.dm = Mock({'getObjectHistory': [('', 0, ), ]}) self.app.dm = Mock({'getObjectHistory': [('', 0, ), ]})
self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2) self.operation.askObjectHistory(conn, INVALID_OID, 1, 2)
self.checkAnswerObjectHistory(conn) self.checkAnswerObjectHistory(conn)
def test_27_askStoreTransaction2(self): def test_27_askStoreTransaction2(self):
# add transaction entry # add transaction entry
packet = Packets.AskStoreTransaction()
packet.setId(0)
conn = Mock({'getUUID': self.getNewUUID()}) conn = Mock({'getUUID': self.getNewUUID()})
self.operation.askStoreTransaction(conn, packet, self.operation.askStoreTransaction(conn, INVALID_TID, '', '', '', ())
INVALID_TID, '', '', '', ())
t = self.app.transaction_dict.get(INVALID_TID, None) t = self.app.transaction_dict.get(INVALID_TID, None)
self.assertNotEquals(t, None) self.assertNotEquals(t, None)
self.assertTrue(isinstance(t, TransactionInformation)) self.assertTrue(isinstance(t, TransactionInformation))
...@@ -246,16 +217,13 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -246,16 +217,13 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject2(self): def test_28_askStoreObject2(self):
# locked => delayed response # locked => delayed response
packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
oid = '\x02' * 8 oid = '\x02' * 8
tid1, tid2 = self.getTwoIDs() tid1, tid2 = self.getTwoIDs()
self.app.store_lock_dict[oid] = tid1 self.app.store_lock_dict[oid] = tid1
self.assertTrue(oid in self.app.store_lock_dict) self.assertTrue(oid in self.app.store_lock_dict)
t_before = self.app.transaction_dict.items()[:] t_before = self.app.transaction_dict.items()[:]
self.operation.askStoreObject(conn, packet, oid, self.operation.askStoreObject(conn, oid, INVALID_SERIAL, 0, 0, '', tid2)
INVALID_SERIAL, 0, 0, '', tid2)
self.assertEquals(len(self.app.event_queue), 1) self.assertEquals(len(self.app.event_queue), 1)
t_after = self.app.transaction_dict.items()[:] t_after = self.app.transaction_dict.items()[:]
self.assertEquals(t_before, t_after) self.assertEquals(t_before, t_after)
...@@ -264,12 +232,10 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -264,12 +232,10 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject3(self): def test_28_askStoreObject3(self):
# locked => unresolvable conflict => answer # locked => unresolvable conflict => answer
packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
tid1, tid2 = self.getTwoIDs() tid1, tid2 = self.getTwoIDs()
self.app.store_lock_dict[INVALID_OID] = tid2 self.app.store_lock_dict[INVALID_OID] = tid2
self.operation.askStoreObject(conn, packet, INVALID_OID, self.operation.askStoreObject(conn, INVALID_OID,
INVALID_SERIAL, 0, 0, '', tid1) INVALID_SERIAL, 0, 0, '', tid1)
self.checkAnswerStoreObject(conn) self.checkAnswerStoreObject(conn)
self.assertEquals(self.app.store_lock_dict[INVALID_OID], tid2) self.assertEquals(self.app.store_lock_dict[INVALID_OID], tid2)
...@@ -279,12 +245,10 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -279,12 +245,10 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject4(self): def test_28_askStoreObject4(self):
# resolvable conflict => answer # resolvable conflict => answer
packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
self.app.dm = Mock({'getObjectHistory':((self.getNewUUID(), ), )}) self.app.dm = Mock({'getObjectHistory':((self.getNewUUID(), ), )})
self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None) self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None)
self.operation.askStoreObject(conn, packet, INVALID_OID, self.operation.askStoreObject(conn, INVALID_OID,
INVALID_SERIAL, 0, 0, '', INVALID_TID) INVALID_SERIAL, 0, 0, '', INVALID_TID)
self.checkAnswerStoreObject(conn) self.checkAnswerStoreObject(conn)
self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None) self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None)
...@@ -294,10 +258,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -294,10 +258,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject5(self): def test_28_askStoreObject5(self):
# no conflict => answer # no conflict => answer
packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
self.operation.askStoreObject(conn, packet, INVALID_OID, self.operation.askStoreObject(conn, INVALID_OID,
INVALID_SERIAL, 0, 0, '', INVALID_TID) INVALID_SERIAL, 0, 0, '', INVALID_TID)
t = self.app.transaction_dict.get(INVALID_TID, None) t = self.app.transaction_dict.get(INVALID_TID, None)
self.assertNotEquals(t, None) self.assertNotEquals(t, None)
...@@ -310,8 +272,6 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -310,8 +272,6 @@ class StorageClientHandlerTests(NeoTestBase):
def test_29_abortTransaction(self): def test_29_abortTransaction(self):
# remove transaction # remove transaction
packet = Packets.AbortTransaction()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
self.called = False self.called = False
...@@ -321,7 +281,7 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -321,7 +281,7 @@ class StorageClientHandlerTests(NeoTestBase):
self.app.load_lock_dict[0] = object() self.app.load_lock_dict[0] = object()
self.app.store_lock_dict[0] = object() self.app.store_lock_dict[0] = object()
self.app.transaction_dict[INVALID_TID] = transaction self.app.transaction_dict[INVALID_TID] = transaction
self.operation.abortTransaction(conn, packet, INVALID_TID) self.operation.abortTransaction(conn, INVALID_TID)
self.assertTrue(self.called) self.assertTrue(self.called)
self.assertEquals(len(self.app.load_lock_dict), 0) self.assertEquals(len(self.app.load_lock_dict), 0)
self.assertEquals(len(self.app.store_lock_dict), 0) self.assertEquals(len(self.app.store_lock_dict), 0)
......
...@@ -80,7 +80,6 @@ class StorageInitializationHandlerTests(NeoTestBase): ...@@ -80,7 +80,6 @@ class StorageInitializationHandlerTests(NeoTestBase):
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
def test_09_sendPartitionTable(self): def test_09_sendPartitionTable(self):
packet = Packets.SendPartitionTable()
uuid = self.getNewUUID() uuid = self.getNewUUID()
# send a table # send a table
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
...@@ -101,19 +100,19 @@ class StorageInitializationHandlerTests(NeoTestBase): ...@@ -101,19 +100,19 @@ class StorageInitializationHandlerTests(NeoTestBase):
(2, ((node_2, CellStates.UP_TO_DATE), (node_3, CellStates.UP_TO_DATE)))] (2, ((node_2, CellStates.UP_TO_DATE), (node_3, CellStates.UP_TO_DATE)))]
self.assertFalse(self.app.pt.filled()) self.assertFalse(self.app.pt.filled())
# send part of the table, won't be filled # send part of the table, won't be filled
self.verification.sendPartitionTable(conn, packet, 1, row_list[:1]) self.verification.sendPartitionTable(conn, 1, row_list[:1])
self.assertFalse(self.app.pt.filled()) self.assertFalse(self.app.pt.filled())
self.assertEqual(self.app.pt.getID(), 1) self.assertEqual(self.app.pt.getID(), 1)
self.assertEqual(self.app.dm.getPartitionTable(), []) self.assertEqual(self.app.dm.getPartitionTable(), [])
# send remaining of the table (ack with AnswerPartitionTable) # send remaining of the table (ack with AnswerPartitionTable)
self.verification.sendPartitionTable(conn, packet, 1, row_list[1:]) self.verification.sendPartitionTable(conn, 1, row_list[1:])
self.verification.answerPartitionTable(conn, packet, 1, []) self.verification.answerPartitionTable(conn, 1, [])
self.assertTrue(self.app.pt.filled()) self.assertTrue(self.app.pt.filled())
self.assertEqual(self.app.pt.getID(), 1) self.assertEqual(self.app.pt.getID(), 1)
self.assertNotEqual(self.app.dm.getPartitionTable(), []) self.assertNotEqual(self.app.dm.getPartitionTable(), [])
# send a complete new table and ack # send a complete new table and ack
self.verification.sendPartitionTable(conn, packet, 2, row_list) self.verification.sendPartitionTable(conn, 2, row_list)
self.verification.answerPartitionTable(conn, packet, 2, []) self.verification.answerPartitionTable(conn, 2, [])
self.assertTrue(self.app.pt.filled()) self.assertTrue(self.app.pt.filled())
self.assertEqual(self.app.pt.getID(), 2) self.assertEqual(self.app.pt.getID(), 2)
self.assertNotEqual(self.app.dm.getPartitionTable(), []) self.assertNotEqual(self.app.dm.getPartitionTable(), [])
......
...@@ -33,10 +33,9 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -33,10 +33,9 @@ class StorageMasterHandlerTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
"isServer": _listening, "isServer": _listening,
}) })
packet = Packet(msg_type=_msg_type)
# hook # hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled() self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, packet=packet, **kwargs) self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)
def setUp(self): def setUp(self):
self.prepareDatabase(number=1) self.prepareDatabase(number=1)
...@@ -95,10 +94,9 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -95,10 +94,9 @@ class StorageMasterHandlerTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
app.replicator = Mock({}) app.replicator = Mock({})
packet = Packets.NotifyPartitionChanges()
self.app.pt = Mock({'getID': 1}) self.app.pt = Mock({'getID': 1})
count = len(self.app.nm.getList()) count = len(self.app.nm.getList())
self.operation.notifyPartitionChanges(conn, packet, 0, ()) self.operation.notifyPartitionChanges(conn, 0, ())
self.assertEquals(self.app.pt.getID(), 1) self.assertEquals(self.app.pt.getID(), 1)
self.assertEquals(len(self.app.nm.getList()), count) self.assertEquals(len(self.app.nm.getList()), count)
calls = self.app.replicator.mockGetNamedCalls('removePartition') calls = self.app.replicator.mockGetNamedCalls('removePartition')
...@@ -119,7 +117,6 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -119,7 +117,6 @@ class StorageMasterHandlerTests(NeoTestBase):
"isServer": False, "isServer": False,
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
packet = Packets.NotifyPartitionChanges()
app = self.app app = self.app
# register nodes # register nodes
app.nm.createStorage(uuid=uuid1) app.nm.createStorage(uuid=uuid1)
...@@ -131,7 +128,7 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -131,7 +128,7 @@ class StorageMasterHandlerTests(NeoTestBase):
app.dm = Mock({ }) app.dm = Mock({ })
app.replicator = Mock({}) app.replicator = Mock({})
count = len(app.nm.getList()) count = len(app.nm.getList())
self.operation.notifyPartitionChanges(conn, packet, ptid2, cells) self.operation.notifyPartitionChanges(conn, ptid2, cells)
# ptid set # ptid set
self.assertEquals(app.pt.getID(), ptid2) self.assertEquals(app.pt.getID(), ptid2)
# dm call # dm call
...@@ -142,39 +139,34 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -142,39 +139,34 @@ class StorageMasterHandlerTests(NeoTestBase):
def test_16_stopOperation1(self): def test_16_stopOperation1(self):
# OperationFailure # OperationFailure
conn = Mock({ 'isServer': False }) conn = Mock({ 'isServer': False })
packet = Packets.StopOperation() self.assertRaises(OperationFailure, self.operation.stopOperation, conn)
self.assertRaises(OperationFailure, self.operation.stopOperation, conn, packet)
def test_22_lockInformation2(self): def test_22_lockInformation2(self):
# load transaction informations # load transaction informations
conn = Mock({ 'isServer': False, }) conn = Mock({ 'isServer': False, })
self.app.dm = Mock({ }) self.app.dm = Mock({ })
packet = Packets.LockInformation()
packet.setId(1)
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
self.app.transaction_dict[INVALID_TID] = transaction self.app.transaction_dict[INVALID_TID] = transaction
self.operation.lockInformation(conn, packet, INVALID_TID) self.operation.lockInformation(conn, INVALID_TID)
self.assertEquals(self.app.load_lock_dict[0], INVALID_TID) self.assertEquals(self.app.load_lock_dict[0], INVALID_TID)
calls = self.app.dm.mockGetNamedCalls('storeTransaction') calls = self.app.dm.mockGetNamedCalls('storeTransaction')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
self.checkNotifyInformationLocked(conn, answered_packet=packet) self.checkNotifyInformationLocked(conn)
# transaction not in transaction_dict -> KeyError # transaction not in transaction_dict -> KeyError
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
conn = Mock({ 'isServer': False, }) conn = Mock({ 'isServer': False, })
self.operation.lockInformation(conn, packet, '\x01' * 8) self.operation.lockInformation(conn, '\x01' * 8)
self.checkNotifyInformationLocked(conn, answered_packet=packet) self.checkNotifyInformationLocked(conn)
def test_23_notifyUnlockInformation2(self): def test_23_notifyUnlockInformation2(self):
# delete transaction informations # delete transaction informations
conn = Mock({ 'isServer': False, }) conn = Mock({ 'isServer': False, })
self.app.dm = Mock({ }) self.app.dm = Mock({ })
packet = Packets.LockInformation()
packet.setId(1)
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
self.app.transaction_dict[INVALID_TID] = transaction self.app.transaction_dict[INVALID_TID] = transaction
self.app.load_lock_dict[0] = transaction self.app.load_lock_dict[0] = transaction
self.app.store_lock_dict[0] = transaction self.app.store_lock_dict[0] = transaction
self.operation.notifyUnlockInformation(conn, packet, INVALID_TID) self.operation.notifyUnlockInformation(conn, INVALID_TID)
self.assertEquals(len(self.app.load_lock_dict), 0) self.assertEquals(len(self.app.load_lock_dict), 0)
self.assertEquals(len(self.app.store_lock_dict), 0) self.assertEquals(len(self.app.store_lock_dict), 0)
self.assertEquals(len(self.app.store_lock_dict), 0) self.assertEquals(len(self.app.store_lock_dict), 0)
...@@ -184,33 +176,29 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -184,33 +176,29 @@ class StorageMasterHandlerTests(NeoTestBase):
# transaction not in transaction_dict -> KeyError # transaction not in transaction_dict -> KeyError
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
conn = Mock({ 'isServer': False, }) conn = Mock({ 'isServer': False, })
self.operation.lockInformation(conn, packet, '\x01' * 8) self.operation.lockInformation(conn, '\x01' * 8)
self.checkNotifyInformationLocked(conn, answered_packet=packet) self.checkNotifyInformationLocked(conn)
def test_30_answerLastIDs(self): def test_30_answerLastIDs(self):
# set critical TID on replicator # set critical TID on replicator
conn = Mock() conn = Mock()
packet = Packets.AnswerLastIDs()
self.app.replicator = Mock() self.app.replicator = Mock()
self.operation.answerLastIDs( self.operation.answerLastIDs(
conn=conn, conn=conn,
packet=packet,
loid=INVALID_OID, loid=INVALID_OID,
ltid=INVALID_TID, ltid=INVALID_TID,
lptid=INVALID_TID, lptid=INVALID_TID,
) )
calls = self.app.replicator.mockGetNamedCalls('setCriticalTID') calls = self.app.replicator.mockGetNamedCalls('setCriticalTID')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(packet, INVALID_TID) calls[0].checkArgs(conn.getUUID(), INVALID_TID)
def test_31_answerUnfinishedTransactions(self): def test_31_answerUnfinishedTransactions(self):
# set unfinished TID on replicator # set unfinished TID on replicator
conn = Mock() conn = Mock()
packet = Packets.AnswerUnfinishedTransactions()
self.app.replicator = Mock() self.app.replicator = Mock()
self.operation.answerUnfinishedTransactions( self.operation.answerUnfinishedTransactions(
conn=conn, conn=conn,
packet=packet,
tid_list=(INVALID_TID, ), tid_list=(INVALID_TID, ),
) )
calls = self.app.replicator.mockGetNamedCalls('setUnfinishedTIDList') calls = self.app.replicator.mockGetNamedCalls('setUnfinishedTIDList')
......
...@@ -31,10 +31,9 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -31,10 +31,9 @@ class StorageStorageHandlerTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
"isServer": _listening, "isServer": _listening,
}) })
packet = Packet(msg_type=_msg_type)
# hook # hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled() self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, packet=packet, **kwargs) self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)
def setUp(self): def setUp(self):
self.prepareDatabase(number=1) self.prepareDatabase(number=1)
...@@ -60,19 +59,15 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -60,19 +59,15 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = Mock({ }) conn = Mock({ })
packet = Packets.AskTransactionInformation() self.operation.askTransactionInformation(conn, INVALID_TID)
packet.setId(0)
self.operation.askTransactionInformation(conn, packet, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self): def test_18_askTransactionInformation2(self):
# answer # answer
conn = Mock({ }) conn = Mock({ })
packet = Packets.AskTransactionInformation()
packet.setId(0)
dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), }) dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), })
self.app.dm = dm self.app.dm = dm
self.operation.askTransactionInformation(conn, packet, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkAnswerTransactionInformation(conn) self.checkAnswerTransactionInformation(conn)
def test_24_askObject1(self): def test_24_askObject1(self):
...@@ -82,10 +77,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -82,10 +77,8 @@ class StorageStorageHandlerTests(NeoTestBase):
packet = Packets.AskObject() packet = Packets.AskObject()
self.app.load_lock_dict[INVALID_OID] = object() self.app.load_lock_dict[INVALID_OID] = object()
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, oid=INVALID_OID,
oid=INVALID_OID, serial=INVALID_SERIAL, tid=INVALID_TID)
serial=INVALID_SERIAL,
tid=INVALID_TID)
self.assertEquals(len(self.app.event_queue), 1) self.assertEquals(len(self.app.event_queue), 1)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.assertEquals(len(self.app.dm.mockGetNamedCalls('getObject')), 0) self.assertEquals(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
...@@ -94,13 +87,9 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -94,13 +87,9 @@ class StorageStorageHandlerTests(NeoTestBase):
# invalid serial / tid / packet not found # invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None}) self.app.dm = Mock({'getObject': None})
conn = Mock({}) conn = Mock({})
packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, oid=INVALID_OID,
oid=INVALID_OID, serial=INVALID_SERIAL, tid=INVALID_TID)
serial=INVALID_SERIAL,
tid=INVALID_TID)
calls = self.app.dm.mockGetNamedCalls('getObject') calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
...@@ -111,13 +100,9 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -111,13 +100,9 @@ class StorageStorageHandlerTests(NeoTestBase):
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )}) self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )})
conn = Mock({}) conn = Mock({})
packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, oid=INVALID_OID,
oid=INVALID_OID, serial=INVALID_SERIAL, tid=INVALID_TID)
serial=INVALID_SERIAL,
tid=INVALID_TID)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn) self.checkAnswerObject(conn)
...@@ -127,19 +112,16 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -127,19 +112,16 @@ class StorageStorageHandlerTests(NeoTestBase):
app.pt = Mock() app.pt = Mock()
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packets.AskTIDs() self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None)
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, packet, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0)
def test_25_askTIDs2(self): def test_25_askTIDs2(self):
# well case => answer # well case => answer
conn = Mock({}) conn = Mock({})
packet = Packets.AskTIDs()
packet.setId(0)
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.operation.askTIDs(conn, packet, 1, 2, 1) self.operation.askTIDs(conn, 1, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ]) calls[0].checkArgs(1, 1, 1, [1, ])
...@@ -148,8 +130,6 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -148,8 +130,6 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_25_askTIDs3(self): def test_25_askTIDs3(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = Mock({}) conn = Mock({})
packet = Packets.AskTIDs()
packet.setId(0)
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({ self.app.pt = Mock({
...@@ -157,7 +137,7 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -157,7 +137,7 @@ class StorageStorageHandlerTests(NeoTestBase):
'getPartitions': 1, 'getPartitions': 1,
'getAssignedPartitionList': [0], 'getAssignedPartitionList': [0],
}) })
self.operation.askTIDs(conn, packet, 1, 2, INVALID_PARTITION) self.operation.askTIDs(conn, 1, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
...@@ -169,23 +149,20 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -169,23 +149,20 @@ class StorageStorageHandlerTests(NeoTestBase):
app = self.app app = self.app
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packets.AskObjectHistory() self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn,
packet.setId(0) 1, 1, None)
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, packet, 1, 1, None)
self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self): def test_26_askObjectHistory2(self):
# first case: empty history # first case: empty history
packet = Packets.AskObjectHistory()
packet.setId(0)
conn = Mock({}) conn = Mock({})
self.app.dm = Mock({'getObjectHistory': None}) self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2) self.operation.askObjectHistory(conn, INVALID_OID, 1, 2)
self.checkAnswerObjectHistory(conn) self.checkAnswerObjectHistory(conn)
# second case: not empty history # second case: not empty history
conn = Mock({}) conn = Mock({})
self.app.dm = Mock({'getObjectHistory': [('', 0, ), ]}) self.app.dm = Mock({'getObjectHistory': [('', 0, ), ]})
self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2) self.operation.askObjectHistory(conn, INVALID_OID, 1, 2)
self.checkAnswerObjectHistory(conn) self.checkAnswerObjectHistory(conn)
def test_25_askOIDs1(self): def test_25_askOIDs1(self):
...@@ -194,20 +171,16 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -194,20 +171,16 @@ class StorageStorageHandlerTests(NeoTestBase):
app.pt = Mock() app.pt = Mock()
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packets.AskOIDs() self.checkProtocolErrorRaised(self.operation.askOIDs, conn, 1, 1, None)
packet.setId(0)
self.checkProtocolErrorRaised(self.operation.askOIDs, conn, packet, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getOIDList')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getOIDList')), 0)
def test_25_askOIDs2(self): def test_25_askOIDs2(self):
# well case > answer OIDs # well case > answer OIDs
conn = Mock({}) conn = Mock({})
packet = Packets.AskOIDs()
packet.setId(0)
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.operation.askOIDs(conn, packet, 1, 2, 1) self.operation.askOIDs(conn, 1, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList') calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ]) calls[0].checkArgs(1, 1, 1, [1, ])
...@@ -216,8 +189,6 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -216,8 +189,6 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_25_askOIDs3(self): def test_25_askOIDs3(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = Mock({}) conn = Mock({})
packet = Packets.AskOIDs()
packet.setId(0)
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.app.pt = Mock({ self.app.pt = Mock({
...@@ -225,7 +196,7 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -225,7 +196,7 @@ class StorageStorageHandlerTests(NeoTestBase):
'getPartitions': 1, 'getPartitions': 1,
'getAssignedPartitionList': [0], 'getAssignedPartitionList': [0],
}) })
self.operation.askOIDs(conn, packet, 1, 2, INVALID_PARTITION) self.operation.askOIDs(conn, 1, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList') calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
......
...@@ -85,7 +85,6 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -85,7 +85,6 @@ class StorageVerificationHandlerTests(NeoTestBase):
def test_07_askLastIDs(self): def test_07_askLastIDs(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
packet = Mock()
# return invalid if db store nothing # return invalid if db store nothing
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
...@@ -93,7 +92,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -93,7 +92,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
last_ptid = '\x01' * 8 last_ptid = '\x01' * 8
last_oid = '\x02' * 8 last_oid = '\x02' * 8
self.app.pt = Mock({'getID': last_ptid}) self.app.pt = Mock({'getID': last_ptid})
self.verification.askLastIDs(conn, packet) self.verification.askLastIDs(conn)
oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True) oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
self.assertEqual(oid, INVALID_OID) self.assertEqual(oid, INVALID_OID)
self.assertEqual(tid, INVALID_TID) self.assertEqual(tid, INVALID_TID)
...@@ -125,7 +124,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -125,7 +124,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
checksum, value) values (0, 4, 0, 0, '')""") checksum, value) values (0, 4, 0, 0, '')""")
self.app.dm.commit() self.app.dm.commit()
self.app.dm.setLastOID(last_oid) self.app.dm.setLastOID(last_oid)
self.verification.askLastIDs(conn, packet) self.verification.askLastIDs(conn)
self.checkAnswerLastIDs(conn) self.checkAnswerLastIDs(conn)
oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True) oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
self.assertEqual(oid, last_oid) self.assertEqual(oid, last_oid)
...@@ -134,7 +133,6 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -134,7 +133,6 @@ class StorageVerificationHandlerTests(NeoTestBase):
def test_08_askPartitionTable(self): def test_08_askPartitionTable(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
packet = Mock()
# try to get unknown offset # try to get unknown offset
self.assertEqual(len(self.app.pt.getNodeList()), 0) self.assertEqual(len(self.app.pt.getNodeList()), 0)
self.assertFalse(self.app.pt.hasOffset(1)) self.assertFalse(self.app.pt.hasOffset(1))
...@@ -142,7 +140,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -142,7 +140,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServer" : False}) "isServer" : False})
self.verification.askPartitionTable(conn, packet, [1,]) self.verification.askPartitionTable(conn, [1])
ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True) ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True)
self.assertEqual(len(row_list), 1) self.assertEqual(len(row_list), 1)
offset, rows = row_list[0] offset, rows = row_list[0]
...@@ -159,7 +157,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -159,7 +157,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServer" : False}) "isServer" : False})
self.verification.askPartitionTable(conn, packet, [1,]) self.verification.askPartitionTable(conn, [1])
ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True) ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True)
self.assertEqual(len(row_list), 1) self.assertEqual(len(row_list), 1)
offset, rows = row_list[0] offset, rows = row_list[0]
...@@ -172,9 +170,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -172,9 +170,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
"isServer": False, "isServer": False,
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
packet = Packets.NotifyPartitionChanges() self.verification.notifyPartitionChanges(conn, 1, ())
self.verification.notifyPartitionChanges(conn, packet, 1, ()) self.verification.notifyPartitionChanges(conn, 0, ())
self.verification.notifyPartitionChanges(conn, packet, 0, ())
self.assertEqual(self.app.pt.getID(), 1) self.assertEqual(self.app.pt.getID(), 1)
# new node # new node
...@@ -182,7 +179,6 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -182,7 +179,6 @@ class StorageVerificationHandlerTests(NeoTestBase):
"isServer": False, "isServer": False,
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
packet = Packets.NotifyPartitionChanges()
new_uuid = self.getNewUUID() new_uuid = self.getNewUUID()
cell = (0, new_uuid, CellStates.UP_TO_DATE) cell = (0, new_uuid, CellStates.UP_TO_DATE)
self.app.nm.createStorage(uuid=new_uuid) self.app.nm.createStorage(uuid=new_uuid)
...@@ -190,7 +186,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -190,7 +186,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
self.app.dm = Mock({ }) self.app.dm = Mock({ })
ptid, self.ptid = self.getTwoIDs() ptid, self.ptid = self.getTwoIDs()
# pt updated # pt updated
self.verification.notifyPartitionChanges(conn, packet, ptid, (cell, )) self.verification.notifyPartitionChanges(conn, ptid, (cell, ))
# check db update # check db update
calls = self.app.dm.mockGetNamedCalls('changePartitionTable') calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
...@@ -201,23 +197,19 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -201,23 +197,19 @@ class StorageVerificationHandlerTests(NeoTestBase):
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
self.assertFalse(self.app.operational) self.assertFalse(self.app.operational)
packet = Packets.StopOperation() self.verification.startOperation(conn)
self.verification.startOperation(conn, packet)
self.assertTrue(self.app.operational) self.assertTrue(self.app.operational)
def test_12_stopOperation(self): def test_12_stopOperation(self):
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packets.StopOperation() self.assertRaises(OperationFailure, self.verification.stopOperation, conn)
self.assertRaises(OperationFailure, self.verification.stopOperation, conn, packet)
def test_13_askUnfinishedTransactions(self): def test_13_askUnfinishedTransactions(self):
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packets.AskUnfinishedTransactions() self.verification.askUnfinishedTransactions(conn)
packet.setId(0)
self.verification.askUnfinishedTransactions(conn, packet)
(tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True) (tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 0) self.assertEqual(len(tid_list), 0)
...@@ -228,9 +220,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -228,9 +220,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
self.app.dm.commit() self.app.dm.commit()
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packets.AskUnfinishedTransactions() self.verification.askUnfinishedTransactions(conn)
packet.setId(0)
self.verification.askUnfinishedTransactions(conn, packet)
(tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True) (tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1) self.assertEqual(len(tid_list), 1)
self.assertEqual(u64(tid_list[0]), 4) self.assertEqual(u64(tid_list[0]), 4)
...@@ -239,9 +229,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -239,9 +229,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
# ask from client conn with no data # ask from client conn with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packets.AskTransactionInformation() self.verification.askTransactionInformation(conn, p64(1))
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(1))
code, message = self.checkErrorPacket(conn, decode=True) code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND) self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
...@@ -255,9 +243,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -255,9 +243,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
# object from trans # object from trans
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packets.AskTransactionInformation() self.verification.askTransactionInformation(conn, p64(1))
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(1))
tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True) tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1) self.assertEqual(u64(tid), 1)
self.assertEqual(user, 'u2') self.assertEqual(user, 'u2')
...@@ -268,9 +254,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -268,9 +254,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
# object from ttrans # object from ttrans
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packets.AskTransactionInformation() self.verification.askTransactionInformation(conn, p64(3))
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(3))
tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True) tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 3) self.assertEqual(u64(tid), 3)
self.assertEqual(user, 'u1') self.assertEqual(user, 'u1')
...@@ -283,9 +267,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -283,9 +267,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': True }) 'isServer': True })
# find the one in trans # find the one in trans
packet = Packets.AskTransactionInformation() self.verification.askTransactionInformation(conn, p64(1))
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(1))
tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True) tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1) self.assertEqual(u64(tid), 1)
self.assertEqual(user, 'u2') self.assertEqual(user, 'u2')
...@@ -296,9 +278,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -296,9 +278,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
# do not find the one in ttrans # do not find the one in ttrans
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': True }) 'isServer': True })
packet = Packets.AskTransactionInformation() self.verification.askTransactionInformation(conn, p64(2))
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(2))
code, message = self.checkErrorPacket(conn, decode=True) code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND) self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
...@@ -306,9 +286,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -306,9 +286,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packets.AskObjectPresent() self.verification.askObjectPresent(conn, p64(1), p64(2))
packet.setId(0)
self.verification.askObjectPresent(conn, packet, p64(1), p64(2))
code, message = self.checkErrorPacket(conn, decode=True) code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.OID_NOT_FOUND) self.assertEqual(code, ErrorCodes.OID_NOT_FOUND)
...@@ -319,9 +297,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -319,9 +297,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
self.app.dm.commit() self.app.dm.commit()
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packets.AskObjectPresent() self.verification.askObjectPresent(conn, p64(1), p64(2))
packet.setId(0)
self.verification.askObjectPresent(conn, packet, p64(1), p64(2))
oid, tid = self.checkAnswerObjectPresent(conn, decode=True) oid, tid = self.checkAnswerObjectPresent(conn, decode=True)
self.assertEqual(u64(tid), 2) self.assertEqual(u64(tid), 2)
self.assertEqual(u64(oid), 1) self.assertEqual(u64(oid), 1)
...@@ -330,14 +306,13 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -330,14 +306,13 @@ class StorageVerificationHandlerTests(NeoTestBase):
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packets.AskObjectPresent() self.verification.deleteTransaction(conn, p64(1))
self.verification.deleteTransaction(conn, packet, p64(1))
# client connection with data # client connection with data
self.app.dm.begin() self.app.dm.begin()
self.app.dm.query("""insert into tobj (oid, serial, compression, self.app.dm.query("""insert into tobj (oid, serial, compression,
checksum, value) values (1, 2, 0, 0, '')""") checksum, value) values (1, 2, 0, 0, '')""")
self.app.dm.commit() self.app.dm.commit()
self.verification.deleteTransaction(conn, packet, p64(2)) self.verification.deleteTransaction(conn, p64(2))
result = self.app.dm.query('select * from tobj') result = self.app.dm.query('select * from tobj')
self.assertEquals(len(result), 0) self.assertEquals(len(result), 0)
...@@ -347,8 +322,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -347,8 +322,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
'isServer': False }) 'isServer': False })
dm = Mock() dm = Mock()
self.app.dm = dm self.app.dm = dm
packet = Packets.CommitTransaction() self.verification.commitTransaction(conn, p64(1))
self.verification.commitTransaction(conn, packet, p64(1))
self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1) self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
call = dm.mockGetNamedCalls("finishTransaction")[0] call = dm.mockGetNamedCalls("finishTransaction")[0]
tid = call.getParam(0) tid = call.getParam(0)
......
...@@ -54,8 +54,7 @@ class BootstrapManagerTests(NeoTestBase): ...@@ -54,8 +54,7 @@ class BootstrapManagerTests(NeoTestBase):
def testHandleNotReady(self): def testHandleNotReady(self):
# the primary is not ready # the primary is not ready
conn = Mock({}) conn = Mock({})
packet = Mock({}) self.bootstrap.notReady(conn, '')
self.bootstrap.notReady(conn, packet, '')
self.checkClosed(conn) self.checkClosed(conn)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
......
...@@ -44,39 +44,39 @@ class HandlerTests(NeoTestBase): ...@@ -44,39 +44,39 @@ class HandlerTests(NeoTestBase):
conn = Mock({'getAddress': ('127.0.0.1', 10000)}) conn = Mock({'getAddress': ('127.0.0.1', 10000)})
packet = self.getFakePacket() packet = self.getFakePacket()
# all is ok # all is ok
self.setFakeMethod(lambda c, p: None) self.setFakeMethod(lambda c: None)
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
# raise UnexpectedPacketError # raise UnexpectedPacketError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c, p): raise UnexpectedPacketError('fake packet') def fake(c): raise UnexpectedPacketError('fake packet')
self.setFakeMethod(fake) self.setFakeMethod(fake)
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise PacketMalformedError # raise PacketMalformedError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c, p): raise PacketMalformedError('message') def fake(c): raise PacketMalformedError('message')
self.setFakeMethod(fake) self.setFakeMethod(fake)
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise BrokenNodeDisallowedError # raise BrokenNodeDisallowedError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c, p): raise BrokenNodeDisallowedError def fake(c): raise BrokenNodeDisallowedError
self.setFakeMethod(fake) self.setFakeMethod(fake)
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise NotReadyError # raise NotReadyError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c, p): raise NotReadyError def fake(c): raise NotReadyError
self.setFakeMethod(fake) self.setFakeMethod(fake)
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise ProtocolError # raise ProtocolError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c, p): raise ProtocolError def fake(c): raise ProtocolError
self.setFakeMethod(fake) self.setFakeMethod(fake)
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment