Commit d048a52d authored by Julien Muchembled's avatar Julien Muchembled

Remove AskNodeInformation packet

When Client (including backup master) and admin nodes are identified,
the primary master now sends them automatically all nodes with
NotifyNodeInformation, as with storage nodes.
parent 35664759
...@@ -125,7 +125,6 @@ class Application(BaseApplication): ...@@ -125,7 +125,6 @@ class Application(BaseApplication):
# passive handler # passive handler
self.master_conn.setHandler(self.master_event_handler) self.master_conn.setHandler(self.master_event_handler)
self.master_conn.ask(Packets.AskClusterState()) self.master_conn.ask(Packets.AskClusterState())
self.master_conn.ask(Packets.AskNodeInformation())
self.master_conn.ask(Packets.AskPartitionTable()) self.master_conn.ask(Packets.AskPartitionTable())
def sendPartitionTable(self, conn, min_offset, max_offset, uuid): def sendPartitionTable(self, conn, min_offset, max_offset, uuid):
......
...@@ -106,11 +106,6 @@ class MasterEventHandler(EventHandler): ...@@ -106,11 +106,6 @@ class MasterEventHandler(EventHandler):
def answerClusterState(self, conn, state): def answerClusterState(self, conn, state):
self.app.cluster_state = state self.app.cluster_state = state
def answerNodeInformation(self, conn):
# XXX: This will no more exists when the initialization module will be
# implemented for factorize code (as done for bootstrap)
logging.debug("answerNodeInformation")
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
......
...@@ -256,7 +256,6 @@ class Application(ThreadedApplication): ...@@ -256,7 +256,6 @@ class Application(ThreadedApplication):
# operational. Might raise ConnectionClosed so that the new # operational. Might raise ConnectionClosed so that the new
# primary can be looked-up again. # primary can be looked-up again.
logging.info('Initializing from master') logging.info('Initializing from master')
ask(conn, Packets.AskNodeInformation(), handler=handler)
ask(conn, Packets.AskPartitionTable(), handler=handler) ask(conn, Packets.AskPartitionTable(), handler=handler)
ask(conn, Packets.AskLastTransaction(), handler=handler) ask(conn, Packets.AskLastTransaction(), handler=handler)
if self.pt.operational(): if self.pt.operational():
......
...@@ -30,6 +30,16 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -30,6 +30,16 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
self.app.trying_master_node = None self.app.trying_master_node = None
conn.close() conn.close()
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerLastTransaction(*args):
pass
class PrimaryNotificationsHandler(MTEventHandler):
""" Handler that process the notifications from the primary master """
def _acceptIdentification(self, node, uuid, num_partitions, def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list): num_replicas, your_uuid, primary, known_master_list):
app = self.app app = self.app
...@@ -81,23 +91,8 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -81,23 +91,8 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
# Always create partition table # Always create partition table
app.pt = PartitionTable(num_partitions, num_replicas) app.pt = PartitionTable(num_partitions, num_replicas)
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
pass
class PrimaryNotificationsHandler(MTEventHandler):
""" Handler that process the notifications from the primary master """
def packetReceived(self, conn, packet, kw={}):
if type(packet) is Packets.AnswerLastTransaction:
app = self.app app = self.app
ltid = packet.decode()[0]
if app.last_tid != ltid: if app.last_tid != ltid:
# Either we're connecting or we already know the last tid # Either we're connecting or we already know the last tid
# via invalidations. # via invalidations.
...@@ -124,15 +119,15 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -124,15 +119,15 @@ class PrimaryNotificationsHandler(MTEventHandler):
db = app.getDB() db = app.getDB()
db is None or db.invalidateCache() db is None or db.invalidateCache()
app.last_tid = ltid app.last_tid = ltid
elif type(packet) is Packets.AnswerTransactionFinished:
def answerTransactionFinished(self, conn, _, tid, callback, cache_dict):
app = self.app app = self.app
app.last_tid = tid = packet.decode()[1] app.last_tid = tid
callback = kw.pop('callback')
# Update cache # Update cache
cache = app._cache cache = app._cache
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
for oid, data in kw.pop('cache_dict').iteritems(): for oid, data in cache_dict.iteritems():
# Update ex-latest value in cache # Update ex-latest value in cache
cache.invalidate(oid, tid) cache.invalidate(oid, tid)
if data is not None: if data is not None:
...@@ -142,7 +137,6 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -142,7 +137,6 @@ class PrimaryNotificationsHandler(MTEventHandler):
callback(tid) callback(tid)
finally: finally:
app._cache_lock_release() app._cache_lock_release()
MTEventHandler.packetReceived(self, conn, packet, kw)
def connectionClosed(self, conn): def connectionClosed(self, conn):
app = self.app app = self.app
......
...@@ -41,14 +41,6 @@ class StorageEventHandler(MTEventHandler): ...@@ -41,14 +41,6 @@ class StorageEventHandler(MTEventHandler):
self.app.cp.removeConnection(node) self.app.cp.removeConnection(node)
super(StorageEventHandler, self).connectionFailed(conn) super(StorageEventHandler, self).connectionFailed(conn)
class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """
def notReady(self, conn, message):
conn.close()
raise NodeNotReady(message)
def _acceptIdentification(self, node, def _acceptIdentification(self, node,
uuid, num_partitions, num_replicas, your_uuid, primary, uuid, num_partitions, num_replicas, your_uuid, primary,
master_list): master_list):
...@@ -57,6 +49,13 @@ class StorageBootstrapHandler(AnswerBaseHandler): ...@@ -57,6 +49,13 @@ class StorageBootstrapHandler(AnswerBaseHandler):
primary, self.app.master_conn) primary, self.app.master_conn)
assert uuid == node.getUUID(), (uuid, node.getUUID()) assert uuid == node.getUUID(), (uuid, node.getUUID())
class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """
def notReady(self, conn, message):
conn.close()
raise NodeNotReady(message)
class StorageAnswersHandler(AnswerBaseHandler): class StorageAnswersHandler(AnswerBaseHandler):
""" Handle all messages related to ZODB operations """ """ Handle all messages related to ZODB operations """
......
...@@ -227,6 +227,9 @@ class MTEventHandler(EventHandler): ...@@ -227,6 +227,9 @@ class MTEventHandler(EventHandler):
def packetReceived(self, conn, packet, kw={}): def packetReceived(self, conn, packet, kw={}):
"""Redirect all received packet to dispatcher thread.""" """Redirect all received packet to dispatcher thread."""
if packet.isResponse(): if packet.isResponse():
if packet.poll_thread:
self.dispatch(conn, packet, kw)
kw = {}
if not (self.dispatcher.dispatch(conn, packet.getId(), packet, kw) if not (self.dispatcher.dispatch(conn, packet.getId(), packet, kw)
or type(packet) is Packets.Pong): or type(packet) is Packets.Pong):
raise ProtocolError('Unexpected response packet from %r: %r' raise ProtocolError('Unexpected response packet from %r: %r'
...@@ -254,3 +257,6 @@ class AnswerBaseHandler(EventHandler): ...@@ -254,3 +257,6 @@ class AnswerBaseHandler(EventHandler):
packetReceived = unexpectedInAnswerHandler packetReceived = unexpectedInAnswerHandler
peerBroken = unexpectedInAnswerHandler peerBroken = unexpectedInAnswerHandler
protocolError = unexpectedInAnswerHandler protocolError = unexpectedInAnswerHandler
def acceptIdentification(*args):
pass
...@@ -234,6 +234,7 @@ class Packet(object): ...@@ -234,6 +234,7 @@ class Packet(object):
_code = None _code = None
_fmt = None _fmt = None
_id = None _id = None
poll_thread = False
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
assert self._code is not None, "Packet class not registered" assert self._code is not None, "Packet class not registered"
...@@ -680,6 +681,7 @@ class RequestIdentification(Packet): ...@@ -680,6 +681,7 @@ class RequestIdentification(Packet):
Request a node identification. This must be the first packet for any Request a node identification. This must be the first packet for any
connection. Any -> Any. connection. Any -> Any.
""" """
poll_thread = True
_fmt = PStruct('request_identification', _fmt = PStruct('request_identification',
PProtocol('protocol_version'), PProtocol('protocol_version'),
...@@ -867,6 +869,8 @@ class FinishTransaction(Packet): ...@@ -867,6 +869,8 @@ class FinishTransaction(Packet):
Finish a transaction. C -> PM. Finish a transaction. C -> PM.
Answer when a transaction is finished. PM -> C. Answer when a transaction is finished. PM -> C.
""" """
poll_thread = True
_fmt = PStruct('ask_finish_transaction', _fmt = PStruct('ask_finish_transaction',
PTID('tid'), PTID('tid'),
PFOidList, PFOidList,
...@@ -1152,12 +1156,6 @@ class NotifyNodeInformation(Packet): ...@@ -1152,12 +1156,6 @@ class NotifyNodeInformation(Packet):
PFNodeList, PFNodeList,
) )
class NodeInformation(Packet):
"""
Ask node information
"""
_answer = PFEmpty
class SetClusterState(Packet): class SetClusterState(Packet):
""" """
Set the cluster state Set the cluster state
...@@ -1373,6 +1371,7 @@ class LastTransaction(Packet): ...@@ -1373,6 +1371,7 @@ class LastTransaction(Packet):
Answer last committed TID. Answer last committed TID.
M -> C M -> C
""" """
poll_thread = True
_answer = PStruct('answer_last_transaction', _answer = PStruct('answer_last_transaction',
PTID('tid'), PTID('tid'),
...@@ -1521,6 +1520,7 @@ def register(request, ignore_when_closed=None): ...@@ -1521,6 +1520,7 @@ def register(request, ignore_when_closed=None):
# build a class for the answer # build a class for the answer
answer = type('Answer%s' % (request.__name__, ), (Packet, ), {}) answer = type('Answer%s' % (request.__name__, ), (Packet, ), {})
answer._fmt = request._answer answer._fmt = request._answer
answer.poll_thread = request.poll_thread
# compute the answer code # compute the answer code
code = code | RESPONSE_MASK code = code | RESPONSE_MASK
answer._request = request answer._request = request
...@@ -1673,8 +1673,6 @@ class Packets(dict): ...@@ -1673,8 +1673,6 @@ class Packets(dict):
AddPendingNodes, ignore_when_closed=False) AddPendingNodes, ignore_when_closed=False)
TweakPartitionTable = register( TweakPartitionTable = register(
TweakPartitionTable, ignore_when_closed=False) TweakPartitionTable, ignore_when_closed=False)
AskNodeInformation, AnswerNodeInformation = register(
NodeInformation)
SetClusterState = register( SetClusterState = register(
SetClusterState, ignore_when_closed=False) SetClusterState, ignore_when_closed=False)
NotifyClusterInformation = register( NotifyClusterInformation = register(
......
...@@ -114,7 +114,6 @@ class BackupApplication(object): ...@@ -114,7 +114,6 @@ class BackupApplication(object):
raise RuntimeError("inconsistent number of partitions") raise RuntimeError("inconsistent number of partitions")
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskNodeInformation())
conn.ask(Packets.AskPartitionTable()) conn.ask(Packets.AskPartitionTable())
conn.ask(Packets.AskLastTransaction()) conn.ask(Packets.AskLastTransaction())
# debug variable to log how big 'tid_list' can be. # debug variable to log how big 'tid_list' can be.
......
...@@ -27,6 +27,8 @@ class MasterHandler(EventHandler): ...@@ -27,6 +27,8 @@ class MasterHandler(EventHandler):
def connectionCompleted(self, conn, new=None): def connectionCompleted(self, conn, new=None):
if new is None: if new is None:
super(MasterHandler, self).connectionCompleted(conn) super(MasterHandler, self).connectionCompleted(conn)
elif new:
self._notifyNodeInformation(conn)
def requestIdentification(self, conn, node_type, uuid, address, name): def requestIdentification(self, conn, node_type, uuid, address, name):
self.checkClusterName(name) self.checkClusterName(name)
...@@ -88,10 +90,6 @@ class MasterHandler(EventHandler): ...@@ -88,10 +90,6 @@ class MasterHandler(EventHandler):
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
def askNodeInformation(self, conn):
self._notifyNodeInformation(conn)
conn.answer(Packets.AnswerNodeInformation())
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
......
...@@ -31,9 +31,6 @@ class BackupHandler(EventHandler): ...@@ -31,9 +31,6 @@ class BackupHandler(EventHandler):
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list) self.app.nm.update(node_list)
......
...@@ -31,14 +31,13 @@ class ClientServiceHandler(MasterHandler): ...@@ -31,14 +31,13 @@ class ClientServiceHandler(MasterHandler):
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
app.nm.remove(node) app.nm.remove(node)
def askNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
# send informations about master and storages only # send informations about master and storages only
nm = self.app.nm nm = self.app.nm
node_list = [] node_list = []
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
conn.answer(Packets.AnswerNodeInformation())
def askBeginTransaction(self, conn, tid): def askBeginTransaction(self, conn, tid):
""" """
......
...@@ -23,6 +23,9 @@ from . import MasterHandler ...@@ -23,6 +23,9 @@ from . import MasterHandler
class BaseElectionHandler(EventHandler): class BaseElectionHandler(EventHandler):
def _notifyNodeInformation(self, conn):
pass
def reelectPrimary(self, conn): def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested' raise ElectionFailure, 'reelection requested'
......
...@@ -27,13 +27,11 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -27,13 +27,11 @@ class StorageServiceHandler(BaseServiceHandler):
def connectionCompleted(self, conn, new): def connectionCompleted(self, conn, new):
app = self.app app = self.app
uuid = conn.getUUID() uuid = conn.getUUID()
node = app.nm.getByUUID(uuid)
app.setStorageNotReady(uuid) app.setStorageNotReady(uuid)
if new: if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new) super(StorageServiceHandler, self).connectionCompleted(conn, new)
# XXX: what other values could happen ? if app.nm.getByUUID(uuid).isRunning(): # node may be PENDING
if node.isRunning(): conn.notify(Packets.StartOperation(app.backup_tid))
conn.notify(Packets.StartOperation(bool(app.backup_tid)))
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
......
...@@ -20,9 +20,6 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID ...@@ -20,9 +20,6 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
def answerNodeInformation(self, conn):
pass
def sendPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, row_list):
app = self.app app = self.app
pt = app.pt pt = app.pt
......
...@@ -753,11 +753,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -753,11 +753,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# will raise IndexError at the third iteration # will raise IndexError at the third iteration
app = self.getApp('127.0.0.1:10010 127.0.0.1:10011') app = self.getApp('127.0.0.1:10010 127.0.0.1:10011')
# TODO: test more connection failure cases # TODO: test more connection failure cases
all_passed = []
# askLastTransaction # askLastTransaction
def _ask9(_):
all_passed.append(1)
# Seventh packet : askNodeInformation succeeded
def _ask8(_): def _ask8(_):
pass pass
# Sixth packet : askPartitionTable succeeded # Sixth packet : askPartitionTable succeeded
...@@ -789,8 +785,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -789,8 +785,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# telling us what its address is.) # telling us what its address is.)
def _ask1(_): def _ask1(_):
pass pass
ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask6, _ask7, ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask6, _ask7, _ask8]
_ask8, _ask9]
def _ask_base(conn, _, handler=None): def _ask_base(conn, _, handler=None):
ask_func_list.pop(0)(conn) ask_func_list.pop(0)(conn)
app.nm.getByAddress(conn.getAddress())._connection = None app.nm.getByAddress(conn.getAddress())._connection = None
...@@ -801,7 +796,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -801,7 +796,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.pt = Mock({ 'operational': False}) app.pt = Mock({ 'operational': False})
app.start = lambda: None app.start = lambda: None
app.master_conn = app._connectToPrimaryNode() app.master_conn = app._connectToPrimaryNode()
self.assertEqual(len(all_passed), 1) self.assertFalse(ask_func_list)
self.assertTrue(app.master_conn is not None) self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational()) self.assertTrue(app.pt.operational())
......
...@@ -44,69 +44,6 @@ class MasterHandlerTests(NeoUnitTestBase): ...@@ -44,69 +44,6 @@ class MasterHandlerTests(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
return node, conn return node, conn
class MasterBootstrapHandlerTests(MasterHandlerTests):
def setUp(self):
super(MasterBootstrapHandlerTests, self).setUp()
self.handler = PrimaryBootstrapHandler(self.app)
def checkCalledOnApp(self, method, index=0):
calls = self.app.mockGetNamedCalls(method)
self.assertTrue(len(calls) > index)
return calls[index].params
def test_notReady(self):
conn = self.getFakeConnection()
self.handler.notReady(conn, 'message')
self.assertEqual(self.app.trying_master_node, None)
def test_acceptIdentification1(self):
""" Non-master node """
node, conn = self.getKnownMaster()
self.handler.acceptIdentification(conn, NodeTypes.CLIENT,
node.getUUID(), 100, 0, None, None, [])
self.checkClosed(conn)
def test_acceptIdentification2(self):
""" No UUID supplied """
node, conn = self.getKnownMaster()
uuid = self.getMasterUUID()
addr = conn.getAddress()
self.checkProtocolErrorRaised(self.handler.acceptIdentification,
conn, NodeTypes.MASTER, uuid, 100, 0, None,
addr, [(addr, uuid)],
)
def test_acceptIdentification3(self):
""" identification accepted """
node, conn = self.getKnownMaster()
uuid = self.getMasterUUID()
addr = conn.getAddress()
your_uuid = self.getClientUUID()
self.handler.acceptIdentification(conn, NodeTypes.MASTER, uuid,
100, 2, your_uuid, addr, [(addr, uuid)])
self.assertEqual(self.app.uuid, your_uuid)
self.assertEqual(node.getUUID(), uuid)
self.assertTrue(isinstance(self.app.pt, PartitionTable))
def _getMasterList(self, uuid_list):
port = 1000
master_list = []
for uuid in uuid_list:
master_list.append((('127.0.0.1', port), uuid))
port += 1
return master_list
def test_answerPartitionTable(self):
conn = self.getFakeConnection()
self.app.pt = Mock()
ptid = 0
row_list = ([], [])
self.handler.answerPartitionTable(conn, ptid, row_list)
load_calls = self.app.pt.mockGetNamedCalls('load')
self.assertEqual(len(load_calls), 1)
# load_calls[0].checkArgs(ptid, row_list, self.app.nm)
class MasterNotificationsHandlerTests(MasterHandlerTests): class MasterNotificationsHandlerTests(MasterHandlerTests):
......
...@@ -144,18 +144,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -144,18 +144,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.assertEqual(len(txn.getOIDList()), 0) self.assertEqual(len(txn.getOIDList()), 0)
self.assertEqual(len(txn.getUUIDList()), 1) self.assertEqual(len(txn.getUUIDList()), 1)
def test_askNodeInformations(self):
# check that only informations about master and storages nodes are
# send to a client
self.app.nm.createClient()
conn = self.getFakeConnection()
self.service.askNodeInformation(conn)
calls = conn.mockGetNamedCalls('notify')
self.assertEqual(len(calls), 1)
packet = calls[0].getParam(0)
(node_list, ) = packet.decode()
self.assertEqual(len(node_list), 2)
def test_connectionClosed(self): def test_connectionClosed(self):
# give a client uuid which have unfinished transactions # give a client uuid which have unfinished transactions
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment