Commit 7f885a83 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Index node by identified state as it's often used.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1828 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent c79c38fa
......@@ -73,6 +73,7 @@ class Node(object):
old_uuid = self._uuid
self._uuid = uuid
self._manager._updateUUID(self, old_uuid)
self._manager._updateIdentified(self)
def getUUID(self):
return self._uuid
......@@ -83,6 +84,7 @@ class Node(object):
"""
assert self._connection is not None
self._connection = None
self._manager._updateIdentified(self)
def setConnection(self, connection):
"""
......@@ -92,6 +94,7 @@ class Node(object):
assert self._connection is None
self._connection = connection
connection.setOnClose(self.onConnectionClosed)
self._manager._updateIdentified(self)
def getConnection(self):
"""
......@@ -250,6 +253,7 @@ class NodeManager(object):
self._uuid_dict = {}
self._type_dict = {}
self._state_dict = {}
self._identified_dict = {}
def add(self, node):
if node in self._node_set:
......@@ -259,6 +263,7 @@ class NodeManager(object):
self._updateUUID(node, None)
self.__updateSet(self._type_dict, None, node.__class__, node)
self.__updateSet(self._state_dict, None, node.getState(), node)
self._updateIdentified(node)
def remove(self, node):
if node is None or node not in self._node_set:
......@@ -268,6 +273,7 @@ class NodeManager(object):
self.__drop(self._uuid_dict, node.getUUID())
self.__dropSet(self._state_dict, node.getState(), node)
self.__dropSet(self._type_dict, node.__class__, node)
self._updateIdentified(node)
def __drop(self, index_dict, key):
try:
......@@ -285,6 +291,16 @@ class NodeManager(object):
if new_key is not None:
index_dict[new_key] = node
def _updateIdentified(self, node):
uuid = node.getUUID()
if node.isIdentified():
self._identified_dict[uuid] = node
else:
try:
del self._identified_dict[uuid]
except KeyError:
pass
def _updateAddress(self, node, old_address):
self.__update(self._address_dict, old_address, node.getAddress(), node)
......@@ -313,13 +329,12 @@ class NodeManager(object):
def getIdentifiedList(self, pool_set=None):
"""
Returns a generator to iterate over identified nodes
pool_set is an iterable of UUIDs allowed
"""
# TODO: use an index
if pool_set is None:
return [x for x in self._node_set if x.isIdentified()]
else:
return [x for x in self._node_set if x.isIdentified() and
x.getUUID() in pool_set]
if pool_set is not None:
identified_nodes = self._identified_dict.items()
return [v for k, v in identified_nodes if k in pool_set]
return list(self._identified_dict.values())
def getConnectedList(self):
"""
......@@ -329,30 +344,33 @@ class NodeManager(object):
return [x for x in self._node_set if x.isConnected()]
def __getList(self, index_dict, key):
return list(index_dict.setdefault(key, set()))
return index_dict.setdefault(key, set())
def getByStateList(self, state):
""" Get a node list filtered per the node state """
return self.__getList(self._state_dict, state)
return list(self.__getList(self._state_dict, state))
def __getTypeList(self, type_klass):
return self.__getList(self._type_dict, type_klass)
def __getTypeList(self, type_klass, only_identified=False):
node_set = self.__getList(self._type_dict, type_klass)
if only_identified:
return [x for x in node_set if x.getUUID() in self._identified_dict]
return list(node_set)
def getMasterList(self):
def getMasterList(self, only_identified=False):
""" Return a list with master nodes """
return self.__getTypeList(MasterNode)
return self.__getTypeList(MasterNode, only_identified)
def getStorageList(self):
def getStorageList(self, only_identified=False):
""" Return a list with storage nodes """
return self.__getTypeList(StorageNode)
return self.__getTypeList(StorageNode, only_identified)
def getClientList(self):
def getClientList(self, only_identified=False):
""" Return a list with client nodes """
return self.__getTypeList(ClientNode)
return self.__getTypeList(ClientNode, only_identified)
def getAdminList(self):
def getAdminList(self, only_identified=False):
""" Return a list with admin nodes """
return self.__getTypeList(AdminNode)
return self.__getTypeList(AdminNode, only_identified)
def getByAddress(self, address):
""" Return the node that match with a given address """
......
......@@ -155,6 +155,10 @@ class NodeManagerTests(NeoTestBase):
node_found = self.manager.getByUUID(node.getUUID())
self.assertEqual(node_found, node)
def checkIdentified(self, node_list, pool_set=None):
identified_node_list = self.manager.getIdentifiedList(pool_set)
self.assertEqual(set(identified_node_list), set(node_list))
def testInit(self):
""" Check the manager is empty when started """
manager = self.manager
......@@ -273,6 +277,31 @@ class NodeManagerTests(NeoTestBase):
self.checkNodes([self.master, self.admin, new_storage])
self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN)
def testIdentified(self):
# set up four nodes
manager = self.manager
manager.add(self.master)
manager.add(self.storage)
manager.add(self.client)
manager.add(self.admin)
# switch node to connected
self.checkIdentified([])
self.master.setConnection(Mock())
self.checkIdentified([self.master])
self.storage.setConnection(Mock())
self.checkIdentified([self.master, self.storage])
self.client.setConnection(Mock())
self.checkIdentified([self.master, self.storage, self.client])
self.admin.setConnection(Mock())
self.checkIdentified([self.master, self.storage, self.client, self.admin])
# check the pool_set attribute
self.checkIdentified([self.master], pool_set=[self.master.getUUID()])
self.checkIdentified([self.storage], pool_set=[self.storage.getUUID()])
self.checkIdentified([self.client], pool_set=[self.client.getUUID()])
self.checkIdentified([self.admin], pool_set=[self.admin.getUUID()])
self.checkIdentified([self.master, self.storage], pool_set=[
self.master.getUUID(), self.storage.getUUID()])
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment