Commit db8db123 authored by Grégory Wisniewski's avatar Grégory Wisniewski

MTConnection handle local queue to unify ask() prototype.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1925 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f7707668
...@@ -246,7 +246,7 @@ class Application(object): ...@@ -246,7 +246,7 @@ class Application(object):
def _askStorage(self, conn, packet): def _askStorage(self, conn, packet):
""" Send a request to a storage node and process it's answer """ """ Send a request to a storage node and process it's answer """
try: try:
msg_id = conn.ask(self.local_var.queue, packet) msg_id = conn.ask(packet)
finally: finally:
# assume that the connection was already locked # assume that the connection was already locked
conn.unlock() conn.unlock()
...@@ -258,7 +258,7 @@ class Application(object): ...@@ -258,7 +258,7 @@ class Application(object):
conn = self._getMasterConnection() conn = self._getMasterConnection()
conn.lock() conn.lock()
try: try:
msg_id = conn.ask(self.local_var.queue, packet) msg_id = conn.ask(packet)
finally: finally:
conn.unlock() conn.unlock()
self._waitMessage(conn, msg_id, self.primary_handler) self._waitMessage(conn, msg_id, self.primary_handler)
...@@ -321,7 +321,8 @@ class Application(object): ...@@ -321,7 +321,8 @@ class Application(object):
self.trying_master_node = master_list[0] self.trying_master_node = master_list[0]
index += 1 index += 1
# Connect to master # Connect to master
conn = MTClientConnection(self.em, self.notifications_handler, conn = MTClientConnection(self.local_var, self.em,
self.notifications_handler,
addr=self.trying_master_node.getAddress(), addr=self.trying_master_node.getAddress(),
connector=self.connector_handler(), connector=self.connector_handler(),
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
...@@ -333,8 +334,7 @@ class Application(object): ...@@ -333,8 +334,7 @@ class Application(object):
logging.error('Connection to master node %s failed', logging.error('Connection to master node %s failed',
self.trying_master_node) self.trying_master_node)
continue continue
msg_id = conn.ask(self.local_var.queue, msg_id = conn.ask(Packets.AskPrimary())
Packets.AskPrimary())
finally: finally:
conn.unlock() conn.unlock()
try: try:
...@@ -358,7 +358,7 @@ class Application(object): ...@@ -358,7 +358,7 @@ class Application(object):
break break
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
self.uuid, None, self.name) self.uuid, None, self.name)
msg_id = conn.ask(self.local_var.queue, p) msg_id = conn.ask(p)
finally: finally:
conn.unlock() conn.unlock()
try: try:
...@@ -373,16 +373,14 @@ class Application(object): ...@@ -373,16 +373,14 @@ class Application(object):
if self.uuid is not None: if self.uuid is not None:
conn.lock() conn.lock()
try: try:
msg_id = conn.ask(self.local_var.queue, msg_id = conn.ask(Packets.AskNodeInformation())
Packets.AskNodeInformation())
finally: finally:
conn.unlock() conn.unlock()
self._waitMessage(conn, msg_id, self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler) handler=self.primary_bootstrap_handler)
conn.lock() conn.lock()
try: try:
msg_id = conn.ask(self.local_var.queue, msg_id = conn.ask(Packets.AskPartitionTable([]))
Packets.AskPartitionTable([]))
finally: finally:
conn.unlock() conn.unlock()
self._waitMessage(conn, msg_id, self._waitMessage(conn, msg_id,
...@@ -600,14 +598,13 @@ class Application(object): ...@@ -600,14 +598,13 @@ class Application(object):
# Store data on each node # Store data on each node
self.local_var.object_stored_counter_dict[oid] = 0 self.local_var.object_stored_counter_dict[oid] = 0
self.local_var.object_serial_dict[oid] = (serial, version) self.local_var.object_serial_dict[oid] = (serial, version)
local_queue = self.local_var.queue
for cell in cell_list: for cell in cell_list:
conn = self.cp.getConnForCell(cell) conn = self.cp.getConnForCell(cell)
if conn is None: if conn is None:
continue continue
try: try:
try: try:
conn.ask(local_queue, p) conn.ask(p)
finally: finally:
conn.unlock() conn.unlock()
except ConnectionClosed: except ConnectionClosed:
...@@ -882,8 +879,7 @@ class Application(object): ...@@ -882,8 +879,7 @@ class Application(object):
continue continue
try: try:
conn.ask(self.local_var.queue, Packets.AskTIDs(first, last, conn.ask(Packets.AskTIDs(first, last, INVALID_PARTITION))
INVALID_PARTITION))
finally: finally:
conn.unlock() conn.unlock()
......
...@@ -50,7 +50,8 @@ class ConnectionPool(object): ...@@ -50,7 +50,8 @@ class ConnectionPool(object):
while True: while True:
logging.debug('trying to connect to %s - %s', node, node.getState()) logging.debug('trying to connect to %s - %s', node, node.getState())
app.setNodeReady() app.setNodeReady()
conn = MTClientConnection(app.em, app.storage_event_handler, addr, conn = MTClientConnection(app.local_var, app.em,
app.storage_event_handler, addr,
connector=app.connector_handler(), dispatcher=app.dispatcher) connector=app.connector_handler(), dispatcher=app.dispatcher)
conn.lock() conn.lock()
...@@ -62,7 +63,7 @@ class ConnectionPool(object): ...@@ -62,7 +63,7 @@ class ConnectionPool(object):
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name)
msg_id = conn.ask(app.local_var.queue, p) msg_id = conn.ask(p)
finally: finally:
conn.unlock() conn.unlock()
......
...@@ -565,9 +565,10 @@ class ServerConnection(Connection): ...@@ -565,9 +565,10 @@ class ServerConnection(Connection):
class MTClientConnection(ClientConnection): class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection.""" """A Multithread-safe version of ClientConnection."""
def __init__(self, *args, **kwargs): def __init__(self, local_var, *args, **kwargs):
# _lock is only here for lock debugging purposes. Do not use. # _lock is only here for lock debugging purposes. Do not use.
self._lock = lock = RLock() self._lock = lock = RLock()
self._local_var = local_var
self.acquire = lock.acquire self.acquire = lock.acquire
self.release = lock.release self.release = lock.release
self.dispatcher = kwargs.pop('dispatcher') self.dispatcher = kwargs.pop('dispatcher')
...@@ -600,10 +601,10 @@ class MTClientConnection(ClientConnection): ...@@ -600,10 +601,10 @@ class MTClientConnection(ClientConnection):
return super(MTClientConnection, self).notify(*args, **kw) return super(MTClientConnection, self).notify(*args, **kw)
@lockCheckWrapper @lockCheckWrapper
def ask(self, queue, packet, timeout=CRITICAL_TIMEOUT): def ask(self, packet, timeout=CRITICAL_TIMEOUT):
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
self.dispatcher.register(self, msg_id, queue) self.dispatcher.register(self, msg_id, self._local_var.queue)
self._addPacket(packet) self._addPacket(packet)
if not self._handlers.isPending(): if not self._handlers.isPending():
self._timeout.update(time(), timeout=timeout) self._timeout.update(time(), timeout=timeout)
......
...@@ -76,7 +76,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -76,7 +76,7 @@ class ClientApplicationTests(NeoTestBase):
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
# client connection got queue as first parameter # client connection got queue as first parameter
packet = calls[0].getParam(1) packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), packet_type) self.assertEquals(packet.getType(), packet_type)
if decode: if decode:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment