Commit fd0b9c98 authored by Julien Muchembled's avatar Julien Muchembled

tests: make Patch usable as a context manager

parent 91c66356
......@@ -26,6 +26,7 @@ import unittest
import MySQLdb
import transaction
from functools import wraps
from mock import Mock
from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
......@@ -47,7 +48,7 @@ def expectedFailure(exception=AssertionError):
# XXX: passing sys.exc_info() causes deadlocks
raise _ExpectedFailure((type(e), None, None))
raise _UnexpectedSuccess
return functools.wraps(func)(wrapper)
return wraps(func)(wrapper)
if callable(exception) and not isinstance(exception, type):
func = exception
exception = Exception
......@@ -514,6 +515,45 @@ class NeoUnitTestBase(NeoTestBase):
def checkAnswerObjectPresent(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
class Patch(object):
applied = False
def __init__(self, patched, **patch):
(name, patch), = patch.iteritems()
wrapped = getattr(patched, name)
wrapper = lambda *args, **kw: patch(wrapped, *args, **kw)
self._patched = patched
self._name = name
self._wrapper = wraps(wrapped)(wrapper)
try:
orig = patched.__dict__[name]
self._revert = lambda: setattr(patched, name, orig)
except KeyError:
self._revert = lambda: delattr(patched, name)
def apply(self):
assert not self.applied
setattr(self._patched, self._name, self._wrapper)
self.applied = True
def revert(self):
del self.applied
self._revert()
def __del__(self):
if self.applied:
self.revert()
def __enter__(self):
self.apply()
return self
def __exit__(self, t, v, tb):
self.__del__()
connector_cpt = 0
class DoNothingConnector(Mock):
......
......@@ -36,7 +36,7 @@ from neo.lib.connector import SocketConnector, \
from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64
from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
......@@ -386,23 +386,6 @@ class LoggerThreadName(str):
return str.__str__(self)
class Patch(object):
def __init__(self, patched, **patch):
(name, patch), = patch.iteritems()
wrapped = getattr(patched, name)
wrapper = lambda *args, **kw: patch(wrapped, *args, **kw)
orig = patched.__dict__.get(name)
setattr(patched, name, wraps(wrapped)(wrapper))
if orig is None:
self._revert = lambda: delattr(patched, name)
else:
self._revert = lambda: setattr(patched, name, orig)
def __del__(self):
self._revert()
class ConnectionFilter(object):
filtered_count = 0
......@@ -469,6 +452,8 @@ class ConnectionFilter(object):
def add(self, filter, *patches):
with self.lock:
self.filter_dict[filter] = patches
for p in patches:
p.apply()
def remove(self, *filters):
with self.lock:
......
......@@ -28,8 +28,8 @@ from neo.storage.transactions import TransactionManager, \
from neo.lib.connection import ConnectionClosed, MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID
from .. import expectedFailure, _UnexpectedSuccess
from . import ClientApplication, NEOCluster, NEOThreadedTest, Patch
from .. import expectedFailure, _UnexpectedSuccess, Patch
from . import ClientApplication, NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum
from neo.client.exception import NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
......@@ -271,14 +271,11 @@ class Test(NEOThreadedTest):
o1.value += 1
o2.value += 2
p = (Patch(TransactionManager, storeObject=onStoreObject),
Patch(MTClientConnection, ask=onAsk))
try:
with Patch(TransactionManager, storeObject=onStoreObject), \
Patch(MTClientConnection, ask=onAsk):
t = self.newThread(t1.commit)
t2.commit()
t.join()
finally:
del p
t1.begin()
t2.begin()
self.assertEqual(o1.value, 3)
......@@ -578,6 +575,7 @@ class Test(NEOThreadedTest):
x2 = c2.root()['x']
p = Patch(cluster.client, _handlePacket=_handlePacket)
try:
p.apply()
t = self.newThread(t1.commit)
l1.acquire()
t2.begin()
......@@ -649,6 +647,7 @@ class Test(NEOThreadedTest):
cache._remove(cache._oid_dict[x2._p_oid].pop())
p = Patch(cluster.client, _loadFromStorage=_loadFromStorage)
try:
p.apply()
t = self.newThread(x2._p_activate)
l1.acquire()
# At this point, x could not be found the cache and the result
......@@ -685,6 +684,7 @@ class Test(NEOThreadedTest):
t1.abort()
p = Patch(c1, _flush_invalidations=_flush_invalidations)
try:
p.apply()
t = self.newThread(t1.begin)
l1.acquire()
cluster.client.setPoll(0)
......@@ -745,9 +745,8 @@ class Test(NEOThreadedTest):
cluster.client.setPoll(1)
# Check reconnection to storage.
p = Patch(cluster.client.cp, getConnForNode=getConnForNode)
with Patch(cluster.client.cp, getConnForNode=getConnForNode):
self.assertFalse(cluster.client.history(x1._p_oid))
del p
self.assertFalse(conn)
self.assertTrue(cluster.client.history(x1._p_oid))
......@@ -831,6 +830,7 @@ class Test(NEOThreadedTest):
master_nodes=cluster.master_nodes)
p = Patch(client.storage_bootstrap_handler, notReady=notReady)
try:
p.apply()
client.setPoll(1)
x = client.load(ZERO_TID)
finally:
......
......@@ -25,7 +25,8 @@ from neo.lib.connection import ClientConnection
from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, Patch, \
from .. import Patch
from . import ConnectionFilter, NEOCluster, NEOThreadedTest, \
predictable_random, Serialized
......@@ -225,10 +226,8 @@ class ReplicationTests(NEOThreadedTest):
# a second replication partially and aborts.
p = Patch(backup.storage_list[storage].replicator,
fetchObjects=fetchObjects)
try:
with p:
importZODB(lambda x: counts[0] > 1)
finally:
del p
upstream.client.setPoll(0)
if event > 5:
backup.neoctl.checkReplicas(check_dict, ZERO_TID, None)
......@@ -274,8 +273,7 @@ class ReplicationTests(NEOThreadedTest):
def __init__(orig, *args, **kw):
count[0] += 1
orig(*args, **kw)
p = Patch(ClientConnection, __init__=__init__)
try:
with Patch(ClientConnection, __init__=__init__):
upstream.storage.listening_conn.close()
Serialized.tic(); self.assertEqual(count[0], 0)
Serialized.tic(); count[0] or Serialized.tic()
......@@ -284,8 +282,6 @@ class ReplicationTests(NEOThreadedTest):
time.sleep(1.1)
Serialized.tic(); self.assertEqual(count[0], 3)
Serialized.tic(); self.assertEqual(count[0], 3)
finally:
del p
@backup_test()
def testBackupDelayedUnlockTransaction(self, backup):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment