Commit ea54aa2b authored by Jeremy Hylton's avatar Jeremy Hylton

Add per-storage transaction timeout feature and a couple of tests.

parent 1fd91693
...@@ -33,6 +33,7 @@ from ZEO import ClientStub ...@@ -33,6 +33,7 @@ from ZEO import ClientStub
from ZEO.CommitLog import CommitLog from ZEO.CommitLog import CommitLog
from ZEO.zrpc.server import Dispatcher from ZEO.zrpc.server import Dispatcher
from ZEO.zrpc.connection import ManagedServerConnection, Delay, MTDelay from ZEO.zrpc.connection import ManagedServerConnection, Delay, MTDelay
from ZEO.zrpc.trigger import trigger
import zLOG import zLOG
from ZODB.POSException import StorageError, StorageTransactionError from ZODB.POSException import StorageError, StorageTransactionError
...@@ -72,7 +73,8 @@ class StorageServer: ...@@ -72,7 +73,8 @@ class StorageServer:
ManagedServerConnectionClass = ManagedServerConnection ManagedServerConnectionClass = ManagedServerConnection
def __init__(self, addr, storages, read_only=0, def __init__(self, addr, storages, read_only=0,
invalidation_queue_size=100): invalidation_queue_size=100,
transaction_timeout=None):
"""StorageServer constructor. """StorageServer constructor.
This is typically invoked from the start.py script. This is typically invoked from the start.py script.
...@@ -104,6 +106,11 @@ class StorageServer: ...@@ -104,6 +106,11 @@ class StorageServer:
N == invalidation_queue_size. This queue is used to N == invalidation_queue_size. This queue is used to
speed client cache verification when a client disconnects speed client cache verification when a client disconnects
for a short period of time. for a short period of time.
transaction_timout -- The maximum amount of time to wait for
a transaction to commit after acquiring the storage lock.
If the transaction takes too long, the client connection
will be closed and the transaction aborted.
""" """
self.addr = addr self.addr = addr
...@@ -125,6 +132,15 @@ class StorageServer: ...@@ -125,6 +132,15 @@ class StorageServer:
self.dispatcher = self.DispatcherClass(addr, self.dispatcher = self.DispatcherClass(addr,
factory=self.new_connection, factory=self.new_connection,
reuse_addr=1) reuse_addr=1)
self.timeouts = {}
for name in self.storages.keys():
if transaction_timeout is None:
# An object with no-op methods
timeout = StubTimeoutThread()
else:
timeout = TimeoutThread(transaction_timeout)
timeout.start()
self.timeouts[name] = timeout
def new_connection(self, sock, addr): def new_connection(self, sock, addr):
"""Internal: factory to create a new connection. """Internal: factory to create a new connection.
...@@ -147,11 +163,14 @@ class StorageServer: ...@@ -147,11 +163,14 @@ class StorageServer:
list of current connections for that storage; this information list of current connections for that storage; this information
is needed to handle invalidation. This function updates this is needed to handle invalidation. This function updates this
dictionary. dictionary.
Returns the timeout object for the appropriate storage.
""" """
l = self.connections.get(storage_id) l = self.connections.get(storage_id)
if l is None: if l is None:
l = self.connections[storage_id] = [] l = self.connections[storage_id] = []
l.append(conn) l.append(conn)
return self.timeouts[storage_id]
def invalidate(self, conn, storage_id, tid, invalidated=(), info=None): def invalidate(self, conn, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients. """Internal: broadcast info and invalidations to clients.
...@@ -216,6 +235,7 @@ class StorageServer: ...@@ -216,6 +235,7 @@ class StorageServer:
This is only called from the test suite, AFAICT. This is only called from the test suite, AFAICT.
""" """
self.timeout.stop()
self.dispatcher.close() self.dispatcher.close()
for storage in self.storages.values(): for storage in self.storages.values():
storage.close() storage.close()
...@@ -246,6 +266,7 @@ class ZEOStorage: ...@@ -246,6 +266,7 @@ class ZEOStorage:
def __init__(self, server, read_only=0): def __init__(self, server, read_only=0):
self.server = server self.server = server
self.timeout = None
self.connection = None self.connection = None
self.client = None self.client = None
self.storage = None self.storage = None
...@@ -350,7 +371,7 @@ class ZEOStorage: ...@@ -350,7 +371,7 @@ class ZEOStorage:
self.storage_id = storage_id self.storage_id = storage_id
self.storage = storage self.storage = storage
self.setup_delegation() self.setup_delegation()
self.server.register_connection(storage_id, self) self.timeout = self.server.register_connection(storage_id, self)
def get_info(self): def get_info(self):
return {'length': len(self.storage), return {'length': len(self.storage),
...@@ -512,6 +533,7 @@ class ZEOStorage: ...@@ -512,6 +533,7 @@ class ZEOStorage:
self.invalidated, self.get_size_info()) self.invalidated, self.get_size_info())
self.transaction = None self.transaction = None
self.locked = 0 self.locked = 0
self.timeout.end(self)
# Return the tid, for cache invalidation optimization # Return the tid, for cache invalidation optimization
self._handle_waiting() self._handle_waiting()
return tid return tid
...@@ -523,6 +545,7 @@ class ZEOStorage: ...@@ -523,6 +545,7 @@ class ZEOStorage:
self.storage.tpc_abort(self.transaction) self.storage.tpc_abort(self.transaction)
self.transaction = None self.transaction = None
self.locked = 0 self.locked = 0
self.timeout.end(self)
self._handle_waiting() self._handle_waiting()
def _abort(self): def _abort(self):
...@@ -584,6 +607,7 @@ class ZEOStorage: ...@@ -584,6 +607,7 @@ class ZEOStorage:
def _tpc_begin(self, txn, tid, status): def _tpc_begin(self, txn, tid, status):
self.locked = 1 self.locked = 1
self.storage.tpc_begin(txn, tid, status) self.storage.tpc_begin(txn, tid, status)
self.timeout.begin(self)
def _store(self, oid, serial, data, version): def _store(self, oid, serial, data, version):
try: try:
...@@ -702,6 +726,86 @@ class ZEOStorage: ...@@ -702,6 +726,86 @@ class ZEOStorage:
else: else:
return 1 return 1
class StubTimeoutThread:
def begin(self, client):
pass
def end(self, client):
pass
def stop(self):
pass
class TimeoutThread(threading.Thread):
"""Monitors transaction progress and generates timeouts."""
def __init__(self, timeout):
threading.Thread.__init__(self)
self.setDaemon(1)
self._timeout = timeout
self._client = None
self._deadline = None
self._stop = 0
self._active = threading.Event()
self._lock = threading.Lock()
self._trigger = trigger()
def stop(self):
self._stop = 1
def begin(self, client):
self._lock.acquire()
try:
self._active.set()
self._client = client
self._deadline = time.time() + self._timeout
finally:
self._lock.release()
def end(self, client):
# The ZEOStorage will call this message for every aborted
# transaction, regardless of whether the transaction started
# the 2PC. Ignore here if 2PC never began.
if client is not self._client:
return
self._lock.acquire()
try:
self._active.clear()
self._client = None
self._deadline = None
finally:
self._lock.release()
def run(self):
while not self._stop:
self._active.wait()
self._lock.acquire()
try:
howlong = self._deadline - time.time()
finally:
self._lock.release()
if howlong <= 0:
self.timeout()
else:
time.sleep(howlong)
def timeout(self):
self._lock.acquire()
try:
client = self._client
deadline = self._deadline
self._active.clear()
self._client = None
self._deadline = None
finally:
self._lock.release()
if client is None:
return
elapsed = time.time() - (deadline - self._timeout)
client.log("Transaction timeout after %d seconds" % int(elapsed))
self._trigger.pull_trigger(lambda: client.connection.close())
def run_in_thread(method, *args): def run_in_thread(method, *args):
t = SlowMethodThread(method, args) t = SlowMethodThread(method, args)
t.start() t.start()
......
...@@ -60,6 +60,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -60,6 +60,7 @@ class CommonSetupTearDown(StorageTestBase):
__super_tearDown = StorageTestBase.tearDown __super_tearDown = StorageTestBase.tearDown
keep = 0 keep = 0
invq = None invq = None
timeout = None
def setUp(self): def setUp(self):
"""Test setup for connection tests. """Test setup for connection tests.
...@@ -131,7 +132,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -131,7 +132,7 @@ class CommonSetupTearDown(StorageTestBase):
path = "%s.%d" % (self.file, index) path = "%s.%d" % (self.file, index)
conf = self.getConfig(path, create, read_only) conf = self.getConfig(path, create, read_only)
zeoport, adminaddr, pid = forker.start_zeo_server( zeoport, adminaddr, pid = forker.start_zeo_server(
conf, addr, ro_svr, self.keep, self.invq) conf, addr, ro_svr, self.keep, self.invq, self.timeout)
self._pids.append(pid) self._pids.append(pid)
self._servers.append(adminaddr) self._servers.append(adminaddr)
...@@ -674,6 +675,29 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -674,6 +675,29 @@ class ReconnectionTests(CommonSetupTearDown):
perstorage.close() perstorage.close()
class TimeoutTests(CommonSetupTearDown):
timeout = 1
def checkTimeout(self):
storage = self.openClientStorage()
txn = Transaction()
storage.tpc_begin(txn)
storage.tpc_vote(txn)
time.sleep(2)
self.assertRaises(Disconnected, storage.tpc_finish, txn)
def checkTimeoutOnAbort(self):
storage = self.openClientStorage()
txn = Transaction()
storage.tpc_begin(txn)
storage.tpc_vote(txn)
storage.tpc_abort(txn)
def checkTimeoutOnAbortNoLock(self):
storage = self.openClientStorage()
txn = Transaction()
storage.tpc_begin(txn)
storage.tpc_abort(txn)
class MSTThread(threading.Thread): class MSTThread(threading.Thread):
......
...@@ -51,7 +51,8 @@ def get_port(): ...@@ -51,7 +51,8 @@ def get_port():
raise RuntimeError, "Can't find port" raise RuntimeError, "Can't find port"
def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None): def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None,
timeout=None):
"""Start a ZEO server in a separate process. """Start a ZEO server in a separate process.
Returns the ZEO port, the test server port, and the pid. Returns the ZEO port, the test server port, and the pid.
...@@ -79,6 +80,8 @@ def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None): ...@@ -79,6 +80,8 @@ def start_zeo_server(conf, addr=None, ro_svr=0, keep=0, invq=None):
args.append('-k') args.append('-k')
if invq: if invq:
args += ['-Q', str(invq)] args += ['-Q', str(invq)]
if timeout:
args += ['-T', str(timeout)]
args.append(str(port)) args.append(str(port))
d = os.environ.copy() d = os.environ.copy()
d['PYTHONPATH'] = os.pathsep.join(sys.path) d['PYTHONPATH'] = os.pathsep.join(sys.path)
......
...@@ -59,6 +59,12 @@ class FileStorageReconnectionTests( ...@@ -59,6 +59,12 @@ class FileStorageReconnectionTests(
): ):
"""FileStorage-specific re-connection tests.""" """FileStorage-specific re-connection tests."""
class FileStorageTimeoutTests(
FileStorageConfig,
ConnectionTests.TimeoutTests
):
# doesn't test anything that is storage-specific
pass
class BDBConnectionTests( class BDBConnectionTests(
BerkeleyStorageConfig, BerkeleyStorageConfig,
...@@ -74,7 +80,8 @@ class BDBReconnectionTests( ...@@ -74,7 +80,8 @@ class BDBReconnectionTests(
"""Berkeley storage re-connection tests.""" """Berkeley storage re-connection tests."""
test_classes = [FileStorageConnectionTests, FileStorageReconnectionTests] test_classes = [FileStorageConnectionTests, FileStorageReconnectionTests,
FileStorageTimeoutTests]
import BDBStorage import BDBStorage
if BDBStorage.is_available: if BDBStorage.is_available:
......
...@@ -117,8 +117,9 @@ def main(): ...@@ -117,8 +117,9 @@ def main():
keep = 0 keep = 0
configfile = None configfile = None
invalidation_queue_size = 100 invalidation_queue_size = 100
transaction_timeout = None
# Parse the arguments and let getopt.error percolate # Parse the arguments and let getopt.error percolate
opts, args = getopt.getopt(sys.argv[1:], 'rkC:Q:') opts, args = getopt.getopt(sys.argv[1:], 'rkC:Q:T:')
for opt, arg in opts: for opt, arg in opts:
if opt == '-r': if opt == '-r':
ro_svr = 1 ro_svr = 1
...@@ -128,6 +129,8 @@ def main(): ...@@ -128,6 +129,8 @@ def main():
configfile = arg configfile = arg
elif opt == '-Q': elif opt == '-Q':
invalidation_queue_size = int(arg) invalidation_queue_size = int(arg)
elif opt == '-T':
transaction_timeout = int(arg)
# Open the config file and let ZConfig parse the data there. Then remove # Open the config file and let ZConfig parse the data there. Then remove
# the config file, otherwise we'll leave turds. # the config file, otherwise we'll leave turds.
fp = open(configfile, 'r') fp = open(configfile, 'r')
...@@ -150,7 +153,8 @@ def main(): ...@@ -150,7 +153,8 @@ def main():
log(label, 'creating the storage server') log(label, 'creating the storage server')
serv = ZEO.StorageServer.StorageServer( serv = ZEO.StorageServer.StorageServer(
addr, {'1': storage}, ro_svr, addr, {'1': storage}, ro_svr,
invalidation_queue_size=invalidation_queue_size) invalidation_queue_size=invalidation_queue_size,
transaction_timeout=transaction_timeout)
log(label, 'entering ThreadedAsync loop') log(label, 'entering ThreadedAsync loop')
ThreadedAsync.LoopCallback.loop() ThreadedAsync.LoopCallback.loop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment