Commit ea54aa2b authored by Jeremy Hylton's avatar Jeremy Hylton

Add per-storage transaction timeout feature and a couple of tests.

parent 1fd91693
......@@ -33,6 +33,7 @@ from ZEO import ClientStub
from ZEO.CommitLog import CommitLog
from ZEO.zrpc.server import Dispatcher
from ZEO.zrpc.connection import ManagedServerConnection, Delay, MTDelay
from ZEO.zrpc.trigger import trigger
import zLOG
from ZODB.POSException import StorageError, StorageTransactionError
......@@ -72,7 +73,8 @@ class StorageServer:
ManagedServerConnectionClass = ManagedServerConnection
def __init__(self, addr, storages, read_only=0,
invalidation_queue_size=100):
invalidation_queue_size=100,
transaction_timeout=None):
"""StorageServer constructor.
This is typically invoked from the start.py script.
......@@ -104,6 +106,11 @@ class StorageServer:
N == invalidation_queue_size. This queue is used to
speed client cache verification when a client disconnects
for a short period of time.
transaction_timout -- The maximum amount of time to wait for
a transaction to commit after acquiring the storage lock.
If the transaction takes too long, the client connection
will be closed and the transaction aborted.
"""
self.addr = addr
......@@ -125,6 +132,15 @@ class StorageServer:
self.dispatcher = self.DispatcherClass(addr,
factory=self.new_connection,
reuse_addr=1)
self.timeouts = {}
for name in self.storages.keys():
if transaction_timeout is None:
# An object with no-op methods
timeout = StubTimeoutThread()
else:
timeout = TimeoutThread(transaction_timeout)
timeout.start()
self.timeouts[name] = timeout
def new_connection(self, sock, addr):
"""Internal: factory to create a new connection.
......@@ -147,11 +163,14 @@ class StorageServer:
list of current connections for that storage; this information
is needed to handle invalidation. This function updates this
dictionary.
Returns the timeout object for the appropriate storage.
"""
l = self.connections.get(storage_id)
if l is None:
l = self.connections[storage_id] = []
l.append(conn)
return self.timeouts[storage_id]
def invalidate(self, conn, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients.
......@@ -216,6 +235,7 @@ class StorageServer:
This is only called from the test suite, AFAICT.
"""
self.timeout.stop()
self.dispatcher.close()
for storage in self.storages.values():
storage.close()
......@@ -246,6 +266,7 @@ class ZEOStorage:
def __init__(self, server, read_only=0):
self.server = server
self.timeout = None
self.connection = None
self.client = None
self.storage = None
......@@ -350,7 +371,7 @@ class ZEOStorage:
self.storage_id = storage_id
self.storage = storage
self.setup_delegation()
self.server.register_connection(storage_id, self)
self.timeout = self.server.register_connection(storage_id, self)
def get_info(self):
return {'length': len(self.storage),
......@@ -512,6 +533,7 @@ class ZEOStorage:
self.invalidated, self.get_size_info())
self.transaction = None
self.locked = 0
self.timeout.end(self)
# Return the tid, for cache invalidation optimization
self._handle_waiting()
return tid
......@@ -523,6 +545,7 @@ class ZEOStorage:
self.storage.tpc_abort(self.transaction)
self.transaction = None
self.locked = 0
self.timeout.end(self)
self._handle_waiting()
def _abort(self):
......@@ -584,6 +607,7 @@ class ZEOStorage:
def _tpc_begin(self, txn, tid, status):
self.locked = 1
self.storage.tpc_begin(txn, tid, status)
self.timeout.begin(self)
def _store(self, oid, serial, data, version):
try:
......@@ -702,6 +726,86 @@ class ZEOStorage:
else:
return 1
class StubTimeoutThread:
def begin(self, client):
pass
def end(self, client):
pass
def stop(self):
pass
class TimeoutThread(threading.Thread):
"""Monitors transaction progress and generates timeouts."""
def __init__(self, timeout):
threading.Thread.__init__(self)
self.setDaemon(1)
self._timeout = timeout
self._client = None
self._deadline = None
self._stop = 0
self._active = threading.Event()
self._lock = threading.Lock()
self._trigger = trigger()
def stop(self):
self._stop = 1
def begin(self, client):
self._lock.acquire()
try:
self._active.set()
self._client = client
self._deadline = time.time() + self._timeout
finally:
self._lock.release()
def end(self, client):
# The ZEOStorage will call this message for every aborted
# transaction, regardless of whether the transaction started
# the 2PC. Ignore here if 2PC never began.
if client is not self._client:
return
self._lock.acquire()
try:
self._active.clear()
self._client = None
self._deadline = None
finally:
self._lock.release()
def run(self):
while not self._stop:
self._active.wait()
self._lock.acquire()
try:
howlong = self._deadline - time.time()
finally:
self._lock.release()
if howlong <= 0:
self.timeout()
else:
time.sleep(howlong)
def timeout(self):
self._lock.acquire()
try:
client = self._client
deadline = self._deadline
self._active.clear()
self._client = None
self._deadline = None
finally:
self._lock.release()
if client is None:
return
elapsed = time.time() - (deadline - self._timeout)
client.log("Transaction timeout after %d seconds" % int(elapsed))
self._trigger.pull_trigger(lambda: client.connection.close())
def run_in_thread(method, *args):
t = SlowMethodThread(method, args)
t.start()
......
......@@ -60,6 +60,7 @@ class CommonSetupTearDown(StorageTestBase):
__super_tearDown = StorageTestBase.tearDown
keep = 0
invq = None
timeout = None
def setUp(self):
"""Test setup for connection tests.
......@@ -131,7 +132,7 @@ class CommonSetupTearDown(StorageTestBase):
path = "%s.%d" % (self.file, index)
conf = self.getConfig(path, create, read_only)
zeoport, adminaddr, pid = forker.start_zeo_server(
conf, addr, ro_svr, self.keep, self.invq)
conf, addr, ro_svr, self.keep, self.invq, self.timeout)
self._pids.append(pid)
self._servers.append(adminaddr)
......@@ -674,6 +675,29 @@ class ReconnectionTests(CommonSetupTearDown):
perstorage.close()
class TimeoutTests(CommonSetupTearDown):
timeout = 1
def checkTimeout(self):
storage = self.openClientStorage()
txn = Transaction()
storage.tpc_begin(txn)
storage.tpc_vote(txn)
time.sleep(2)
self.assertRaises(Disconnected, storage.tpc_finish, txn)
def checkTimeoutOnAbort(self):
storage = self.openClientStorage()
txn = Transaction()
storage.tpc_begin(txn)
storage.tpc_vote(txn)
storage.tpc_abort(txn)
def checkTimeoutOnAbortNoLock(self):
storage = self.openClientStorage()
txn = Transaction()
storage.tpc_begin(txn)
storage.tpc_abort(txn)
class MSTThread(threading.Thread):
......
......@@ -51,7 +51,8 @@ def get_port():
raise RuntimeError, "Can't find port"
def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None):
def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None,
timeout=None):
"""Start a ZEO server in a separate process.
Returns the ZEO port, the test server port, and the pid.
......@@ -79,6 +80,8 @@ def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None):
args.append('-k')
if invq:
args += ['-Q', str(invq)]
if timeout:
args += ['-T', str(timeout)]
args.append(str(port))
d = os.environ.copy()
d['PYTHONPATH'] = os.pathsep.join(sys.path)
......
......@@ -59,6 +59,12 @@ class FileStorageReconnectionTests(
):
"""FileStorage-specific re-connection tests."""
class FileStorageTimeoutTests(
FileStorageConfig,
ConnectionTests.TimeoutTests
):
# doesn't test anything that is storage-specific
pass
class BDBConnectionTests(
BerkeleyStorageConfig,
......@@ -74,7 +80,8 @@ class BDBReconnectionTests(
"""Berkeley storage re-connection tests."""
test_classes = [FileStorageConnectionTests, FileStorageReconnectionTests]
test_classes = [FileStorageConnectionTests, FileStorageReconnectionTests,
FileStorageTimeoutTests]
import BDBStorage
if BDBStorage.is_available:
......
......@@ -117,8 +117,9 @@ def main():
keep = 0
configfile = None
invalidation_queue_size = 100
transaction_timeout = None
# Parse the arguments and let getopt.error percolate
opts, args = getopt.getopt(sys.argv[1:], 'rkC:Q:')
opts, args = getopt.getopt(sys.argv[1:], 'rkC:Q:T:')
for opt, arg in opts:
if opt == '-r':
ro_svr = 1
......@@ -128,6 +129,8 @@ def main():
configfile = arg
elif opt == '-Q':
invalidation_queue_size = int(arg)
elif opt == '-T':
transaction_timeout = int(arg)
# Open the config file and let ZConfig parse the data there. Then remove
# the config file, otherwise we'll leave turds.
fp = open(configfile, 'r')
......@@ -150,7 +153,8 @@ def main():
log(label, 'creating the storage server')
serv = ZEO.StorageServer.StorageServer(
addr, {'1': storage}, ro_svr,
invalidation_queue_size=invalidation_queue_size)
invalidation_queue_size=invalidation_queue_size,
transaction_timeout=transaction_timeout)
log(label, 'entering ThreadedAsync loop')
ThreadedAsync.LoopCallback.loop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment