Commit 96897224 authored by Julien Muchembled's avatar Julien Muchembled

Global connection filtering in threaded tests, and fix id of delayed packets

parent 930af0fb
......@@ -64,12 +64,11 @@ class MasterClientElectionTests(MasterClientElectionTestBase):
self.app.unconnected_master_node_set = set()
self.app.negotiating_master_node_set = set()
# apply monkey patches
self._addPacket = ClientConnection._addPacket
ClientConnection._addPacket = _addPacket
def _tearDown(self, success):
# restore patched methods
ClientConnection._addPacket = self._addPacket
del ClientConnection._addPacket
NeoUnitTestBase._tearDown(self, success)
def _checkUnconnected(self, node):
......@@ -220,13 +219,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.storage_address = (self.local_ip, 2000)
self.master_address = (self.local_ip, 3000)
# apply monkey patches
self._addPacket = ClientConnection._addPacket
ClientConnection._addPacket = _addPacket
def _tearDown(self, success):
NeoUnitTestBase._tearDown(self, success)
# restore environnement
ClientConnection._addPacket = self._addPacket
del ClientConnection._addPacket
def test_requestIdentification1(self):
""" A non-master node request identification """
......
......@@ -19,6 +19,7 @@
import os, random, socket, sys, tempfile, threading, time, types, weakref
import traceback
from collections import deque
from contextlib import contextmanager
from itertools import count
from functools import wraps
from zlib import decompress
......@@ -163,15 +164,17 @@ class SerializedEventManager(EventManager):
class Node(object):
def filterConnection(self, *peers):
def getConnectionList(self, *peers):
addr = lambda c: c and (c.accepted_from or c.getAddress())
addr_set = set(addr(c.connector) for peer in peers
for c in peer.em.connection_dict.itervalues()
if isinstance(c, Connection))
addr_set.discard(None)
conn_list = (c for c in self.em.connection_dict.itervalues()
return (c for c in self.em.connection_dict.itervalues()
if isinstance(c, Connection) and addr(c.connector) in addr_set)
return ConnectionFilter(*conn_list)
def filterConnection(self, *peers):
return ConnectionFilter(self.getConnectionList(*peers))
class ServerNode(Node):
......@@ -334,16 +337,14 @@ class ClientApplication(Node, neo.client.app.Application):
Serialized.background()
close = __del__
def filterConnection(self, *peers):
conn_list = []
def getConnectionList(self, *peers):
for peer in peers:
if isinstance(peer, MasterApplication):
conn = self._getMasterConnection()
else:
assert isinstance(peer, StorageApplication)
conn = self.cp.getConnForNode(self.nm.getByUUID(peer.uuid))
conn_list.append(conn)
return ConnectionFilter(*conn_list)
yield conn
class NeoCTL(neo.neoctl.app.NeoCTL):
......@@ -391,59 +392,65 @@ class Patch(object):
class ConnectionFilter(object):
filtered_count = 0
def __init__(self, *conns):
filter_list = []
filter_queue = weakref.WeakKeyDictionary()
lock = threading.Lock()
_addPacket = Connection._addPacket
@contextmanager
def __new__(cls, conn_list=()):
self = object.__new__(cls)
self.filter_dict = {}
self.lock = threading.Lock()
self.conn_list = [(conn, self._patch(conn)) for conn in conns]
def _patch(self, conn):
assert '_addPacket' not in conn.__dict__, "already patched"
lock = self.lock
filter_dict = self.filter_dict
orig = conn.__class__._addPacket
queue = deque()
def _addPacket(packet):
lock.acquire()
try:
if not queue:
for filter in filter_dict:
if filter(conn, packet):
self.filtered_count += 1
break
else:
return orig(conn, packet)
queue.append(packet)
finally:
lock.release()
conn._addPacket = _addPacket
return queue
self.conn_list = frozenset(conn_list)
if not cls.filter_list:
def _addPacket(conn, packet):
with cls.lock:
try:
queue = cls.filter_queue[conn]
except KeyError:
for self in cls.filter_list:
if self(conn, packet):
self.filtered_count += 1
break
else:
return cls._addPacket(conn, packet)
cls.filter_queue[conn] = queue = deque()
p = packet.__new__(packet.__class__)
p.__dict__.update(packet.__dict__)
queue.append(p)
Connection._addPacket = _addPacket
try:
cls.filter_list.append(self)
yield self
finally:
del cls.filter_list[-1:]
if not cls.filter_list:
Connection._addPacket = cls._addPacket.im_func
with cls.lock:
cls._retry()
def __call__(self, conn, packet):
if not self.conn_list or conn in self.conn_list:
for filter in self.filter_dict:
if filter(conn, packet):
return True
return False
def __call__(self, revert=1):
with self.lock:
self.filter_dict.clear()
self._retry()
if revert:
for conn, queue in self.conn_list:
assert not queue
del conn._addPacket
del self.conn_list[:]
def _retry(self):
for conn, queue in self.conn_list:
@classmethod
def _retry(cls):
for conn, queue in cls.filter_queue.items():
while queue:
packet = queue.popleft()
for filter in self.filter_dict:
if filter(conn, packet):
for self in cls.filter_list:
if self(conn, packet):
queue.appendleft(packet)
break
else:
conn.__class__._addPacket(conn, packet)
cls._addPacket(conn, packet)
continue
break
def clear(self):
self(0)
else:
del cls.filter_queue[conn]
def add(self, filter, *patches):
with self.lock:
......
......@@ -141,8 +141,8 @@ class Test(NEOThreadedTest):
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreObject(orig, tm, ttid, serial, oid, *args):
if oid == resume_oid and delayUnlockInformation in master_storage:
master_storage.remove(delayUnlockInformation)
if oid == resume_oid and delayUnlockInformation in m2s:
m2s.remove(delayUnlockInformation)
try:
return orig(tm, ttid, serial, oid, *args)
except Exception, e:
......@@ -153,18 +153,15 @@ class Test(NEOThreadedTest):
cluster.start()
t, c = cluster.getTransaction()
c.root()[0] = ob = PCounter()
master_storage = cluster.master.filterConnection(cluster.storage)
try:
with cluster.master.filterConnection(cluster.storage) as m2s:
resume_oid = None
master_storage.add(delayUnlockInformation,
m2s.add(delayUnlockInformation,
Patch(TransactionManager, storeObject=onStoreObject))
t.commit()
resume_oid = ob._p_oid
ob._p_changed = 1
t.commit()
self.assertFalse(delayUnlockInformation in master_storage)
finally:
master_storage()
self.assertFalse(delayUnlockInformation in m2s)
finally:
cluster.stop()
self.assertEqual(except_list, [DelayedError])
......@@ -561,9 +558,8 @@ class Test(NEOThreadedTest):
client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x
master_client = cluster.master.filterConnection(cluster.client)
try:
master_client.add(lambda conn, packet:
with cluster.master.filterConnection(cluster.client) as m2c:
m2c.add(lambda conn, packet:
isinstance(packet, Packets.InvalidateObjects))
tid = client.tpc_finish(txn, None)
client.setPoll(0)
......@@ -574,8 +570,6 @@ class Test(NEOThreadedTest):
x2 = c2.root()['x']
cache.clear() # bypass cache
self.assertEqual(x2.value, 0)
finally:
master_client()
x2._p_deactivate()
t1.begin() # process invalidation and sync connection storage
self.assertEqual(x2.value, 0)
......
......@@ -191,16 +191,13 @@ class ReplicationTests(NEOThreadedTest):
# and node 1 must switch to node 2
pt: 0: UU.|U.U|.UU
"""
def connected(orig, *args, **kw):
patch[0] = s1.filterConnection(s0)
patch[0].add(delayAskFetch,
Patch(s0.dm, changePartitionTable=changePartitionTable))
return orig(*args, **kw)
def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and packet.decode()[0] == offset
return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \
conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
patch[0].remove(delayAskFetch)
connection_filter.remove(delayAskFetch)
# XXX: this is currently not done by
# default for performance reason
orig.im_self.dropPartitions((offset,))
......@@ -221,13 +218,11 @@ class ReplicationTests(NEOThreadedTest):
offset, = [offset for offset, row in enumerate(
cluster.master.pt.partition_list)
for cell in row if cell.isFeeding()]
patch = [Patch(s1.replicator, fetchTransactions=connected)]
try:
with ConnectionFilter() as connection_filter:
connection_filter.add(delayAskFetch,
Patch(s0.dm, changePartitionTable=changePartitionTable))
cluster.tic()
self.assertEqual(1, patch[0].filtered_count)
patch[0]()
finally:
del patch[:]
self.assertEqual(1, connection_filter.filtered_count)
cluster.tic()
self.checkPartitionReplicated(s1, s2, offset)
finally:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment