Commit a2e278d5 authored by Julien Muchembled's avatar Julien Muchembled

client: fix race condition between Storage.load() and invalidations

This fixes a bug that could manifest as follows:

  Traceback (most recent call last):
    File "neo/client/app.py", line 432, in load
      self._cache.store(oid, data, tid, next_tid)
    File "neo/client/cache.py", line 223, in store
      assert item.tid == tid, (item, tid)
  AssertionError: (<CacheItem oid='\x00\x00\x00\x00\x00\x00\x00\x01' tid='\x03\xcb\xc6\xca\xfd\xc7\xda\xee' next_tid='\x03\xcb\xc6\xca\xfd\xd8\t\x88' data='...' counter=1 level=1 expire=10000 prev=<...> next=<...>>, '\x03\xcb\xc6\xca\xfd\xd8\t\x88')

The big changes in the threaded test framework are required because we need to
reproduce a race condition between client threads and this conflicts with the
serialization of epoll events (deadlock).
parent 743026d5
......@@ -410,8 +410,16 @@ class Application(ThreadedApplication):
if result:
return result
self._loading_oid = oid
self._loading_invalidated = []
finally:
release()
# While the cache lock is released, an arbitrary number of
# invalidations may be processed, for this oid or not. And at this
# precise moment, if both tid and before_tid are None (which is
# unlikely to happen with recent ZODB), self.last_tid can be any
# new tid. Since we can get any serial from storage, fixing
# next_tid requires to keep a list of all possible serials.
# When not bound to a ZODB Connection, load() may be the
# first method called and last_tid may still be None.
# This happens, for example, when opening the DB.
......@@ -423,12 +431,11 @@ class Application(ThreadedApplication):
acquire()
try:
if self._loading_oid:
# Common case (no race condition).
self._cache.store(oid, data, tid, next_tid)
elif self._loading_invalidated:
# oid has just been invalidated.
if not next_tid:
next_tid = self._loading_invalidated
for t in self._loading_invalidated:
if tid < t:
next_tid = t
break
self._cache.store(oid, data, tid, next_tid)
# Else, we just reconnected to the master.
finally:
......
......@@ -127,8 +127,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
for oid in oid_list:
invalidate(oid, tid)
if oid == loading:
app._loading_oid = None
app._loading_invalidated = tid
app._loading_invalidated.append(tid)
db = app.getDB()
if db is not None:
db.invalidate(tid, oid_list)
......
......@@ -26,6 +26,7 @@ from zlib import decompress
import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app
from neo.admin.handler import MasterEventHandler
from neo.client import Storage
from neo.lib import logging
from neo.lib.connection import BaseConnection, \
......@@ -36,6 +37,7 @@ from neo.lib.locking import SimpleQueue
from neo.lib.protocol import uuid_str, \
ClusterStates, Enum, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64
from neo.master.recovery import RecoveryManager
from .. import (getTempDirectory, setupMySQLdb,
ImporterConfigParser, NeoTestBase, Patch,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER)
......@@ -119,9 +121,12 @@ class Serialized(object):
detect which node has a readable epoll object.
"""
check_timeout = False
_disabled = False
@classmethod
def init(cls):
if cls._disabled:
return
cls._busy = set()
cls._busy_cond = threading.Condition(threading.Lock())
cls._epoll = select.epoll()
......@@ -138,6 +143,8 @@ class Serialized(object):
@classmethod
def stop(cls):
if cls._disabled:
return
assert not cls._fd_dict, ("file descriptor leak (%r)\nThis may happen"
" when a test fails, in which case you can see the real exception"
" by disabling this one." % cls._fd_dict)
......@@ -148,6 +155,25 @@ class Serialized(object):
def _sort_key(cls, fd_event):
return -cls._fd_dict[fd_event[0]]._last
@classmethod
@contextmanager
def until(cls, patched=None, **patch):
if cls._disabled:
if patched is None:
yield int
else:
l = threading.Lock()
l.acquire()
(name, patch), = patch.iteritems()
def release():
p.revert()
l.release()
with Patch(patched, **{name: lambda *args, **kw:
patch(release, *args, **kw)}) as p:
yield l.acquire
else:
yield cls.tic
@classmethod
@contextmanager
def pdb(cls):
......@@ -174,6 +200,10 @@ class Serialized(object):
# We also increase SocketConnector.SOMAXCONN in tests so that
# a connection attempt is never delayed inside the kernel.
timeout=0):
if cls._disabled:
if timeout:
time.sleep(timeout)
return
# If you're in a pdb here, 'n' switches to another thread
# (the following lines are not supposed to be debugged into)
with cls._tic_lock, cls.pdb():
......@@ -208,6 +238,8 @@ class Serialized(object):
cls._sched_lock.acquire()
def __init__(self, app, busy=True):
if self._disabled:
return
self._epoll = app.em.epoll
app.em.epoll = self
# XXX: It may have been initialized before the SimpleQueue is patched.
......@@ -360,7 +392,8 @@ class ServerNode(Node):
finally:
self._afterRun()
logging.debug('stopping %r', self)
self.em.epoll.exit()
if isinstance(self.em.epoll, Serialized):
self.em.epoll.exit()
def _afterRun(self):
try:
......@@ -427,7 +460,8 @@ class ClientApplication(Node, neo.client.app.Application):
try:
super(ClientApplication, self)._run()
finally:
self.em.epoll.exit()
if isinstance(self.em.epoll, Serialized):
self.em.epoll.exit()
def start(self):
isinstance(self.em.epoll, Serialized) or Serialized(self)
......@@ -616,6 +650,8 @@ class NEOCluster(object):
def __init__(orig, self): # temporary definition for SimpleQueue patch
orig(self)
if Serialized._disabled:
return
lock = self._lock
def _lock(blocking=True):
if blocking:
......@@ -765,22 +801,41 @@ class NEOCluster(object):
self.started = True
self._patch()
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
for node in self.master_list if master_list is None else master_list:
node.start()
for node in self.admin_list:
node.start()
Serialized.tic()
if master_list is None:
master_list = self.master_list
if storage_list is None:
storage_list = self.storage_list
for node in storage_list:
node.start()
Serialized.tic()
if recovering:
expected_state = ClusterStates.RECOVERING
else:
self.startCluster()
Serialized.tic()
expected_state = ClusterStates.RUNNING, ClusterStates.BACKINGUP
def answerPartitionTable(release, orig, *args):
orig(*args)
release()
def dispatch(release, orig, handler, *args):
orig(handler, *args)
node_list = handler.app.nm.getStorageList(only_identified=True)
if len(node_list) == len(storage_list) and not any(
node.getConnection().isPending() for node in node_list):
release()
expected_state = (ClusterStates.RECOVERING,) if recovering else (
ClusterStates.RUNNING, ClusterStates.BACKINGUP)
def notifyClusterInformation(release, orig, handler, conn, state):
orig(handler, conn, state)
if state in expected_state:
release()
with Serialized.until(MasterEventHandler,
answerPartitionTable=answerPartitionTable) as tic1, \
Serialized.until(RecoveryManager, dispatch=dispatch) as tic2, \
Serialized.until(MasterEventHandler,
notifyClusterInformation=notifyClusterInformation) as tic3:
for node in master_list:
node.start()
for node in self.admin_list:
node.start()
tic1()
for node in storage_list:
node.start()
tic2()
if not recovering:
self.startCluster()
tic3()
self.checkStarted(expected_state, storage_list)
def checkStarted(self, expected_state, storage_list=None):
......@@ -1120,12 +1175,16 @@ def predictable_random(seed=None):
return wraps(wrapped)(wrapper)
return decorator
def with_cluster(start_cluster=True, **cluster_kw):
def with_cluster(serialized=True, start_cluster=True, **cluster_kw):
def decorator(wrapped):
def wrapper(self, *args, **kw):
with NEOCluster(**cluster_kw) as cluster:
if start_cluster:
cluster.start()
return wrapped(self, cluster, *args, **kw)
try:
Serialized._disabled = not serialized
with NEOCluster(**cluster_kw) as cluster:
if start_cluster:
cluster.start()
return wrapped(self, cluster, *args, **kw)
finally:
Serialized._disabled = False
return wraps(wrapped)(wrapper)
return decorator
......@@ -37,7 +37,7 @@ from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID)
from .. import unpickle_state, Patch, TransactionalResource
from . import ClientApplication, ConnectionFilter, LockLock, NEOCluster, \
NEOThreadedTest, RandomConflictDict, ThreadId, with_cluster
NEOThreadedTest, RandomConflictDict, Serialized, ThreadId, with_cluster
from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.transactions import Transaction
......@@ -979,6 +979,45 @@ class Test(NEOThreadedTest):
self.assertFalse(invalidations(c1))
self.assertEqual(x1.value, 1)
@with_cluster(serialized=False)
def testExternalInvalidation2(self, cluster):
t, c = cluster.getTransaction()
r = c.root()
x = r[''] = PCounter()
t.commit()
tid1 = x._p_serial
nonlocal_ = [0, 1]
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
def invalidateObjects(orig, *args):
if not nonlocal_[0]:
l1.acquire()
orig(*args)
nonlocal_[0] += 1
if nonlocal_[0] == 2:
l2.release()
def _cache_lock_release(orig):
orig()
if nonlocal_[1]:
nonlocal_[1] = 0
l1.release()
l2.acquire()
with cluster.newClient() as client, \
Patch(client.notifications_handler,
invalidateObjects=invalidateObjects):
client.sync()
with cluster.master.filterConnection(client) as mc2:
mc2.delayInvalidateObjects()
x._p_changed = 1
t.commit()
tid2 = x._p_serial
self.assertEqual((tid1, tid2), client.load(x._p_oid)[1:])
r._p_changed = 1
t.commit()
with Patch(client, _cache_lock_release=_cache_lock_release):
self.assertEqual((tid2, None), client.load(x._p_oid)[1:])
self.assertEqual(nonlocal_, [2, 0])
@with_cluster(storage_count=2, partitions=2)
def testReadVerifyingStorage(self, cluster):
s1, s2 = cluster.sortStorageList()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment