Commit 6f6d071d authored by Julien Muchembled's avatar Julien Muchembled

Simplify setup of monkey-patches in threaded tests

parent 3e1ed6a4
......@@ -519,11 +519,15 @@ class Patch(object):
def __init__(self, patched, **patch):
(name, patch), = patch.iteritems()
wrapped = getattr(patched, name)
wrapper = lambda *args, **kw: patch(wrapped, *args, **kw)
self._patched = patched
self._name = name
self._wrapper = wraps(wrapped)(wrapper)
if callable(patch):
wrapped = getattr(patched, name, None)
func = patch
patch = lambda *args, **kw: func(wrapped, *args, **kw)
if callable(wrapped):
patch = wraps(wrapped)(patch)
self._patch = patch
try:
orig = patched.__dict__[name]
self._revert = lambda: setattr(patched, name, orig)
......@@ -532,7 +536,7 @@ class Patch(object):
def apply(self):
assert not self.applied
setattr(self._patched, self._name, self._wrapper)
setattr(self._patched, self._name, self._patch)
self.applied = True
def revert(self):
......
......@@ -504,11 +504,23 @@ class ConnectionFilter(object):
class NEOCluster(object):
BaseConnection_getTimeout = staticmethod(BaseConnection.getTimeout)
CONNECT_LIMIT = SocketConnector.CONNECT_LIMIT
SimpleQueue__init__ = staticmethod(SimpleQueue.__init__)
SocketConnector_bind = staticmethod(SocketConnector._bind)
SocketConnector_connect = staticmethod(SocketConnector._connect)
def __init__(orig, self): # temporary definition for SimpleQueue patch
orig(self)
lock = self._lock
def _lock(blocking=True):
if blocking:
while not lock(False):
Serialized.tic(step=1)
return True
return lock(False)
self._lock = _lock
_patches = (
Patch(BaseConnection, getTimeout=lambda orig, self: None),
Patch(SimpleQueue, __init__=__init__),
Patch(SocketConnector, CONNECT_LIMIT=0),
Patch(SocketConnector, _bind=lambda orig, self, addr: orig(self, BIND)),
Patch(SocketConnector, _connect = lambda orig, self, addr:
orig(self, ServerNode.resolv(addr))))
_patch_count = 0
_resource_dict = weakref.WeakValueDictionary()
......@@ -525,23 +537,8 @@ class NEOCluster(object):
cls._patch_count += 1
if cls._patch_count > 1:
return
def __init__(self):
cls.SimpleQueue__init__(self)
lock = self._lock
def _lock(blocking=True):
if blocking:
while not lock(False):
Serialized.tic(step=1)
return True
return lock(False)
self._lock = _lock
BaseConnection.getTimeout = lambda self: None
SimpleQueue.__init__ = __init__
SocketConnector.CONNECT_LIMIT = 0
SocketConnector._bind = lambda self, addr: \
cls.SocketConnector_bind(self, BIND)
SocketConnector._connect = lambda self, addr: \
cls.SocketConnector_connect(self, ServerNode.resolv(addr))
for patch in cls._patches:
patch.apply()
Serialized.init()
@staticmethod
......@@ -551,11 +548,8 @@ class NEOCluster(object):
cls._patch_count -= 1
if cls._patch_count:
return
BaseConnection.getTimeout = cls.BaseConnection_getTimeout
SimpleQueue.__init__ = cls.SimpleQueue__init__
SocketConnector.CONNECT_LIMIT = cls.CONNECT_LIMIT
SocketConnector._bind = cls.SocketConnector_bind
SocketConnector._connect = cls.SocketConnector_connect
for patch in cls._patches:
patch.revert()
Serialized.stop()
def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment