Commit 7f885a83 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Index node by identified state as it's often used.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1828 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent c79c38fa
...@@ -73,6 +73,7 @@ class Node(object): ...@@ -73,6 +73,7 @@ class Node(object):
old_uuid = self._uuid old_uuid = self._uuid
self._uuid = uuid self._uuid = uuid
self._manager._updateUUID(self, old_uuid) self._manager._updateUUID(self, old_uuid)
self._manager._updateIdentified(self)
def getUUID(self): def getUUID(self):
return self._uuid return self._uuid
...@@ -83,6 +84,7 @@ class Node(object): ...@@ -83,6 +84,7 @@ class Node(object):
""" """
assert self._connection is not None assert self._connection is not None
self._connection = None self._connection = None
self._manager._updateIdentified(self)
def setConnection(self, connection): def setConnection(self, connection):
""" """
...@@ -92,6 +94,7 @@ class Node(object): ...@@ -92,6 +94,7 @@ class Node(object):
assert self._connection is None assert self._connection is None
self._connection = connection self._connection = connection
connection.setOnClose(self.onConnectionClosed) connection.setOnClose(self.onConnectionClosed)
self._manager._updateIdentified(self)
def getConnection(self): def getConnection(self):
""" """
...@@ -250,6 +253,7 @@ class NodeManager(object): ...@@ -250,6 +253,7 @@ class NodeManager(object):
self._uuid_dict = {} self._uuid_dict = {}
self._type_dict = {} self._type_dict = {}
self._state_dict = {} self._state_dict = {}
self._identified_dict = {}
def add(self, node): def add(self, node):
if node in self._node_set: if node in self._node_set:
...@@ -259,6 +263,7 @@ class NodeManager(object): ...@@ -259,6 +263,7 @@ class NodeManager(object):
self._updateUUID(node, None) self._updateUUID(node, None)
self.__updateSet(self._type_dict, None, node.__class__, node) self.__updateSet(self._type_dict, None, node.__class__, node)
self.__updateSet(self._state_dict, None, node.getState(), node) self.__updateSet(self._state_dict, None, node.getState(), node)
self._updateIdentified(node)
def remove(self, node): def remove(self, node):
if node is None or node not in self._node_set: if node is None or node not in self._node_set:
...@@ -268,6 +273,7 @@ class NodeManager(object): ...@@ -268,6 +273,7 @@ class NodeManager(object):
self.__drop(self._uuid_dict, node.getUUID()) self.__drop(self._uuid_dict, node.getUUID())
self.__dropSet(self._state_dict, node.getState(), node) self.__dropSet(self._state_dict, node.getState(), node)
self.__dropSet(self._type_dict, node.__class__, node) self.__dropSet(self._type_dict, node.__class__, node)
self._updateIdentified(node)
def __drop(self, index_dict, key): def __drop(self, index_dict, key):
try: try:
...@@ -285,6 +291,16 @@ class NodeManager(object): ...@@ -285,6 +291,16 @@ class NodeManager(object):
if new_key is not None: if new_key is not None:
index_dict[new_key] = node index_dict[new_key] = node
def _updateIdentified(self, node):
uuid = node.getUUID()
if node.isIdentified():
self._identified_dict[uuid] = node
else:
try:
del self._identified_dict[uuid]
except KeyError:
pass
def _updateAddress(self, node, old_address): def _updateAddress(self, node, old_address):
self.__update(self._address_dict, old_address, node.getAddress(), node) self.__update(self._address_dict, old_address, node.getAddress(), node)
...@@ -313,13 +329,12 @@ class NodeManager(object): ...@@ -313,13 +329,12 @@ class NodeManager(object):
def getIdentifiedList(self, pool_set=None): def getIdentifiedList(self, pool_set=None):
""" """
Returns a generator to iterate over identified nodes Returns a generator to iterate over identified nodes
pool_set is an iterable of UUIDs allowed
""" """
# TODO: use an index if pool_set is not None:
if pool_set is None: identified_nodes = self._identified_dict.items()
return [x for x in self._node_set if x.isIdentified()] return [v for k, v in identified_nodes if k in pool_set]
else: return list(self._identified_dict.values())
return [x for x in self._node_set if x.isIdentified() and
x.getUUID() in pool_set]
def getConnectedList(self): def getConnectedList(self):
""" """
...@@ -329,30 +344,33 @@ class NodeManager(object): ...@@ -329,30 +344,33 @@ class NodeManager(object):
return [x for x in self._node_set if x.isConnected()] return [x for x in self._node_set if x.isConnected()]
def __getList(self, index_dict, key): def __getList(self, index_dict, key):
return list(index_dict.setdefault(key, set())) return index_dict.setdefault(key, set())
def getByStateList(self, state): def getByStateList(self, state):
""" Get a node list filtered per the node state """ """ Get a node list filtered per the node state """
return self.__getList(self._state_dict, state) return list(self.__getList(self._state_dict, state))
def __getTypeList(self, type_klass): def __getTypeList(self, type_klass, only_identified=False):
return self.__getList(self._type_dict, type_klass) node_set = self.__getList(self._type_dict, type_klass)
if only_identified:
return [x for x in node_set if x.getUUID() in self._identified_dict]
return list(node_set)
def getMasterList(self): def getMasterList(self, only_identified=False):
""" Return a list with master nodes """ """ Return a list with master nodes """
return self.__getTypeList(MasterNode) return self.__getTypeList(MasterNode, only_identified)
def getStorageList(self): def getStorageList(self, only_identified=False):
""" Return a list with storage nodes """ """ Return a list with storage nodes """
return self.__getTypeList(StorageNode) return self.__getTypeList(StorageNode, only_identified)
def getClientList(self): def getClientList(self, only_identified=False):
""" Return a list with client nodes """ """ Return a list with client nodes """
return self.__getTypeList(ClientNode) return self.__getTypeList(ClientNode, only_identified)
def getAdminList(self): def getAdminList(self, only_identified=False):
""" Return a list with admin nodes """ """ Return a list with admin nodes """
return self.__getTypeList(AdminNode) return self.__getTypeList(AdminNode, only_identified)
def getByAddress(self, address): def getByAddress(self, address):
""" Return the node that match with a given address """ """ Return the node that match with a given address """
......
...@@ -155,6 +155,10 @@ class NodeManagerTests(NeoTestBase): ...@@ -155,6 +155,10 @@ class NodeManagerTests(NeoTestBase):
node_found = self.manager.getByUUID(node.getUUID()) node_found = self.manager.getByUUID(node.getUUID())
self.assertEqual(node_found, node) self.assertEqual(node_found, node)
def checkIdentified(self, node_list, pool_set=None):
identified_node_list = self.manager.getIdentifiedList(pool_set)
self.assertEqual(set(identified_node_list), set(node_list))
def testInit(self): def testInit(self):
""" Check the manager is empty when started """ """ Check the manager is empty when started """
manager = self.manager manager = self.manager
...@@ -273,6 +277,31 @@ class NodeManagerTests(NeoTestBase): ...@@ -273,6 +277,31 @@ class NodeManagerTests(NeoTestBase):
self.checkNodes([self.master, self.admin, new_storage]) self.checkNodes([self.master, self.admin, new_storage])
self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN) self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN)
def testIdentified(self):
# set up four nodes
manager = self.manager
manager.add(self.master)
manager.add(self.storage)
manager.add(self.client)
manager.add(self.admin)
# switch node to connected
self.checkIdentified([])
self.master.setConnection(Mock())
self.checkIdentified([self.master])
self.storage.setConnection(Mock())
self.checkIdentified([self.master, self.storage])
self.client.setConnection(Mock())
self.checkIdentified([self.master, self.storage, self.client])
self.admin.setConnection(Mock())
self.checkIdentified([self.master, self.storage, self.client, self.admin])
# check the pool_set attribute
self.checkIdentified([self.master], pool_set=[self.master.getUUID()])
self.checkIdentified([self.storage], pool_set=[self.storage.getUUID()])
self.checkIdentified([self.client], pool_set=[self.client.getUUID()])
self.checkIdentified([self.admin], pool_set=[self.admin.getUUID()])
self.checkIdentified([self.master, self.storage], pool_set=[
self.master.getUUID(), self.storage.getUUID()])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment