Commit 4e739de4 authored by Julien Muchembled's avatar Julien Muchembled

client: review connection locking (MTClientConnection)

This mainly changes several methods to lock automatically instead of asserting
that the caller did it. This removes any overhead for non-MT classes, and
the use of 'with' instead of lock/unlock methods also simplifies the API.
parent e438f864
......@@ -175,11 +175,8 @@ class Application(object):
handler = self.primary_handler
else:
raise ValueError, 'Unknown node type: %r' % (node.__class__, )
conn.lock()
try:
with conn.lock:
handler.dispatch(conn, packet, kw)
finally:
conn.unlock()
def _waitAnyMessage(self, queue, block=True):
"""
......
......@@ -26,7 +26,7 @@ class BaseHandler(EventHandler):
self.dispatcher = app.dispatcher
def dispatch(self, conn, packet, kw={}):
assert conn._lock._is_owned()
assert conn.lock._is_owned() # XXX: see also lockCheckWrapper
super(BaseHandler, self).dispatch(conn, packet, kw)
def packetReceived(self, conn, packet, kw={}):
......
......@@ -72,8 +72,7 @@ class ConnectionPool(object):
"""Drop connections."""
for conn in self.connection_dict.values():
# Drop first connection which looks not used
conn.lock()
try:
with conn.lock:
if not conn.pending() and \
not self.app.dispatcher.registered(conn):
del self.connection_dict[conn.getUUID()]
......@@ -82,8 +81,6 @@ class ConnectionPool(object):
'storage node %s:%d closed', *conn.getAddress())
if len(self.connection_dict) <= self.max_pool_size:
break
finally:
conn.unlock()
def notifyFailure(self, node):
self.node_failure_dict[node.getUUID()] = time.time() + MAX_FAILURE_AGE
......
......@@ -41,28 +41,6 @@ def not_closed(func):
return wraps(func)(decorator)
def lockCheckWrapper(func):
"""
This function is to be used as a wrapper around
MT(Client|Server)Connection class methods.
It uses a "_" method on RLock class, so it might stop working without
notice (sadly, RLock does not offer any "acquired" method, but that one
will do as it checks that current thread holds this lock).
It requires moniroted class to have an RLock instance in self._lock
property.
"""
def wrapper(self, *args, **kw):
if not self._lock._is_owned():
import traceback
logging.warning('%s called on %s instance without being locked.'
' Stack:\n%s', func.func_code.co_name,
self.__class__.__name__, ''.join(traceback.format_stack()))
# Call anyway
return func(self, *args, **kw)
return wraps(func)(wrapper)
class HandlerSwitcher(object):
_next_timeout = None
_next_timeout_msg_id = None
......@@ -250,11 +228,8 @@ class BaseConnection(object):
def checkTimeout(self, t):
pass
def lock(self):
return 1
def unlock(self):
return None
def lockWrapper(self, func):
return func
def getConnector(self):
return self.connector
......@@ -495,6 +470,7 @@ class Connection(BaseConnection):
self.analyse()
if self.aborted:
self.em.removeReader(self)
return not not self._queue
def analyse(self):
"""Analyse received data."""
......@@ -562,8 +538,8 @@ class Connection(BaseConnection):
global connect_limit
t = time()
if t < connect_limit:
self.checkTimeout = lambda t: t < connect_limit or \
self._delayed_closure()
self.checkTimeout = self.lockWrapper(lambda t:
t < connect_limit or self._delayed_closure())
self.readable = self.writable = lambda: None
else:
connect_limit = t + 1
......@@ -707,7 +683,7 @@ class ClientConnection(Connection):
self.writable()
def _connectionCompleted(self):
self.writable = super(ClientConnection, self).writable
self.writable = self.lockWrapper(super(ClientConnection, self).writable)
self.connecting = False
self.updateTimeout(time())
self.getHandler().connectionCompleted(self)
......@@ -729,36 +705,53 @@ class ServerConnection(Connection):
self.updateTimeout(time())
class MTConnectionType(type):
def __init__(cls, *args):
if __debug__:
for name in 'analyse', 'answer':
setattr(cls, name, cls.lockCheckWrapper(name))
for name in ('close', 'checkTimeout', 'notify',
'process', 'readable', 'writable'):
setattr(cls, name, cls.__class__.lockWrapper(cls, name))
def lockCheckWrapper(cls, name):
def wrapper(self, *args, **kw):
# XXX: Unfortunately, RLock does not has any public method
# to test whether we own the lock or not.
assert self.lock._is_owned(), (self, args, kw)
return getattr(super(cls, self), name)(*args, **kw)
return wraps(getattr(cls, name).im_func)(wrapper)
def lockWrapper(cls, name):
def wrapper(self, *args, **kw):
with self.lock:
return getattr(super(cls, self), name)(*args, **kw)
return wraps(getattr(cls, name).im_func)(wrapper)
class MTClientConnection(ClientConnection):
"""A Multithread-safe version of ClientConnection."""
def __metaclass__(name, base, d):
for k in ('analyse', 'answer', 'checkTimeout',
'process', 'readable', 'writable'):
d[k] = lockCheckWrapper(getattr(base[0], k).im_func)
return type(name, base, d)
__metaclass__ = MTConnectionType
def lockWrapper(self, func):
lock = self.lock
def wrapper(*args, **kw):
with lock:
return func(*args, **kw)
return wrapper
def __init__(self, *args, **kwargs):
# _lock is only here for lock debugging purposes. Do not use.
self._lock = lock = RLock()
self.lock = lock.acquire
self.unlock = lock.release
self.lock = lock = RLock()
self.dispatcher = kwargs.pop('dispatcher')
self.dispatcher.needPollThread()
with lock:
super(MTClientConnection, self).__init__(*args, **kwargs)
def notify(self, *args, **kw):
self.lock()
try:
return super(MTClientConnection, self).notify(*args, **kw)
finally:
self.unlock()
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None,
queue=None, **kw):
self.lock()
try:
with self.lock:
if self.isClosed():
raise ConnectionClosed
# XXX: Here, we duplicate Connection.ask because we need to call
......@@ -778,12 +771,3 @@ class MTClientConnection(ClientConnection):
handlers.emit(packet, timeout, on_timeout, kw)
self.updateTimeout(t)
return msg_id
finally:
self.unlock()
def close(self):
self.lock()
try:
super(MTClientConnection, self).close()
finally:
self.unlock()
......@@ -92,16 +92,12 @@ class EpollEventManager(object):
if not self._pending_processing:
return
to_process = self._pending_processing.pop(0)
to_process.lock()
try:
try:
to_process.process()
finally:
# ...and requeue if there are pending messages
if to_process.hasPendingMessages():
self._addPendingConnection(to_process)
to_process.process()
finally:
to_process.unlock()
# ...and requeue if there are pending messages
if to_process.hasPendingMessages():
self._addPendingConnection(to_process)
# Non-blocking call: as we handled a packet, we should just offer
# poll a chance to fetch & send already-available data, but it must
# not delay us.
......@@ -122,12 +118,7 @@ class EpollEventManager(object):
for fd, event in event_list:
if event & EPOLLIN:
conn = self.connection_dict[fd]
conn.lock()
try:
conn.readable()
finally:
conn.unlock()
if conn.hasPendingMessages():
if conn.readable():
self._addPendingConnection(conn)
if event & EPOLLOUT:
wlist.append(fd)
......@@ -140,11 +131,7 @@ class EpollEventManager(object):
conn = self.connection_dict[fd]
except KeyError:
continue
conn.lock()
try:
conn.writable()
finally:
conn.unlock()
conn.writable()
for fd in elist:
# This can fail, if a connection is closed in previous calls to
......@@ -153,21 +140,12 @@ class EpollEventManager(object):
conn = self.connection_dict[fd]
except KeyError:
continue
conn.lock()
try:
conn.readable()
finally:
conn.unlock()
if conn.hasPendingMessages():
if conn.readable():
self._addPendingConnection(conn)
t = time()
for conn in self.connection_dict.values():
conn.lock()
try:
conn.checkTimeout(t)
finally:
conn.unlock()
conn.checkTimeout(t)
def addReader(self, conn):
connector = conn.getConnector()
......
......@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import threading
import unittest
from cPickle import dumps
from mock import Mock, ReturnValues
......@@ -44,6 +45,11 @@ def _getMasterConnection(self):
self.master_conn = Mock()
return self.master_conn
def getConnection(kw):
conn = Mock(kw)
conn.lock = threading.RLock()
return conn
def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None)
conn.ask(packet, **kw)
......@@ -110,7 +116,7 @@ class ClientApplicationTests(NeoUnitTestBase):
makeTID = makeOID
def getNodeCellConn(self, index=1, address=('127.0.0.1', 10000), uuid=None):
conn = Mock({
conn = getConnection({
'getAddress': address,
'__repr__': 'connection mock',
'getUUID': uuid,
......@@ -167,8 +173,7 @@ class ClientApplicationTests(NeoUnitTestBase):
response_packet = Packets.AnswerNewOIDs(test_oid_list[:])
response_packet.setId(0)
app.master_conn = Mock({'getNextId': test_msg_id, '_addPacket': None,
'expectMessage': None, 'lock': None,
'unlock': None,
'expectMessage': None,
# Test-specific method
'fakeReceived': response_packet})
new_oid = app.new_oid()
......@@ -434,12 +439,12 @@ class ClientApplicationTests(NeoUnitTestBase):
packet2 = Packets.AnswerStoreObject(conflicting=1, oid=oid1, serial=tid)
packet3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid)
[p.setId(i) for p, i in zip([packet1, packet2, packet3], range(3))]
conn1 = Mock({'__repr__': 'conn1', 'getAddress': address1,
'fakeReceived': packet1, 'getUUID': uuid1})
conn2 = Mock({'__repr__': 'conn2', 'getAddress': address2,
'fakeReceived': packet2, 'getUUID': uuid2})
conn3 = Mock({'__repr__': 'conn3', 'getAddress': address3,
'fakeReceived': packet3, 'getUUID': uuid3})
conn1 = getConnection({'__repr__': 'conn1', 'getAddress': address1,
'fakeReceived': packet1, 'getUUID': uuid1})
conn2 = getConnection({'__repr__': 'conn2', 'getAddress': address2,
'fakeReceived': packet2, 'getUUID': uuid2})
conn3 = getConnection({'__repr__': 'conn3', 'getAddress': address3,
'fakeReceived': packet3, 'getUUID': uuid3})
node1 = Mock({'__repr__': 'node1', '__hash__': 1, 'getConnection': conn1})
node2 = Mock({'__repr__': 'node2', '__hash__': 2, 'getConnection': conn2})
node3 = Mock({'__repr__': 'node3', '__hash__': 3, 'getConnection': conn3})
......@@ -520,7 +525,7 @@ class ClientApplicationTests(NeoUnitTestBase):
transaction_info = Packets.AnswerTransactionInformation(tid1, '', '',
'', False, (oid0, ))
transaction_info.setId(1)
conn = Mock({
conn = getConnection({
'getNextId': 1,
'fakeReceived': transaction_info,
'getAddress': ('127.0.0.1', 10020),
......@@ -706,7 +711,7 @@ class ClientApplicationTests(NeoUnitTestBase):
})
asked = []
def answerTIDs(packet):
conn = Mock({'getAddress': packet})
conn = getConnection({'getAddress': packet})
app.nm.createStorage(address=conn.getAddress())
def ask(p, queue, **kw):
asked.append(p)
......
......@@ -587,7 +587,7 @@ class ConnectionTests(NeoUnitTestBase):
DoNothingConnector.receive = receive
try:
bc = self._makeConnection()
bc._queue = Mock()
bc._queue = Mock({'__len__': 0})
self._checkReadBuf(bc, '')
self.assertFalse(bc.aborted)
bc.readable()
......
......@@ -104,13 +104,9 @@ class EventTests(NeoUnitTestBase):
self.assertEqual(data, 10)
# need to rebuild completely this test and the the packet queue
# check readable conn
#self.assertEqual(len(r_conn.mockGetNamedCalls("lock")), 1)
#self.assertEqual(len(r_conn.mockGetNamedCalls("unlock")), 1)
#self.assertEqual(len(r_conn.mockGetNamedCalls("readable")), 1)
#self.assertEqual(len(r_conn.mockGetNamedCalls("writable")), 0)
# check writable conn
#self.assertEqual(len(w_conn.mockGetNamedCalls("lock")), 1)
#self.assertEqual(len(w_conn.mockGetNamedCalls("unlock")), 1)
#self.assertEqual(len(w_conn.mockGetNamedCalls("readable")), 0)
#self.assertEqual(len(w_conn.mockGetNamedCalls("writable")), 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment