Commit b3e5254a authored by Grégory Wisniewski's avatar Grégory Wisniewski

Add the thread-safe version of the partition table for the client app. A

per-instance lock is acquired before invoke any partition table method from the
client app. Update tests.


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@632 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f92d874c
#
# Copyright (C) 2006-2009 Nexedi SA # Copyright (C) 2006-2009 Nexedi SA
# #
# This program is free software; you can redistribute it and/or # This program is free software; you can redistribute it and/or
...@@ -264,10 +263,6 @@ class Application(object): ...@@ -264,10 +263,6 @@ class Application(object):
lock = Lock() lock = Lock()
self._nm_acquire = lock.acquire self._nm_acquire = lock.acquire
self._nm_release = lock.release self._nm_release = lock.release
# __pt ensure exclusive access to the partition table
lock = Lock()
self._pt_acquire = lock.acquire
self._pt_release = lock.release
def _notifyDeadStorage(self, s_node): def _notifyDeadStorage(self, s_node):
""" Notify a storage failure to the primary master """ """ Notify a storage failure to the primary master """
...@@ -343,14 +338,15 @@ class Application(object): ...@@ -343,14 +338,15 @@ class Application(object):
def _getPartitionTable(self): def _getPartitionTable(self):
""" Return the partition table manager, reconnect the PMN if needed """ """ Return the partition table manager, reconnect the PMN if needed """
self._pt_acquire(True) # this ensure the master connection is established and the partition
try: # table is up to date.
if self.master_conn is None: self._getMasterConnection()
self.master_conn = self._connectToPrimaryMasterNode() return self.pt
assert self.pt is not None
return self.pt def _getCellListForID(self, id, readable=False, writable=False):
finally: """ Return the cells available for the specified (O|T)ID """
self._pt_release() pt = self._getPartitionTable()
return pt.getCellListForID(id, readable, writable)
def _connectToPrimaryMasterNode(self): def _connectToPrimaryMasterNode(self):
logging.debug('connecting to primary master...') logging.debug('connecting to primary master...')
...@@ -409,6 +405,7 @@ class Application(object): ...@@ -409,6 +405,7 @@ class Application(object):
elif self.pt is not None and self.pt.operational(): elif self.pt is not None and self.pt.operational():
# Connected to primary master node # Connected to primary master node
break break
sleep(0.1)
if self.pt is not None and self.pt.operational() \ if self.pt is not None and self.pt.operational() \
and self.uuid != INVALID_UUID: and self.uuid != INVALID_UUID:
# Connected to primary master node and got all informations # Connected to primary master node and got all informations
...@@ -461,8 +458,7 @@ class Application(object): ...@@ -461,8 +458,7 @@ class Application(object):
def _load(self, oid, serial = INVALID_TID, tid = INVALID_TID, cache = 0): def _load(self, oid, serial = INVALID_TID, tid = INVALID_TID, cache = 0):
"""Internal method which manage load ,loadSerial and loadBefore.""" """Internal method which manage load ,loadSerial and loadBefore."""
pt = self._getPartitionTable() cell_list = self._getCellListForID(oid, readable=True)
cell_list = pt.getCellListForID(oid, readable=True)
if len(cell_list) == 0: if len(cell_list) == 0:
# No cells available, so why are we running ? # No cells available, so why are we running ?
logging.error('oid %s not found because no storage is available for it', dump(oid)) logging.error('oid %s not found because no storage is available for it', dump(oid))
...@@ -596,8 +592,7 @@ class Application(object): ...@@ -596,8 +592,7 @@ class Application(object):
logging.debug('storing oid %s serial %s', logging.debug('storing oid %s serial %s',
dump(oid), dump(serial)) dump(oid), dump(serial))
# Find which storage node to use # Find which storage node to use
pt = self._getPartitionTable() cell_list = self._getCellListForID(oid, writable=True)
cell_list = pt.getCellListForID(oid, writable=True)
if len(cell_list) == 0: if len(cell_list) == 0:
# FIXME must wait for cluster to be ready # FIXME must wait for cluster to be ready
raise NEOStorageError raise NEOStorageError
...@@ -653,7 +648,7 @@ class Application(object): ...@@ -653,7 +648,7 @@ class Application(object):
oid_list = self.local_var.data_dict.keys() oid_list = self.local_var.data_dict.keys()
# Store data on each node # Store data on each node
pt = self._getPartitionTable() pt = self._getPartitionTable()
cell_list = pt.getCellListForID(self.local_var.tid, writable=True) cell_list = self._getCellListForID(self.local_var.tid, writable=True)
self.local_var.voted_counter = 0 self.local_var.voted_counter = 0
for cell in cell_list: for cell in cell_list:
logging.info("voting object %s %s" %(cell.getServer(), cell.getState())) logging.info("voting object %s %s" %(cell.getServer(), cell.getState()))
...@@ -683,12 +678,11 @@ class Application(object): ...@@ -683,12 +678,11 @@ class Application(object):
return return
cell_set = set() cell_set = set()
pt = self._getPartitionTable()
# select nodes where objects were stored # select nodes where objects were stored
for oid in self.local_var.data_dict.iterkeys(): for oid in self.local_var.data_dict.iterkeys():
cell_set |= set(pt.getCellListForID(oid, writable=True)) cell_set |= set(self._getCellListForID(oid, writable=True))
# select nodes where transaction was stored # select nodes where transaction was stored
cell_set |= set(pt.getCellListForID(self.local_var.tid, writable=True)) cell_set |= set(self._getCellListForID(self.local_var.tid, writable=True))
# cancel transaction one all those nodes # cancel transaction one all those nodes
for cell in cell_set: for cell in cell_set:
...@@ -746,8 +740,7 @@ class Application(object): ...@@ -746,8 +740,7 @@ class Application(object):
raise StorageTransactionError(self, transaction_id) raise StorageTransactionError(self, transaction_id)
# First get transaction information from a storage node. # First get transaction information from a storage node.
pt = self._getPartitionTable() cell_list = self._getCellListForID(transaction_id, writable=True)
cell_list = pt.getCellListForID(transaction_id, writable=True)
shuffle(cell_list) shuffle(cell_list)
for cell in cell_list: for cell in cell_list:
conn = self.cp.getConnForNode(cell) conn = self.cp.getConnForNode(cell)
...@@ -848,7 +841,7 @@ class Application(object): ...@@ -848,7 +841,7 @@ class Application(object):
# For each transaction, get info # For each transaction, get info
undo_info = [] undo_info = []
for tid in ordered_tids: for tid in ordered_tids:
cell_list = pt.getCellListForID(tid, readable=True) cell_list = self._getCellListForID(tid, readable=True)
shuffle(cell_list) shuffle(cell_list)
for cell in cell_list: for cell in cell_list:
conn = self.cp.getConnForNode(storage_node) conn = self.cp.getConnForNode(storage_node)
...@@ -891,8 +884,7 @@ class Application(object): ...@@ -891,8 +884,7 @@ class Application(object):
# FIXME: filter function isn't used # FIXME: filter function isn't used
def history(self, oid, version=None, length=1, filter=None, object_only=0): def history(self, oid, version=None, length=1, filter=None, object_only=0):
# Get history informations for object first # Get history informations for object first
pt = self._getPartitionTable() cell_list = self._getCellListForID(oid, readable=True)
cell_list = pt.getCellListForID(oid, readable=True)
shuffle(cell_list) shuffle(cell_list)
for cell in cell_list: for cell in cell_list:
...@@ -922,7 +914,7 @@ class Application(object): ...@@ -922,7 +914,7 @@ class Application(object):
# Now that we have object informations, get txn informations # Now that we have object informations, get txn informations
history_list = [] history_list = []
for serial, size in self.local_var.history[1]: for serial, size in self.local_var.history[1]:
pt.getCellListForID(serial, readable=True) self._getCellListForID(serial, readable=True)
shuffle(cell_list) shuffle(cell_list)
for cell in cell_list: for cell in cell_list:
......
...@@ -26,7 +26,7 @@ from neo.protocol import Packet, \ ...@@ -26,7 +26,7 @@ from neo.protocol import Packet, \
BROKEN_STATE, FEEDING_STATE, DISCARDED_STATE, DOWN_STATE, \ BROKEN_STATE, FEEDING_STATE, DISCARDED_STATE, DOWN_STATE, \
HIDDEN_STATE HIDDEN_STATE
from neo.node import MasterNode, StorageNode, ClientNode from neo.node import MasterNode, StorageNode, ClientNode
from neo.pt import PartitionTable from neo.pt import MTPartitionTable as PartitionTable
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.exception import ElectionFailure from neo.exception import ElectionFailure
from neo.util import dump from neo.util import dump
...@@ -35,7 +35,6 @@ from neo.handler import identification_required, restrict_node_types ...@@ -35,7 +35,6 @@ from neo.handler import identification_required, restrict_node_types
from ZODB.TimeStamp import TimeStamp from ZODB.TimeStamp import TimeStamp
from ZODB.utils import p64 from ZODB.utils import p64
class BaseHandler(EventHandler): class BaseHandler(EventHandler):
"""Base class for client-side EventHandler implementations.""" """Base class for client-side EventHandler implementations."""
......
...@@ -481,13 +481,14 @@ class ClientHandlerTests(NeoTestBase): ...@@ -481,13 +481,14 @@ class ClientHandlerTests(NeoTestBase):
node = Mock({'getNodeType': node_type}) node = Mock({'getNodeType': node_type})
class App: class App:
nm = Mock({'getNodeByUUID': node}) nm = Mock({'getNodeByUUID': node})
pt = None pt = Mock()
app = App() app = App()
client_handler = PrimaryBootstrapHandler(app, self.getDispatcher()) client_handler = PrimaryBootstrapHandler(app, self.getDispatcher())
conn = self.getConnection() conn = self.getConnection()
client_handler.handleSendPartitionTable(conn, None, 0, []) client_handler.handleSendPartitionTable(conn, None, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertTrue(app.pt is None) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
def test_newSendPartitionTable(self): def test_newSendPartitionTable(self):
node = Mock({'getNodeType': MASTER_NODE_TYPE}) node = Mock({'getNodeType': MASTER_NODE_TYPE})
...@@ -674,7 +675,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -674,7 +675,7 @@ class ClientHandlerTests(NeoTestBase):
node = Mock({'getNodeType': node_type, 'getUUID': test_master_uuid}) node = Mock({'getNodeType': node_type, 'getUUID': test_master_uuid})
class App: class App:
nm = Mock({'getNodeByUUID': node}) nm = Mock({'getNodeByUUID': node})
pt = None pt = Mock()
ptid = INVALID_PTID ptid = INVALID_PTID
primary_master_node = node primary_master_node = node
app = App() app = App()
...@@ -682,13 +683,14 @@ class ClientHandlerTests(NeoTestBase): ...@@ -682,13 +683,14 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection(uuid=test_master_uuid) conn = self.getConnection(uuid=test_master_uuid)
client_handler.handleNotifyPartitionChanges(conn, None, 0, []) client_handler.handleNotifyPartitionChanges(conn, None, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertTrue(app.pt is None) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
def test_noPrimaryMasterNotifyPartitionChanges(self): def test_noPrimaryMasterNotifyPartitionChanges(self):
node = Mock({'getNodeType': MASTER_NODE_TYPE}) node = Mock({'getNodeType': MASTER_NODE_TYPE})
class App: class App:
nm = Mock({'getNodeByUUID': node}) nm = Mock({'getNodeByUUID': node})
pt = None pt = Mock()
ptid = INVALID_PTID ptid = INVALID_PTID
primary_master_node = None primary_master_node = None
app = App() app = App()
...@@ -696,7 +698,8 @@ class ClientHandlerTests(NeoTestBase): ...@@ -696,7 +698,8 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
client_handler.handleNotifyPartitionChanges(conn, None, 0, []) client_handler.handleNotifyPartitionChanges(conn, None, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertTrue(app.pt is None) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
def test_nonPrimaryMasterNotifyPartitionChanges(self): def test_nonPrimaryMasterNotifyPartitionChanges(self):
test_master_uuid = self.getNewUUID() test_master_uuid = self.getNewUUID()
...@@ -707,7 +710,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -707,7 +710,7 @@ class ClientHandlerTests(NeoTestBase):
test_master_node = Mock({'getUUID': test_master_uuid}) test_master_node = Mock({'getUUID': test_master_uuid})
class App: class App:
nm = Mock({'getNodeByUUID': node}) nm = Mock({'getNodeByUUID': node})
pt = None pt = Mock()
ptid = INVALID_PTID ptid = INVALID_PTID
primary_master_node = test_master_node primary_master_node = test_master_node
app = App() app = App()
...@@ -715,7 +718,8 @@ class ClientHandlerTests(NeoTestBase): ...@@ -715,7 +718,8 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection(uuid=test_sender_uuid) conn = self.getConnection(uuid=test_sender_uuid)
client_handler.handleNotifyPartitionChanges(conn, None, 0, []) client_handler.handleNotifyPartitionChanges(conn, None, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertTrue(app.pt is None) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
def test_ignoreOutdatedPTIDNotifyPartitionChanges(self): def test_ignoreOutdatedPTIDNotifyPartitionChanges(self):
test_master_uuid = self.getNewUUID() test_master_uuid = self.getNewUUID()
...@@ -723,7 +727,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -723,7 +727,7 @@ class ClientHandlerTests(NeoTestBase):
test_ptid = 1 test_ptid = 1
class App: class App:
nm = Mock({'getNodeByUUID': node}) nm = Mock({'getNodeByUUID': node})
pt = None pt = Mock()
primary_master_node = node primary_master_node = node
ptid = test_ptid ptid = test_ptid
app = App() app = App()
...@@ -731,7 +735,8 @@ class ClientHandlerTests(NeoTestBase): ...@@ -731,7 +735,8 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection(uuid=test_master_uuid) conn = self.getConnection(uuid=test_master_uuid)
client_handler.handleNotifyPartitionChanges(conn, None, test_ptid, []) client_handler.handleNotifyPartitionChanges(conn, None, test_ptid, [])
# Check that nothing happened # Check that nothing happened
self.assertTrue(app.pt is None) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
self.assertEquals(app.ptid, test_ptid) self.assertEquals(app.ptid, test_ptid)
def test_unknownNodeNotifyPartitionChanges(self): def test_unknownNodeNotifyPartitionChanges(self):
......
...@@ -21,6 +21,8 @@ from neo.protocol import UP_TO_DATE_STATE, OUT_OF_DATE_STATE, FEEDING_STATE, \ ...@@ -21,6 +21,8 @@ from neo.protocol import UP_TO_DATE_STATE, OUT_OF_DATE_STATE, FEEDING_STATE, \
DISCARDED_STATE, RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, \ DISCARDED_STATE, RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, \
BROKEN_STATE, VALID_CELL_STATE_LIST, HIDDEN_STATE BROKEN_STATE, VALID_CELL_STATE_LIST, HIDDEN_STATE
from neo.util import dump, u64 from neo.util import dump, u64
from neo.locking import RLock
class Cell(object): class Cell(object):
"""This class represents a cell in a partition table.""" """This class represents a cell in a partition table."""
...@@ -222,3 +224,48 @@ class PartitionTable(object): ...@@ -222,3 +224,48 @@ class PartitionTable(object):
return () return ()
return [(cell.getUUID(), cell.getState()) for cell in row] return [(cell.getUUID(), cell.getState()) for cell in row]
def thread_safe(method):
def wrapper(self, *args, **kwargs):
self.lock()
try:
return method(self, *args, **kwargs)
finally:
self.unlock()
return wrapper
class MTPartitionTable(PartitionTable):
""" Thread-safe aware version of the partition table, override only methods
used in the client """
def __init__(self, *args, **kwargs):
self._lock = RLock()
PartitionTable.__init__(self, *args, **kwargs)
def lock(self):
self._lock.acquire()
def unlock(self):
self._lock.release()
@thread_safe
def getCellListForID(self, *args, **kwargs):
return PartitionTable.getCellListForID(self, *args, **kwargs)
@thread_safe
def setCell(self, *args, **kwargs):
return PartitionTable.setCell(self, *args, **kwargs)
@thread_safe
def clear(self, *args, **kwargs):
return PartitionTable.clear(self, *args, **kwargs)
@thread_safe
def operational(self, *args, **kwargs):
return PartitionTable.operational(self, *args, **kwargs)
@thread_safe
def getNodeList(self, *args, **kwargs):
return PartitionTable.getNodeList(self, *args, **kwargs)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment