Commit 63a70d44 authored by Kirill Smelkov's avatar Kirill Smelkov

Merge remote-tracking branch 'origin/master' into t

* origin/master:
  master,client: ignore notifications before complete initialization
  Update comment that was still showing UUIDs instead of node ids
  Remove dead code found by coverage
  Remove some useless unit tests
parents 83c151e9 36b2d141
...@@ -220,6 +220,7 @@ class Application(ThreadedApplication): ...@@ -220,6 +220,7 @@ class Application(ThreadedApplication):
ask = self._ask ask = self._ask
handler = self.primary_bootstrap_handler handler = self.primary_bootstrap_handler
while 1: while 1:
self.ignore_invalidations = True
# Get network connection to primary master # Get network connection to primary master
while 1: while 1:
if self.primary_master_node is not None: if self.primary_master_node is not None:
......
...@@ -120,6 +120,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -120,6 +120,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
db = app.getDB() db = app.getDB()
db is None or db.invalidateCache() db is None or db.invalidateCache()
app.last_tid = ltid app.last_tid = ltid
app.ignore_invalidations = False
def answerTransactionFinished(self, conn, _, tid, callback, cache_dict): def answerTransactionFinished(self, conn, _, tid, callback, cache_dict):
app = self.app app = self.app
...@@ -159,6 +160,8 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -159,6 +160,8 @@ class PrimaryNotificationsHandler(MTEventHandler):
def invalidateObjects(self, conn, tid, oid_list): def invalidateObjects(self, conn, tid, oid_list):
app = self.app app = self.app
if app.ignore_invalidations:
return
app.last_tid = tid app.last_tid = tid
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
......
...@@ -281,8 +281,8 @@ class NodeManager(object): ...@@ -281,8 +281,8 @@ class NodeManager(object):
self._address_dict.pop(node.getAddress(), None) self._address_dict.pop(node.getAddress(), None)
# - a master known by address but without UUID # - a master known by address but without UUID
self._uuid_dict.pop(node.getUUID(), None) self._uuid_dict.pop(node.getUUID(), None)
self.__dropSet(self._state_dict, node.getState(), node) self._state_dict[node.getState()].remove(node)
self.__dropSet(self._type_dict, node.getType(), node) self._type_dict[node.getType()].remove(node)
uuid = node.getUUID() uuid = node.getUUID()
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
self._master_db.discard(node.getAddress()) self._master_db.discard(node.getAddress())
...@@ -305,10 +305,6 @@ class NodeManager(object): ...@@ -305,10 +305,6 @@ class NodeManager(object):
def _updateUUID(self, node, old_uuid): def _updateUUID(self, node, old_uuid):
self.__update(self._uuid_dict, old_uuid, node.getUUID(), node) self.__update(self._uuid_dict, old_uuid, node.getUUID(), node)
def __dropSet(self, set_dict, key, node):
if key in set_dict:
set_dict[key].remove(node)
def __updateSet(self, set_dict, old_key, new_key, node): def __updateSet(self, set_dict, old_key, new_key, node):
""" Update a set index from old to new key """ """ Update a set index from old to new key """
if old_key in set_dict: if old_key in set_dict:
......
...@@ -243,10 +243,10 @@ class PartitionTable(object): ...@@ -243,10 +243,10 @@ class PartitionTable(object):
"""Help debugging partition table management. """Help debugging partition table management.
Output sample: Output sample:
pt: node 0: 67ae354b4ed240a0594d042cf5c01b28, R pt: node 0: S1, R
pt: node 1: a68a01e8bf93e287bd505201c1405bc2, R pt: node 1: S2, R
pt: node 2: ad7ffe8ceef4468a0c776f3035c7a543, R pt: node 2: S3, R
pt: node 3: df57d7298678996705cd0092d84580f4, R pt: node 3: S4, R
pt: 00: .UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU. pt: 00: .UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.
pt: 11: U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U pt: 11: U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U|.UU.|U..U
......
...@@ -114,6 +114,7 @@ class BackupApplication(object): ...@@ -114,6 +114,7 @@ class BackupApplication(object):
del bootstrap, node del bootstrap, node
if num_partitions != pt.getPartitions(): if num_partitions != pt.getPartitions():
raise RuntimeError("inconsistent number of partitions") raise RuntimeError("inconsistent number of partitions")
self.ignore_invalidations = True
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskPartitionTable()) conn.ask(Packets.AskPartitionTable())
......
...@@ -29,7 +29,8 @@ class BackupHandler(EventHandler): ...@@ -29,7 +29,8 @@ class BackupHandler(EventHandler):
self.app.pt.load(ptid, row_list, self.app.nm) self.app.pt.load(ptid, row_list, self.app.nm)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) if self.app.pt.filled():
self.app.pt.update(ptid, cell_list, self.app.nm)
# NOTE invalidation from M -> Mb (all partitions) # NOTE invalidation from M -> Mb (all partitions)
def answerLastTransaction(self, conn, tid): def answerLastTransaction(self, conn, tid):
...@@ -38,10 +39,13 @@ class BackupHandler(EventHandler): ...@@ -38,10 +39,13 @@ class BackupHandler(EventHandler):
app.invalidatePartitions(tid, set(xrange(app.pt.getPartitions()))) app.invalidatePartitions(tid, set(xrange(app.pt.getPartitions())))
else: # upstream DB is empty else: # upstream DB is empty
assert app.app.getLastTransaction() == tid assert app.app.getLastTransaction() == tid
app.ignore_invalidations = False
# NOTE invalidation from M -> Mb # NOTE invalidation from M -> Mb
def invalidateObjects(self, conn, tid, oid_list): def invalidateObjects(self, conn, tid, oid_list):
app = self.app app = self.app
if app.ignore_invalidations:
return
getPartition = app.app.pt.getPartition getPartition = app.app.pt.getPartition
partition_set = set(map(getPartition, oid_list)) partition_set = set(map(getPartition, oid_list))
partition_set.add(getPartition(tid)) partition_set.add(getPartition(tid))
......
...@@ -56,7 +56,6 @@ UNIT_TEST_MODULES = [ ...@@ -56,7 +56,6 @@ UNIT_TEST_MODULES = [
'neo.tests.master.testTransactions', 'neo.tests.master.testTransactions',
# storage application # storage application
'neo.tests.storage.testClientHandler', 'neo.tests.storage.testClientHandler',
'neo.tests.storage.testInitializationHandler',
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp', 'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
......
...@@ -16,17 +16,14 @@ ...@@ -16,17 +16,14 @@
import threading import threading
import unittest import unittest
from cPickle import dumps
from mock import Mock, ReturnValues from mock import Mock, ReturnValues
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
from .. import NeoUnitTestBase, buildUrlFromString from .. import NeoUnitTestBase, buildUrlFromString
from neo.client.app import Application from neo.client.app import Application
from neo.client.cache import test as testCache from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.lib.protocol import NodeTypes, Packets, Errors, \ from neo.lib.protocol import NodeTypes, Packets, Errors, UUID_NAMESPACES
INVALID_PARTITION, UUID_NAMESPACES
from neo.lib.util import makeChecksum from neo.lib.util import makeChecksum
import time
class Dispatcher(object): class Dispatcher(object):
...@@ -60,9 +57,6 @@ def _ask(self, conn, packet, handler=None, **kw): ...@@ -60,9 +57,6 @@ def _ask(self, conn, packet, handler=None, **kw):
handler.dispatch(conn, conn.fakeReceived()) handler.dispatch(conn, conn.fakeReceived())
return self.getHandlerData() return self.getHandlerData()
def resolving_tryToResolveConflict(oid, conflict_serial, serial, data):
return data
def failing_tryToResolveConflict(oid, conflict_serial, serial, data): def failing_tryToResolveConflict(oid, conflict_serial, serial, data):
raise ConflictError raise ConflictError
...@@ -88,10 +82,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -88,10 +82,8 @@ class ClientApplicationTests(NeoUnitTestBase):
# some helpers # some helpers
def _begin(self, app, txn, tid=None): def _begin(self, app, txn, tid):
txn_context = app._txn_container.new(txn) txn_context = app._txn_container.new(txn)
if tid is None:
tid = self.makeTID()
txn_context['ttid'] = tid txn_context['ttid'] = tid
return txn_context return txn_context
...@@ -160,28 +152,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -160,28 +152,6 @@ class ClientApplicationTests(NeoUnitTestBase):
testCache = testCache testCache = testCache
def test_registerDB(self):
app = self.getApp()
dummy_db = []
app.registerDB(dummy_db, None)
self.assertTrue(app.getDB() is dummy_db)
def test_new_oid(self):
app = self.getApp()
test_msg_id = 50
test_oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02']
response_packet = Packets.AnswerNewOIDs(test_oid_list[:])
response_packet.setId(0)
app.master_conn = Mock({'getNextId': test_msg_id, '_addPacket': None,
'expectMessage': None,
# Test-specific method
'fakeReceived': response_packet})
new_oid = app.new_oid()
self.assertTrue(new_oid in test_oid_list)
self.assertEqual(len(app.new_oid_list), 1)
self.assertTrue(app.new_oid_list[0] in test_oid_list)
self.assertNotEqual(app.new_oid_list[0], new_oid)
def test_load(self): def test_load(self):
app = self.getApp() app = self.getApp()
cache = app._cache cache = app._cache
...@@ -340,84 +310,13 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -340,84 +310,13 @@ class ClientApplicationTests(NeoUnitTestBase):
app.store(oid, tid, 'DATA', None, txn) app.store(oid, tid, 'DATA', None, txn)
self.checkAskStoreObject(conn) self.checkAskStoreObject(conn)
txn_context['queue'].put((conn, packet, {})) txn_context['queue'].put((conn, packet, {}))
app.waitStoreResponses(txn_context, resolving_tryToResolveConflict) app.waitStoreResponses(txn_context, None) # no conflict in this test
self.assertEqual(txn_context['object_stored_counter_dict'][oid], self.assertEqual(txn_context['object_stored_counter_dict'][oid],
{tid: {uuid}}) {tid: {uuid}})
self.assertEqual(txn_context['cache_dict'][oid], 'DATA') self.assertEqual(txn_context['cache_dict'][oid], 'DATA')
self.assertFalse(oid in txn_context['data_dict']) self.assertFalse(oid in txn_context['data_dict'])
self.assertFalse(oid in txn_context['conflict_serial_dict']) self.assertFalse(oid in txn_context['conflict_serial_dict'])
def test_tpc_vote1(self):
app = self.getApp()
txn = self.makeTransactionObject()
# invalid transaction > StorageTransactionError
self.assertRaises(StorageTransactionError, app.tpc_vote, txn,
resolving_tryToResolveConflict)
def test_tpc_vote3(self):
app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject()
self._begin(app, txn, tid)
# response -> OK
packet = Packets.AnswerStoreTransaction(tid=tid)
packet.setId(0)
conn = Mock({
'getNextId': 1,
'fakeReceived': packet,
})
node = Mock({
'__hash__': 1,
'__repr__': 'FakeNode',
})
app.cp = self.getConnectionPool([(node, conn)])
app.tpc_vote(txn, resolving_tryToResolveConflict)
self.checkAskStoreTransaction(conn)
self.checkDispatcherRegisterCalled(app, conn)
def test_tpc_abort1(self):
# ignore mismatch transaction
app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject()
old_txn = object()
self._begin(app, old_txn, tid)
app.master_conn = Mock()
conn = Mock()
cell = Mock()
app.cp = Mock({'getConnForCell': ReturnValues(None, cell)})
app.tpc_abort(txn)
# no packet sent
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn)
txn_context = app._txn_container.get(old_txn)
self.assertTrue(txn_context['txn'] is old_txn)
self.assertEqual(txn_context['ttid'], tid)
def test_tpc_abort2(self):
# 2 nodes : 1 transaction in the first, 2 objects in the second
# connections to each node should received only one packet to abort
# and transaction must also be aborted on the master node
# for simplicity, just one cell per partition
oid1, oid2 = self.makeOID(2), self.makeOID(4) # on partition 0
app, tid = self.getApp(), self.makeTID(1) # on partition 1
txn = self.makeTransactionObject()
txn_context = self._begin(app, txn, tid)
app.master_conn = Mock({'__hash__': 0})
app.num_partitions = 2
cell1 = Mock({ 'getNode': 'NODE1', '__hash__': 1 })
cell2 = Mock({ 'getNode': 'NODE2', '__hash__': 2 })
conn1, conn2 = Mock({ 'getNextId': 1, }), Mock({ 'getNextId': 2, })
app.cp = Mock({ 'getConnForNode': ReturnValues(conn1, conn2), })
# fake data
txn_context['involved_nodes'].update([cell1, cell2])
app.tpc_abort(txn)
# will check if there was just one call/packet :
self.checkNotifyPacket(conn1, Packets.AbortTransaction)
self.checkNotifyPacket(conn2, Packets.AbortTransaction)
self.checkNotifyPacket(app.master_conn, Packets.AbortTransaction)
self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
def test_tpc_abort3(self): def test_tpc_abort3(self):
""" check that abort is sent to all nodes involved in the transaction """ """ check that abort is sent to all nodes involved in the transaction """
app = self.getApp() app = self.getApp()
...@@ -471,37 +370,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -471,37 +370,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkAbortTransaction(conn2) self.checkAbortTransaction(conn2)
self.checkAbortTransaction(conn3) self.checkAbortTransaction(conn3)
def test_tpc_finish1(self):
# transaction mismatch: raise
app = self.getApp()
txn = self.makeTransactionObject()
app.master_conn = Mock()
self.assertRaises(StorageTransactionError, app.tpc_finish, txn, None)
# no packet sent
self.checkNoPacketSent(app.master_conn)
def test_tpc_finish3(self):
# transaction is finished
app = self.getApp()
tid = self.makeTID()
ttid = self.makeTID()
txn = self.makeTransactionObject()
txn_context = self._begin(app, txn, tid)
self.f_called = False
self.f_called_with_tid = None
packet = Packets.AnswerTransactionFinished(ttid, tid)
packet.setId(0)
app.master_conn = Mock({
'getNextId': 1,
'getAddress': ('127.0.0.1', 10010),
'fakeReceived': packet,
})
txn_context['voted'] = None
app.tpc_finish(txn, None)
self.checkAskFinishTransaction(app.master_conn)
#self.checkDispatcherRegisterCalled(app, app.master_conn)
self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
def test_undo1(self): def test_undo1(self):
# invalid transaction # invalid transaction
app = self.getApp() app = self.getApp()
...@@ -668,13 +536,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -668,13 +536,9 @@ class ClientApplicationTests(NeoUnitTestBase):
conn.ask = lambda p, queue=None, **kw: \ conn.ask = lambda p, queue=None, **kw: \
type(p) is Packets.AskObjectUndoSerial and \ type(p) is Packets.AskObjectUndoSerial and \
queue.put((conn, undo_serial, kw)) queue.put((conn, undo_serial, kw))
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
raise Exception, 'Test called conflict resolution, but there ' \
'is no conflict in this test !'
# The undo # The undo
txn = self.beginTransaction(app, tid=tid3) txn = self.beginTransaction(app, tid=tid3)
app.undo(tid1, txn, tryToResolveConflict) app.undo(tid1, txn, None) # no conflict resolution in this test
# Checking what happened # Checking what happened
moid, mserial, mdata, mdata_serial = store_marker[0] moid, mserial, mdata, mdata_serial = store_marker[0]
self.assertEqual(moid, oid0) self.assertEqual(moid, oid0)
...@@ -682,67 +546,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -682,67 +546,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertEqual(mdata, None) self.assertEqual(mdata, None)
self.assertEqual(mdata_serial, tid0) self.assertEqual(mdata_serial, tid0)
def test_undoLog(self):
app = self.getApp()
app.num_partitions = 2
uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
# two nodes, two partition, two transaction, two objects :
tid1, tid2 = self.makeTID(1), self.makeTID(2)
oid1, oid2 = self.makeOID(1), self.makeOID(2)
# TIDs packets supplied by _ask hook
# TXN info packets
extension = dumps({})
p1 = Packets.AnswerTIDs([tid1])
p2 = Packets.AnswerTIDs([tid2])
p3 = Packets.AnswerTransactionInformation(tid1, '', '',
extension, False, (oid1, ))
p4 = Packets.AnswerTransactionInformation(tid2, '', '',
extension, False, (oid2, ))
p1.setId(0)
p2.setId(1)
p3.setId(2)
p4.setId(3)
conn = Mock({
'getNextId': 1,
'getUUID': ReturnValues(uuid1, uuid2),
'fakeGetApp': app,
'fakeReceived': ReturnValues(p3, p4),
'getAddress': ('127.0.0.1', 10021),
})
asked = []
def answerTIDs(packet):
conn = getConnection({'getAddress': packet})
app.nm.createStorage(address=conn.getAddress())
def ask(p, queue, **kw):
asked.append(p)
queue.put((conn, packet, kw))
conn.ask = ask
return conn
app.dispatcher = Dispatcher()
app.pt = Mock({
'getNodeSet': (Mock(), Mock()),
})
app.cp = Mock({
'getConnForNode': ReturnValues(answerTIDs(p1), answerTIDs(p2)),
'iterateForObject': [(Mock(), conn)]
})
def txn_filter(info):
return info['id'] > '\x00' * 8
first = 0
last = 4
result = app.undoLog(first, last, filter=txn_filter)
pfirst, plast, ppartition = asked.pop().decode()
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
pfirst, plast, ppartition = asked.pop().decode()
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
self.assertEqual(result[0]['id'], tid1)
self.assertEqual(result[1]['id'], tid2)
self.assertFalse(asked)
def test_connectToPrimaryNode(self): def test_connectToPrimaryNode(self):
# here we have three master nodes : # here we have three master nodes :
# the connection to the first will fail # the connection to the first will fail
...@@ -800,65 +603,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -800,65 +603,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertTrue(app.master_conn is not None) self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational()) self.assertTrue(app.pt.operational())
def test_askPrimary(self):
""" _askPrimary is private but test it anyway """
app = self.getApp()
conn = Mock()
app.master_conn = conn
app.primary_handler = Mock()
self.test_ok = False
def _ask_hook(app, conn, packet, handler=None):
conn.ask(packet)
self.assertTrue(handler is app.primary_handler)
self.test_ok = True
_ask_old = Application._ask
Application._ask = _ask_hook
packet = Packets.AskBeginTransaction()
packet.setId(0)
try:
app._askPrimary(packet)
finally:
Application._ask = _ask_old
# check packet sent, connection locked during process and dispatcher updated
self.checkAskNewTid(conn)
self.checkDispatcherRegisterCalled(app, conn)
# and _ask called
self.assertTrue(self.test_ok)
# check NEOStorageError is raised when the primary connection is lost
app.master_conn = None
# check disabled since we reconnect to pmn
#self.assertRaises(NEOStorageError, app._askPrimary, packet)
def test_threadContextIsolation(self):
""" Thread context properties must not be visible across instances
while remaining in the same thread """
app1 = self.getApp()
app1_local = app1._thread_container
app2 = self.getApp()
app2_local = app2._thread_container
property_id = 'thread_context_test'
value = 'value'
self.assertFalse(hasattr(app1_local, property_id))
self.assertFalse(hasattr(app2_local, property_id))
setattr(app1_local, property_id, value)
self.assertEqual(getattr(app1_local, property_id), value)
self.assertFalse(hasattr(app2_local, property_id))
def test_pack(self):
app = self.getApp()
marker = []
def askPrimary(packet):
marker.append(packet)
app._askPrimary = askPrimary
# XXX: could not identify a value causing TimeStamp to return ZERO_TID
#self.assertRaises(NEOStorageError, app.pack, )
self.assertEqual(len(marker), 0)
now = time.time()
app.pack(now)
self.assertEqual(len(marker), 1)
self.assertEqual(type(marker[0]), Packets.AskPack)
# XXX: how to validate packet content ?
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import time, unittest import time, unittest
from mock import Mock, ReturnValues from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.client.app import ConnectionPool from neo.client.app import ConnectionPool
...@@ -24,27 +24,6 @@ from neo.client import pool ...@@ -24,27 +24,6 @@ from neo.client import pool
class ConnectionPoolTests(NeoUnitTestBase): class ConnectionPoolTests(NeoUnitTestBase):
def test_removeConnection(self):
app = None
pool = ConnectionPool(app)
test_node_uuid = self.getStorageUUID()
other_node_uuid = self.getStorageUUID()
test_node = Mock({'getUUID': test_node_uuid})
other_node = Mock({'getUUID': other_node_uuid})
# Test sanity check
self.assertEqual(getattr(pool, 'connection_dict', None), {})
# Call must not raise if node is not known
self.assertEqual(len(pool.connection_dict), 0)
pool.removeConnection(test_node)
# Test that removal with another uuid doesn't affect entry
pool.connection_dict[test_node_uuid] = None
self.assertEqual(len(pool.connection_dict), 1)
pool.removeConnection(other_node)
self.assertEqual(len(pool.connection_dict), 1)
# Test that removeConnection works
pool.removeConnection(test_node)
self.assertEqual(len(pool.connection_dict), 0)
# TODO: test getConnForNode (requires splitting complex functionalities) # TODO: test getConnForNode (requires splitting complex functionalities)
def test_CellSortKey(self): def test_CellSortKey(self):
...@@ -81,30 +60,6 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -81,30 +60,6 @@ class ConnectionPoolTests(NeoUnitTestBase):
pool = ConnectionPool(app) pool = ConnectionPool(app)
self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next) self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next)
def test_iterateForObject_connectionRefused(self):
# connection refused at the first try
oid = self.getOID(1)
node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'})
app = Mock()
app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': ReturnValues(None, conn)})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
def test_iterateForObject_connectionAccepted(self):
# connection accepted
oid = self.getOID(1)
node = Mock({'__repr__': 'node', 'isRunning': True})
cell = Mock({'__repr__': 'cell', 'getNode': node})
conn = Mock({'__repr__': 'conn'})
app = Mock()
app.pt = Mock({'getCellList': [cell]})
pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': conn})
self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,102 +17,16 @@ ...@@ -17,102 +17,16 @@
import unittest import unittest
from mock import Mock from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.node import NodeManager from neo.client.handlers.master import PrimaryAnswersHandler
from neo.client.handlers.master import PrimaryNotificationsHandler, \
PrimaryAnswersHandler
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
class MasterHandlerTests(NeoUnitTestBase): class MasterHandlerTests(NeoUnitTestBase):
def setUp(self):
super(MasterHandlerTests, self).setUp()
self.db = Mock()
self.app = Mock({'getDB': self.db,
'txn_contexts': ()})
self.app.nm = NodeManager()
self.app.dispatcher = Mock()
self._next_port = 3000
def getKnownMaster(self):
node = self.app.nm.createMaster(address=(
self.local_ip, self._next_port),
)
self._next_port += 1
conn = self.getFakeConnection(address=node.getAddress())
node.setConnection(conn)
return node, conn
class MasterNotificationsHandlerTests(MasterHandlerTests):
def setUp(self):
super(MasterNotificationsHandlerTests, self).setUp()
self.handler = PrimaryNotificationsHandler(self.app)
def test_connectionClosed(self):
conn = self.getFakeConnection()
node = Mock()
self.app.master_conn = conn
self.app.primary_master_node = node
self.handler.connectionClosed(conn)
self.assertEqual(self.app.master_conn, None)
self.assertEqual(self.app.primary_master_node, None)
def test_invalidateObjects(self):
conn = self.getFakeConnection()
tid = self.getNextTID()
oid1, oid2, oid3 = self.getOID(1), self.getOID(2), self.getOID(3)
self.app._cache = Mock({
'invalidate': None,
})
self.handler.invalidateObjects(conn, tid, [oid1, oid3])
cache_calls = self.app._cache.mockGetNamedCalls('invalidate')
self.assertEqual(len(cache_calls), 2)
cache_calls[0].checkArgs(oid1, tid)
cache_calls[1].checkArgs(oid3, tid)
invalidation_calls = self.db.mockGetNamedCalls('invalidate')
self.assertEqual(len(invalidation_calls), 1)
invalidation_calls[0].checkArgs(tid, [oid1, oid3])
def test_notifyPartitionChanges(self):
conn = self.getFakeConnection()
self.app.pt = Mock({'filled': True})
ptid = 0
cell_list = (Mock(), Mock())
self.handler.notifyPartitionChanges(conn, ptid, cell_list)
update_calls = self.app.pt.mockGetNamedCalls('update')
self.assertEqual(len(update_calls), 1)
update_calls[0].checkArgs(ptid, cell_list, self.app.nm)
class MasterAnswersHandlerTests(MasterHandlerTests):
def setUp(self): def setUp(self):
super(MasterAnswersHandlerTests, self).setUp() super(MasterHandlerTests, self).setUp()
self.app = Mock()
self.handler = PrimaryAnswersHandler(self.app) self.handler = PrimaryAnswersHandler(self.app)
def test_answerBeginTransaction(self):
tid = self.getNextTID()
conn = self.getFakeConnection()
self.handler.answerBeginTransaction(conn, tid)
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
def test_answerNewOIDs(self):
conn = self.getFakeConnection()
oid1, oid2, oid3 = self.getOID(0), self.getOID(1), self.getOID(2)
self.handler.answerNewOIDs(conn, [oid1, oid2, oid3])
self.assertEqual(self.app.new_oid_list, [oid3, oid2, oid1])
def test_answerTransactionFinished(self):
conn = self.getFakeConnection()
ttid2 = self.getNextTID()
tid2 = self.getNextTID()
self.handler.answerTransactionFinished(conn, ttid2, tid2)
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid2)
def test_answerPack(self): def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False) self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
# Check it doesn't raise # Check it doesn't raise
......
...@@ -19,8 +19,6 @@ from mock import Mock ...@@ -19,8 +19,6 @@ from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.client.handlers.storage import StorageAnswersHandler from neo.client.handlers.storage import StorageAnswersHandler
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError
from ZODB.TimeStamp import TimeStamp
class StorageAnswerHandlerTests(NeoUnitTestBase): class StorageAnswerHandlerTests(NeoUnitTestBase):
...@@ -29,20 +27,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -29,20 +27,6 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
self.app = Mock() self.app = Mock()
self.handler = StorageAnswersHandler(self.app) self.handler = StorageAnswersHandler(self.app)
def _checkHandlerData(self, ref):
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ref)
def test_answerObject(self):
conn = self.getFakeConnection()
oid = self.getOID(0)
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.handler.answerObject(conn, *the_object)
self._checkHandlerData(the_object[1:])
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict, def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict): conflict_serial_dict, resolved_conflict_serial_dict):
app = Mock({ app = Mock({
...@@ -119,86 +103,11 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -119,86 +103,11 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {}, self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {},
{oid: {tid}}).answerStoreObject(conn, 1, oid, tid_2) {oid: {tid}}).answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_4(self):
uuid = self.getStorageUUID()
conn = self.getFakeConnection(uuid=uuid)
oid = self.getOID(0)
tid = self.getNextTID()
# no conflict
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {}
h = self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict)
h.app.getHandlerData()['cache_dict'] = {oid: None}
h.answerStoreObject(conn, 0, oid, tid)
self.assertFalse(oid in conflict_serial_dict)
self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {tid: {uuid}})
def test_answerTransactionInformation(self):
conn = self.getFakeConnection()
tid = self.getNextTID()
user = 'USER'
desc = 'DESC'
ext = 'EXT'
packed = False
oid_list = [self.getOID(0), self.getOID(1)]
self.handler.answerTransactionInformation(conn, tid, user, desc, ext,
packed, oid_list)
self._checkHandlerData(({
'time': TimeStamp(tid).timeTime(),
'user_name': user,
'description': desc,
'id': tid,
'oids': oid_list,
'packed': packed,
}, ext))
def test_oidNotFound(self):
conn = self.getFakeConnection()
self.assertRaises(NEOStorageNotFoundError, self.handler.oidNotFound,
conn, 'message')
def test_oidDoesNotExist(self):
conn = self.getFakeConnection()
self.assertRaises(NEOStorageDoesNotExistError,
self.handler.oidDoesNotExist, conn, 'message')
def test_tidNotFound(self): def test_tidNotFound(self):
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.assertRaises(NEOStorageNotFoundError, self.handler.tidNotFound, self.assertRaises(NEOStorageNotFoundError, self.handler.tidNotFound,
conn, 'message') conn, 'message')
def test_answerTIDs(self):
uuid = self.getStorageUUID()
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
tid_list = [tid1, tid2]
conn = self.getFakeConnection(uuid=uuid)
tid_set = set()
StorageAnswersHandler(Mock()).answerTIDs(conn, tid_list, tid_set)
self.assertEqual(tid_set, set(tid_list))
def test_answerObjectUndoSerial(self):
uuid = self.getStorageUUID()
conn = self.getFakeConnection(uuid=uuid)
oid1 = self.getOID(1)
oid2 = self.getOID(2)
tid0 = self.getNextTID()
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
undo_dict = {}
handler = StorageAnswersHandler(Mock())
handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]}, undo_dict)
self.assertEqual(undo_dict, {oid1: [tid0, tid1]})
handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]}, undo_dict)
self.assertEqual(undo_dict, {
oid1: [tid0, tid1],
oid2: [tid2, tid3],
})
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -258,23 +258,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -258,23 +258,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.assertEqual(node.getUUID(), new_uuid) self.assertEqual(node.getUUID(), new_uuid)
self.assertNotEqual(node.getUUID(), uuid) self.assertNotEqual(node.getUUID(), uuid)
def _getNodeList(self):
return [x.asTuple() for x in self.app.nm.getList()]
def __getClient(self): def __getClient(self):
uuid = self.getClientUUID() uuid = self.getClientUUID()
conn = self.getFakeConnection(uuid=uuid, address=self.client_address) conn = self.getFakeConnection(uuid=uuid, address=self.client_address)
self.app.nm.createClient(uuid=uuid, address=self.client_address) self.app.nm.createClient(uuid=uuid, address=self.client_address)
return conn return conn
def __getMaster(self, port=1000, register=True):
uuid = self.getMasterUUID()
address = ('127.0.0.1', port)
conn = self.getFakeConnection(uuid=uuid, address=address)
if register:
self.app.nm.createMaster(uuid=uuid, address=address)
return conn
def testRequestIdentification1(self): def testRequestIdentification1(self):
""" Check with a non-master node, must be refused """ """ Check with a non-master node, must be refused """
conn = self.__getClient() conn = self.__getClient()
......
...@@ -92,25 +92,6 @@ class MasterAppTests(NeoUnitTestBase): ...@@ -92,25 +92,6 @@ class MasterAppTests(NeoUnitTestBase):
self.checkNoPacketSent(master_conn) self.checkNoPacketSent(master_conn)
self.checkNotifyNodeInformation(storage_conn) self.checkNotifyNodeInformation(storage_conn)
def test_storageReadinessAPI(self):
uuid_1 = self.getStorageUUID()
uuid_2 = self.getStorageUUID()
self.assertFalse(self.app.isStorageReady(uuid_1))
self.assertFalse(self.app.isStorageReady(uuid_2))
# Must not raise, nor change readiness
self.app.setStorageNotReady(uuid_1)
self.assertFalse(self.app.isStorageReady(uuid_1))
self.assertFalse(self.app.isStorageReady(uuid_2))
# Mark as ready, only one must change
self.app.setStorageReady(uuid_1)
self.assertTrue(self.app.isStorageReady(uuid_1))
self.assertFalse(self.app.isStorageReady(uuid_2))
self.app.setStorageReady(uuid_2)
# Mark not ready, only one must change
self.app.setStorageNotReady(uuid_1)
self.assertFalse(self.app.isStorageReady(uuid_1))
self.assertTrue(self.app.isStorageReady(uuid_2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -36,11 +36,8 @@ class MasterRecoveryTests(NeoUnitTestBase): ...@@ -36,11 +36,8 @@ class MasterRecoveryTests(NeoUnitTestBase):
node.setState(NodeStates.RUNNING) node.setState(NodeStates.RUNNING)
# define some variable to simulate client and storage node # define some variable to simulate client and storage node
self.client_port = 11022
self.storage_port = 10021 self.storage_port = 10021
self.master_port = 10011 self.master_port = 10011
self.master_address = ('127.0.0.1', self.master_port)
self.storage_address = ('127.0.0.1', self.storage_port)
def _tearDown(self, success): def _tearDown(self, success):
self.app.close() self.app.close()
...@@ -58,16 +55,11 @@ class MasterRecoveryTests(NeoUnitTestBase): ...@@ -58,16 +55,11 @@ class MasterRecoveryTests(NeoUnitTestBase):
return uuid return uuid
# Tests # Tests
def test_01_connectionClosed(self):
uuid = self.identifyToMasterNode(node_type=NodeTypes.MASTER, port=self.master_port)
conn = self.getFakeConnection(uuid, self.master_address)
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.RUNNING)
self.recovery.connectionClosed(conn)
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.TEMPORARILY_DOWN)
def test_10_answerPartitionTable(self): def test_10_answerPartitionTable(self):
# XXX: This test does much less that it seems, because all 'for' loops
# iterate over empty lists. Currently, only testRecovery covers
# some paths in NodeManager._createNode: apart from that, we could
# delete it entirely.
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port) uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
# not from target node, ignore # not from target node, ignore
......
...@@ -17,11 +17,9 @@ ...@@ -17,11 +17,9 @@
import unittest import unittest
from mock import Mock from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates, Packets from neo.lib.protocol import NodeTypes, Packets
from neo.master.handlers.storage import StorageServiceHandler from neo.master.handlers.storage import StorageServiceHandler
from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
from neo.lib.exception import StoppedOperation
class MasterStorageHandlerTests(NeoUnitTestBase): class MasterStorageHandlerTests(NeoUnitTestBase):
...@@ -34,23 +32,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -34,23 +32,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.app.pt.clear() self.app.pt.clear()
self.app.em = Mock() self.app.em = Mock()
self.service = StorageServiceHandler(self.app) self.service = StorageServiceHandler(self.app)
self.client_handler = ClientServiceHandler(self.app)
# define some variable to simulate client and storage node
self.client_port = 11022
self.storage_port = 10021
self.master_port = 10010
self.master_address = ('127.0.0.1', self.master_port)
self.client_address = ('127.0.0.1', self.client_port)
self.storage_address = ('127.0.0.1', self.storage_port)
def _allocatePort(self): def _allocatePort(self):
self.port = getattr(self, 'port', 1000) + 1 self.port = getattr(self, 'port', 1000) + 1
return self.port return self.port
def _getClient(self):
return self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
ip='127.0.0.1', port=self._allocatePort())
def _getStorage(self): def _getStorage(self):
return self.identifyToMasterNode(node_type=NodeTypes.STORAGE, return self.identifyToMasterNode(node_type=NodeTypes.STORAGE,
ip='127.0.0.1', port=self._allocatePort()) ip='127.0.0.1', port=self._allocatePort())
...@@ -67,98 +53,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -67,98 +53,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
return (node, conn) return (node, conn)
def test_answerInformationLocked_1(self):
"""
Master must refuse to lock if the TID is greater than the last TID
"""
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
self.app.tm.setLastTID(tid1)
self.assertTrue(tid1 < tid2)
node, conn = self.identifyToMasterNode()
self.checkProtocolErrorRaised(self.service.answerInformationLocked,
conn, tid2)
self.checkNoPacketSent(conn)
def test_answerInformationLocked_2(self):
"""
Master must:
- lock each storage
- notify the client
- invalidate other clients
- unlock storages
"""
# one client and two storages required
client_1, client_conn_1 = self._getClient()
client_2, client_conn_2 = self._getClient()
storage_1, storage_conn_1 = self._getStorage()
storage_2, storage_conn_2 = self._getStorage()
uuid_list = storage_1.getUUID(), storage_2.getUUID()
oid_list = self.getOID(), self.getOID()
msg_id = 1
# register a transaction
ttid = self.app.tm.begin(client_1)
tid = self.app.tm.prepare(ttid, 1, oid_list, uuid_list,
msg_id)
self.assertTrue(ttid in self.app.tm)
# the first storage acknowledge the lock
self.service.answerInformationLocked(storage_conn_1, ttid)
self.checkNoPacketSent(client_conn_1)
self.checkNoPacketSent(client_conn_2)
self.checkNoPacketSent(storage_conn_1)
self.checkNoPacketSent(storage_conn_2)
# then the second
self.service.answerInformationLocked(storage_conn_2, ttid)
self.checkAnswerTransactionFinished(client_conn_1)
self.checkInvalidateObjects(client_conn_2)
self.checkNotifyUnlockInformation(storage_conn_1)
self.checkNotifyUnlockInformation(storage_conn_2)
def test_13_askUnfinishedTransactions(self):
service = self.service
node, conn = self.identifyToMasterNode()
# give a uuid
service.askUnfinishedTransactions(conn)
packet = self.checkAnswerUnfinishedTransactions(conn)
max_tid, tid_list = packet.decode()
self.assertEqual(tid_list, [])
# create some transaction
node, conn = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port=self.client_port)
ttid = self.app.tm.begin(node)
self.app.tm.prepare(ttid, 1,
[self.getOID(1)], [node.getUUID()], 1)
conn = self.getFakeConnection(node.getUUID(), self.storage_address)
service.askUnfinishedTransactions(conn)
max_tid, tid_list = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1)
def test_connectionClosed(self):
method = self.service.connectionClosed
state = NodeStates.TEMPORARILY_DOWN
# define two nodes
node1, conn1 = self.identifyToMasterNode()
node2, conn2 = self.identifyToMasterNode(port=10022)
node1.setRunning()
node2.setRunning()
self.assertEqual(node1.getState(), NodeStates.RUNNING)
self.assertEqual(node2.getState(), NodeStates.RUNNING)
# filled the pt
self.app.pt.make(self.app.nm.getStorageList())
self.assertTrue(self.app.pt.filled())
self.assertTrue(self.app.pt.operational())
# drop one node
lptid = self.app.pt.getID()
method(conn1)
self.assertEqual(node1.getState(), state)
self.assertTrue(lptid < self.app.pt.getID())
# drop the second, no storage node left
lptid = self.app.pt.getID()
self.assertEqual(node2.getState(), NodeStates.RUNNING)
self.assertRaises(StoppedOperation, method, conn2)
self.assertEqual(node2.getState(), state)
self.assertEqual(lptid, self.app.pt.getID())
def test_answerPack(self): def test_answerPack(self):
# Note: incoming status has no meaning here, so it's left to False. # Note: incoming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage() node1, conn1 = self._getStorage()
...@@ -183,13 +77,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -183,13 +77,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.assertTrue(status) self.assertTrue(status)
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
def test_notifyReady(self):
node, conn = self._getStorage()
uuid = node.getUUID()
self.assertFalse(self.app.isStorageReady(uuid))
self.service.notifyReady(conn)
self.assertTrue(self.app.isStorageReady(uuid))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,13 +20,10 @@ from struct import pack ...@@ -20,13 +20,10 @@ from struct import pack
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes
from neo.lib.util import packTID, unpackTID, addTID from neo.lib.util import packTID, unpackTID, addTID
from neo.master.transactions import Transaction, TransactionManager from neo.master.transactions import TransactionManager
class testTransactionManager(NeoUnitTestBase): class testTransactionManager(NeoUnitTestBase):
def makeTID(self, i):
return pack('!Q', i)
def makeOID(self, i): def makeOID(self, i):
return pack('!Q', i) return pack('!Q', i)
...@@ -35,58 +32,6 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -35,58 +32,6 @@ class testTransactionManager(NeoUnitTestBase):
node = Mock({'getUUID': uuid, '__hash__': uuid, '__repr__': 'FakeNode'}) node = Mock({'getUUID': uuid, '__hash__': uuid, '__repr__': 'FakeNode'})
return uuid, node return uuid, node
def testTransaction(self):
# test data
node = Mock({'__repr__': 'Node'})
tid = self.makeTID(1)
ttid = self.makeTID(2)
oid_list = (oid1, oid2) = [self.makeOID(1), self.makeOID(2)]
uuid_list = (uuid1, uuid2) = [self.getStorageUUID(),
self.getStorageUUID()]
msg_id = 1
# create transaction object
txn = Transaction(node, ttid)
txn.prepare(tid, oid_list, uuid_list, msg_id)
self.assertEqual(txn.getUUIDList(), uuid_list)
self.assertEqual(txn.getOIDList(), oid_list)
# lock nodes one by one
self.assertFalse(txn.lock(uuid1))
self.assertTrue(txn.lock(uuid2))
# check that repr() works
repr(txn)
def testManager(self):
# test data
node = Mock({'__hash__': 1})
msg_id = 1
oid_list = (oid1, oid2) = self.makeOID(1), self.makeOID(2)
uuid_list = uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
client_uuid = self.getClientUUID()
# create transaction manager
callback = Mock()
txnman = TransactionManager(on_commit=callback)
self.assertFalse(txnman.hasPending())
self.assertEqual(txnman.registerForNotification(uuid1), [])
# begin the transaction
ttid = txnman.begin(node)
self.assertTrue(ttid is not None)
self.assertEqual(len(txnman.registerForNotification(uuid1)), 1)
self.assertTrue(txnman.hasPending())
# prepare the transaction
tid = txnman.prepare(ttid, 1, oid_list, uuid_list, msg_id)
self.assertTrue(txnman.hasPending())
self.assertEqual(txnman.registerForNotification(uuid1), [ttid])
txn = txnman[ttid]
self.assertEqual(txn.getTID(), tid)
self.assertEqual(txn.getUUIDList(), list(uuid_list))
self.assertEqual(txn.getOIDList(), list(oid_list))
# lock nodes
txnman.lock(ttid, uuid1)
self.assertEqual(len(callback.getNamedCalls('__call__')), 0)
txnman.lock(ttid, uuid2)
self.assertEqual(len(callback.getNamedCalls('__call__')), 1)
self.assertEqual(txnman.registerForNotification(uuid1), [])
def test_storageLost(self): def test_storageLost(self):
client1 = Mock({'__hash__': 1}) client1 = Mock({'__hash__': 1})
client2 = Mock({'__hash__': 2}) client2 = Mock({'__hash__': 2})
...@@ -95,7 +40,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -95,7 +40,7 @@ class testTransactionManager(NeoUnitTestBase):
storage_2_uuid = self.getStorageUUID() storage_2_uuid = self.getStorageUUID()
oid_list = [self.makeOID(1), ] oid_list = [self.makeOID(1), ]
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(None)
# Transaction 1: 2 storage nodes involved, one will die and the other # Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock # already answered node lock
msg_id_1 = 1 msg_id_1 = 1
...@@ -172,7 +117,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -172,7 +117,7 @@ class testTransactionManager(NeoUnitTestBase):
Note: this implementation might change later, for more parallelism. Note: this implementation might change later, for more parallelism.
""" """
client_uuid, client = self.makeNode(NodeTypes.CLIENT) client_uuid, client = self.makeNode(NodeTypes.CLIENT)
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(None)
# With a requested TID, lock spans from begin to remove # With a requested TID, lock spans from begin to remove
ttid1 = self.getNextTID() ttid1 = self.getNextTID()
ttid2 = self.getNextTID() ttid2 = self.getNextTID()
...@@ -189,7 +134,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -189,7 +134,7 @@ class testTransactionManager(NeoUnitTestBase):
def testClientDisconectsAfterBegin(self): def testClientDisconectsAfterBegin(self):
client_uuid1, node1 = self.makeNode(NodeTypes.CLIENT) client_uuid1, node1 = self.makeNode(NodeTypes.CLIENT)
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(None)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tm.begin(node1, tid1) tm.begin(node1, tid1)
......
...@@ -17,23 +17,13 @@ ...@@ -17,23 +17,13 @@
import unittest import unittest
from mock import Mock, ReturnValues from mock import Mock, ReturnValues
from collections import deque from collections import deque
from neo.lib.util import makeChecksum
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.transactions import ConflictError
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.protocol import INVALID_PARTITION, INVALID_TID, INVALID_OID from neo.lib.protocol import INVALID_TID, INVALID_OID, Packets, LockState
from neo.lib.protocol import Packets, LockState, ZERO_HASH
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs):
conn = self.getFakeConnection(address=("127.0.0.1", self.master_port),
is_server=_listening)
# hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)
def setUp(self): def setUp(self):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1) self.prepareDatabase(number=1)
...@@ -53,7 +43,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -53,7 +43,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
pmn = self.app.nm.getMasterList()[0] pmn = self.app.nm.getMasterList()[0]
pmn.setUUID(self.master_uuid) pmn.setUUID(self.master_uuid)
self.app.primary_master_node = pmn self.app.primary_master_node = pmn
self.master_port = 10010
def _tearDown(self, success): def _tearDown(self, success):
self.app.close() self.app.close()
...@@ -63,17 +52,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -63,17 +52,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
def _getConnection(self, uuid=None): def _getConnection(self, uuid=None):
return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000)) return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000))
def _checkTransactionsAborted(self, uuid):
calls = self.app.tm.mockGetNamedCalls('abortFor')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(uuid)
def test_connectionLost(self):
uuid = self.getClientUUID()
self.app.nm.createClient(uuid=uuid)
conn = self._getConnection(uuid=uuid)
self.operation.connectionClosed(conn)
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = self._getConnection() conn = self._getConnection()
...@@ -81,15 +59,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -81,15 +59,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askTransactionInformation(conn, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self):
# answer
conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
dm = Mock({ "getTransaction": (oid_list, 'user', 'desc', '', False), })
self.app.dm = dm
self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkAnswerTransactionInformation(conn)
def test_24_askObject1(self): def test_24_askObject1(self):
# delayed response # delayed response
conn = self._getConnection() conn = self._getConnection()
...@@ -103,33 +72,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -103,33 +72,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0) self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
def test_24_askObject2(self):
# invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None})
conn = self._getConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=INVALID_OID,
serial=INVALID_TID, tid=INVALID_TID)
calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEqual(len(self.app.event_queue), 0)
self.assertEqual(len(calls), 1)
calls[0].checkArgs(INVALID_OID, INVALID_TID, INVALID_TID)
self.checkErrorPacket(conn)
def test_24_askObject3(self):
# object found => answer
serial = self.getNextTID()
next_serial = self.getNextTID()
oid = self.getOID(1)
tid = self.getNextTID()
H = "0" * 20
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self._getConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
self.assertEqual(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
def test_25_askTIDs1(self): def test_25_askTIDs1(self):
# invalid offsets => error # invalid offsets => error
app = self.app app = self.app
...@@ -151,23 +93,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -151,23 +93,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
calls[0].checkArgs(1, 1, [1, ]) calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn) self.checkAnswerTids(conn)
def test_25_askTIDs3(self):
# invalid partition => answer usable partitions
conn = self._getConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
self.operation.askTIDs(conn, 1, 2, INVALID_PARTITION)
self.assertEqual(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, [0])
self.checkAnswerTids(conn)
def test_26_askObjectHistory1(self): def test_26_askObjectHistory1(self):
# invalid offsets => error # invalid offsets => error
app = self.app app = self.app
...@@ -177,87 +102,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -177,87 +102,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
1, 1, None) 1, 1, None)
self.assertEqual(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEqual(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self):
oid1, oid2 = self.getOID(1), self.getOID(2)
# first case: empty history
conn = self._getConnection()
self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, oid1, 1, 2)
self.checkErrorPacket(conn)
# second case: not empty history
conn = self._getConnection()
serial = self.getNextTID()
self.app.dm = Mock({'getObjectHistory': [(serial, 0, ), ]})
self.operation.askObjectHistory(conn, oid2, 1, 2)
self.checkAnswerObjectHistory(conn)
def _getObject(self):
oid = self.getOID(0)
serial = self.getNextTID()
data = 'DATA'
return (oid, serial, 1, makeChecksum(data), data)
def _checkStoreObjectCalled(self, *args):
calls = self.app.tm.mockGetNamedCalls('storeObject')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args)
def test_askStoreObject1(self):
# no conflict => answer
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject()
self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, None, tid, False)
self._checkStoreObjectCalled(tid, serial, oid, comp,
checksum, data, None, False)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True)
self.assertEqual(pconflicting, 0)
self.assertEqual(poid, oid)
self.assertEqual(pserial, serial)
def test_askStoreObjectWithDataTID(self):
# same as test_askStoreObject1, but with a non-None data_tid value
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, ZERO_HASH,
'', data_tid, tid, False)
self._checkStoreObjectCalled(tid, serial, oid, comp,
None, None, data_tid, False)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True)
self.assertEqual(pconflicting, 0)
self.assertEqual(poid, oid)
self.assertEqual(pserial, serial)
def test_askStoreObject2(self):
# conflict error
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
locking_tid = self.getNextTID(tid)
def fakeStoreObject(*args):
raise ConflictError(locking_tid)
self.app.tm.storeObject = fakeStoreObject
oid, serial, comp, checksum, data = self._getObject()
self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, None, tid, False)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True)
self.assertEqual(pconflicting, 1)
self.assertEqual(poid, oid)
self.assertEqual(pserial, locking_tid)
def test_abortTransaction(self):
conn = self._getConnection()
tid = self.getNextTID()
self.operation.abortTransaction(conn, tid)
calls = self.app.tm.mockGetNamedCalls('abort')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
def test_askObjectUndoSerial(self): def test_askObjectUndoSerial(self):
conn = self._getConnection(uuid=self.getClientUUID()) conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID() tid = self.getNextTID()
......
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeStates, NodeTypes, NotReadyError, \ from neo.lib.protocol import NodeTypes, BrokenNodeDisallowedError
BrokenNodeDisallowedError
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.identification import IdentificationHandler from neo.storage.handlers.identification import IdentificationHandler
...@@ -39,31 +37,6 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -39,31 +37,6 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
del self.app del self.app
super(StorageIdentificationHandlerTests, self)._tearDown(success) super(StorageIdentificationHandlerTests, self)._tearDown(success)
def test_requestIdentification1(self):
""" nodes are rejected during election or if unknown storage """
self.app.ready = False
self.assertRaises(
NotReadyError,
self.identification.requestIdentification,
self.getFakeConnection(),
NodeTypes.CLIENT,
self.getClientUUID(),
None,
self.app.name,
None,
)
self.app.ready = True
self.assertRaises(
NotReadyError,
self.identification.requestIdentification,
self.getFakeConnection(),
NodeTypes.STORAGE,
self.getStorageUUID(),
None,
self.app.name,
None,
)
def test_requestIdentification3(self): def test_requestIdentification3(self):
""" broken nodes must be rejected """ """ broken nodes must be rejected """
uuid = self.getClientUUID() uuid = self.getClientUUID()
...@@ -80,28 +53,5 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -80,28 +53,5 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
None, None,
) )
def test_requestIdentification2(self):
""" accepted client must be connected and running """
uuid = self.getClientUUID()
conn = self.getFakeConnection(uuid=uuid)
node = self.app.nm.createClient(uuid=uuid, state=NodeStates.RUNNING)
master = (self.local_ip, 3000)
self.app.master_node = Mock({
'getAddress': master,
})
self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid,
None, self.app.name, None)
self.assertTrue(node.isRunning())
self.assertTrue(node.isConnected())
self.assertEqual(node.getUUID(), uuid)
self.assertTrue(node.getConnection() is conn)
args = self.checkAcceptIdentification(conn, decode=True)
node_type, address, _np, _nr, _uuid, _master, _master_list = args
self.assertEqual(node_type, NodeTypes.STORAGE)
self.assertEqual(address, None)
self.assertEqual(_uuid, uuid)
self.assertEqual(_master, master)
# TODO: check _master_list ?
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
#
# Copyright (C) 2009-2016 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from .. import NeoUnitTestBase
from neo.lib.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.initialization import InitializationHandler
from neo.lib.protocol import CellStates
from neo.lib.exception import PrimaryFailure
class StorageInitializationHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.verification = InitializationHandler(self.app)
# define some variable to simulate client and storage node
self.master_port = 10010
self.storage_port = 10020
self.client_port = 11011
self.num_partitions = 1009
self.num_replicas = 2
self.app.operational = False
self.app.load_lock_dict = {}
self.app.pt = PartitionTable(self.num_partitions, self.num_replicas)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageInitializationHandlerTests, self)._tearDown(success)
def getClientConnection(self):
address = ("127.0.0.1", self.client_port)
return self.getFakeConnection(uuid=self.getClientUUID(),
address=address)
def test_03_connectionClosed(self):
conn = self.getClientConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
# nothing happens
self.checkNoPacketSent(conn)
def test_09_answerPartitionTable(self):
# send a table
conn = self.getClientConnection()
self.app.pt = PartitionTable(3, 2)
node_1 = self.getStorageUUID()
node_2 = self.getStorageUUID()
node_3 = self.getStorageUUID()
self.app.uuid = node_1
# SN already know all nodes
self.app.nm.createStorage(uuid=node_1)
self.app.nm.createStorage(uuid=node_2)
self.app.nm.createStorage(uuid=node_3)
self.assertFalse(list(self.app.dm.getPartitionTable()))
row_list = [(0, ((node_1, CellStates.UP_TO_DATE), (node_2, CellStates.UP_TO_DATE))),
(1, ((node_3, CellStates.UP_TO_DATE), (node_1, CellStates.UP_TO_DATE))),
(2, ((node_2, CellStates.UP_TO_DATE), (node_3, CellStates.UP_TO_DATE)))]
self.assertFalse(self.app.pt.filled())
# send a complete new table and ack
self.verification.sendPartitionTable(conn, 2, row_list)
self.assertTrue(self.app.pt.filled())
self.assertEqual(self.app.pt.getID(), 2)
self.assertTrue(list(self.app.dm.getPartitionTable()))
if __name__ == "__main__":
unittest.main()
...@@ -20,9 +20,8 @@ from collections import deque ...@@ -20,9 +20,8 @@ from collections import deque
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler from neo.storage.handlers.master import MasterOperationHandler
from neo.lib.exception import PrimaryFailure
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.protocol import CellStates, Packets from neo.lib.protocol import CellStates
class StorageMasterHandlerTests(NeoUnitTestBase): class StorageMasterHandlerTests(NeoUnitTestBase):
...@@ -54,13 +53,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -54,13 +53,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
address = ("127.0.0.1", self.master_port) address = ("127.0.0.1", self.master_port)
return self.getFakeConnection(uuid=self.master_uuid, address=address) return self.getFakeConnection(uuid=self.master_uuid, address=address)
def test_07_connectionClosed2(self):
# primary has closed the connection
conn = self.getMasterConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.operation.connectionClosed, conn)
self.checkNoPacketSent(conn)
def test_14_notifyPartitionChanges1(self): def test_14_notifyPartitionChanges1(self):
# old partition change -> do nothing # old partition change -> do nothing
app = self.app app = self.app
...@@ -104,19 +96,5 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -104,19 +96,5 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(ptid2, cells) calls[0].checkArgs(ptid2, cells)
def _getConnection(self):
return self.getFakeConnection()
def test_askPack(self):
self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection()
tid = self.getNextTID()
self.operation.askPack(conn, tid)
calls = self.app.dm.mockGetNamedCalls('pack')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, self.app.tm.updateObjectDataForPack)
# Content has no meaning here, don't check.
self.checkAnswerPacket(conn, Packets.AnswerPack)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from mock import Mock, ReturnValues from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.lib.protocol import CellStates from neo.lib.protocol import CellStates
...@@ -141,25 +141,6 @@ class StorageAppTests(NeoUnitTestBase): ...@@ -141,25 +141,6 @@ class StorageAppTests(NeoUnitTestBase):
raise_on_duplicate=False) raise_on_duplicate=False)
self.assertEqual(len(self.app.event_queue), 2) self.assertEqual(len(self.app.event_queue), 2)
def test_03_executeQueuedEvents(self):
self.assertEqual(len(self.app.event_queue), 0)
msg_id = 1325136
msg_id_2 = 1325137
event = Mock({'__repr__': 'event'})
conn = Mock({'__repr__': 'conn', 'getPeerId': ReturnValues(msg_id, msg_id_2)})
self.app.queueEvent(event, conn, ("test", ))
self.app.executeQueuedEvents()
self.assertEqual(len(event.mockGetNamedCalls("__call__")), 1)
call = event.mockGetNamedCalls("__call__")[0]
params = call.getParam(1)
self.assertEqual(params, "test")
params = call.kwparams
self.assertEqual(params, {})
calls = conn.mockGetNamedCalls("setPeerId")
self.assertEqual(len(calls), 2)
calls[0].checkArgs(msg_id)
calls[1].checkArgs(msg_id_2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -14,29 +14,14 @@ ...@@ -14,29 +14,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import random
import unittest import unittest
from mock import Mock, ReturnValues from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.transactions import Transaction, TransactionManager from neo.storage.transactions import Transaction, TransactionManager
from neo.storage.transactions import ConflictError, DelayedError
class TransactionTests(NeoUnitTestBase): class TransactionTests(NeoUnitTestBase):
def testInit(self):
uuid = self.getClientUUID()
ttid = self.getNextTID()
tid = self.getNextTID()
txn = Transaction(uuid, ttid)
self.assertEqual(txn.getUUID(), uuid)
self.assertEqual(txn.getTTID(), ttid)
self.assertEqual(txn.getTID(), None)
txn.setTID(tid)
self.assertEqual(txn.getTID(), tid)
self.assertEqual(txn.getObjectList(), [])
self.assertEqual(txn.getOIDList(), [])
def testLock(self): def testLock(self):
txn = Transaction(self.getClientUUID(), self.getNextTID()) txn = Transaction(self.getClientUUID(), self.getNextTID())
self.assertFalse(txn.isLocked()) self.assertFalse(txn.isLocked())
...@@ -45,29 +30,6 @@ class TransactionTests(NeoUnitTestBase): ...@@ -45,29 +30,6 @@ class TransactionTests(NeoUnitTestBase):
# disallow lock more than once # disallow lock more than once
self.assertRaises(AssertionError, txn.lock) self.assertRaises(AssertionError, txn.lock)
def testObjects(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2)
object1 = oid1, "0" * 20, None
object2 = oid2, "1" * 20, None
self.assertEqual(txn.getObjectList(), [])
self.assertEqual(txn.getOIDList(), [])
txn.addObject(*object1)
self.assertEqual(txn.getObjectList(), [object1])
self.assertEqual(txn.getOIDList(), [oid1])
txn.addObject(*object2)
self.assertEqual(txn.getObjectList(), [object1, object2])
self.assertEqual(txn.getOIDList(), [oid1, oid2])
def test_getObject(self):
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
txn = Transaction(self.getClientUUID(), self.getNextTID())
object_info = oid_1, None, None
txn.addObject(*object_info)
self.assertRaises(KeyError, txn.getObject, oid_2)
self.assertEqual(txn.getObject(oid_1), object_info)
class TransactionManagerTests(NeoUnitTestBase): class TransactionManagerTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
...@@ -78,282 +40,10 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -78,282 +40,10 @@ class TransactionManagerTests(NeoUnitTestBase):
self.app.pt = Mock({'isAssigned': True}) self.app.pt = Mock({'isAssigned': True})
self.app.em = Mock({'setTimeout': None}) self.app.em = Mock({'setTimeout': None})
self.manager = TransactionManager(self.app) self.manager = TransactionManager(self.app)
self.ltid = None
def register(self, uuid, ttid): def register(self, uuid, ttid):
self.manager.register(Mock({'getUUID': uuid}), ttid) self.manager.register(Mock({'getUUID': uuid}), ttid)
def _getTransaction(self):
tid = self.getNextTID(self.ltid)
oid_list = [self.getOID(1), self.getOID(2)]
return (tid, ('USER', 'DESC', 'EXT', oid_list))
def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[3]):
self.manager.storeObject(tid, None,
oid, 1, '%020d' % i, '0' + str(i), None)
def _getObject(self, value):
oid = self.getOID(value)
serial = self.getNextTID()
return (serial, (oid, 1, '%020d' % value, 'O' + str(value), None))
def _checkTransactionStored(self, *args):
calls = self.app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args)
def _checkTransactionFinished(self, *args):
calls = self.app.dm.mockGetNamedCalls('unlockTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args)
def _checkQueuedEventExecuted(self, number=1):
calls = self.app.mockGetNamedCalls('executeQueuedEvents')
self.assertEqual(len(calls), number)
def assertRegistered(self, ttid):
self.assertIn(ttid, self.manager._transaction_dict)
def assertNotRegistered(self, ttid):
self.assertNotIn(ttid, self.manager._transaction_dict)
def testSimpleCase(self):
""" One node, one transaction, not abort """
data_id_list = random.random(), random.random()
self.app.dm.mockAddReturnValues(holdData=ReturnValues(*data_id_list))
uuid = self.getClientUUID()
ttid = self.getNextTID()
tid, txn = self._getTransaction()
serial1, object1 = self._getObject(1)
serial2, object2 = self._getObject(2)
self.register(uuid, ttid)
self.manager.storeObject(ttid, serial1, *object1)
self.manager.storeObject(ttid, serial2, *object2)
self.assertRegistered(ttid)
self.manager.vote(ttid, txn)
user, desc, ext, oid_list = txn
call, = self.app.dm.mockGetNamedCalls('storeTransaction')
call.checkArgs(ttid, [
(object1[0], data_id_list[0], object1[4]),
(object2[0], data_id_list[1], object2[4]),
], (oid_list, user, desc, ext, False, ttid))
self.manager.lock(ttid, tid)
call, = self.app.dm.mockGetNamedCalls('lockTransaction')
call.checkArgs(tid, ttid)
self.manager.unlock(ttid)
self.assertNotRegistered(ttid)
call, = self.app.dm.mockGetNamedCalls('unlockTransaction')
call.checkArgs(tid, ttid)
def testDelayed(self):
""" Two transactions, the first cause the second to be delayed """
uuid = self.getClientUUID()
ttid1 = self.getNextTID()
ttid2 = self.getNextTID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial, obj = self._getObject(1)
# first transaction lock the object
self.register(uuid, ttid1)
self.assertRegistered(ttid1)
self._storeTransactionObjects(ttid1, txn1)
self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# the second is delayed
self.register(uuid, ttid2)
self.assertRegistered(ttid2)
self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial, *obj)
def testUnresolvableConflict(self):
""" A newer transaction has already modified an object """
uuid = self.getClientUUID()
ttid1 = self.getNextTID()
ttid2 = self.getNextTID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial, obj = self._getObject(1)
# the (later) transaction lock (change) the object
self.register(uuid, ttid2)
self.assertRegistered(ttid2)
self._storeTransactionObjects(ttid2, txn2)
self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the previous it's not using the latest version
self.register(uuid, ttid1)
self.assertRegistered(ttid1)
self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial, *obj)
def testResolvableConflict(self):
""" Try to store an object with the latest revision """
uuid = self.getClientUUID()
tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
next_serial = self.getNextTID(serial)
# try to store without the last revision
self.app.dm = Mock({'getLastObjectTID': next_serial})
self.register(uuid, tid)
self.assertRaises(ConflictError, self.manager.storeObject,
tid, serial, *obj)
def testLockDelayed(self):
""" Check lock delay """
uuid1 = self.getClientUUID()
uuid2 = self.getClientUUID()
self.assertNotEqual(uuid1, uuid2)
ttid1 = self.getNextTID()
ttid2 = self.getNextTID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
# first transaction lock objects
self.register(uuid1, ttid1)
self.assertRegistered(ttid1)
self.manager.storeObject(ttid1, serial1, *obj1)
self.manager.storeObject(ttid1, serial1, *obj2)
self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# second transaction is delayed
self.register(uuid2, ttid2)
self.assertRegistered(ttid2)
self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial1, *obj1)
self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial2, *obj2)
def testLockConflict(self):
""" Check lock conflict """
uuid1 = self.getClientUUID()
uuid2 = self.getClientUUID()
self.assertNotEqual(uuid1, uuid2)
ttid1 = self.getNextTID()
ttid2 = self.getNextTID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
# the second transaction lock objects
self.register(uuid2, ttid2)
self.manager.storeObject(ttid2, serial1, *obj1)
self.manager.storeObject(ttid2, serial2, *obj2)
self.assertRegistered(ttid2)
self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the first get a conflict
self.register(uuid1, ttid1)
self.assertRegistered(ttid1)
self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial1, *obj1)
self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial2, *obj2)
def testAbortUnlocked(self):
""" Abort a non-locked transaction """
uuid = self.getClientUUID()
tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
self.register(uuid, tid)
self.manager.storeObject(tid, serial, *obj)
self.assertRegistered(tid)
self.manager.vote(tid, txn)
# transaction is not locked
self.manager.abort(tid)
self.assertNotRegistered(tid)
self.assertFalse(self.manager.loadLocked(obj[0]))
self._checkQueuedEventExecuted()
def testAbortLockedDoNothing(self):
""" Try to abort a locked transaction """
uuid = self.getClientUUID()
ttid = self.getNextTID()
tid, txn = self._getTransaction()
self.register(uuid, ttid)
self._storeTransactionObjects(ttid, txn)
self.manager.vote(ttid, txn)
# lock transaction
self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.abort(ttid)
self.assertRegistered(ttid)
for oid in txn[-1]:
self.assertTrue(self.manager.loadLocked(oid))
self._checkQueuedEventExecuted(number=0)
def testAbortForNode(self):
""" Abort transaction for a node """
uuid1 = self.getClientUUID()
uuid2 = self.getClientUUID()
self.assertNotEqual(uuid1, uuid2)
ttid1 = self.getNextTID()
ttid2 = self.getNextTID()
ttid3 = self.getNextTID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
tid3, txn3 = self._getTransaction()
self.register(uuid1, ttid1)
self.register(uuid2, ttid2)
self.register(uuid2, ttid3)
self.manager.vote(ttid1, txn1)
# node 2 owns tid2 & tid3 and lock tid2 only
self._storeTransactionObjects(ttid2, txn2)
self.manager.vote(ttid2, txn2)
self.manager.vote(ttid3, txn3)
self.manager.lock(ttid2, tid2)
self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertRegistered(ttid3)
self.manager.abortFor(uuid2)
# only tid3 is aborted
self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertNotRegistered(ttid3)
self._checkQueuedEventExecuted(number=1)
def testReset(self):
""" Reset the manager """
uuid = self.getClientUUID()
tid, txn = self._getTransaction()
ttid = self.getNextTID()
self.register(uuid, ttid)
self._storeTransactionObjects(ttid, txn)
self.manager.vote(ttid, txn)
self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.reset()
self.assertNotRegistered(ttid)
for oid in txn[0]:
self.assertFalse(self.manager.loadLocked(oid))
def test_getObjectFromTransaction(self):
data_id = random.random()
self.app.dm.mockAddReturnValues(holdData=ReturnValues(data_id))
uuid = self.getClientUUID()
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
self.register(uuid, tid1)
self.manager.storeObject(tid1, serial1, *obj1)
self.assertEqual(self.manager.getObjectFromTransaction(tid2, obj1[0]),
None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj2[0]),
None)
self.assertEqual(self.manager.getObjectFromTransaction(tid1, obj1[0]),
(obj1[0], data_id, obj1[4]))
def test_getLockingTID(self):
uuid = self.getClientUUID()
serial1, obj1 = self._getObject(1)
oid1 = obj1[0]
tid1, txn1 = self._getTransaction()
self.assertEqual(self.manager.getLockingTID(oid1), None)
self.register(uuid, tid1)
self.manager.storeObject(tid1, serial1, *obj1)
self.assertEqual(self.manager.getLockingTID(oid1), tid1)
def test_updateObjectDataForPack(self): def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID() ram_serial = self.getNextTID()
oid = self.getOID(1) oid = self.getOID(1)
......
...@@ -19,7 +19,7 @@ from time import time ...@@ -19,7 +19,7 @@ from time import time
from mock import Mock from mock import Mock
from neo.lib import connection, logging from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ClientConnection, \ from neo.lib.connection import BaseConnection, ClientConnection, \
MTClientConnection, HandlerSwitcher, CRITICAL_TIMEOUT MTClientConnection, CRITICAL_TIMEOUT
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets from neo.lib.protocol import Packets
from . import NeoUnitTestBase, Patch from . import NeoUnitTestBase, Patch
...@@ -159,186 +159,5 @@ class MTConnectionTests(ConnectionTests): ...@@ -159,186 +159,5 @@ class MTConnectionTests(ConnectionTests):
# ... except Ping # ... except Ping
ask(Packets.Ping()) ask(Packets.Ping())
class HandlerSwitcherTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self._handler = handler = Mock({
'__repr__': 'initial handler',
})
self._connection = Mock({
'__repr__': 'connection',
'getAddress': ('127.0.0.1', 10000),
})
self._handlers = HandlerSwitcher(handler)
def _makeNotification(self, msg_id):
packet = Packets.StartOperation()
packet.setId(msg_id)
return packet
def _makeRequest(self, msg_id):
packet = Packets.AskBeginTransaction()
packet.setId(msg_id)
return packet
def _makeAnswer(self, msg_id):
packet = Packets.AnswerBeginTransaction(self.getNextTID())
packet.setId(msg_id)
return packet
def _makeHandler(self):
return Mock({'__repr__': 'handler'})
def _checkPacketReceived(self, handler, packet, index=0):
calls = handler.mockGetNamedCalls('packetReceived')
self.assertEqual(len(calls), index + 1)
def _checkCurrentHandler(self, handler):
self.assertTrue(self._handlers.getHandler() is handler)
def testInit(self):
self._checkCurrentHandler(self._handler)
self.assertFalse(self._handlers.isPending())
def testEmit(self):
# First case, emit is called outside of a handler
self.assertFalse(self._handlers.isPending())
request = self._makeRequest(1)
self._handlers.emit(request, 0, None)
self.assertTrue(self._handlers.isPending())
# Second case, emit is called from inside a handler with a pending
# handler change.
new_handler = self._makeHandler()
applied = self._handlers.setHandler(new_handler)
self.assertFalse(applied)
self._checkCurrentHandler(self._handler)
call_tracker = []
def packetReceived(conn, packet, kw):
self._handlers.emit(self._makeRequest(2), 0, None)
call_tracker.append(True)
self._handler.packetReceived = packetReceived
self._handlers.handle(self._connection, self._makeAnswer(1))
self.assertEqual(call_tracker, [True])
# Effective handler must not have changed (new request is blocking
# it)
self._checkCurrentHandler(self._handler)
# Handling the next response will cause the handler to change
delattr(self._handler, 'packetReceived')
self._handlers.handle(self._connection, self._makeAnswer(2))
self._checkCurrentHandler(new_handler)
def testHandleNotification(self):
# handle with current handler
notif1 = self._makeNotification(1)
self._handlers.handle(self._connection, notif1)
self._checkPacketReceived(self._handler, notif1)
# emit a request and delay an handler
request = self._makeRequest(2)
self._handlers.emit(request, 0, None)
handler = self._makeHandler()
applied = self._handlers.setHandler(handler)
self.assertFalse(applied)
# next notification fall into the current handler
notif2 = self._makeNotification(3)
self._handlers.handle(self._connection, notif2)
self._checkPacketReceived(self._handler, notif2, index=1)
# handle with new handler
answer = self._makeAnswer(2)
self._handlers.handle(self._connection, answer)
notif3 = self._makeNotification(4)
self._handlers.handle(self._connection, notif3)
self._checkPacketReceived(handler, notif2)
def testHandleAnswer1(self):
# handle with current handler
request = self._makeRequest(1)
self._handlers.emit(request, 0, None)
answer = self._makeAnswer(1)
self._handlers.handle(self._connection, answer)
self._checkPacketReceived(self._handler, answer)
def testHandleAnswer2(self):
# handle with blocking handler
request = self._makeRequest(1)
self._handlers.emit(request, 0, None)
handler = self._makeHandler()
applied = self._handlers.setHandler(handler)
self.assertFalse(applied)
answer = self._makeAnswer(1)
self._handlers.handle(self._connection, answer)
self._checkPacketReceived(self._handler, answer)
self._checkCurrentHandler(handler)
def testHandleAnswer3(self):
# multiple setHandler
r1 = self._makeRequest(1)
r2 = self._makeRequest(2)
r3 = self._makeRequest(3)
a1 = self._makeAnswer(1)
a2 = self._makeAnswer(2)
a3 = self._makeAnswer(3)
h1 = self._makeHandler()
h2 = self._makeHandler()
h3 = self._makeHandler()
# emit all requests and setHandleres
self._handlers.emit(r1, 0, None)
applied = self._handlers.setHandler(h1)
self.assertFalse(applied)
self._handlers.emit(r2, 0, None)
applied = self._handlers.setHandler(h2)
self.assertFalse(applied)
self._handlers.emit(r3, 0, None)
applied = self._handlers.setHandler(h3)
self.assertFalse(applied)
self._checkCurrentHandler(self._handler)
self.assertTrue(self._handlers.isPending())
# process answers
self._handlers.handle(self._connection, a1)
self._checkCurrentHandler(h1)
self._handlers.handle(self._connection, a2)
self._checkCurrentHandler(h2)
self._handlers.handle(self._connection, a3)
self._checkCurrentHandler(h3)
def testHandleAnswer4(self):
# process out of order
r1 = self._makeRequest(1)
r2 = self._makeRequest(2)
r3 = self._makeRequest(3)
a1 = self._makeAnswer(1)
a2 = self._makeAnswer(2)
a3 = self._makeAnswer(3)
h = self._makeHandler()
# emit all requests
self._handlers.emit(r1, 0, None)
self._handlers.emit(r2, 0, None)
self._handlers.emit(r3, 0, None)
applied = self._handlers.setHandler(h)
self.assertFalse(applied)
# process answers
self._handlers.handle(self._connection, a1)
self._checkCurrentHandler(self._handler)
self._handlers.handle(self._connection, a2)
self._checkCurrentHandler(self._handler)
self._handlers.handle(self._connection, a3)
self._checkCurrentHandler(h)
def testHandleUnexpected(self):
# process out of order
r1 = self._makeRequest(1)
r2 = self._makeRequest(2)
a2 = self._makeAnswer(2)
h = self._makeHandler()
# emit requests around state setHandler
self._handlers.emit(r1, 0, None)
applied = self._handlers.setHandler(h)
self.assertFalse(applied)
self._handlers.emit(r2, 0, None)
# process answer for next state
self._handlers.handle(self._connection, a2)
self.checkAborted(self._connection)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from mock import Mock
from . import NeoTestBase from . import NeoTestBase
from neo.lib.dispatcher import Dispatcher, ForgottenPacket from neo.lib.dispatcher import Dispatcher, ForgottenPacket
from Queue import Queue from Queue import Queue
...@@ -26,88 +25,6 @@ class DispatcherTests(NeoTestBase): ...@@ -26,88 +25,6 @@ class DispatcherTests(NeoTestBase):
NeoTestBase.setUp(self) NeoTestBase.setUp(self)
self.dispatcher = Dispatcher() self.dispatcher = Dispatcher()
def testRegister(self):
conn = object()
queue = Queue()
MARKER = object()
self.dispatcher.register(conn, 1, queue)
self.assertTrue(queue.empty())
self.assertTrue(self.dispatcher.dispatch(conn, 1, MARKER, {}))
self.assertFalse(queue.empty())
self.assertEqual(queue.get(block=False), (conn, MARKER, {}))
self.assertTrue(queue.empty())
self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
def testUnregister(self):
conn = object()
queue = Mock()
self.dispatcher.register(conn, 2, queue)
self.dispatcher.unregister(conn)
self.assertEqual(len(queue.mockGetNamedCalls('put')), 1)
self.assertFalse(self.dispatcher.dispatch(conn, 2, None, {}))
def testRegistered(self):
conn1 = object()
conn2 = object()
self.assertFalse(self.dispatcher.registered(conn1))
self.assertFalse(self.dispatcher.registered(conn2))
self.dispatcher.register(conn1, 1, Mock())
self.assertTrue(self.dispatcher.registered(conn1))
self.assertFalse(self.dispatcher.registered(conn2))
self.dispatcher.register(conn2, 2, Mock())
self.assertTrue(self.dispatcher.registered(conn1))
self.assertTrue(self.dispatcher.registered(conn2))
self.dispatcher.unregister(conn1)
self.assertFalse(self.dispatcher.registered(conn1))
self.assertTrue(self.dispatcher.registered(conn2))
self.dispatcher.unregister(conn2)
self.assertFalse(self.dispatcher.registered(conn1))
self.assertFalse(self.dispatcher.registered(conn2))
def testPending(self):
conn1 = object()
conn2 = object()
class Queue(object):
_empty = True
def empty(self):
return self._empty
def put(self, value):
pass
queue1 = Queue()
queue2 = Queue()
self.dispatcher.register(conn1, 1, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.dispatcher.register(conn2, 2, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.dispatcher.register(conn2, 3, queue2)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn1, 1, None, {})
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.dispatch(conn2, 2, None, {})
self.assertFalse(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
queue1._empty = False
self.assertTrue(self.dispatcher.pending(queue1))
queue1._empty = True
self.dispatcher.register(conn1, 4, queue1)
self.dispatcher.register(conn2, 5, queue1)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertTrue(self.dispatcher.pending(queue2))
self.dispatcher.unregister(conn2)
self.assertTrue(self.dispatcher.pending(queue1))
self.assertFalse(self.dispatcher.pending(queue2))
self.dispatcher.unregister(conn1)
self.assertFalse(self.dispatcher.pending(queue1))
self.assertFalse(self.dispatcher.pending(queue2))
def testForget(self): def testForget(self):
conn = object() conn = object()
queue = Queue() queue = Queue()
......
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import shutil
import unittest import unittest
from mock import Mock from mock import Mock
from neo.lib.protocol import NodeTypes, NodeStates from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.node import Node, MasterDB from neo.lib.node import Node, MasterDB
from . import NeoUnitTestBase, getTempDirectory from . import NeoUnitTestBase, getTempDirectory
from time import time from time import time
from os import chmod, mkdir, rmdir, unlink from os import chmod, mkdir, rmdir
from os.path import join, exists from os.path import join, exists
class NodesTests(NeoUnitTestBase): class NodesTests(NeoUnitTestBase):
...@@ -29,16 +30,6 @@ class NodesTests(NeoUnitTestBase): ...@@ -29,16 +30,6 @@ class NodesTests(NeoUnitTestBase):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
self.nm = Mock() self.nm = Mock()
def _updatedByAddress(self, node, index=0):
calls = self.nm.mockGetNamedCalls('_updateAddress')
self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node)
def _updatedByUUID(self, node, index=0):
calls = self.nm.mockGetNamedCalls('_updateUUID')
self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node)
def testInit(self): def testInit(self):
""" Check the node initialization """ """ Check the node initialization """
address = ('127.0.0.1', 10000) address = ('127.0.0.1', 10000)
...@@ -60,23 +51,6 @@ class NodesTests(NeoUnitTestBase): ...@@ -60,23 +51,6 @@ class NodesTests(NeoUnitTestBase):
self.assertTrue(previous_time < node.getLastStateChange()) self.assertTrue(previous_time < node.getLastStateChange())
self.assertTrue(time() - 1 < node.getLastStateChange() < time()) self.assertTrue(time() - 1 < node.getLastStateChange() < time())
def testAddress(self):
""" Check if the node is indexed by address """
node = Node(self.nm)
self.assertEqual(node.getAddress(), None)
address = ('127.0.0.1', 10000)
node.setAddress(address)
self._updatedByAddress(node)
def testUUID(self):
""" As for Address but UUID """
node = Node(self.nm)
self.assertEqual(node.getAddress(), None)
uuid = self.getNewUUID(None)
node.setUUID(uuid)
self._updatedByUUID(node)
class NodeManagerTests(NeoUnitTestBase): class NodeManagerTests(NeoUnitTestBase):
def _addStorage(self): def _addStorage(self):
...@@ -209,12 +183,6 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -209,12 +183,6 @@ class NodeManagerTests(NeoUnitTestBase):
self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN) self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN)
class MasterDBTests(NeoUnitTestBase): class MasterDBTests(NeoUnitTestBase):
def _checkMasterDB(self, path, expected_master_list):
db = list(MasterDB(path))
db_set = set(db)
# Generic sanity check
self.assertEqual(len(db), len(db_set))
self.assertEqual(db_set, set(expected_master_list))
def testInitialAccessRights(self): def testInitialAccessRights(self):
""" """
...@@ -254,9 +222,7 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -254,9 +222,7 @@ class MasterDBTests(NeoUnitTestBase):
db2 = MasterDB(db_file) db2 = MasterDB(db_file)
self.assertFalse(address in db2, [x for x in db2]) self.assertFalse(address in db2, [x for x in db2])
finally: finally:
if exists(db_file): shutil.rmtree(directory)
unlink(db_file)
rmdir(directory)
def testPersistence(self): def testPersistence(self):
temp_dir = getTempDirectory() temp_dir = getTempDirectory()
...@@ -280,9 +246,7 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -280,9 +246,7 @@ class MasterDBTests(NeoUnitTestBase):
self.assertFalse(address in db3, [x for x in db3]) self.assertFalse(address in db3, [x for x in db3])
self.assertTrue(address2 in db3, [x for x in db3]) self.assertTrue(address2 in db3, [x for x in db3])
finally: finally:
if exists(db_file): shutil.rmtree(directory)
unlink(db_file)
rmdir(directory)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -26,6 +26,7 @@ from collections import defaultdict ...@@ -26,6 +26,7 @@ from collections import defaultdict
from functools import wraps from functools import wraps
from neo.lib import logging from neo.lib import logging
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.master.handlers.backup import BackupHandler
from neo.storage.checker import CHECK_COUNT from neo.storage.checker import CHECK_COUNT
from neo.storage.replicator import Replicator from neo.storage.replicator import Replicator
from neo.lib.connector import SocketConnector from neo.lib.connector import SocketConnector
...@@ -368,6 +369,31 @@ class ReplicationTests(NEOThreadedTest): ...@@ -368,6 +369,31 @@ class ReplicationTests(NEOThreadedTest):
# TODO check tids # TODO check tids
self.assertEqual(1, self.checkBackup(backup)) self.assertEqual(1, self.checkBackup(backup))
def testBackupEarlyInvalidation(self):
"""
The backup master must ignore notification before being fully
initialized.
"""
upstream = NEOCluster()
try:
upstream.start()
backup = NEOCluster(upstream=upstream)
try:
backup.start()
with ConnectionFilter() as f:
f.add(lambda conn, packet:
isinstance(packet, Packets.AskPartitionTable) and
isinstance(conn.getHandler(), BackupHandler))
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
upstream.importZODB()(1)
self.tic()
self.tic()
self.assertTrue(backup.master.isAlive())
finally:
backup.stop()
finally:
upstream.stop()
def testSafeTweak(self): def testSafeTweak(self):
""" """
Check that tweak always tries to keep a minimum of (replicas + 1) Check that tweak always tries to keep a minimum of (replicas + 1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment