Commit fd0b9c98 authored by Julien Muchembled's avatar Julien Muchembled

tests: make Patch usable as a context manager

parent 91c66356
...@@ -26,6 +26,7 @@ import unittest ...@@ -26,6 +26,7 @@ import unittest
import MySQLdb import MySQLdb
import transaction import transaction
from functools import wraps
from mock import Mock from mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
...@@ -47,7 +48,7 @@ def expectedFailure(exception=AssertionError): ...@@ -47,7 +48,7 @@ def expectedFailure(exception=AssertionError):
# XXX: passing sys.exc_info() causes deadlocks # XXX: passing sys.exc_info() causes deadlocks
raise _ExpectedFailure((type(e), None, None)) raise _ExpectedFailure((type(e), None, None))
raise _UnexpectedSuccess raise _UnexpectedSuccess
return functools.wraps(func)(wrapper) return wraps(func)(wrapper)
if callable(exception) and not isinstance(exception, type): if callable(exception) and not isinstance(exception, type):
func = exception func = exception
exception = Exception exception = Exception
...@@ -514,6 +515,45 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -514,6 +515,45 @@ class NeoUnitTestBase(NeoTestBase):
def checkAnswerObjectPresent(self, conn, **kw): def checkAnswerObjectPresent(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw) return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
class Patch(object):
applied = False
def __init__(self, patched, **patch):
(name, patch), = patch.iteritems()
wrapped = getattr(patched, name)
wrapper = lambda *args, **kw: patch(wrapped, *args, **kw)
self._patched = patched
self._name = name
self._wrapper = wraps(wrapped)(wrapper)
try:
orig = patched.__dict__[name]
self._revert = lambda: setattr(patched, name, orig)
except KeyError:
self._revert = lambda: delattr(patched, name)
def apply(self):
assert not self.applied
setattr(self._patched, self._name, self._wrapper)
self.applied = True
def revert(self):
del self.applied
self._revert()
def __del__(self):
if self.applied:
self.revert()
def __enter__(self):
self.apply()
return self
def __exit__(self, t, v, tb):
self.__del__()
connector_cpt = 0 connector_cpt = 0
class DoNothingConnector(Mock): class DoNothingConnector(Mock):
......
...@@ -36,7 +36,7 @@ from neo.lib.connector import SocketConnector, \ ...@@ -36,7 +36,7 @@ from neo.lib.connector import SocketConnector, \
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64 from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64
from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
...@@ -386,23 +386,6 @@ class LoggerThreadName(str): ...@@ -386,23 +386,6 @@ class LoggerThreadName(str):
return str.__str__(self) return str.__str__(self)
class Patch(object):
def __init__(self, patched, **patch):
(name, patch), = patch.iteritems()
wrapped = getattr(patched, name)
wrapper = lambda *args, **kw: patch(wrapped, *args, **kw)
orig = patched.__dict__.get(name)
setattr(patched, name, wraps(wrapped)(wrapper))
if orig is None:
self._revert = lambda: delattr(patched, name)
else:
self._revert = lambda: setattr(patched, name, orig)
def __del__(self):
self._revert()
class ConnectionFilter(object): class ConnectionFilter(object):
filtered_count = 0 filtered_count = 0
...@@ -469,6 +452,8 @@ class ConnectionFilter(object): ...@@ -469,6 +452,8 @@ class ConnectionFilter(object):
def add(self, filter, *patches): def add(self, filter, *patches):
with self.lock: with self.lock:
self.filter_dict[filter] = patches self.filter_dict[filter] = patches
for p in patches:
p.apply()
def remove(self, *filters): def remove(self, *filters):
with self.lock: with self.lock:
......
...@@ -28,8 +28,8 @@ from neo.storage.transactions import TransactionManager, \ ...@@ -28,8 +28,8 @@ from neo.storage.transactions import TransactionManager, \
from neo.lib.connection import ConnectionClosed, MTClientConnection from neo.lib.connection import ConnectionClosed, MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID ZERO_TID
from .. import expectedFailure, _UnexpectedSuccess from .. import expectedFailure, _UnexpectedSuccess, Patch
from . import ClientApplication, NEOCluster, NEOThreadedTest, Patch from . import ClientApplication, NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum from neo.lib.util import add64, makeChecksum
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
...@@ -271,14 +271,11 @@ class Test(NEOThreadedTest): ...@@ -271,14 +271,11 @@ class Test(NEOThreadedTest):
o1.value += 1 o1.value += 1
o2.value += 2 o2.value += 2
p = (Patch(TransactionManager, storeObject=onStoreObject), with Patch(TransactionManager, storeObject=onStoreObject), \
Patch(MTClientConnection, ask=onAsk)) Patch(MTClientConnection, ask=onAsk):
try:
t = self.newThread(t1.commit) t = self.newThread(t1.commit)
t2.commit() t2.commit()
t.join() t.join()
finally:
del p
t1.begin() t1.begin()
t2.begin() t2.begin()
self.assertEqual(o1.value, 3) self.assertEqual(o1.value, 3)
...@@ -578,6 +575,7 @@ class Test(NEOThreadedTest): ...@@ -578,6 +575,7 @@ class Test(NEOThreadedTest):
x2 = c2.root()['x'] x2 = c2.root()['x']
p = Patch(cluster.client, _handlePacket=_handlePacket) p = Patch(cluster.client, _handlePacket=_handlePacket)
try: try:
p.apply()
t = self.newThread(t1.commit) t = self.newThread(t1.commit)
l1.acquire() l1.acquire()
t2.begin() t2.begin()
...@@ -649,6 +647,7 @@ class Test(NEOThreadedTest): ...@@ -649,6 +647,7 @@ class Test(NEOThreadedTest):
cache._remove(cache._oid_dict[x2._p_oid].pop()) cache._remove(cache._oid_dict[x2._p_oid].pop())
p = Patch(cluster.client, _loadFromStorage=_loadFromStorage) p = Patch(cluster.client, _loadFromStorage=_loadFromStorage)
try: try:
p.apply()
t = self.newThread(x2._p_activate) t = self.newThread(x2._p_activate)
l1.acquire() l1.acquire()
# At this point, x could not be found the cache and the result # At this point, x could not be found the cache and the result
...@@ -685,6 +684,7 @@ class Test(NEOThreadedTest): ...@@ -685,6 +684,7 @@ class Test(NEOThreadedTest):
t1.abort() t1.abort()
p = Patch(c1, _flush_invalidations=_flush_invalidations) p = Patch(c1, _flush_invalidations=_flush_invalidations)
try: try:
p.apply()
t = self.newThread(t1.begin) t = self.newThread(t1.begin)
l1.acquire() l1.acquire()
cluster.client.setPoll(0) cluster.client.setPoll(0)
...@@ -745,9 +745,8 @@ class Test(NEOThreadedTest): ...@@ -745,9 +745,8 @@ class Test(NEOThreadedTest):
cluster.client.setPoll(1) cluster.client.setPoll(1)
# Check reconnection to storage. # Check reconnection to storage.
p = Patch(cluster.client.cp, getConnForNode=getConnForNode) with Patch(cluster.client.cp, getConnForNode=getConnForNode):
self.assertFalse(cluster.client.history(x1._p_oid)) self.assertFalse(cluster.client.history(x1._p_oid))
del p
self.assertFalse(conn) self.assertFalse(conn)
self.assertTrue(cluster.client.history(x1._p_oid)) self.assertTrue(cluster.client.history(x1._p_oid))
...@@ -831,6 +830,7 @@ class Test(NEOThreadedTest): ...@@ -831,6 +830,7 @@ class Test(NEOThreadedTest):
master_nodes=cluster.master_nodes) master_nodes=cluster.master_nodes)
p = Patch(client.storage_bootstrap_handler, notReady=notReady) p = Patch(client.storage_bootstrap_handler, notReady=notReady)
try: try:
p.apply()
client.setPoll(1) client.setPoll(1)
x = client.load(ZERO_TID) x = client.load(ZERO_TID)
finally: finally:
......
...@@ -25,7 +25,8 @@ from neo.lib.connection import ClientConnection ...@@ -25,7 +25,8 @@ from neo.lib.connection import ClientConnection
from neo.lib.protocol import CellStates, ClusterStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64 from neo.lib.util import p64
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, Patch, \ from .. import Patch
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, \
predictable_random, Serialized predictable_random, Serialized
...@@ -225,10 +226,8 @@ class ReplicationTests(NEOThreadedTest): ...@@ -225,10 +226,8 @@ class ReplicationTests(NEOThreadedTest):
# a second replication partially and aborts. # a second replication partially and aborts.
p = Patch(backup.storage_list[storage].replicator, p = Patch(backup.storage_list[storage].replicator,
fetchObjects=fetchObjects) fetchObjects=fetchObjects)
try: with p:
importZODB(lambda x: counts[0] > 1) importZODB(lambda x: counts[0] > 1)
finally:
del p
upstream.client.setPoll(0) upstream.client.setPoll(0)
if event > 5: if event > 5:
backup.neoctl.checkReplicas(check_dict, ZERO_TID, None) backup.neoctl.checkReplicas(check_dict, ZERO_TID, None)
...@@ -274,8 +273,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -274,8 +273,7 @@ class ReplicationTests(NEOThreadedTest):
def __init__(orig, *args, **kw): def __init__(orig, *args, **kw):
count[0] += 1 count[0] += 1
orig(*args, **kw) orig(*args, **kw)
p = Patch(ClientConnection, __init__=__init__) with Patch(ClientConnection, __init__=__init__):
try:
upstream.storage.listening_conn.close() upstream.storage.listening_conn.close()
Serialized.tic(); self.assertEqual(count[0], 0) Serialized.tic(); self.assertEqual(count[0], 0)
Serialized.tic(); count[0] or Serialized.tic() Serialized.tic(); count[0] or Serialized.tic()
...@@ -284,8 +282,6 @@ class ReplicationTests(NEOThreadedTest): ...@@ -284,8 +282,6 @@ class ReplicationTests(NEOThreadedTest):
time.sleep(1.1) time.sleep(1.1)
Serialized.tic(); self.assertEqual(count[0], 3) Serialized.tic(); self.assertEqual(count[0], 3)
Serialized.tic(); self.assertEqual(count[0], 3) Serialized.tic(); self.assertEqual(count[0], 3)
finally:
del p
@backup_test() @backup_test()
def testBackupDelayedUnlockTransaction(self, backup): def testBackupDelayedUnlockTransaction(self, backup):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment