Commit 1d738521 authored by Julien Muchembled's avatar Julien Muchembled

qa: add threaded test helper to filter connection by packet type

parent 7dc8d4db
...@@ -22,7 +22,7 @@ from collections import deque ...@@ -22,7 +22,7 @@ from collections import deque
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from contextlib import contextmanager from contextlib import contextmanager
from itertools import count from itertools import count
from functools import wraps from functools import partial, wraps
from thread import get_ident from thread import get_ident
from zlib import decompress from zlib import decompress
from mock import Mock from mock import Mock
...@@ -36,7 +36,7 @@ from neo.lib.connection import BaseConnection, \ ...@@ -36,7 +36,7 @@ from neo.lib.connection import BaseConnection, \
from neo.lib.connector import SocketConnector, ConnectorException from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER
...@@ -553,6 +553,22 @@ class ConnectionFilter(object): ...@@ -553,6 +553,22 @@ class ConnectionFilter(object):
def __contains__(self, filter): def __contains__(self, filter):
return filter in self.filter_dict return filter in self.filter_dict
def byPacket(self, packet_type, *args):
patches = []
other = []
for x in args:
(patches if isinstance(x, Patch) else other).append(x)
def delay(conn, packet):
return isinstance(packet, packet_type) and False not in (
callback(conn) for callback in other)
self.add(delay, *patches)
return delay
def __getattr__(self, attr):
if attr.startswith('delay'):
return partial(self.byPacket, getattr(Packets, attr[5:]))
return self.__getattribute__(attr)
class NEOCluster(object): class NEOCluster(object):
SSL = None SSL = None
......
...@@ -248,8 +248,6 @@ class Test(NEOThreadedTest): ...@@ -248,8 +248,6 @@ class Test(NEOThreadedTest):
def testDelayedUnlockInformation(self): def testDelayedUnlockInformation(self):
except_list = [] except_list = []
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreObject(orig, tm, ttid, serial, oid, *args): def onStoreObject(orig, tm, ttid, serial, oid, *args):
if oid == resume_oid and delayUnlockInformation in m2s: if oid == resume_oid and delayUnlockInformation in m2s:
m2s.remove(delayUnlockInformation) m2s.remove(delayUnlockInformation)
...@@ -265,13 +263,13 @@ class Test(NEOThreadedTest): ...@@ -265,13 +263,13 @@ class Test(NEOThreadedTest):
c.root()[0] = ob = PCounter() c.root()[0] = ob = PCounter()
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.master.filterConnection(cluster.storage) as m2s:
resume_oid = None resume_oid = None
m2s.add(delayUnlockInformation, delayUnlockInformation = m2s.delayNotifyUnlockInformation(
Patch(TransactionManager, storeObject=onStoreObject)) Patch(TransactionManager, storeObject=onStoreObject))
t.commit() t.commit()
resume_oid = ob._p_oid resume_oid = ob._p_oid
ob._p_changed = 1 ob._p_changed = 1
t.commit() t.commit()
self.assertFalse(delayUnlockInformation in m2s) self.assertNotIn(delayUnlockInformation, m2s)
finally: finally:
cluster.stop() cluster.stop()
self.assertEqual(except_list, [DelayedError]) self.assertEqual(except_list, [DelayedError])
...@@ -451,8 +449,7 @@ class Test(NEOThreadedTest): ...@@ -451,8 +449,7 @@ class Test(NEOThreadedTest):
r[''] = '' r[''] = ''
with Patch(ClientOperationHandler, askObject=askObject): with Patch(ClientOperationHandler, askObject=askObject):
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.master.filterConnection(cluster.storage) as m2s:
m2s.add(lambda conn, packet: # delay unlock m2s.delayNotifyUnlockInformation()
isinstance(packet, Packets.NotifyUnlockInformation))
t.commit() t.commit()
c.cacheMinimize() c.cacheMinimize()
cluster.client._cache.clear() cluster.client._cache.clear()
...@@ -524,8 +521,7 @@ class Test(NEOThreadedTest): ...@@ -524,8 +521,7 @@ class Test(NEOThreadedTest):
orig() orig()
def stop(): def stop():
with cluster.master.filterConnection(s0) as m2s0: with cluster.master.filterConnection(s0) as m2s0:
m2s0.add(lambda conn, packet: m2s0.delayNotifyPartitionChanges()
isinstance(packet, Packets.NotifyPartitionChanges))
s1.stop() s1.stop()
cluster.join((s1,)) cluster.join((s1,))
self.assertEqual(getClusterState(), ClusterStates.RUNNING) self.assertEqual(getClusterState(), ClusterStates.RUNNING)
...@@ -566,8 +562,6 @@ class Test(NEOThreadedTest): ...@@ -566,8 +562,6 @@ class Test(NEOThreadedTest):
def testVerificationCommitUnfinishedTransactions(self): def testVerificationCommitUnfinishedTransactions(self):
""" Verification step should commit locked transactions """ """ Verification step should commit locked transactions """
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onLockTransaction(storage, die=False): def onLockTransaction(storage, die=False):
def lock(orig, *args, **kw): def lock(orig, *args, **kw):
if die: if die:
...@@ -608,7 +602,7 @@ class Test(NEOThreadedTest): ...@@ -608,7 +602,7 @@ class Test(NEOThreadedTest):
self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3)) self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3))
r[2] = 'ok' r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s: with cluster.master.filterConnection(s0) as m2s:
m2s.add(delayUnlockInformation) m2s.delayNotifyUnlockInformation()
t.commit() t.commit()
x.value = 1 x.value = 1
# s0 will accept to store y (because it's not locked) but will # s0 will accept to store y (because it's not locked) but will
...@@ -916,8 +910,7 @@ class Test(NEOThreadedTest): ...@@ -916,8 +910,7 @@ class Test(NEOThreadedTest):
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x # Delay invalidation for x
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.add(lambda conn, packet: m2c.delayInvalidateObjects()
isinstance(packet, Packets.InvalidateObjects))
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn, None)
# Change to x is committed. Testing connection must ask the # Change to x is committed. Testing connection must ask the
# storage node to return original value of x, even if we # storage node to return original value of x, even if we
...@@ -1164,8 +1157,7 @@ class Test(NEOThreadedTest): ...@@ -1164,8 +1157,7 @@ class Test(NEOThreadedTest):
cluster.master.filterConnection(cluster.storage) as m2s: cluster.master.filterConnection(cluster.storage) as m2s:
s2m.add(delayAnswerLockInformation, Patch(cluster.client, s2m.add(delayAnswerLockInformation, Patch(cluster.client,
_connectToPrimaryNode=_connectToPrimaryNode)) _connectToPrimaryNode=_connectToPrimaryNode))
m2s.add(lambda conn, packet: m2s.delayNotifyUnlockInformation()
isinstance(packet, Packets.NotifyUnlockInformation))
t.commit() # the final TID is returned by the storage (tm) t.commit() # the final TID is returned by the storage (tm)
t.begin() t.begin()
self.assertEqual(r['x'].value, 2) self.assertEqual(r['x'].value, 2)
...@@ -1208,8 +1200,6 @@ class Test(NEOThreadedTest): ...@@ -1208,8 +1200,6 @@ class Test(NEOThreadedTest):
cluster.stop() cluster.stop()
def testRecycledClientUUID(self): def testRecycledClientUUID(self):
def delayNotifyInformation(conn, packet):
return isinstance(packet, Packets.NotifyNodeInformation)
def notReady(orig, *args): def notReady(orig, *args):
m2s.discard(delayNotifyInformation) m2s.discard(delayNotifyInformation)
return orig(*args) return orig(*args)
...@@ -1218,7 +1208,7 @@ class Test(NEOThreadedTest): ...@@ -1218,7 +1208,7 @@ class Test(NEOThreadedTest):
cluster.start() cluster.start()
cluster.getTransaction() cluster.getTransaction()
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.master.filterConnection(cluster.storage) as m2s:
m2s.add(delayNotifyInformation) delayNotifyInformation = m2s.delayNotifyNodeInformation()
cluster.client.master_conn.close() cluster.client.master_conn.close()
with cluster.newClient() as client, Patch( with cluster.newClient() as client, Patch(
client.storage_bootstrap_handler, notReady=notReady): client.storage_bootstrap_handler, notReady=notReady):
...@@ -1504,8 +1494,7 @@ class Test(NEOThreadedTest): ...@@ -1504,8 +1494,7 @@ class Test(NEOThreadedTest):
with LockLock() as ll, s1.filterConnection(cluster.client) as f, \ with LockLock() as ll, s1.filterConnection(cluster.client) as f, \
Patch(cluster.client.storage_handler, Patch(cluster.client.storage_handler,
answerStoreObject=answerStoreObject) as p: answerStoreObject=answerStoreObject) as p:
f.add(lambda conn, packet: f.delayAnswerStoreObject()
isinstance(packet, Packets.AnswerStoreObject))
t = self.newThread(t1.commit) t = self.newThread(t1.commit)
ll() ll()
t.join() t.join()
......
...@@ -155,8 +155,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -155,8 +155,7 @@ class ReplicationTests(NEOThreadedTest):
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic() self.tic()
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: conn.getUUID() is None and f.delayAddObject(lambda conn: conn.getUUID() is None)
isinstance(packet, Packets.AddObject))
while not f.filtered_count: while not f.filtered_count:
importZODB(1) importZODB(1)
self.tic() self.tic()
...@@ -271,8 +270,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -271,8 +270,7 @@ class ReplicationTests(NEOThreadedTest):
def testBackupUpstreamStorageDead(self, backup): def testBackupUpstreamStorageDead(self, backup):
upstream = backup.upstream upstream = backup.upstream
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.delayInvalidateObjects()
isinstance(packet, Packets.InvalidateObjects))
upstream.importZODB()(1) upstream.importZODB()(1)
count = [0] count = [0]
def _connect(orig, conn): def _connect(orig, conn):
...@@ -301,8 +299,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -301,8 +299,7 @@ class ReplicationTests(NEOThreadedTest):
""" """
upstream = backup.upstream upstream = backup.upstream
with upstream.master.filterConnection(upstream.storage) as f: with upstream.master.filterConnection(upstream.storage) as f:
f.add(lambda conn, packet: f.delayNotifyUnlockInformation()
isinstance(packet, Packets.NotifyUnlockInformation))
upstream.importZODB()(1) upstream.importZODB()(1)
self.tic() self.tic()
self.tic() self.tic()
...@@ -320,8 +317,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -320,8 +317,7 @@ class ReplicationTests(NEOThreadedTest):
try: try:
backup.start() backup.start()
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.delayAskPartitionTable(lambda conn:
isinstance(packet, Packets.AskPartitionTable) and
isinstance(conn.getHandler(), BackupHandler)) isinstance(conn.getHandler(), BackupHandler))
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
upstream.importZODB()(1) upstream.importZODB()(1)
...@@ -349,8 +345,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -349,8 +345,7 @@ class ReplicationTests(NEOThreadedTest):
importZODB(1) importZODB(1)
backup.reset() backup.reset()
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.delayAskFetchTransactions()
isinstance(packet, Packets.AskFetchTransactions))
backup.start() backup.start()
self.assertEqual(last_tid, backup.backup_tid) self.assertEqual(last_tid, backup.backup_tid)
self.tic() self.tic()
...@@ -457,8 +452,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -457,8 +452,7 @@ class ReplicationTests(NEOThreadedTest):
cluster.neoctl.enableStorageList([s1.uuid]) cluster.neoctl.enableStorageList([s1.uuid])
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.add(lambda conn, packet: m2c.delayNotifyPartitionChanges()
isinstance(packet, Packets.NotifyPartitionChanges))
self.tic() self.tic()
self.assertEqual('foo', storage.load(oid)[0]) self.assertEqual('foo', storage.load(oid)[0])
finally: finally:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment