Commit cefe62f3 authored by Vincent Pelletier's avatar Vincent Pelletier

Add back "queue" parameter on MTClientConnection.ask

Make it optional, to suit "ping" use, but check that it's always passed
except in that special case.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2143 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 9c3b756d
...@@ -260,14 +260,14 @@ class Application(object): ...@@ -260,14 +260,14 @@ class Application(object):
@profiler_decorator @profiler_decorator
def _askStorage(self, conn, packet): def _askStorage(self, conn, packet):
""" Send a request to a storage node and process it's answer """ """ Send a request to a storage node and process it's answer """
msg_id = conn.ask(packet) msg_id = conn.ask(packet, queue=self.local_var.queue)
self._waitMessage(conn, msg_id, self.storage_handler) self._waitMessage(conn, msg_id, self.storage_handler)
@profiler_decorator @profiler_decorator
def _askPrimary(self, packet): def _askPrimary(self, packet):
""" Send a request to the primary master and process it's answer """ """ Send a request to the primary master and process it's answer """
conn = self._getMasterConnection() conn = self._getMasterConnection()
msg_id = conn.ask(packet) msg_id = conn.ask(packet, queue=self.local_var.queue)
self._waitMessage(conn, msg_id, self.primary_handler) self._waitMessage(conn, msg_id, self.primary_handler)
@profiler_decorator @profiler_decorator
...@@ -308,6 +308,7 @@ class Application(object): ...@@ -308,6 +308,7 @@ class Application(object):
logging.debug('connecting to primary master...') logging.debug('connecting to primary master...')
ready = False ready = False
nm = self.nm nm = self.nm
queue = self.local_var.queue
while not ready: while not ready:
# Get network connection to primary master # Get network connection to primary master
index = 0 index = 0
...@@ -328,7 +329,7 @@ class Application(object): ...@@ -328,7 +329,7 @@ class Application(object):
self.trying_master_node = master_list[0] self.trying_master_node = master_list[0]
index += 1 index += 1
# Connect to master # Connect to master
conn = MTClientConnection(self.local_var, self.em, conn = MTClientConnection(self.em,
self.notifications_handler, self.notifications_handler,
addr=self.trying_master_node.getAddress(), addr=self.trying_master_node.getAddress(),
connector=self.connector_handler(), connector=self.connector_handler(),
...@@ -339,7 +340,7 @@ class Application(object): ...@@ -339,7 +340,7 @@ class Application(object):
logging.error('Connection to master node %s failed', logging.error('Connection to master node %s failed',
self.trying_master_node) self.trying_master_node)
continue continue
msg_id = conn.ask(Packets.AskPrimary()) msg_id = conn.ask(Packets.AskPrimary(), queue=queue)
try: try:
self._waitMessage(conn, msg_id, self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler) handler=self.primary_bootstrap_handler)
...@@ -359,7 +360,7 @@ class Application(object): ...@@ -359,7 +360,7 @@ class Application(object):
break break
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
self.uuid, None, self.name) self.uuid, None, self.name)
msg_id = conn.ask(p) msg_id = conn.ask(p, queue=queue)
try: try:
self._waitMessage(conn, msg_id, self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler) handler=self.primary_bootstrap_handler)
...@@ -370,10 +371,10 @@ class Application(object): ...@@ -370,10 +371,10 @@ class Application(object):
# Node identification was refused by master. # Node identification was refused by master.
time.sleep(1) time.sleep(1)
if self.uuid is not None: if self.uuid is not None:
msg_id = conn.ask(Packets.AskNodeInformation()) msg_id = conn.ask(Packets.AskNodeInformation(), queue=queue)
self._waitMessage(conn, msg_id, self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler) handler=self.primary_bootstrap_handler)
msg_id = conn.ask(Packets.AskPartitionTable()) msg_id = conn.ask(Packets.AskPartitionTable(), queue=queue)
self._waitMessage(conn, msg_id, self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler) handler=self.primary_bootstrap_handler)
ready = self.uuid is not None and self.pt is not None \ ready = self.uuid is not None and self.pt is not None \
...@@ -597,12 +598,13 @@ class Application(object): ...@@ -597,12 +598,13 @@ class Application(object):
self.local_var.object_stored_counter_dict[oid] = {} self.local_var.object_stored_counter_dict[oid] = {}
self.local_var.object_serial_dict[oid] = (serial, version) self.local_var.object_serial_dict[oid] = (serial, version)
getConnForCell = self.cp.getConnForCell getConnForCell = self.cp.getConnForCell
queue = self.local_var.queue
for cell in cell_list: for cell in cell_list:
conn = getConnForCell(cell) conn = getConnForCell(cell)
if conn is None: if conn is None:
continue continue
try: try:
conn.ask(p, on_timeout=on_timeout) conn.ask(p, on_timeout=on_timeout, queue=queue)
except ConnectionClosed: except ConnectionClosed:
continue continue
...@@ -870,9 +872,10 @@ class Application(object): ...@@ -870,9 +872,10 @@ class Application(object):
undo_error_oid_list = self.local_var.undo_error_oid_list = [] undo_error_oid_list = self.local_var.undo_error_oid_list = []
ask_undo_transaction = Packets.AskUndoTransaction(tid, undone_tid) ask_undo_transaction = Packets.AskUndoTransaction(tid, undone_tid)
getConnForNode = self.cp.getConnForNode getConnForNode = self.cp.getConnForNode
queue = self.local_var.queue
for storage_node in self.nm.getStorageList(): for storage_node in self.nm.getStorageList():
storage_conn = getConnForNode(storage_node) storage_conn = getConnForNode(storage_node)
storage_conn.ask(ask_undo_transaction) storage_conn.ask(ask_undo_transaction, queue=queue)
# Wait for all AnswerUndoTransaction. # Wait for all AnswerUndoTransaction.
self.waitResponses() self.waitResponses()
...@@ -927,11 +930,12 @@ class Application(object): ...@@ -927,11 +930,12 @@ class Application(object):
storage_node_list = pt.getNodeList() storage_node_list = pt.getNodeList()
self.local_var.node_tids = {} self.local_var.node_tids = {}
queue = self.local_var.queue
for storage_node in storage_node_list: for storage_node in storage_node_list:
conn = self.cp.getConnForNode(storage_node) conn = self.cp.getConnForNode(storage_node)
if conn is None: if conn is None:
continue continue
conn.ask(Packets.AskTIDs(first, last, INVALID_PARTITION)) conn.ask(Packets.AskTIDs(first, last, INVALID_PARTITION), queue=queue)
# Wait for answers from all storages. # Wait for answers from all storages.
while len(self.local_var.node_tids) != len(storage_node_list): while len(self.local_var.node_tids) != len(storage_node_list):
......
...@@ -50,7 +50,7 @@ class ConnectionPool(object): ...@@ -50,7 +50,7 @@ class ConnectionPool(object):
while True: while True:
logging.debug('trying to connect to %s - %s', node, node.getState()) logging.debug('trying to connect to %s - %s', node, node.getState())
app.setNodeReady() app.setNodeReady()
conn = MTClientConnection(app.local_var, app.em, conn = MTClientConnection(app.em,
app.storage_event_handler, addr, app.storage_event_handler, addr,
connector=app.connector_handler(), dispatcher=app.dispatcher) connector=app.connector_handler(), dispatcher=app.dispatcher)
conn.lock() conn.lock()
...@@ -63,7 +63,7 @@ class ConnectionPool(object): ...@@ -63,7 +63,7 @@ class ConnectionPool(object):
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name)
msg_id = conn.ask(p) msg_id = conn.ask(p, queue=app.local_var.queue)
finally: finally:
conn.unlock() conn.unlock()
......
...@@ -684,10 +684,9 @@ class ServerConnection(Connection): ...@@ -684,10 +684,9 @@ class ServerConnection(Connection):
class MTClientConnection(ClientConnection): class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection.""" """A Multithread-safe version of ClientConnection."""
def __init__(self, local_var, *args, **kwargs): def __init__(self, *args, **kwargs):
# _lock is only here for lock debugging purposes. Do not use. # _lock is only here for lock debugging purposes. Do not use.
self._lock = lock = RLock() self._lock = lock = RLock()
self._local_var = local_var
self.acquire = lock.acquire self.acquire = lock.acquire
self.release = lock.release self.release = lock.release
self.dispatcher = kwargs.pop('dispatcher') self.dispatcher = kwargs.pop('dispatcher')
...@@ -723,7 +722,8 @@ class MTClientConnection(ClientConnection): ...@@ -723,7 +722,8 @@ class MTClientConnection(ClientConnection):
self.unlock() self.unlock()
@profiler_decorator @profiler_decorator
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None): def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None,
queue=None):
self.lock() self.lock()
try: try:
# XXX: Here, we duplicate Connection.ask because we need to call # XXX: Here, we duplicate Connection.ask because we need to call
...@@ -731,7 +731,12 @@ class MTClientConnection(ClientConnection): ...@@ -731,7 +731,12 @@ class MTClientConnection(ClientConnection):
# _addPacket is called. # _addPacket is called.
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
self.dispatcher.register(self, msg_id, self._local_var.queue) if queue is None:
if not isinstance(packet, Packets.Ping):
raise TypeError, 'Only Ping packet can be asked ' \
'without a queue, got a %r.' % (packet, )
else:
self.dispatcher.register(self, msg_id, queue)
self._addPacket(packet) self._addPacket(packet)
t = time() t = time()
# If there is no pending request, initialise timeout values. # If there is no pending request, initialise timeout values.
......
...@@ -27,6 +27,7 @@ from neo.connector import ConnectorException, ConnectorTryAgainException, \ ...@@ -27,6 +27,7 @@ from neo.connector import ConnectorException, ConnectorTryAgainException, \
from neo.protocol import Packets, ParserState from neo.protocol import Packets, ParserState
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.util import ReadBuffer from neo.util import ReadBuffer
from neo.locking import Queue
class ConnectionTests(NeoTestBase): class ConnectionTests(NeoTestBase):
...@@ -808,6 +809,30 @@ class ConnectionTests(NeoTestBase): ...@@ -808,6 +809,30 @@ class ConnectionTests(NeoTestBase):
self.assertEqual(bc.aborted, True) self.assertEqual(bc.aborted, True)
self.assertTrue(bc.isServer()) self.assertTrue(bc.isServer())
class MTConnectionTests(ConnectionTests):
# XXX: here we test non-client-connection-related things too, which
# duplicates test suite work... Should be fragmented into finer-grained
# test classes.
def setUp(self):
super(MTConnectionTests, self).setUp()
self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'})
def _makeClientConnection(self):
self.connector = DoNothingConnector()
return MTClientConnection(event_manager=self.em, handler=self.handler,
connector=self.connector, addr=self.address,
dispatcher=self.dispatcher)
def test_MTClientConnectionQueueParameter(self):
queue = Queue()
ask = self._makeClientConnection().ask
packet = Packets.AskPrimary() # Any non-Ping simple "ask" packet
# One cannot "ask" anything without a queue
self.assertRaises(TypeError, ask, packet)
ask(packet, queue=queue)
# ... except Ping
ask(Packets.Ping())
class HandlerSwitcherTests(NeoTestBase): class HandlerSwitcherTests(NeoTestBase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment