Commit 4d571267 authored by Julien Muchembled's avatar Julien Muchembled

client: merge load optimizations

parents d4603189 8ba42463
......@@ -80,7 +80,7 @@ class Application(ThreadedApplication):
# no self-assigned NID, primary master will supply us one
self._cache = ClientCache() if cache_size is None else \
ClientCache(max_size=cache_size)
self._loading_oid = None
self._loading = defaultdict(lambda: (Lock(), []))
self.new_oids = ()
self.last_oid = '\0' * 8
self.storage_event_handler = storage.StorageEventHandler(self)
......@@ -91,19 +91,13 @@ class Application(ThreadedApplication):
self.notifications_handler = master.PrimaryNotificationsHandler( self)
self._txn_container = TransactionContainer()
# Lock definition :
# _load_lock is used to make loading and storing atomic
lock = Lock()
self._load_lock_acquire = lock.acquire
self._load_lock_release = lock.release
# _oid_lock is used in order to not call multiple oid
# generation at the same time
lock = Lock()
self._oid_lock_acquire = lock.acquire
self._oid_lock_release = lock.release
lock = Lock()
# _cache_lock is used for the client cache
self._cache_lock_acquire = lock.acquire
self._cache_lock_release = lock.release
self._cache_lock = Lock()
# _connecting_to_master_node is used to prevent simultaneous master
# node connection attempts
self._connecting_to_master_node = Lock()
......@@ -402,21 +396,32 @@ class Application(ThreadedApplication):
"""
# TODO:
# - rename parameters (here? and in handlers & packet definitions)
acquire = self._cache_lock_acquire
release = self._cache_lock_release
# XXX: Consider using a more fine-grained lock.
self._load_lock_acquire()
acquired = False
lock = self._cache_lock
try:
acquire()
try:
result = self._loadFromCache(oid, tid, before_tid)
if result:
return result
self._loading_oid = oid
self._loading_invalidated = []
finally:
release()
while 1:
with lock:
if tid:
result = self._cache.load(oid, tid + '*')
assert not result or result[1] == tid
else:
result = self._cache.load(oid, before_tid)
if result:
return result
load_lock = self._loading[oid][0]
acquired = load_lock.acquire(0)
# Several concurrent cache misses for the same oid are probably
# for the same tid so we use a per-oid lock to avoid asking the
# same data to the storage node.
if acquired:
# The first thread does load from storage,
# and fills cache with the response.
break
# The other threads wait for the first one to complete and
# loop, possibly resulting in a new cache miss if a different
# tid is actually wanted or if the data was too big.
with load_lock:
pass
# While the cache lock is released, an arbitrary number of
# invalidations may be processed, for this oid or not. And at this
# precise moment, if both tid and before_tid are None (which is
......@@ -432,20 +437,24 @@ class Application(ThreadedApplication):
# we got from master.
before_tid = p64(u64(self.last_tid) + 1)
data, tid, next_tid, _ = self._loadFromStorage(oid, tid, before_tid)
acquire()
try:
if self._loading_oid:
with lock:
loading = self._loading.pop(oid, None)
if loading:
assert loading[0] is load_lock
if not next_tid:
for t in self._loading_invalidated:
for t in loading[1]:
if tid < t:
next_tid = t
break
self._cache.store(oid, data, tid, next_tid)
# Else, we just reconnected to the master.
finally:
release()
finally:
self._load_lock_release()
load_lock.release()
except:
if acquired:
with lock:
self._loading.pop(oid, None)
load_lock.release()
raise
return data, tid, next_tid
def _loadFromStorage(self, oid, at_tid, before_tid):
......@@ -464,16 +473,6 @@ class Application(ThreadedApplication):
Packets.AskObject(oid, at_tid, before_tid),
askStorage)
def _loadFromCache(self, oid, at_tid=None, before_tid=None):
"""
Load from local cache, return None if not found.
"""
if at_tid:
result = self._cache.load(oid, at_tid + '*')
assert not result or result[1] == at_tid
return result
return self._cache.load(oid, before_tid)
def tpc_begin(self, storage, transaction, tid=None, status=' '):
"""Begin a new transaction."""
# First get a transaction, only one is allowed at a time
......@@ -729,29 +728,23 @@ class Application(ThreadedApplication):
txn_container = self._txn_container
if not txn_container.get(transaction).voted:
self.tpc_vote(transaction)
checked_list = []
self._load_lock_acquire()
txn_context = txn_container.pop(transaction)
cache_dict = txn_context.cache_dict
checked_list = [oid for oid, data in cache_dict.iteritems()
if data is CHECKED_SERIAL]
for oid in checked_list:
del cache_dict[oid]
ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try:
# Call finish on master
txn_context = txn_container.pop(transaction)
cache_dict = txn_context.cache_dict
checked_list = [oid for oid, data in cache_dict.iteritems()
if data is CHECKED_SERIAL]
for oid in checked_list:
del cache_dict[oid]
ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid
except ConnectionClosed:
tid = self._getFinalTID(ttid)
if not tid:
raise
return tid
finally:
self._load_lock_release()
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid
except ConnectionClosed:
tid = self._getFinalTID(ttid)
if not tid:
raise
return tid
def _getFinalTID(self, ttid):
try:
......@@ -991,11 +984,8 @@ class Application(ThreadedApplication):
# It should not be otherwise required (clients should be free to load
# old data as long as it is available in cache, event if it was pruned
# by a pack), so don't bother invalidating on other clients.
self._cache_lock_acquire()
try:
with self._cache_lock:
self._cache.clear()
finally:
self._cache_lock_release()
def getLastTID(self, oid):
return self.load(oid)[1]
......
......@@ -45,8 +45,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
# Either we're connecting or we already know the last tid
# via invalidations.
assert app.master_conn is None, app.master_conn
app._cache_lock_acquire()
try:
with app._cache_lock:
if app_last_tid < ltid:
app._cache.clear_current()
# In the past, we tried not to invalidate the
......@@ -60,9 +59,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
app._cache.clear()
# Make sure a parallel load won't refill the cache
# with garbage.
app._loading_oid = app._loading_invalidated = None
finally:
app._cache_lock_release()
app._loading.clear()
db = app.getDB()
db is None or db.invalidateCache()
app.last_tid = ltid
......@@ -73,18 +70,20 @@ class PrimaryNotificationsHandler(MTEventHandler):
app.last_tid = tid
# Update cache
cache = app._cache
app._cache_lock_acquire()
try:
with app._cache_lock:
invalidate = app._cache.invalidate
loading_get = app._loading.get
for oid, data in cache_dict.iteritems():
# Update ex-latest value in cache
cache.invalidate(oid, tid)
invalidate(oid, tid)
loading = loading_get(oid)
if loading:
loading[1].append(tid)
if data is not None:
# Store in cache with no next_tid
cache.store(oid, data, tid, None)
if callback is not None:
callback(tid)
finally:
app._cache_lock_release()
def connectionClosed(self, conn):
app = self.app
......@@ -113,19 +112,17 @@ class PrimaryNotificationsHandler(MTEventHandler):
if app.ignore_invalidations:
return
app.last_tid = tid
app._cache_lock_acquire()
try:
with app._cache_lock:
invalidate = app._cache.invalidate
loading = app._loading_oid
loading_get = app._loading.get
for oid in oid_list:
invalidate(oid, tid)
if oid == loading:
app._loading_invalidated.append(tid)
loading = loading_get(oid)
if loading:
loading[1].append(tid)
db = app.getDB()
if db is not None:
db.invalidate(tid, oid_list)
finally:
app._cache_lock_release()
def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
pt = self.app.pt = object.__new__(PartitionTable)
......
......@@ -197,8 +197,7 @@ elif IF == 'trace-cache':
@defer
def profile(app):
app._cache_lock_acquire()
try:
with app._cache_lock:
cache = app._cache
if type(cache) is ClientCache:
app._cache = CacheTracer(cache, '%s-%s.neo-cache-trace' %
......@@ -206,5 +205,3 @@ elif IF == 'trace-cache':
app._cache.clear()
else:
app._cache = cache.close()
finally:
app._cache_lock_release()
......@@ -1105,8 +1105,7 @@ class NEOThreadedTest(NeoTestBase):
def run(self):
try:
apply(*self.__target)
self.__exc_info = None
self.__result = apply(*self.__target)
except:
self.__exc_info = sys.exc_info()
if self.__exc_info[0] is NEOThreadedTest.failureException:
......@@ -1114,10 +1113,13 @@ class NEOThreadedTest(NeoTestBase):
def join(self, timeout=None):
threading.Thread.join(self, timeout)
if not self.is_alive() and self.__exc_info:
etype, value, tb = self.__exc_info
del self.__exc_info
raise etype, value, tb
if not self.is_alive():
try:
return self.__result
except AttributeError:
etype, value, tb = self.__exc_info
del self.__exc_info
raise etype, value, tb
class newThread(newPausedThread):
......
......@@ -902,6 +902,27 @@ class Test(NEOThreadedTest):
self.assertEqual(c.root()['1'].value, 1)
self.assertNotIn('2', c.root())
@with_cluster()
def testLoadVsFinish(self, cluster):
t1, c1 = cluster.getTransaction()
c1.root()['x'] = x1 = PCounter()
t1.commit()
t1.begin()
x1.value = 1
t2, c2 = cluster.getTransaction()
x2 = c2.root()['x']
cluster.client._cache.clear()
def _loadFromStorage(orig, *args):
r = orig(*args)
ll()
return r
with LockLock() as ll, Patch(cluster.client,
_loadFromStorage=_loadFromStorage):
t = self.newThread(x2._p_activate)
ll()
t1.commit()
t.join()
@with_cluster()
def testInternalInvalidation(self, cluster):
def _handlePacket(orig, conn, packet, kw={}, handler=None):
......@@ -989,6 +1010,8 @@ class Test(NEOThreadedTest):
t.join()
self.assertEqual(x2.value, 1)
self.assertEqual(x1.value, 0)
self.assertEqual((x2._p_serial, x1._p_serial),
cluster.client._cache.load(x1._p_oid, x1._p_serial)[1:])
def invalidations(conn):
try:
......@@ -1026,7 +1049,7 @@ class Test(NEOThreadedTest):
x = r[''] = PCounter()
t.commit()
tid1 = x._p_serial
nonlocal_ = [0, 1]
nonlocal_ = [0, 0, 0]
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
def invalidateObjects(orig, *args):
......@@ -1036,27 +1059,72 @@ class Test(NEOThreadedTest):
nonlocal_[0] += 1
if nonlocal_[0] == 2:
l2.release()
def _cache_lock_release(orig):
orig()
if nonlocal_[1]:
nonlocal_[1] = 0
class CacheLock(object):
def __init__(self, client):
self._lock = client._cache_lock
def __enter__(self):
self._lock.acquire()
def __exit__(self, t, v, tb):
count = nonlocal_[1]
nonlocal_[1] = count + 1
self._lock.release()
if count == 0:
load_same.start()
l2.acquire()
elif count == 1:
load_other.start()
def _loadFromStorage(orig, *args):
count = nonlocal_[2]
nonlocal_[2] = count + 1
if not count:
l1.release()
l2.acquire()
return orig(*args)
with cluster.newClient() as client, \
Patch(client.notifications_handler,
invalidateObjects=invalidateObjects):
client.sync()
with cluster.master.filterConnection(client) as mc2:
mc2.delayInvalidateObjects()
# A first client node (C1) modifies an oid whereas
# invalidations to the other node (C2) are delayed.
x._p_changed = 1
t.commit()
tid2 = x._p_serial
# C2 loads the most recent revision of this oid (last_tid=tid1).
self.assertEqual((tid1, tid2), client.load(x._p_oid)[1:])
# C2 poll thread is frozen just before processing invalidation
# packet for tid2. C1 modifies something else -> tid3
r._p_changed = 1
t.commit()
with Patch(client, _cache_lock_release=_cache_lock_release):
self.assertEqual((tid2, None), client.load(x._p_oid)[1:])
self.assertEqual(nonlocal_, [2, 0])
self.assertEqual(tid1, client.last_tid)
load_same = self.newPausedThread(client.load, x._p_oid)
load_other = self.newPausedThread(client.load, r._p_oid)
with Patch(client, _cache_lock=CacheLock(client)), \
Patch(client, _loadFromStorage=_loadFromStorage):
# 1. Just after having found nothing in cache, the worker
# thread asks the poll thread to get notified about
# invalidations for the loading oid.
# <context switch> (l1)
# 2. Both invalidations are processed. -> last_tid=tid3
# <context switch> (l2)
# 3. The worker thread loads before tid3+1.
# The poll thread notified [tid2], which must be ignored.
# In parallel, 2 other loads are done (both cache misses):
# - one for the same oid, which waits for first load to
# complete and in particular fill cache, in order to
# avoid asking the same data to the storage node
# - another for a different oid, which doesn't wait, as shown
# by the fact that it returns an old record (i.e. before any
# invalidation packet is processed)
loaded = client.load(x._p_oid)
self.assertEqual((tid2, None), loaded[1:])
self.assertEqual(loaded, load_same.join())
self.assertEqual((tid1, r._p_serial), load_other.join()[1:])
# To summary:
# - 3 concurrent loads starting with cache misses
# - 2 loads from storage
# - 1 load ending with a cache hit
self.assertEqual(nonlocal_, [2, 8, 2])
@with_cluster(storage_count=2, partitions=2)
def testReadVerifyingStorage(self, cluster):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment