Commit 2b62301d authored by Jim Fulton's avatar Jim Fulton

Made caches thread safe. In theory, caches are protected by ZEO

clients, but ZEO clients haven't provided very good protection,
leading to cache corruption.  We'll hopefully fix these client issues,
which cause other problems beside cache corruption, but it seems
prudent to provide low-level cache protection.
parent 02292886
...@@ -341,8 +341,6 @@ class ClientStorage(object): ...@@ -341,8 +341,6 @@ class ClientStorage(object):
else: else:
cache_path = None cache_path = None
self._cache = self.ClientCacheClass(cache_path, size=cache_size) self._cache = self.ClientCacheClass(cache_path, size=cache_size)
# TODO: maybe there's a better time to open the cache? Unclear.
self._cache.open()
self._rpc_mgr = self.ConnectionManagerClass(addr, self, self._rpc_mgr = self.ConnectionManagerClass(addr, self,
tmin=min_disconnect_poll, tmin=min_disconnect_poll,
......
...@@ -30,6 +30,7 @@ import BTrees.LOBTree ...@@ -30,6 +30,7 @@ import BTrees.LOBTree
import logging import logging
import os import os
import tempfile import tempfile
import threading
import time import time
import ZODB.fsIndex import ZODB.fsIndex
...@@ -119,6 +120,21 @@ ZEC_HEADER_SIZE = 12 ...@@ -119,6 +120,21 @@ ZEC_HEADER_SIZE = 12
# to the end of the file that the new object can't fit in one # to the end of the file that the new object can't fit in one
# contiguous chunk, currentofs is reset to ZEC_HEADER_SIZE first. # contiguous chunk, currentofs is reset to ZEC_HEADER_SIZE first.
class locked(object):
def __init__(self, func):
self.func = func
def __get__(self, inst, class_):
if inst is None:
return self
def call(*args, **kw):
inst._lock.acquire()
try:
return self.func(inst, *args, **kw)
finally:
inst._lock.release()
return call
class ClientCache(object): class ClientCache(object):
"""A simple in-memory cache.""" """A simple in-memory cache."""
...@@ -197,6 +213,10 @@ class ClientCache(object): ...@@ -197,6 +213,10 @@ class ClientCache(object):
self._setup_trace(path) self._setup_trace(path)
self.open()
self._lock = threading.RLock()
# Backward compatibility. Client code used to have to use the fc # Backward compatibility. Client code used to have to use the fc
# attr to get to the file cache to get cache stats. # attr to get to the file cache to get cache stats.
@property @property
...@@ -351,6 +371,7 @@ class ClientCache(object): ...@@ -351,6 +371,7 @@ class ClientCache(object):
# instance, and also written out near the start of the cache file. The # instance, and also written out near the start of the cache file. The
# new tid must be strictly greater than our current idea of the most # new tid must be strictly greater than our current idea of the most
# recent tid. # recent tid.
@locked
def setLastTid(self, tid): def setLastTid(self, tid):
if self.tid is not None and tid <= self.tid: if self.tid is not None and tid <= self.tid:
raise ValueError("new last tid (%s) must be greater than " raise ValueError("new last tid (%s) must be greater than "
...@@ -367,10 +388,11 @@ class ClientCache(object): ...@@ -367,10 +388,11 @@ class ClientCache(object):
# @return a transaction id # @return a transaction id
# @defreturn string, or None if no transaction is yet known # @defreturn string, or None if no transaction is yet known
def getLastTid(self): def getLastTid(self):
if self.tid == z64: tid = self.tid
if tid == z64:
return None return None
else: else:
return self.tid return tid
## ##
# Return the current data record for oid. # Return the current data record for oid.
...@@ -379,6 +401,7 @@ class ClientCache(object): ...@@ -379,6 +401,7 @@ class ClientCache(object):
# in the cache # in the cache
# @defreturn 3-tuple: (string, string, string) # @defreturn 3-tuple: (string, string, string)
@locked
def load(self, oid): def load(self, oid):
ofs = self.current.get(oid) ofs = self.current.get(oid)
if ofs is None: if ofs is None:
...@@ -406,6 +429,7 @@ class ClientCache(object): ...@@ -406,6 +429,7 @@ class ClientCache(object):
# @return data record, serial number, start tid, and end tid # @return data record, serial number, start tid, and end tid
# @defreturn 4-tuple: (string, string, string, string) # @defreturn 4-tuple: (string, string, string, string)
@locked
def loadBefore(self, oid, before_tid): def loadBefore(self, oid, before_tid):
noncurrent_for_oid = self.noncurrent.get(u64(oid)) noncurrent_for_oid = self.noncurrent.get(u64(oid))
if noncurrent_for_oid is None: if noncurrent_for_oid is None:
...@@ -447,6 +471,7 @@ class ClientCache(object): ...@@ -447,6 +471,7 @@ class ClientCache(object):
# current. # current.
# @param data the actual data # @param data the actual data
@locked
def store(self, oid, start_tid, end_tid, data): def store(self, oid, start_tid, end_tid, data):
seek = self.f.seek seek = self.f.seek
if end_tid is None: if end_tid is None:
...@@ -533,6 +558,7 @@ class ClientCache(object): ...@@ -533,6 +558,7 @@ class ClientCache(object):
# @param oid object id # @param oid object id
# @param tid the id of the transaction that wrote a new revision of oid, # @param tid the id of the transaction that wrote a new revision of oid,
# or None to forget all cached info about oid. # or None to forget all cached info about oid.
@locked
def invalidate(self, oid, tid): def invalidate(self, oid, tid):
if tid > self.tid and tid is not None: if tid > self.tid and tid is not None:
self.setLastTid(tid) self.setLastTid(tid)
...@@ -572,12 +598,18 @@ class ClientCache(object): ...@@ -572,12 +598,18 @@ class ClientCache(object):
seek = self.f.seek seek = self.f.seek
read = self.f.read read = self.f.read
for oid, ofs in self.current.iteritems(): for oid, ofs in self.current.iteritems():
self._lock.acquire()
try:
seek(ofs) seek(ofs)
assert read(1) == 'a', (ofs, self.f.tell(), oid) assert read(1) == 'a', (ofs, self.f.tell(), oid)
size, saved_oid, tid, end_tid = unpack(">I8s8s8s", read(28)) size, saved_oid, tid, end_tid = unpack(">I8s8s8s", read(28))
assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid) assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid)
assert end_tid == z64, (ofs, self.f.tell(), oid) assert end_tid == z64, (ofs, self.f.tell(), oid)
yield oid, tid result = oid, tid
finally:
self._lock.release()
yield result
def dump(self): def dump(self):
from ZODB.utils import oid_repr from ZODB.utils import oid_repr
......
...@@ -67,7 +67,6 @@ class CacheTests(unittest.TestCase): ...@@ -67,7 +67,6 @@ class CacheTests(unittest.TestCase):
# testSerialization reads the entire file into a string, it's not # testSerialization reads the entire file into a string, it's not
# good to leave it that big. # good to leave it that big.
self.cache = ZEO.cache.ClientCache(size=1024**2) self.cache = ZEO.cache.ClientCache(size=1024**2)
self.cache.open()
def tearDown(self): def tearDown(self):
if self.cache.path: if self.cache.path:
...@@ -154,7 +153,6 @@ class CacheTests(unittest.TestCase): ...@@ -154,7 +153,6 @@ class CacheTests(unittest.TestCase):
dst.write(src.read(self.cache.maxsize)) dst.write(src.read(self.cache.maxsize))
dst.close() dst.close()
copy = ZEO.cache.ClientCache(path) copy = ZEO.cache.ClientCache(path)
copy.open()
# Verify that internals of both objects are the same. # Verify that internals of both objects are the same.
# Could also test that external API produces the same results. # Could also test that external API produces the same results.
...@@ -170,7 +168,6 @@ class CacheTests(unittest.TestCase): ...@@ -170,7 +168,6 @@ class CacheTests(unittest.TestCase):
if self.cache.path: if self.cache.path:
os.remove(self.cache.path) os.remove(self.cache.path)
self.cache = ZEO.cache.ClientCache(size=50) self.cache = ZEO.cache.ClientCache(size=50)
self.cache.open()
# We store an object that is a bit larger than the cache can handle. # We store an object that is a bit larger than the cache can handle.
self.cache.store(n1, n2, None, "x"*64) self.cache.store(n1, n2, None, "x"*64)
...@@ -186,7 +183,6 @@ class CacheTests(unittest.TestCase): ...@@ -186,7 +183,6 @@ class CacheTests(unittest.TestCase):
if self.cache.path: if self.cache.path:
os.remove(self.cache.path) os.remove(self.cache.path)
cache = ZEO.cache.ClientCache(size=50) cache = ZEO.cache.ClientCache(size=50)
cache.open()
# We store an object that is a bit larger than the cache can handle. # We store an object that is a bit larger than the cache can handle.
cache.store(n1, n2, n3, "x"*64) cache.store(n1, n2, n3, "x"*64)
...@@ -231,7 +227,6 @@ __test__ = dict( ...@@ -231,7 +227,6 @@ __test__ = dict(
... _ = os.spawnl(os.P_WAIT, sys.executable, sys.executable, 't') ... _ = os.spawnl(os.P_WAIT, sys.executable, sys.executable, 't')
... if os.path.exists('cache'): ... if os.path.exists('cache'):
... cache = ZEO.cache.ClientCache('cache') ... cache = ZEO.cache.ClientCache('cache')
... cache.open()
... cache.close() ... cache.close()
... os.remove('cache') ... os.remove('cache')
... os.remove('cache.lock') ... os.remove('cache.lock')
...@@ -251,7 +246,6 @@ __test__ = dict( ...@@ -251,7 +246,6 @@ __test__ = dict(
>>> cache.store(p64(1), p64(1), None, data) >>> cache.store(p64(1), p64(1), None, data)
>>> cache.close() >>> cache.close()
>>> cache = ZEO.cache.ClientCache('cache', 1000) >>> cache = ZEO.cache.ClientCache('cache', 1000)
>>> cache.open()
>>> cache.store(p64(2), p64(2), None, 'XXX') >>> cache.store(p64(2), p64(2), None, 'XXX')
>>> cache.close() >>> cache.close()
...@@ -267,6 +261,56 @@ __test__ = dict( ...@@ -267,6 +261,56 @@ __test__ = dict(
LockError: Couldn't lock 'cache.lock' LockError: Couldn't lock 'cache.lock'
>>> cache.close() >>> cache.close()
""",
thread_safe =
r"""
>>> import ZEO.cache, ZODB.utils
>>> cache = ZEO.cache.ClientCache('cache', 1000000)
>>> for i in range(100):
... cache.store(ZODB.utils.p64(i), ZODB.utils.p64(1), None, '0')
>>> import random, sys, threading
>>> random = random.Random(0)
>>> stop = False
>>> read_failure = None
>>> def read_thread():
... def pick_oid():
... return ZODB.utils.p64(random.randint(0,99))
...
... try:
... while not stop:
... cache.load(pick_oid())
... cache.loadBefore(pick_oid(), ZODB.utils.p64(2))
... except:
... global read_failure
... read_failure = sys.exc_info()
>>> thread = threading.Thread(target=read_thread)
>>> thread.start()
>>> for tid in range(2,10):
... for oid in range(100):
... oid = ZODB.utils.p64(oid)
... cache.invalidate(oid, ZODB.utils.p64(tid))
... cache.store(oid, ZODB.utils.p64(tid), None, str(tid))
>>> stop = True
>>> thread.join()
>>> if read_failure:
... print 'Read failure:'
... import traceback
... traceback.print_exception(*read_failure)
>>> expected = '9', ZODB.utils.p64(9)
>>> for oid in range(100):
... loaded = cache.load(ZODB.utils.p64(oid))
... if loaded != expected:
... print oid, loaded
""", """,
) )
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment