Commit 15bcd495 authored by Julien Muchembled's avatar Julien Muchembled

tests: new helper to synchronize threads

parent 645920e8
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
# XXX: Consider using ClusterStates.STOPPING to stop clusters # XXX: Consider using ClusterStates.STOPPING to stop clusters
import os, random, select, socket, sys, tempfile, threading, time, weakref import os, random, select, socket, sys, tempfile
import traceback import thread, threading, time, traceback, weakref
from collections import deque from collections import deque
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from contextlib import contextmanager from contextlib import contextmanager
from itertools import count from itertools import count
from functools import wraps from functools import wraps
from thread import get_ident
from zlib import decompress from zlib import decompress
from mock import Mock from mock import Mock
import transaction, ZODB import transaction, ZODB
...@@ -44,6 +45,37 @@ BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 ...@@ -44,6 +45,37 @@ BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]) LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
class LockLock(object):
"""Double lock used as synchronisation point between 2 threads
Used to wait that a slave thread has reached a specific location, and to
keep it suspended there. It resumes on __exit__
"""
def __init__(self):
self._l = threading.Lock(), threading.Lock()
def __call__(self):
"""Define synchronisation point for both threads"""
if self._owner == thread.get_ident():
self._l[0].acquire()
else:
self._l[0].release()
self._l[1].acquire()
def __enter__(self):
self._owner = thread.get_ident()
for l in self._l:
l.acquire(0)
return self
def __exit__(self, t, v, tb):
try:
self._l[1].release()
except thread.error:
pass
class FairLock(deque): class FairLock(deque):
"""Same as a threading.Lock except that waiting threads are queued, so that """Same as a threading.Lock except that waiting threads are queued, so that
the first one waiting for the lock is the first to get it. This is useful the first one waiting for the lock is the first to get it. This is useful
......
...@@ -30,7 +30,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation ...@@ -30,7 +30,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID ZERO_TID
from .. import expectedFailure, _ExpectedFailure, _UnexpectedSuccess, Patch from .. import expectedFailure, _ExpectedFailure, _UnexpectedSuccess, Patch
from . import NEOCluster, NEOThreadedTest from . import LockLock, NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
...@@ -751,12 +751,9 @@ class Test(NEOThreadedTest): ...@@ -751,12 +751,9 @@ class Test(NEOThreadedTest):
self.assertEqual(list(s.dm.getPartitionTable()), pt) self.assertEqual(list(s.dm.getPartitionTable()), pt)
def testInternalInvalidation(self): def testInternalInvalidation(self):
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
def _handlePacket(orig, conn, packet, kw={}, handler=None): def _handlePacket(orig, conn, packet, kw={}, handler=None):
if type(packet) is Packets.AnswerTransactionFinished: if type(packet) is Packets.AnswerTransactionFinished:
l1.release() ll()
l2.acquire()
orig(conn, packet, kw, handler) orig(conn, packet, kw, handler)
cluster = NEOCluster() cluster = NEOCluster()
try: try:
...@@ -768,15 +765,11 @@ class Test(NEOThreadedTest): ...@@ -768,15 +765,11 @@ class Test(NEOThreadedTest):
x1.value = 1 x1.value = 1
t2, c2 = cluster.getTransaction() t2, c2 = cluster.getTransaction()
x2 = c2.root()['x'] x2 = c2.root()['x']
p = Patch(cluster.client, _handlePacket=_handlePacket) with LockLock() as ll, Patch(cluster.client,
try: _handlePacket=_handlePacket):
p.apply()
t = self.newThread(t1.commit) t = self.newThread(t1.commit)
l1.acquire() ll()
t2.begin() t2.begin()
finally:
del p
l2.release()
t.join() t.join()
self.assertEqual(x2.value, 1) self.assertEqual(x2.value, 1)
finally: finally:
...@@ -824,22 +817,18 @@ class Test(NEOThreadedTest): ...@@ -824,22 +817,18 @@ class Test(NEOThreadedTest):
self.assertEqual(x2.value, 1) self.assertEqual(x2.value, 1)
# Now test cache invalidation during a load from a storage # Now test cache invalidation during a load from a storage
l1 = threading.Lock(); l1.acquire() ll = LockLock()
l2 = threading.Lock(); l2.acquire()
def _loadFromStorage(orig, *args): def _loadFromStorage(orig, *args):
try: try:
return orig(*args) return orig(*args)
finally: finally:
l1.release() ll()
l2.acquire()
x2._p_deactivate() x2._p_deactivate()
# Remove last version of x from cache # Remove last version of x from cache
cache._remove(cache._oid_dict[x2._p_oid].pop()) cache._remove(cache._oid_dict[x2._p_oid].pop())
p = Patch(cluster.client, _loadFromStorage=_loadFromStorage) with ll, Patch(cluster.client, _loadFromStorage=_loadFromStorage):
try:
p.apply()
t = self.newThread(x2._p_activate) t = self.newThread(x2._p_activate)
l1.acquire() ll()
# At this point, x could not be found the cache and the result # At this point, x could not be found the cache and the result
# from the storage (which is <value=1, next_tid=None>) is about # from the storage (which is <value=1, next_tid=None>) is about
# to be processed. # to be processed.
...@@ -849,11 +838,8 @@ class Test(NEOThreadedTest): ...@@ -849,11 +838,8 @@ class Test(NEOThreadedTest):
client.store(x2._p_oid, tid, x, '', txn) # value=0 client.store(x2._p_oid, tid, x, '', txn) # value=0
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn, None)
t1.begin() # make sure invalidation is processed t1.begin() # make sure invalidation is processed
finally:
del p
# Resume processing of answer from storage. An entry should be # Resume processing of answer from storage. An entry should be
# added in cache for x=1 with a fixed next_tid (i.e. not None) # added in cache for x=1 with a fixed next_tid (i.e. not None)
l2.release()
t.join() t.join()
self.assertEqual(x2.value, 1) self.assertEqual(x2.value, 1)
self.assertEqual(x1.value, 0) self.assertEqual(x1.value, 0)
...@@ -863,24 +849,18 @@ class Test(NEOThreadedTest): ...@@ -863,24 +849,18 @@ class Test(NEOThreadedTest):
# is suspended at the beginning of the transaction t1, # is suspended at the beginning of the transaction t1,
# between Storage.sync() and flush of invalidations. # between Storage.sync() and flush of invalidations.
def _flush_invalidations(orig): def _flush_invalidations(orig):
l1.release() ll()
l2.acquire()
orig() orig()
x1._p_deactivate() x1._p_deactivate()
t1.abort() t1.abort()
p = Patch(c1, _flush_invalidations=_flush_invalidations) with ll, Patch(c1, _flush_invalidations=_flush_invalidations):
try:
p.apply()
t = self.newThread(t1.begin) t = self.newThread(t1.begin)
l1.acquire() ll()
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(txn)
client.store(x2._p_oid, tid, y, '', txn) client.store(x2._p_oid, tid, y, '', txn)
tid = client.tpc_finish(txn, None) tid = client.tpc_finish(txn, None)
client.close() client.close()
finally:
del p
l2.release()
t.join() t.join()
# A transaction really begins when it acquires the lock to flush # A transaction really begins when it acquires the lock to flush
# invalidations. The previous lastTransaction() only does a ping # invalidations. The previous lastTransaction() only does a ping
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment