Commit ec6ddd6b authored by Grégory Wisniewski's avatar Grégory Wisniewski

Update client tests.


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@522 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent ed0cb88e
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
import logging import logging
from mock import Mock, ReturnValues, ReturnIterator from mock import Mock, ReturnValues, ReturnIterator
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
from neo.tests.base import NeoTestBase
from neo.client.app import Application from neo.client.app import Application
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError, \ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError, \
NEOStorageConflictError NEOStorageConflictError
...@@ -75,7 +76,7 @@ class TestSocketConnector(object): ...@@ -75,7 +76,7 @@ class TestSocketConnector(object):
def send(self, msg): def send(self, msg):
raise NotImplementedError raise NotImplementedError
class ClientApplicationTest(unittest.TestCase): class ClientApplicationTest(NeoTestBase):
def setUp(self): def setUp(self):
logging.basicConfig(level = logging.WARNING) logging.basicConfig(level = logging.WARNING)
...@@ -90,13 +91,6 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -90,13 +91,6 @@ class ClientApplicationTest(unittest.TestCase):
app.num_replicas = 2 app.num_replicas = 2
return app return app
def getUUID(self):
uuid = INVALID_UUID
while uuid == INVALID_UUID:
uuid = os.urandom(16)
self.uuid = uuid
return uuid
def makeOID(self, value=None): def makeOID(self, value=None):
from random import randint from random import randint
if value is None: if value is None:
...@@ -165,9 +159,6 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -165,9 +159,6 @@ class ClientApplicationTest(unittest.TestCase):
self.assertTrue(isinstance(packet, Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet._type, packet_type) self.assertEquals(packet._type, packet_type)
def checkAnswer(self, conn, packet_type):
self.checkPacketSent(conn, packet_type, 'answer')
def checkAsk(self, conn, packet_type): def checkAsk(self, conn, packet_type):
self.checkPacketSent(conn, packet_type, 'ask') self.checkPacketSent(conn, packet_type, 'ask')
...@@ -251,7 +242,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -251,7 +242,7 @@ class ClientApplicationTest(unittest.TestCase):
app.cp = Mock({ 'getConnForNode' : conn}) app.cp = Mock({ 'getConnForNode' : conn})
app.local_var.asked_object = -1 app.local_var.asked_object = -1
self.assertRaises(NEOStorageNotFoundError, app.load, oid) self.assertRaises(NEOStorageNotFoundError, app.load, oid)
self.checkAsk(conn, ASK_OBJECT) self.checkAskObject(conn)
# object found on storage nodes and put in cache # object found on storage nodes and put in cache
packet = protocol.answerObject(*an_object[1:]) packet = protocol.answerObject(*an_object[1:])
conn = Mock({ conn = Mock({
...@@ -262,7 +253,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -262,7 +253,7 @@ class ClientApplicationTest(unittest.TestCase):
app.local_var.asked_object = an_object app.local_var.asked_object = an_object
result = app.load(oid) result = app.load(oid)
self.assertEquals(result, ('', tid1)) self.assertEquals(result, ('', tid1))
self.checkAsk(conn, ASK_OBJECT) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue(oid in mq)
# object is now cached, try to reload it # object is now cached, try to reload it
conn = Mock({ conn = Mock({
...@@ -291,7 +282,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -291,7 +282,7 @@ class ClientApplicationTest(unittest.TestCase):
app.cp = Mock({ 'getConnForNode' : conn}) app.cp = Mock({ 'getConnForNode' : conn})
app.local_var.asked_object = -1 app.local_var.asked_object = -1
self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2) self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2)
self.checkAsk(conn, ASK_OBJECT) self.checkAskObject(conn)
# object should not have been cached # object should not have been cached
self.assertFalse(oid in mq) self.assertFalse(oid in mq)
# now a cached version ewxists but should not be hit # now a cached version ewxists but should not be hit
...@@ -307,7 +298,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -307,7 +298,7 @@ class ClientApplicationTest(unittest.TestCase):
app.local_var.asked_object = another_object app.local_var.asked_object = another_object
result = app.loadSerial(oid, tid1) result = app.loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT') self.assertEquals(result, 'RIGHT')
self.checkAsk(conn, ASK_OBJECT) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue(oid in mq)
def test_loadBefore(self): def test_loadBefore(self):
...@@ -328,7 +319,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -328,7 +319,7 @@ class ClientApplicationTest(unittest.TestCase):
app.cp = Mock({ 'getConnForNode' : conn}) app.cp = Mock({ 'getConnForNode' : conn})
app.local_var.asked_object = -1 app.local_var.asked_object = -1
self.assertRaises(NEOStorageNotFoundError, app.loadBefore, oid, tid2) self.assertRaises(NEOStorageNotFoundError, app.loadBefore, oid, tid2)
self.checkAsk(conn, ASK_OBJECT) self.checkAskObject(conn)
# no previous versions -> return None # no previous versions -> return None
an_object = (1, oid, tid2, INVALID_SERIAL, 0, 0, '') an_object = (1, oid, tid2, INVALID_SERIAL, 0, 0, '')
packet = protocol.answerObject(*an_object[1:]) packet = protocol.answerObject(*an_object[1:])
...@@ -355,7 +346,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -355,7 +346,7 @@ class ClientApplicationTest(unittest.TestCase):
app.local_var.asked_object = another_object app.local_var.asked_object = another_object
result = app.loadBefore(oid, tid1) result = app.loadBefore(oid, tid1)
self.assertEquals(result, ('RIGHT', tid1, tid2)) self.assertEquals(result, ('RIGHT', tid1, tid2))
self.checkAsk(conn, ASK_OBJECT) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue(oid in mq)
def test_tpc_begin(self): def test_tpc_begin(self):
...@@ -389,7 +380,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -389,7 +380,7 @@ class ClientApplicationTest(unittest.TestCase):
app.dispatcher = Mock({ app.dispatcher = Mock({
}) })
app.tpc_begin(transaction=txn, tid=None) app.tpc_begin(transaction=txn, tid=None)
self.checkAsk(app.master_conn, ASK_NEW_TID) self.checkAskNewTid(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn) self.checkDispatcherRegisterCalled(app, app.master_conn)
# check attributes # check attributes
self.assertTrue(app.local_var.txn is txn) self.assertTrue(app.local_var.txn is txn)
...@@ -441,7 +432,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -441,7 +432,7 @@ class ClientApplicationTest(unittest.TestCase):
self.assertTrue(oid not in app.local_var.data_dict) self.assertTrue(oid not in app.local_var.data_dict)
self.assertEquals(app.conflict_serial, tid) self.assertEquals(app.conflict_serial, tid)
self.assertEquals(app.local_var.object_stored, (-1, tid)) self.assertEquals(app.local_var.object_stored, (-1, tid))
self.checkAsk(conn, ASK_STORE_OBJECT) self.checkAskStoreObject(conn)
self.checkDispatcherRegisterCalled(app, conn) self.checkDispatcherRegisterCalled(app, conn)
def test_store3(self): def test_store3(self):
...@@ -470,7 +461,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -470,7 +461,7 @@ class ClientApplicationTest(unittest.TestCase):
self.assertEquals(app.local_var.object_stored, (oid, tid)) self.assertEquals(app.local_var.object_stored, (oid, tid))
self.assertEquals(app.local_var.data_dict.get(oid, None), 'DATA') self.assertEquals(app.local_var.data_dict.get(oid, None), 'DATA')
self.assertNotEquals(app.conflict_serial, tid) self.assertNotEquals(app.conflict_serial, tid)
self.checkAsk(conn, ASK_STORE_OBJECT) self.checkAskStoreObject(conn)
self.checkDispatcherRegisterCalled(app, conn) self.checkDispatcherRegisterCalled(app, conn)
def test_tpc_vote1(self): def test_tpc_vote1(self):
...@@ -534,7 +525,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -534,7 +525,7 @@ class ClientApplicationTest(unittest.TestCase):
app.dispatcher = Mock() app.dispatcher = Mock()
app.tpc_begin(txn, tid) app.tpc_begin(txn, tid)
app.tpc_vote(txn) app.tpc_vote(txn)
self.checkAsk(conn, ASK_STORE_TRANSACTION) self.checkAskStoreTransaction(conn)
self.checkDispatcherRegisterCalled(app, conn) self.checkDispatcherRegisterCalled(app, conn)
def test_tpc_abort1(self): def test_tpc_abort1(self):
...@@ -627,7 +618,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -627,7 +618,7 @@ class ClientApplicationTest(unittest.TestCase):
self.assertRaises(NEOStorageError, app.tpc_finish, txn, hook) self.assertRaises(NEOStorageError, app.tpc_finish, txn, hook)
self.assertTrue(self.f_called) self.assertTrue(self.f_called)
self.assertEquals(self.f_called_with_tid, tid) self.assertEquals(self.f_called_with_tid, tid)
self.checkAsk(app.master_conn, FINISH_TRANSACTION) self.checkFinishTransaction(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn) self.checkDispatcherRegisterCalled(app, app.master_conn)
def test_tpc_finish3(self): def test_tpc_finish3(self):
...@@ -652,7 +643,7 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -652,7 +643,7 @@ class ClientApplicationTest(unittest.TestCase):
app.tpc_finish(txn, hook) app.tpc_finish(txn, hook)
self.assertTrue(self.f_called) self.assertTrue(self.f_called)
self.assertEquals(self.f_called_with_tid, tid) self.assertEquals(self.f_called_with_tid, tid)
self.checkAsk(app.master_conn, FINISH_TRANSACTION) self.checkFinishTransaction(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn) self.checkDispatcherRegisterCalled(app, app.master_conn)
self.assertEquals(app.local_var.tid, None) self.assertEquals(app.local_var.tid, None)
self.assertEquals(app.local_var.txn, None) self.assertEquals(app.local_var.txn, None)
...@@ -864,13 +855,13 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -864,13 +855,13 @@ class ClientApplicationTest(unittest.TestCase):
self.test_ok = True self.test_ok = True
_waitMessage_old = Application._waitMessage _waitMessage_old = Application._waitMessage
Application._waitMessage = _waitMessage_hook Application._waitMessage = _waitMessage_hook
packet = protocol.askNewOIDs(10) packet = protocol.askNewTID()
try: try:
app._askStorage(conn, packet) app._askStorage(conn, packet)
finally: finally:
Application._waitMessage = _waitMessage_old Application._waitMessage = _waitMessage_old
# check packet sent, connection unlocked and dispatcher updated # check packet sent, connection unlocked and dispatcher updated
self.assertEquals(len(conn.mockGetNamedCalls('ask')), 1) self.checkAskNewTid(conn)
self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1) self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1)
self.assertEquals(len(app.dispatcher.mockGetNamedCalls('register')), 1) self.assertEquals(len(app.dispatcher.mockGetNamedCalls('register')), 1)
# and _waitMessage called # and _waitMessage called
...@@ -889,13 +880,13 @@ class ClientApplicationTest(unittest.TestCase): ...@@ -889,13 +880,13 @@ class ClientApplicationTest(unittest.TestCase):
self.test_ok = True self.test_ok = True
_waitMessage_old = Application._waitMessage _waitMessage_old = Application._waitMessage
Application._waitMessage = _waitMessage_hook Application._waitMessage = _waitMessage_hook
packet = protocol.askNewOIDs(10) packet = protocol.askNewTID()
try: try:
app._askPrimary(packet) app._askPrimary(packet)
finally: finally:
Application._waitMessage = _waitMessage_old Application._waitMessage = _waitMessage_old
# check packet sent, connection locked during process and dispatcher updated # check packet sent, connection locked during process and dispatcher updated
self.assertEquals(len(conn.mockGetNamedCalls('ask')), 1) self.checkAskNewTid(conn)
self.assertEquals(len(conn.mockGetNamedCalls('lock')), 1) self.assertEquals(len(conn.mockGetNamedCalls('lock')), 1)
self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1) self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1)
self.assertEquals(len(app.dispatcher.mockGetNamedCalls('register')), 1) self.assertEquals(len(app.dispatcher.mockGetNamedCalls('register')), 1)
......
This diff is collapsed.
...@@ -179,6 +179,21 @@ class NeoTestBase(unittest.TestCase): ...@@ -179,6 +179,21 @@ class NeoTestBase(unittest.TestCase):
def checkAskObjectPresent(self, conn, **kw): def checkAskObjectPresent(self, conn, **kw):
return self.checkAskPacket(conn, protocol.ASK_OBJECT_PRESENT, **kw) return self.checkAskPacket(conn, protocol.ASK_OBJECT_PRESENT, **kw)
def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, protocol.ASK_OBJECT, **kw)
def checkAskStoreObject(self, conn, **kw):
return self.checkAskPacket(conn, protocol.ASK_STORE_OBJECT, **kw)
def checkAskStoreTransaction(self, conn, **kw):
return self.checkAskPacket(conn, protocol.ASK_STORE_TRANSACTION, **kw)
def checkFinishTransaction(self, conn, **kw):
return self.checkAskPacket(conn, protocol.FINISH_TRANSACTION, **kw)
def checkAskNewTid(self, conn, **kw):
return self.checkAskPacket(conn, protocol.ASK_NEW_TID, **kw)
def checkAcceptNodeIdentification(self, conn, **kw): def checkAcceptNodeIdentification(self, conn, **kw):
return self.checkAnswerPacket(conn, protocol.ACCEPT_NODE_IDENTIFICATION, **kw) return self.checkAnswerPacket(conn, protocol.ACCEPT_NODE_IDENTIFICATION, **kw)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment