Commit 96897224 authored by Julien Muchembled's avatar Julien Muchembled

Global connection filtering in threaded tests, and fix id of delayed packets

parent 930af0fb
...@@ -64,12 +64,11 @@ class MasterClientElectionTests(MasterClientElectionTestBase): ...@@ -64,12 +64,11 @@ class MasterClientElectionTests(MasterClientElectionTestBase):
self.app.unconnected_master_node_set = set() self.app.unconnected_master_node_set = set()
self.app.negotiating_master_node_set = set() self.app.negotiating_master_node_set = set()
# apply monkey patches # apply monkey patches
self._addPacket = ClientConnection._addPacket
ClientConnection._addPacket = _addPacket ClientConnection._addPacket = _addPacket
def _tearDown(self, success): def _tearDown(self, success):
# restore patched methods # restore patched methods
ClientConnection._addPacket = self._addPacket del ClientConnection._addPacket
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
def _checkUnconnected(self, node): def _checkUnconnected(self, node):
...@@ -220,13 +219,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -220,13 +219,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.storage_address = (self.local_ip, 2000) self.storage_address = (self.local_ip, 2000)
self.master_address = (self.local_ip, 3000) self.master_address = (self.local_ip, 3000)
# apply monkey patches # apply monkey patches
self._addPacket = ClientConnection._addPacket
ClientConnection._addPacket = _addPacket ClientConnection._addPacket = _addPacket
def _tearDown(self, success): def _tearDown(self, success):
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
# restore environnement # restore environnement
ClientConnection._addPacket = self._addPacket del ClientConnection._addPacket
def test_requestIdentification1(self): def test_requestIdentification1(self):
""" A non-master node request identification """ """ A non-master node request identification """
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import os, random, socket, sys, tempfile, threading, time, types, weakref import os, random, socket, sys, tempfile, threading, time, types, weakref
import traceback import traceback
from collections import deque from collections import deque
from contextlib import contextmanager
from itertools import count from itertools import count
from functools import wraps from functools import wraps
from zlib import decompress from zlib import decompress
...@@ -163,15 +164,17 @@ class SerializedEventManager(EventManager): ...@@ -163,15 +164,17 @@ class SerializedEventManager(EventManager):
class Node(object): class Node(object):
def filterConnection(self, *peers): def getConnectionList(self, *peers):
addr = lambda c: c and (c.accepted_from or c.getAddress()) addr = lambda c: c and (c.accepted_from or c.getAddress())
addr_set = set(addr(c.connector) for peer in peers addr_set = set(addr(c.connector) for peer in peers
for c in peer.em.connection_dict.itervalues() for c in peer.em.connection_dict.itervalues()
if isinstance(c, Connection)) if isinstance(c, Connection))
addr_set.discard(None) addr_set.discard(None)
conn_list = (c for c in self.em.connection_dict.itervalues() return (c for c in self.em.connection_dict.itervalues()
if isinstance(c, Connection) and addr(c.connector) in addr_set) if isinstance(c, Connection) and addr(c.connector) in addr_set)
return ConnectionFilter(*conn_list)
def filterConnection(self, *peers):
return ConnectionFilter(self.getConnectionList(*peers))
class ServerNode(Node): class ServerNode(Node):
...@@ -334,16 +337,14 @@ class ClientApplication(Node, neo.client.app.Application): ...@@ -334,16 +337,14 @@ class ClientApplication(Node, neo.client.app.Application):
Serialized.background() Serialized.background()
close = __del__ close = __del__
def filterConnection(self, *peers): def getConnectionList(self, *peers):
conn_list = []
for peer in peers: for peer in peers:
if isinstance(peer, MasterApplication): if isinstance(peer, MasterApplication):
conn = self._getMasterConnection() conn = self._getMasterConnection()
else: else:
assert isinstance(peer, StorageApplication) assert isinstance(peer, StorageApplication)
conn = self.cp.getConnForNode(self.nm.getByUUID(peer.uuid)) conn = self.cp.getConnForNode(self.nm.getByUUID(peer.uuid))
conn_list.append(conn) yield conn
return ConnectionFilter(*conn_list)
class NeoCTL(neo.neoctl.app.NeoCTL): class NeoCTL(neo.neoctl.app.NeoCTL):
...@@ -391,59 +392,65 @@ class Patch(object): ...@@ -391,59 +392,65 @@ class Patch(object):
class ConnectionFilter(object): class ConnectionFilter(object):
filtered_count = 0 filtered_count = 0
filter_list = []
def __init__(self, *conns): filter_queue = weakref.WeakKeyDictionary()
lock = threading.Lock()
_addPacket = Connection._addPacket
@contextmanager
def __new__(cls, conn_list=()):
self = object.__new__(cls)
self.filter_dict = {} self.filter_dict = {}
self.lock = threading.Lock() self.conn_list = frozenset(conn_list)
self.conn_list = [(conn, self._patch(conn)) for conn in conns] if not cls.filter_list:
def _addPacket(conn, packet):
def _patch(self, conn): with cls.lock:
assert '_addPacket' not in conn.__dict__, "already patched" try:
lock = self.lock queue = cls.filter_queue[conn]
filter_dict = self.filter_dict except KeyError:
orig = conn.__class__._addPacket for self in cls.filter_list:
queue = deque() if self(conn, packet):
def _addPacket(packet): self.filtered_count += 1
lock.acquire() break
try: else:
if not queue: return cls._addPacket(conn, packet)
for filter in filter_dict: cls.filter_queue[conn] = queue = deque()
if filter(conn, packet): p = packet.__new__(packet.__class__)
self.filtered_count += 1 p.__dict__.update(packet.__dict__)
break queue.append(p)
else: Connection._addPacket = _addPacket
return orig(conn, packet) try:
queue.append(packet) cls.filter_list.append(self)
finally: yield self
lock.release() finally:
conn._addPacket = _addPacket del cls.filter_list[-1:]
return queue if not cls.filter_list:
Connection._addPacket = cls._addPacket.im_func
with cls.lock:
cls._retry()
def __call__(self, conn, packet):
if not self.conn_list or conn in self.conn_list:
for filter in self.filter_dict:
if filter(conn, packet):
return True
return False
def __call__(self, revert=1): @classmethod
with self.lock: def _retry(cls):
self.filter_dict.clear() for conn, queue in cls.filter_queue.items():
self._retry()
if revert:
for conn, queue in self.conn_list:
assert not queue
del conn._addPacket
del self.conn_list[:]
def _retry(self):
for conn, queue in self.conn_list:
while queue: while queue:
packet = queue.popleft() packet = queue.popleft()
for filter in self.filter_dict: for self in cls.filter_list:
if filter(conn, packet): if self(conn, packet):
queue.appendleft(packet) queue.appendleft(packet)
break break
else: else:
conn.__class__._addPacket(conn, packet) cls._addPacket(conn, packet)
continue continue
break break
else:
def clear(self): del cls.filter_queue[conn]
self(0)
def add(self, filter, *patches): def add(self, filter, *patches):
with self.lock: with self.lock:
......
...@@ -141,8 +141,8 @@ class Test(NEOThreadedTest): ...@@ -141,8 +141,8 @@ class Test(NEOThreadedTest):
def delayUnlockInformation(conn, packet): def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation) return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreObject(orig, tm, ttid, serial, oid, *args): def onStoreObject(orig, tm, ttid, serial, oid, *args):
if oid == resume_oid and delayUnlockInformation in master_storage: if oid == resume_oid and delayUnlockInformation in m2s:
master_storage.remove(delayUnlockInformation) m2s.remove(delayUnlockInformation)
try: try:
return orig(tm, ttid, serial, oid, *args) return orig(tm, ttid, serial, oid, *args)
except Exception, e: except Exception, e:
...@@ -153,18 +153,15 @@ class Test(NEOThreadedTest): ...@@ -153,18 +153,15 @@ class Test(NEOThreadedTest):
cluster.start() cluster.start()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
c.root()[0] = ob = PCounter() c.root()[0] = ob = PCounter()
master_storage = cluster.master.filterConnection(cluster.storage) with cluster.master.filterConnection(cluster.storage) as m2s:
try:
resume_oid = None resume_oid = None
master_storage.add(delayUnlockInformation, m2s.add(delayUnlockInformation,
Patch(TransactionManager, storeObject=onStoreObject)) Patch(TransactionManager, storeObject=onStoreObject))
t.commit() t.commit()
resume_oid = ob._p_oid resume_oid = ob._p_oid
ob._p_changed = 1 ob._p_changed = 1
t.commit() t.commit()
self.assertFalse(delayUnlockInformation in master_storage) self.assertFalse(delayUnlockInformation in m2s)
finally:
master_storage()
finally: finally:
cluster.stop() cluster.stop()
self.assertEqual(except_list, [DelayedError]) self.assertEqual(except_list, [DelayedError])
...@@ -561,9 +558,8 @@ class Test(NEOThreadedTest): ...@@ -561,9 +558,8 @@ class Test(NEOThreadedTest):
client.tpc_begin(txn) client.tpc_begin(txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x # Delay invalidation for x
master_client = cluster.master.filterConnection(cluster.client) with cluster.master.filterConnection(cluster.client) as m2c:
try: m2c.add(lambda conn, packet:
master_client.add(lambda conn, packet:
isinstance(packet, Packets.InvalidateObjects)) isinstance(packet, Packets.InvalidateObjects))
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn, None)
client.setPoll(0) client.setPoll(0)
...@@ -574,8 +570,6 @@ class Test(NEOThreadedTest): ...@@ -574,8 +570,6 @@ class Test(NEOThreadedTest):
x2 = c2.root()['x'] x2 = c2.root()['x']
cache.clear() # bypass cache cache.clear() # bypass cache
self.assertEqual(x2.value, 0) self.assertEqual(x2.value, 0)
finally:
master_client()
x2._p_deactivate() x2._p_deactivate()
t1.begin() # process invalidation and sync connection storage t1.begin() # process invalidation and sync connection storage
self.assertEqual(x2.value, 0) self.assertEqual(x2.value, 0)
......
...@@ -191,16 +191,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -191,16 +191,13 @@ class ReplicationTests(NEOThreadedTest):
# and node 1 must switch to node 2 # and node 1 must switch to node 2
pt: 0: UU.|U.U|.UU pt: 0: UU.|U.U|.UU
""" """
def connected(orig, *args, **kw):
patch[0] = s1.filterConnection(s0)
patch[0].add(delayAskFetch,
Patch(s0.dm, changePartitionTable=changePartitionTable))
return orig(*args, **kw)
def delayAskFetch(conn, packet): def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and packet.decode()[0] == offset return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \
conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, cell_list): def changePartitionTable(orig, ptid, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
patch[0].remove(delayAskFetch) connection_filter.remove(delayAskFetch)
# XXX: this is currently not done by # XXX: this is currently not done by
# default for performance reason # default for performance reason
orig.im_self.dropPartitions((offset,)) orig.im_self.dropPartitions((offset,))
...@@ -221,13 +218,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -221,13 +218,11 @@ class ReplicationTests(NEOThreadedTest):
offset, = [offset for offset, row in enumerate( offset, = [offset for offset, row in enumerate(
cluster.master.pt.partition_list) cluster.master.pt.partition_list)
for cell in row if cell.isFeeding()] for cell in row if cell.isFeeding()]
patch = [Patch(s1.replicator, fetchTransactions=connected)] with ConnectionFilter() as connection_filter:
try: connection_filter.add(delayAskFetch,
Patch(s0.dm, changePartitionTable=changePartitionTable))
cluster.tic() cluster.tic()
self.assertEqual(1, patch[0].filtered_count) self.assertEqual(1, connection_filter.filtered_count)
patch[0]()
finally:
del patch[:]
cluster.tic() cluster.tic()
self.checkPartitionReplicated(s1, s2, offset) self.checkPartitionReplicated(s1, s2, offset)
finally: finally:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment