Commit 9385706f authored by Julien Muchembled's avatar Julien Muchembled

Fix identification issues, including a race condition causing id conflicts

The added test describes how the new id timestamps fix the race condition.
These timestamps could be any unique opaque values, and the protocol is
extended to exchange them along with node ids.

Internally, nodes also reuse timestamps as a marker to identify the first
NotifyNodeInformation packets from the master: since this packet is a complete
list of nodes in the cluster, any other node in the node manager has left the
cluster definitely and is removed.

The secondary masters didn't receive update about master nodes.
It's also useless to send them information about non-master nodes.
parent d048a52d
...@@ -105,13 +105,9 @@ class Application(BaseApplication): ...@@ -105,13 +105,9 @@ class Application(BaseApplication):
""" """
self.cluster_state = None self.cluster_state = None
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN, bootstrap = BootstrapManager(self, NodeTypes.ADMIN, self.server)
self.uuid, self.server) self.master_node, self.master_conn, num_partitions, num_replicas = \
data = bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data
self.master_node = node
self.master_conn = conn
self.uuid = uuid
if self.pt is None: if self.pt is None:
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
......
...@@ -120,8 +120,6 @@ class MasterEventHandler(EventHandler): ...@@ -120,8 +120,6 @@ class MasterEventHandler(EventHandler):
def notifyClusterInformation(self, conn, cluster_state): def notifyClusterInformation(self, conn, cluster_state):
self.app.cluster_state = cluster_state self.app.cluster_state = cluster_state
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
class MasterRequestEventHandler(EventHandler): class MasterRequestEventHandler(EventHandler):
""" This class handle all answer from primary master node""" """ This class handle all answer from primary master node"""
......
...@@ -240,10 +240,10 @@ class Application(ThreadedApplication): ...@@ -240,10 +240,10 @@ class Application(ThreadedApplication):
self.notifications_handler, self.notifications_handler,
node=self.trying_master_node, node=self.trying_master_node,
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
p = Packets.RequestIdentification(
NodeTypes.CLIENT, self.uuid, None, self.name, None)
try: try:
ask(conn, Packets.RequestIdentification( ask(conn, p, handler=handler)
NodeTypes.CLIENT, self.uuid, None, self.name),
handler=handler)
except ConnectionClosed: except ConnectionClosed:
continue continue
# If we reached the primary master node, mark as connected # If we reached the primary master node, mark as connected
......
...@@ -87,6 +87,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -87,6 +87,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
raise ProtocolError('No UUID supplied') raise ProtocolError('No UUID supplied')
app.uuid = your_uuid app.uuid = your_uuid
logging.info('Got an UUID: %s', dump(app.uuid)) logging.info('Got an UUID: %s', dump(app.uuid))
app.id_timestamp = None
# Always create partition table # Always create partition table
app.pt = PartitionTable(num_partitions, num_replicas) app.pt = PartitionTable(num_partitions, num_replicas)
...@@ -179,13 +180,14 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -179,13 +180,14 @@ class PrimaryNotificationsHandler(MTEventHandler):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
nm = self.app.nm super(PrimaryNotificationsHandler, self).notifyNodeInformation(
nm.update(node_list) conn, node_list)
# XXX: 'update' automatically closes DOWN nodes. Do we really want # XXX: 'update' automatically closes DOWN nodes. Do we really want
# to do the same thing for nodes in other non-running states ? # to do the same thing for nodes in other non-running states ?
for node_type, addr, uuid, state in node_list: getByUUID = self.app.nm.getByUUID
if state != NodeStates.RUNNING: for node in node_list:
node = nm.getByUUID(uuid) if node[3] != NodeStates.RUNNING:
node = getByUUID(node[2])
if node and node.isConnected(): if node and node.isConnected():
node.getConnection().close() node.getConnection().close()
......
...@@ -57,7 +57,7 @@ class ConnectionPool(object): ...@@ -57,7 +57,7 @@ class ConnectionPool(object):
conn = MTClientConnection(app, app.storage_event_handler, node, conn = MTClientConnection(app, app.storage_event_handler, node,
dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name, app.id_timestamp)
try: try:
app._ask(conn, p, handler=app.storage_bootstrap_handler) app._ask(conn, p, handler=app.storage_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
......
...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler): ...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler):
""" """
accepted = False accepted = False
def __init__(self, app, name, node_type, uuid=None, server=None): def __init__(self, app, node_type, server=None):
""" """
Manage the bootstrap stage of a non-master node, it lookup for the Manage the bootstrap stage of a non-master node, it lookup for the
primary master node, connect to it then returns when the master node primary master node, connect to it then returns when the master node
...@@ -35,12 +35,12 @@ class BootstrapManager(EventHandler): ...@@ -35,12 +35,12 @@ class BootstrapManager(EventHandler):
self.primary = None self.primary = None
self.server = server self.server = server
self.node_type = node_type self.node_type = node_type
self.uuid = uuid
self.name = name
self.num_replicas = None self.num_replicas = None
self.num_partitions = None self.num_partitions = None
self.current = None self.current = None
uuid = property(lambda self: self.app.uuid)
def announcePrimary(self, conn): def announcePrimary(self, conn):
# We found the primary master early enough to be notified of election # We found the primary master early enough to be notified of election
# end. Lucky. Anyway, we must carry on with identification request, so # end. Lucky. Anyway, we must carry on with identification request, so
...@@ -55,7 +55,7 @@ class BootstrapManager(EventHandler): ...@@ -55,7 +55,7 @@ class BootstrapManager(EventHandler):
EventHandler.connectionCompleted(self, conn) EventHandler.connectionCompleted(self, conn)
self.current.setRunning() self.current.setRunning()
conn.ask(Packets.RequestIdentification(self.node_type, self.uuid, conn.ask(Packets.RequestIdentification(self.node_type, self.uuid,
self.server, self.name)) self.server, self.app.name, None))
def connectionFailed(self, conn): def connectionFailed(self, conn):
""" """
...@@ -106,8 +106,9 @@ class BootstrapManager(EventHandler): ...@@ -106,8 +106,9 @@ class BootstrapManager(EventHandler):
self.num_replicas = num_replicas self.num_replicas = num_replicas
if self.uuid != your_uuid: if self.uuid != your_uuid:
# got an uuid from the primary master # got an uuid from the primary master
self.uuid = your_uuid self.app.uuid = your_uuid
logging.info('Got a new UUID: %s', uuid_str(self.uuid)) logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.app.id_timestamp = None
self.accepted = True self.accepted = True
def getPrimaryConnection(self): def getPrimaryConnection(self):
...@@ -141,8 +142,4 @@ class BootstrapManager(EventHandler): ...@@ -141,8 +142,4 @@ class BootstrapManager(EventHandler):
continue continue
# still processing # still processing
poll(1) poll(1)
return (self.current, conn, self.uuid, self.num_partitions, return self.current, conn, self.num_partitions, self.num_replicas
self.num_replicas)
...@@ -165,6 +165,10 @@ class EventHandler(object): ...@@ -165,6 +165,10 @@ class EventHandler(object):
return return
conn.close() conn.close()
def notifyNodeInformation(self, conn, node_list):
app = self.app
app.nm.update(app, node_list)
def ping(self, conn): def ping(self, conn):
conn.answer(Packets.Pong()) conn.answer(Packets.Pong())
......
...@@ -27,6 +27,7 @@ class Node(object): ...@@ -27,6 +27,7 @@ class Node(object):
_connection = None _connection = None
_identified = False _identified = False
id_timestamp = None
def __init__(self, manager, address=None, uuid=None, def __init__(self, manager, address=None, uuid=None,
state=NodeStates.UNKNOWN): state=NodeStates.UNKNOWN):
...@@ -172,7 +173,8 @@ class Node(object): ...@@ -172,7 +173,8 @@ class Node(object):
def asTuple(self): def asTuple(self):
""" Returned tuple is intended to be used in protocol encoders """ """ Returned tuple is intended to be used in protocol encoders """
return (self.getType(), self._address, self._uuid, self._state) return (self.getType(), self._address, self._uuid, self._state,
self.id_timestamp)
def __gt__(self, node): def __gt__(self, node):
# sort per UUID if defined # sort per UUID if defined
...@@ -348,9 +350,11 @@ class NodeManager(object): ...@@ -348,9 +350,11 @@ class NodeManager(object):
""" Return the node that match with a given address """ """ Return the node that match with a given address """
return self._address_dict.get(address, None) return self._address_dict.get(address, None)
def getByUUID(self, uuid): def getByUUID(self, uuid, *id_timestamp):
""" Return the node that match with a given UUID """ """ Return the node that match with a given UUID """
return self._uuid_dict.get(uuid, None) node = self._uuid_dict.get(uuid)
if not id_timestamp or node and (node.id_timestamp,) == id_timestamp:
return node
def _createNode(self, klass, address=None, uuid=None, **kw): def _createNode(self, klass, address=None, uuid=None, **kw):
by_address = self.getByAddress(address) by_address = self.getByAddress(address)
...@@ -386,8 +390,9 @@ class NodeManager(object): ...@@ -386,8 +390,9 @@ class NodeManager(object):
def createFromNodeType(self, node_type, **kw): def createFromNodeType(self, node_type, **kw):
return self._createNode(NODE_TYPE_MAPPING[node_type], **kw) return self._createNode(NODE_TYPE_MAPPING[node_type], **kw)
def update(self, node_list): def update(self, app, node_list):
for node_type, addr, uuid, state in node_list: node_set = self._node_set.copy() if app.id_timestamp is None else None
for node_type, addr, uuid, state, id_timestamp in node_list:
# This should be done here (although klass might not be used in this # This should be done here (although klass might not be used in this
# iteration), as it raises if type is not valid. # iteration), as it raises if type is not valid.
klass = NODE_TYPE_MAPPING[node_type] klass = NODE_TYPE_MAPPING[node_type]
...@@ -397,11 +402,11 @@ class NodeManager(object): ...@@ -397,11 +402,11 @@ class NodeManager(object):
node_by_addr = self.getByAddress(addr) node_by_addr = self.getByAddress(addr)
node = node_by_uuid or node_by_addr node = node_by_uuid or node_by_addr
log_args = node_type, uuid_str(uuid), addr, state log_args = node_type, uuid_str(uuid), addr, state, id_timestamp
if node is None: if node is None:
if state == NodeStates.DOWN: if state == NodeStates.DOWN:
logging.debug('NOT creating node %s %s %s %s', *log_args) logging.debug('NOT creating node %s %s %s %s %s', *log_args)
else: continue
node = self._createNode(klass, address=addr, uuid=uuid, node = self._createNode(klass, address=addr, uuid=uuid,
state=state) state=state)
logging.debug('creating node %r', node) logging.debug('creating node %r', node)
...@@ -414,7 +419,7 @@ class NodeManager(object): ...@@ -414,7 +419,7 @@ class NodeManager(object):
'node_by_addr (%r)' % (node_by_uuid, node_by_addr) 'node_by_addr (%r)' % (node_by_uuid, node_by_addr)
if state == NodeStates.DOWN: if state == NodeStates.DOWN:
logging.debug('dropping node %r (%r), found with %s ' logging.debug('dropping node %r (%r), found with %s '
'%s %s %s', node, node.isConnected(), *log_args) '%s %s %s %s', node, node.isConnected(), *log_args)
if node.isConnected(): if node.isConnected():
# Cut this connection, node removed by handler. # Cut this connection, node removed by handler.
# It's important for a storage to disconnect nodes that # It's important for a storage to disconnect nodes that
...@@ -424,12 +429,20 @@ class NodeManager(object): ...@@ -424,12 +429,20 @@ class NodeManager(object):
# partition table upon disconnection. # partition table upon disconnection.
node.getConnection().close() node.getConnection().close()
self.remove(node) self.remove(node)
else: continue
logging.debug('updating node %r to %s %s %s %s', logging.debug('updating node %r to %s %s %s %s %s',
node, *log_args) node, *log_args)
node.setUUID(uuid) node.setUUID(uuid)
node.setAddress(addr) node.setAddress(addr)
node.setState(state) node.setState(state)
node.id_timestamp = id_timestamp
if app.uuid == uuid:
app.id_timestamp = id_timestamp
if node_set:
# For the first notification, we receive a full list of nodes from
# the master. Remove all unknown nodes from a previous connection.
for node in node_set - self._node_set:
self.remove(node)
self.log() self.log()
def log(self): def log(self):
......
...@@ -595,6 +595,13 @@ class PTID(PItem): ...@@ -595,6 +595,13 @@ class PTID(PItem):
# same definition, for now # same definition, for now
POID = PTID POID = PTID
class PFloat(PStructItemOrNone):
"""
A float number (8-bytes length)
"""
_fmt = '!d'
_None = '\xff' * 8
# common definitions # common definitions
PFEmpty = PStruct('no_content') PFEmpty = PStruct('no_content')
...@@ -608,6 +615,7 @@ PFNodeList = PList('node_list', ...@@ -608,6 +615,7 @@ PFNodeList = PList('node_list',
PAddress('address'), PAddress('address'),
PUUID('uuid'), PUUID('uuid'),
PFNodeState, PFNodeState,
PFloat('id_timestamp'),
), ),
) )
...@@ -689,6 +697,7 @@ class RequestIdentification(Packet): ...@@ -689,6 +697,7 @@ class RequestIdentification(Packet):
PUUID('uuid'), PUUID('uuid'),
PAddress('address'), PAddress('address'),
PString('name'), PString('name'),
PFloat('id_timestamp'),
) )
_answer = PStruct('accept_identification', _answer = PStruct('accept_identification',
......
...@@ -43,6 +43,8 @@ class ThreadContainer(threading.local): ...@@ -43,6 +43,8 @@ class ThreadContainer(threading.local):
class ThreadedApplication(BaseApplication): class ThreadedApplication(BaseApplication):
"""The client node application.""" """The client node application."""
uuid = None
def __init__(self, master_nodes, name, **kw): def __init__(self, master_nodes, name, **kw):
super(ThreadedApplication, self).__init__(**kw) super(ThreadedApplication, self).__init__(**kw)
self.poll_thread = threading.Thread(target=self.run, name=name) self.poll_thread = threading.Thread(target=self.run, name=name)
...@@ -56,8 +58,6 @@ class ThreadedApplication(BaseApplication): ...@@ -56,8 +58,6 @@ class ThreadedApplication(BaseApplication):
for address in master_nodes: for address in master_nodes:
self.nm.createMaster(address=address) self.nm.createMaster(address=address)
# no self-assigned UUID, primary master will supply us one
self.uuid = None
# Internal attribute distinct between thread # Internal attribute distinct between thread
self._thread_container = ThreadContainer() self._thread_container = ThreadContainer()
app_set.add(self) # to register self.on_log app_set.add(self) # to register self.on_log
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys, weakref import sys, weakref
from collections import defaultdict
from time import time from time import time
from neo.lib import logging from neo.lib import logging
...@@ -44,7 +45,6 @@ class Application(BaseApplication): ...@@ -44,7 +45,6 @@ class Application(BaseApplication):
last_transaction = ZERO_TID last_transaction = ZERO_TID
backup_tid = None backup_tid = None
backup_app = None backup_app = None
uuid = None
truncate_tid = None truncate_tid = None
def __init__(self, config): def __init__(self, config):
...@@ -79,9 +79,7 @@ class Application(BaseApplication): ...@@ -79,9 +79,7 @@ class Application(BaseApplication):
self.primary_master_node = None self.primary_master_node = None
self.cluster_state = None self.cluster_state = None
uuid = config.getUUID() self.uuid = config.getUUID()
if uuid:
self.uuid = uuid
# election related data # election related data
self.unconnected_master_node_set = set() self.unconnected_master_node_set = set()
...@@ -227,19 +225,20 @@ class Application(BaseApplication): ...@@ -227,19 +225,20 @@ class Application(BaseApplication):
Broadcast changes for a set a nodes Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth Send only one packet per connection to reduce bandwidth
""" """
node_dict = {} node_dict = defaultdict(list)
# group modified nodes by destination node type # group modified nodes by destination node type
for node in node_list: for node in node_list:
node_info = node.asTuple() node_info = node.asTuple()
def assign_for_notification(node_type): if node.isAdmin():
# helper function continue
node_dict.setdefault(node_type, []).append(node_info) node_dict[NodeTypes.ADMIN].append(node_info)
if node.isMaster() or node.isStorage(): node_dict[NodeTypes.STORAGE].append(node_info)
# client get notifications for master and storage only if node.isClient():
assign_for_notification(NodeTypes.CLIENT) continue
if node.isMaster() or node.isStorage() or node.isClient(): node_dict[NodeTypes.CLIENT].append(node_info)
assign_for_notification(NodeTypes.STORAGE) if node.isStorage():
assign_for_notification(NodeTypes.ADMIN) continue
node_dict[NodeTypes.MASTER].append(node_info)
# send at most one non-empty notification packet per node # send at most one non-empty notification packet per node
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
...@@ -498,7 +497,7 @@ class Application(BaseApplication): ...@@ -498,7 +497,7 @@ class Application(BaseApplication):
conn.setHandler(handler) conn.setHandler(handler)
conn.notify(Packets.NotifyNodeInformation((( conn.notify(Packets.NotifyNodeInformation(((
node.getType(), node.getAddress(), node.getUUID(), node.getType(), node.getAddress(), node.getUUID(),
NodeStates.TEMPORARILY_DOWN),))) NodeStates.TEMPORARILY_DOWN, None),)))
conn.abort() conn.abort()
elif conn.pending(): elif conn.pending():
conn.abort() conn.abort()
......
...@@ -65,6 +65,7 @@ There is no UUID conflict between the 2 clusters: ...@@ -65,6 +65,7 @@ There is no UUID conflict between the 2 clusters:
class BackupApplication(object): class BackupApplication(object):
pt = None pt = None
uuid = None
def __init__(self, app, name, master_addresses): def __init__(self, app, name, master_addresses):
self.app = weakref.proxy(app) self.app = weakref.proxy(app)
...@@ -92,7 +93,7 @@ class BackupApplication(object): ...@@ -92,7 +93,7 @@ class BackupApplication(object):
pt = app.pt pt = app.pt
while True: while True:
app.changeClusterState(ClusterStates.STARTING_BACKUP) app.changeClusterState(ClusterStates.STARTING_BACKUP)
bootstrap = BootstrapManager(self, self.name, NodeTypes.CLIENT) bootstrap = BootstrapManager(self, NodeTypes.CLIENT)
# {offset -> node} # {offset -> node}
self.primary_partition_dict = {} self.primary_partition_dict = {}
# [[tid]] # [[tid]]
...@@ -105,7 +106,7 @@ class BackupApplication(object): ...@@ -105,7 +106,7 @@ class BackupApplication(object):
else: else:
break break
poll(1) poll(1)
node, conn, uuid, num_partitions, num_replicas = \ node, conn, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
try: try:
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
......
...@@ -30,7 +30,7 @@ class MasterHandler(EventHandler): ...@@ -30,7 +30,7 @@ class MasterHandler(EventHandler):
elif new: elif new:
self._notifyNodeInformation(conn) self._notifyNodeInformation(conn)
def requestIdentification(self, conn, node_type, uuid, address, name): def requestIdentification(self, conn, node_type, uuid, address, name, _):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
......
...@@ -31,9 +31,6 @@ class BackupHandler(EventHandler): ...@@ -31,9 +31,6 @@ class BackupHandler(EventHandler):
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
def answerLastTransaction(self, conn, tid): def answerLastTransaction(self, conn, tid):
app = self.app app = self.app
if tid != ZERO_TID: if tid != ZERO_TID:
......
...@@ -32,9 +32,8 @@ class ClientServiceHandler(MasterHandler): ...@@ -32,9 +32,8 @@ class ClientServiceHandler(MasterHandler):
app.nm.remove(node) app.nm.remove(node)
def _notifyNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
# send informations about master and storages only
nm = self.app.nm nm = self.app.nm
node_list = [] node_list = [nm.getByUUID(conn.getUUID()).asTuple()] # for id_timestamp
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
......
...@@ -56,6 +56,11 @@ class BaseElectionHandler(EventHandler): ...@@ -56,6 +56,11 @@ class BaseElectionHandler(EventHandler):
class ClientElectionHandler(BaseElectionHandler): class ClientElectionHandler(BaseElectionHandler):
def notifyNodeInformation(self, conn, node_list):
# XXX: For the moment, do nothing because
# we'll close this connection and reconnect.
pass
def connectionFailed(self, conn): def connectionFailed(self, conn):
addr = conn.getAddress() addr = conn.getAddress()
node = self.app.nm.getByAddress(addr) node = self.app.nm.getByAddress(addr)
...@@ -71,6 +76,7 @@ class ClientElectionHandler(BaseElectionHandler): ...@@ -71,6 +76,7 @@ class ClientElectionHandler(BaseElectionHandler):
app.uuid, app.uuid,
app.server, app.server,
app.name, app.name,
None,
)) ))
super(ClientElectionHandler, self).connectionCompleted(conn) super(ClientElectionHandler, self).connectionCompleted(conn)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \
NotReadyError, ProtocolError, uuid_str NotReadyError, ProtocolError, uuid_str
...@@ -91,6 +92,7 @@ class IdentificationHandler(MasterHandler): ...@@ -91,6 +92,7 @@ class IdentificationHandler(MasterHandler):
uuid=uuid, address=address) uuid=uuid, address=address)
else: else:
node.setUUID(uuid) node.setUUID(uuid)
node.id_timestamp = time()
node.setState(state) node.setState(state)
node.setConnection(conn) node.setConnection(conn)
conn.setHandler(handler) conn.setHandler(handler)
......
...@@ -36,6 +36,10 @@ class SecondaryMasterHandler(MasterHandler): ...@@ -36,6 +36,10 @@ class SecondaryMasterHandler(MasterHandler):
def reelectPrimary(self, conn): def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested' raise ElectionFailure, 'reelection requested'
def _notifyNodeInformation(self, conn):
node_list = [n.asTuple() for n in self.app.nm.getMasterList()]
conn.notify(Packets.NotifyNodeInformation(node_list))
class PrimaryHandler(EventHandler): class PrimaryHandler(EventHandler):
""" Handler used by secondaries to handle primary master""" """ Handler used by secondaries to handle primary master"""
...@@ -58,6 +62,7 @@ class PrimaryHandler(EventHandler): ...@@ -58,6 +62,7 @@ class PrimaryHandler(EventHandler):
app.uuid, app.uuid,
app.server, app.server,
app.name, app.name,
None,
)) ))
super(PrimaryHandler, self).connectionCompleted(conn) super(PrimaryHandler, self).connectionCompleted(conn)
...@@ -68,27 +73,11 @@ class PrimaryHandler(EventHandler): ...@@ -68,27 +73,11 @@ class PrimaryHandler(EventHandler):
self.app.cluster_state = state self.app.cluster_state = state
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
app = self.app super(PrimaryHandler, self).notifyNodeInformation(conn, node_list)
for node_type, addr, uuid, state in node_list: for node_type, _, uuid, state, _ in node_list:
if node_type != NodeTypes.MASTER: assert node_type == NodeTypes.MASTER, node_type
# No interest. if uuid == self.app.uuid and state == NodeStates.UNKNOWN:
continue
if uuid == app.uuid and state == NodeStates.UNKNOWN:
sys.exit() sys.exit()
# Register new master nodes.
if app.server == addr:
# This is self.
continue
else:
n = app.nm.getByAddress(addr)
# master node must be known
assert n is not None
if uuid is not None:
# If I don't know the UUID yet, believe what the peer
# told me at the moment.
if n.getUUID() is None:
n.setUUID(uuid)
def _acceptIdentification(self, node, uuid, num_partitions, def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list): num_replicas, your_uuid, primary, known_master_list):
...@@ -101,4 +90,5 @@ class PrimaryHandler(EventHandler): ...@@ -101,4 +90,5 @@ class PrimaryHandler(EventHandler):
logging.info('My UUID: ' + uuid_str(your_uuid)) logging.info('My UUID: ' + uuid_str(your_uuid))
node.setUUID(uuid) node.setUUID(uuid)
app.id_timestamp = None
...@@ -146,15 +146,14 @@ class Log(object): ...@@ -146,15 +146,14 @@ class Log(object):
def notifyNodeInformation(self, node_list): def notifyNodeInformation(self, node_list):
node_list.sort(key=lambda x: x[2]) node_list.sort(key=lambda x: x[2])
node_list = [(self.uuid_str(uuid), str(node_type), node_list = [(self.uuid_str(x[2]), str(x[0]),
'%s:%u' % address if address else '?', state) '%s:%u' % x[1] if x[1] else '?', str(x[3]))
for node_type, address, uuid, state in node_list] + ((repr(x[4]),) if len(x) > 4 else ()) # BBB
for x in node_list]
if node_list: if node_list:
t = ' ! %%%us | %%%us | %%%us | %%s' % ( t = ''.join(' %%%us |' % max(len(x[i]) for x in node_list)
max(len(x[0]) for x in node_list), for i in xrange(len(node_list[0]) - 1))
max(len(x[1]) for x in node_list), return map((' !' + t + ' %s').__mod__, node_list)
max(len(x[2]) for x in node_list))
return map(t.__mod__, node_list)
return () return ()
......
...@@ -219,14 +219,11 @@ class Application(BaseApplication): ...@@ -219,14 +219,11 @@ class Application(BaseApplication):
conn.close() conn.close()
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, bootstrap = BootstrapManager(self, NodeTypes.STORAGE, self.server)
NodeTypes.STORAGE, self.uuid, self.server) self.master_node, self.master_conn, num_partitions, num_replicas = \
data = bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data uuid = self.uuid
self.master_node = node
self.master_conn = conn
logging.info('I am %s', uuid_str(uuid)) logging.info('I am %s', uuid_str(uuid))
self.uuid = uuid
self.dm.setUUID(uuid) self.dm.setUUID(uuid)
# Reload a partition table from the database. This is necessary # Reload a partition table from the database. This is necessary
......
...@@ -50,8 +50,8 @@ class Checker(object): ...@@ -50,8 +50,8 @@ class Checker(object):
conn.asClient() conn.asClient()
else: else:
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
conn.ask(Packets.RequestIdentification( conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
NodeTypes.STORAGE, uuid, app.server, name)) uuid, app.server, name, app.id_timestamp))
self.conn_dict[conn] = node.isIdentified() self.conn_dict[conn] = node.isIdentified()
conn_set = set(self.conn_dict) conn_set = set(self.conn_dict)
conn_set.discard(None) conn_set.discard(None)
......
...@@ -38,8 +38,8 @@ class BaseMasterHandler(EventHandler): ...@@ -38,8 +38,8 @@ class BaseMasterHandler(EventHandler):
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
"""Store information on nodes, only if this is sent by a primary """Store information on nodes, only if this is sent by a primary
master node.""" master node."""
self.app.nm.update(node_list) super(BaseMasterHandler, self).notifyNodeInformation(conn, node_list)
for node_type, addr, uuid, state in node_list: for node_type, _, uuid, state, _ in node_list:
if uuid == self.app.uuid: if uuid == self.app.uuid:
# This is me, do what the master tell me # This is me, do what the master tell me
logging.info("I was told I'm %s", state) logging.info("I was told I'm %s", state)
......
...@@ -27,7 +27,8 @@ class IdentificationHandler(EventHandler): ...@@ -27,7 +27,8 @@ class IdentificationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
logging.warning('A connection was lost during identification') logging.warning('A connection was lost during identification')
def requestIdentification(self, conn, node_type, uuid, address, name): def requestIdentification(self, conn, node_type, uuid, address, name,
id_timestamp):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
# reject any incoming connections if not ready # reject any incoming connections if not ready
...@@ -41,7 +42,7 @@ class IdentificationHandler(EventHandler): ...@@ -41,7 +42,7 @@ class IdentificationHandler(EventHandler):
else: else:
if uuid == app.uuid: if uuid == app.uuid:
raise ProtocolError("uuid conflict or loopback connection") raise ProtocolError("uuid conflict or loopback connection")
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid, id_timestamp)
if node is None: if node is None:
# Do never create node automatically, or we could get id # Do never create node automatically, or we could get id
# conflicts. We must only rely on the notifications from the # conflicts. We must only rely on the notifications from the
...@@ -56,12 +57,7 @@ class IdentificationHandler(EventHandler): ...@@ -56,12 +57,7 @@ class IdentificationHandler(EventHandler):
handler = ClientReadOnlyOperationHandler handler = ClientReadOnlyOperationHandler
else: else:
handler = ClientOperationHandler handler = ClientOperationHandler
if node.isConnected(): # XXX assert not node.isConnected(), node
# This can happen if we haven't processed yet a notification
# from the master, telling us the existing node is not
# running anymore. If we accept the new client, we won't
# know what to do with this late notification.
raise NotReadyError('uuid conflict: retry later')
assert node.isRunning(), node assert node.isRunning(), node
elif node_type == NodeTypes.STORAGE: elif node_type == NodeTypes.STORAGE:
handler = StorageOperationHandler handler = StorageOperationHandler
......
...@@ -258,7 +258,8 @@ class Replicator(object): ...@@ -258,7 +258,8 @@ class Replicator(object):
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
try: try:
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name)) None if name else app.uuid, app.server, name or app.name,
app.id_timestamp))
except ConnectionClosed: except ConnectionClosed:
if previous_node is self.current_node: if previous_node is self.current_node:
return return
......
...@@ -119,7 +119,7 @@ class NEOProcess(object): ...@@ -119,7 +119,7 @@ class NEOProcess(object):
except ImportError: except ImportError:
raise NotFound, '%s not found' % (command) raise NotFound, '%s not found' % (command)
self.command = command self.command = command
self.arg_dict = {'--' + k: v for k, v in arg_dict.iteritems()} self.arg_dict = arg_dict
self.with_uuid = True self.with_uuid = True
self.setUUID(uuid) self.setUUID(uuid)
...@@ -131,11 +131,11 @@ class NEOProcess(object): ...@@ -131,11 +131,11 @@ class NEOProcess(object):
args = [] args = []
self.with_uuid = with_uuid self.with_uuid = with_uuid
for arg, param in self.arg_dict.iteritems(): for arg, param in self.arg_dict.iteritems():
if with_uuid is False and arg == '--uuid': args.append('--' + arg)
continue
args.append(arg)
if param is not None: if param is not None:
args.append(str(param)) args.append(str(param))
if with_uuid:
args += '--uuid', str(self.uuid)
self.pid = os.fork() self.pid = os.fork()
if self.pid == 0: if self.pid == 0:
# Child # Child
...@@ -213,7 +213,6 @@ class NEOProcess(object): ...@@ -213,7 +213,6 @@ class NEOProcess(object):
Note: for this change to take effect, the node must be restarted. Note: for this change to take effect, the node must be restarted.
""" """
self.uuid = uuid self.uuid = uuid
self.arg_dict['--uuid'] = str(uuid)
def isAlive(self): def isAlive(self):
try: try:
...@@ -297,7 +296,6 @@ class NEOCluster(object): ...@@ -297,7 +296,6 @@ class NEOCluster(object):
def _newProcess(self, node_type, logfile=None, port=None, **kw): def _newProcess(self, node_type, logfile=None, port=None, **kw):
self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0) self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0)
uuid += UUID_NAMESPACES[node_type] << 24 uuid += UUID_NAMESPACES[node_type] << 24
kw['uuid'] = uuid
kw['cluster'] = self.cluster_name kw['cluster'] = self.cluster_name
kw['masters'] = self.master_nodes kw['masters'] = self.master_nodes
if logfile: if logfile:
...@@ -483,13 +481,9 @@ class NEOCluster(object): ...@@ -483,13 +481,9 @@ class NEOCluster(object):
return self.__getNodeList(NodeTypes.CLIENT, state) return self.__getNodeList(NodeTypes.CLIENT, state)
def __getNodeState(self, node_type, uuid): def __getNodeState(self, node_type, uuid):
node_list = self.__getNodeList(node_type) for node in self.__getNodeList(node_type):
for node_type, address, node_uuid, state in node_list: if node[2] == uuid:
if node_uuid == uuid: return node[3]
break
else:
state = None
return state
def getMasterNodeState(self, uuid): def getMasterNodeState(self, uuid):
return self.__getNodeState(NodeTypes.MASTER, uuid) return self.__getNodeState(NodeTypes.MASTER, uuid)
......
...@@ -231,7 +231,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -231,7 +231,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
def test_requestIdentification1(self): def test_requestIdentification1(self):
""" A non-master node request identification """ """ A non-master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.NotReadyError, self.assertRaises(protocol.NotReadyError,
self.election.requestIdentification, self.election.requestIdentification,
conn, NodeTypes.CLIENT, *args) conn, NodeTypes.CLIENT, *args)
...@@ -240,7 +240,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -240,7 +240,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
""" A broken master node request identification """ """ A broken master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
node.setBroken() node.setBroken()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.BrokenNodeDisallowedError, self.assertRaises(protocol.BrokenNodeDisallowedError,
self.election.requestIdentification, self.election.requestIdentification,
conn, NodeTypes.MASTER, *args) conn, NodeTypes.MASTER, *args)
...@@ -248,7 +248,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -248,7 +248,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
def test_requestIdentification4(self): def test_requestIdentification4(self):
""" No conflict """ """ No conflict """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.election.requestIdentification(conn, self.election.requestIdentification(conn,
NodeTypes.MASTER, *args) NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID()) self.checkUUIDSet(conn, node.getUUID())
...@@ -280,11 +280,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -280,11 +280,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
conn = self.__getClient() conn = self.__getClient()
self.checkNotReadyErrorRaised( self.checkNotReadyErrorRaised(
self.election.requestIdentification, self.election.requestIdentification,
conn=conn, conn,
node_type=NodeTypes.CLIENT, NodeTypes.CLIENT,
uuid=conn.getUUID(), conn.getUUID(),
address=conn.getAddress(), conn.getAddress(),
name=self.app.name self.app.name,
None,
) )
def _requestIdentification(self): def _requestIdentification(self):
...@@ -297,6 +298,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -297,6 +298,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
peer_uuid, peer_uuid,
address, address,
self.app.name, self.app.name,
None,
) )
node_type, uuid, partitions, replicas, _peer_uuid, primary, \ node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn, decode=True) master_list = self.checkAcceptIdentification(conn, decode=True)
......
...@@ -50,6 +50,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -50,6 +50,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
self.getClientUUID(), self.getClientUUID(),
None, None,
self.app.name, self.app.name,
None,
) )
self.app.ready = True self.app.ready = True
self.assertRaises( self.assertRaises(
...@@ -60,6 +61,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -60,6 +61,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
self.getStorageUUID(), self.getStorageUUID(),
None, None,
self.app.name, self.app.name,
None,
) )
def test_requestIdentification3(self): def test_requestIdentification3(self):
...@@ -75,6 +77,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -75,6 +77,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
uuid, uuid,
None, None,
self.app.name, self.app.name,
None,
) )
def test_requestIdentification2(self): def test_requestIdentification2(self):
...@@ -87,7 +90,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -87,7 +90,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
'getAddress': master, 'getAddress': master,
}) })
self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid, self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid,
None, self.app.name) None, self.app.name, None)
self.assertTrue(node.isRunning()) self.assertTrue(node.isRunning())
self.assertTrue(node.isConnected()) self.assertTrue(node.isConnected())
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
......
...@@ -28,7 +28,7 @@ class BootstrapManagerTests(NeoUnitTestBase): ...@@ -28,7 +28,7 @@ class BootstrapManagerTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getStorageConfiguration() config = self.getStorageConfiguration()
self.app = Application(config) self.app = Application(config)
self.bootstrap = BootstrapManager(self.app, 'main', NodeTypes.STORAGE) self.bootstrap = BootstrapManager(self.app, NodeTypes.STORAGE)
# define some variable to simulate client and storage node # define some variable to simulate client and storage node
self.master_port = 10010 self.master_port = 10010
self.storage_port = 10020 self.storage_port = 10020
......
...@@ -183,15 +183,15 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -183,15 +183,15 @@ class NodeManagerTests(NeoUnitTestBase):
old_uuid = self.storage.getUUID() old_uuid = self.storage.getUUID()
new_uuid = self.getStorageUUID() new_uuid = self.getStorageUUID()
node_list = ( node_list = (
(NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.DOWN), (NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.DOWN, None),
(NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING), (NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING, None),
(NodeTypes.STORAGE, self.storage.getAddress(), new_uuid, (NodeTypes.STORAGE, self.storage.getAddress(), new_uuid,
NodeStates.RUNNING), NodeStates.RUNNING, None),
(NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(), (NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(),
NodeStates.UNKNOWN), NodeStates.UNKNOWN, None),
) )
# update manager content # update manager content
manager.update(node_list) manager.update(Mock(), node_list)
# - the client gets down # - the client gets down
self.checkClients([]) self.checkClients([])
# - master change it's address # - master change it's address
......
...@@ -27,14 +27,14 @@ from ZODB import DB, POSException ...@@ -27,14 +27,14 @@ from ZODB import DB, POSException
from ZODB.DB import TransactionalUndo from ZODB.DB import TransactionalUndo
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError DelayedError, ConflictError
from neo.lib.connection import MTClientConnection from neo.lib.connection import ServerConnection, MTClientConnection
from neo.lib.exception import DatabaseFailure, StoppedOperation from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID ZERO_OID, ZERO_TID
from .. import expectedFailure, Patch from .. import expectedFailure, Patch
from . import LockLock, NEOCluster, NEOThreadedTest from . import LockLock, NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
...@@ -1347,6 +1347,58 @@ class Test(NEOThreadedTest): ...@@ -1347,6 +1347,58 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testIdTimestamp(self):
"""
Given a master M, a storage S, and 2 clients Ca and Cb.
While Ca(id=1) is being identified by S:
1. connection between Ca and M breaks
2. M -> S: C1 down
3. Cb connect to M: id=1
4. M -> S: C1 up
5. S processes RequestIdentification from Ca with id=1
At 5, S must reject Ca, otherwise Cb can't connect to S. This is where
id timestamps come into play: with C1 up since t2, S rejects Ca due to
a request with t1 < t2.
To avoid issues with clocks that are out of sync, the client gets its
connection timestamp by being notified about itself from the master.
"""
s2c = []
def __init__(orig, self, *args, **kw):
orig(self, *args, **kw)
self.readable = bool
s2c.append(self)
ll()
def connectToStorage(client):
next(client.cp.iterateForObject(0))
cluster = NEOCluster()
try:
cluster.start()
Ca = cluster.client
Ca.pt # only connect to the master
# In a separate thread, connect to the storage but suspend the
# processing of the RequestIdentification packet, until the
# storage is notified about the existence of the other client.
with LockLock() as ll, Patch(ServerConnection, __init__=__init__):
t = self.newThread(connectToStorage, Ca)
ll()
s2c, = s2c
m2c, = cluster.master.getConnectionList(cluster.client)
m2c.close()
Cb = cluster.newClient()
try:
Cb.pt # only connect to the master
del s2c.readable
self.assertRaises(NEOPrimaryMasterLost, t.join)
self.assertTrue(s2c.isClosed())
connectToStorage(Cb)
finally:
Cb.close()
finally:
cluster.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment