Commit e8bec203 authored by Jim Fulton's avatar Jim Fulton

Merge remote-tracking branch 'origin/asyncio' into uvloop

Conflicts:
	setup.py
parents b76fea7d e5653d43
...@@ -34,6 +34,7 @@ import BTrees.OOBTree ...@@ -34,6 +34,7 @@ import BTrees.OOBTree
import zc.lockfile import zc.lockfile
import ZODB import ZODB
import ZODB.BaseStorage import ZODB.BaseStorage
import ZODB.ConflictResolution
import ZODB.interfaces import ZODB.interfaces
import zope.interface import zope.interface
import six import six
...@@ -53,10 +54,7 @@ logger = logging.getLogger(__name__) ...@@ -53,10 +54,7 @@ logger = logging.getLogger(__name__)
# max signed 64-bit value ~ infinity :) Signed cuz LBTree and TimeStamp # max signed 64-bit value ~ infinity :) Signed cuz LBTree and TimeStamp
m64 = b'\x7f\xff\xff\xff\xff\xff\xff\xff' m64 = b'\x7f\xff\xff\xff\xff\xff\xff\xff'
try: from ZODB.ConflictResolution import ResolvedSerial
from ZODB.ConflictResolution import ResolvedSerial
except ImportError:
ResolvedSerial = 'rs'
def tid2time(tid): def tid2time(tid):
return str(TimeStamp(tid)) return str(TimeStamp(tid))
...@@ -77,7 +75,8 @@ def get_timestamp(prev_ts=None): ...@@ -77,7 +75,8 @@ def get_timestamp(prev_ts=None):
MB = 1024**2 MB = 1024**2
class ClientStorage(object): @zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage)
class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
"""A storage class that is a network client to a remote storage. """A storage class that is a network client to a remote storage.
This is a faithful implementation of the Storage API. This is a faithful implementation of the Storage API.
...@@ -333,6 +332,7 @@ class ClientStorage(object): ...@@ -333,6 +332,7 @@ class ClientStorage(object):
The storage isn't really ready to use until after this call. The storage isn't really ready to use until after this call.
""" """
super(ClientStorage, self).registerDB(db)
self._db = db self._db = db
def is_connected(self, test=False): def is_connected(self, test=False):
...@@ -724,18 +724,51 @@ class ClientStorage(object): ...@@ -724,18 +724,51 @@ class ClientStorage(object):
""" """
tbuf = self._check_trans(txn, 'tpc_vote') tbuf = self._check_trans(txn, 'tpc_vote')
try: try:
self._call('vote', id(txn))
conflicts = True
vote_attempts = 0
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
conflicts = False
for oid in self._call('vote', id(txn)) or ():
if isinstance(oid, dict):
# Conflict, let's try to resolve it
conflicts = True
conflict = oid
oid = conflict['oid']
committed, read = conflict['serials']
data = self.tryToResolveConflict(
oid, committed, read, conflict['data'])
self._async('storea', oid, committed, data, id(txn))
tbuf.resolve(oid, data)
else:
tbuf.serial(oid, ResolvedSerial)
vote_attempts += 1
except POSException.StorageTransactionError: except POSException.StorageTransactionError:
# Hm, we got disconnected and reconnected bwtween # Hm, we got disconnected and reconnected bwtween
# _check_trans and voting. Let's chack the transaction again: # _check_trans and voting. Let's chack the transaction again:
tbuf = self._check_trans(txn, 'tpc_vote') self._check_trans(txn, 'tpc_vote')
raise
except POSException.ConflictError as err:
oid = getattr(err, 'oid', None)
if oid is not None:
# This is a band-aid to help recover from a situation
# that shouldn't happen. A Client somehow misses some
# invalidations and has out of date data in its
# cache. We need some whay to invalidate the cache
# entry without invalidations. So, if we see a
# (unresolved) conflict error, we assume that the
# cache entry is bad and invalidate it.
self._cache.invalidate(oid, None)
raise raise
if tbuf.exception: if tbuf.exception:
raise tbuf.exception raise tbuf.exception
if tbuf.serials: if tbuf.server_resolved or tbuf.client_resolved:
return list(tbuf.serials.items()) return list(tbuf.server_resolved) + list(tbuf.client_resolved)
else: else:
return None return None
...@@ -830,6 +863,8 @@ class ClientStorage(object): ...@@ -830,6 +863,8 @@ class ClientStorage(object):
self._update_blob_cache(tbuf, tid) self._update_blob_cache(tbuf, tid)
return tid
def _update_blob_cache(self, tbuf, tid): def _update_blob_cache(self, tbuf, tid):
"""Internal helper move blobs updated by a transaction to the cache. """Internal helper move blobs updated by a transaction to the cache.
""" """
......
...@@ -85,10 +85,11 @@ class ZEOStorage: ...@@ -85,10 +85,11 @@ class ZEOStorage:
blob_tempfile = None blob_tempfile = None
log_label = 'unconnected' log_label = 'unconnected'
locked = False # Don't have storage lock locked = False # Don't have storage lock
verifying = store_failed = 0 verifying = 0
def __init__(self, server, read_only=0): def __init__(self, server, read_only=0):
self.server = server self.server = server
self.client_conflict_resolution = server.client_conflict_resolution
# timeout and stats will be initialized in register() # timeout and stats will be initialized in register()
self.read_only = read_only self.read_only = read_only
# The authentication protocol may define extra methods. # The authentication protocol may define extra methods.
...@@ -334,12 +335,12 @@ class ZEOStorage: ...@@ -334,12 +335,12 @@ class ZEOStorage:
t._extension = ext t._extension = ext
self.serials = [] self.serials = []
self.conflicts = {}
self.invalidated = [] self.invalidated = []
self.txnlog = CommitLog() self.txnlog = CommitLog()
self.blob_log = [] self.blob_log = []
self.tid = tid self.tid = tid
self.status = status self.status = status
self.store_failed = 0
self.stats.active_txns += 1 self.stats.active_txns += 1
# Assign the transaction attribute last. This is so we don't # Assign the transaction attribute last. This is so we don't
...@@ -414,6 +415,7 @@ class ZEOStorage: ...@@ -414,6 +415,7 @@ class ZEOStorage:
self.locked, delay = self.server.lock_storage(self, delay) self.locked, delay = self.server.lock_storage(self, delay)
if self.locked: if self.locked:
result = None
try: try:
self.log( self.log(
"Preparing to commit transaction: %d objects, %d bytes" "Preparing to commit transaction: %d objects, %d bytes"
...@@ -427,38 +429,56 @@ class ZEOStorage: ...@@ -427,38 +429,56 @@ class ZEOStorage:
self.storage.tpc_begin(self.transaction) self.storage.tpc_begin(self.transaction)
for op, args in self.txnlog: for op, args in self.txnlog:
if not getattr(self, op)(*args): getattr(self, op)(*args)
break
# Blob support # Blob support
while self.blob_log and not self.store_failed: while self.blob_log:
oid, oldserial, data, blobfilename = self.blob_log.pop() oid, oldserial, data, blobfilename = self.blob_log.pop()
self._store(oid, oldserial, data, blobfilename) self._store(oid, oldserial, data, blobfilename)
if not self.store_failed:
# Only call tpc_vote of no store call failed,
# otherwise the serialnos() call will deliver an
# exception that will be handled by the client in
# its tpc_vote() method.
serials = self.storage.tpc_vote(self.transaction)
if serials:
self.serials.extend(serials)
self.connection.async('serialnos', self.serials)
except Exception: if not self.conflicts:
try:
serials = self.storage.tpc_vote(self.transaction)
except ConflictError as err:
if (self.client_conflict_resolution and
err.oid and err.serials and err.data
):
self.conflicts[err.oid] = dict(
oid=err.oid, serials=err.serials, data=err.data)
else:
raise
else:
if serials:
self.serials.extend(serials)
result = self.serials
if self.conflicts:
result = list(self.conflicts.values())
self.storage.tpc_abort(self.transaction)
self.server.unlock_storage(self)
self.locked = False
self.server.stop_waiting(self)
except Exception as err:
self.storage.tpc_abort(self.transaction) self.storage.tpc_abort(self.transaction)
self._clear_transaction() self._clear_transaction()
if isinstance(err, ConflictError):
self.stats.conflicts += 1
self.log("conflict error %s" % err, BLATHER)
if not isinstance(err, TransactionError):
logger.exception("While voting")
if delay is not None: if delay is not None:
delay.error(sys.exc_info()) delay.error(sys.exc_info())
else: else:
raise raise
else: else:
if delay is not None: if delay is not None:
delay.reply(None) delay.reply(result)
else: else:
return None return result
else: else:
return delay return delay
...@@ -550,120 +570,41 @@ class ZEOStorage: ...@@ -550,120 +570,41 @@ class ZEOStorage:
self._check_tid(tid, exc=StorageTransactionError) self._check_tid(tid, exc=StorageTransactionError)
self.txnlog.undo(trans_id) self.txnlog.undo(trans_id)
def _op_error(self, oid, err, op):
self.store_failed = 1
if isinstance(err, ConflictError):
self.stats.conflicts += 1
self.log("conflict error oid=%s msg=%s" %
(oid_repr(oid), str(err)), BLATHER)
if not isinstance(err, TransactionError):
# Unexpected errors are logged and passed to the client
self.log("%s error: %s, %s" % ((op,)+ sys.exc_info()[:2]),
logging.ERROR, exc_info=True)
err = self._marshal_error(err)
# The exception is reported back as newserial for this oid
self.serials.append((oid, err))
def _delete(self, oid, serial): def _delete(self, oid, serial):
err = None self.storage.deleteObject(oid, serial, self.transaction)
try:
self.storage.deleteObject(oid, serial, self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
err = e
self._op_error(oid, err, 'delete')
return err is None
def _checkread(self, oid, serial): def _checkread(self, oid, serial):
err = None self.storage.checkCurrentSerialInTransaction(
try: oid, serial, self.transaction)
self.storage.checkCurrentSerialInTransaction(
oid, serial, self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
err = e
self._op_error(oid, err, 'checkCurrentSerialInTransaction')
return err is None
def _store(self, oid, serial, data, blobfile=None): def _store(self, oid, serial, data, blobfile=None):
err = None
try: try:
if blobfile is None: if blobfile is None:
newserial = self.storage.store( self.storage.store(oid, serial, data, '', self.transaction)
oid, serial, data, '', self.transaction)
else: else:
newserial = self.storage.storeBlob( self.storage.storeBlob(
oid, serial, data, blobfile, '', self.transaction) oid, serial, data, blobfile, '', self.transaction)
except (SystemExit, KeyboardInterrupt): except ConflictError as err:
raise if self.client_conflict_resolution and err.serials:
except Exception as error: self.conflicts[oid] = dict(
self._op_error(oid, error, 'store') oid=oid, serials=err.serials, data=data)
err = error else:
raise
else: else:
if oid in self.conflicts:
del self.conflicts[oid]
if serial != b"\0\0\0\0\0\0\0\0": if serial != b"\0\0\0\0\0\0\0\0":
self.invalidated.append(oid) self.invalidated.append(oid)
if isinstance(newserial, bytes):
newserial = [(oid, newserial)]
for oid, s in newserial or ():
if s == ResolvedSerial:
self.stats.conflicts_resolved += 1
self.log("conflict resolved oid=%s"
% oid_repr(oid), BLATHER)
self.serials.append((oid, s))
return err is None
def _restore(self, oid, serial, data, prev_txn): def _restore(self, oid, serial, data, prev_txn):
err = None self.storage.restore(oid, serial, data, '', prev_txn,
try: self.transaction)
self.storage.restore(oid, serial, data, '', prev_txn,
self.transaction)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as err:
self._op_error(oid, err, 'restore')
return err is None
def _undo(self, trans_id): def _undo(self, trans_id):
err = None tid, oids = self.storage.undo(trans_id, self.transaction)
try: self.invalidated.extend(oids)
tid, oids = self.storage.undo(trans_id, self.transaction) self.serials.extend(oids)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as e:
err = e
self._op_error(z64, err, 'undo')
else:
self.invalidated.extend(oids)
self.serials.extend((oid, ResolvedSerial) for oid in oids)
return err is None
def _marshal_error(self, error):
# Try to pickle the exception. If it can't be pickled,
# the RPC response would fail, so use something that can be pickled.
if PY3:
pickler = Pickler(BytesIO(), 3)
else:
# The pure-python version requires at least one argument (PyPy)
pickler = Pickler(0)
pickler.fast = 1
try:
pickler.dump(error)
except:
msg = "Couldn't pickle storage exception: %s" % repr(error)
self.log(msg, logging.ERROR)
error = StorageServerError(msg)
return error
# IStorageIteration support # IStorageIteration support
...@@ -771,6 +712,7 @@ class StorageServer: ...@@ -771,6 +712,7 @@ class StorageServer:
invalidation_age=None, invalidation_age=None,
transaction_timeout=None, transaction_timeout=None,
ssl=None, ssl=None,
client_conflict_resolution=False,
): ):
"""StorageServer constructor. """StorageServer constructor.
...@@ -841,15 +783,23 @@ class StorageServer: ...@@ -841,15 +783,23 @@ class StorageServer:
for name, storage in storages.items(): for name, storage in storages.items():
self._setup_invq(name, storage) self._setup_invq(name, storage)
storage.registerDB(StorageServerDB(self, name)) storage.registerDB(StorageServerDB(self, name))
if client_conflict_resolution:
# XXX this may go away later, when storages grow
# configuration for this.
storage.tryToResolveConflict = never_resolve_conflict
self.invalidation_age = invalidation_age self.invalidation_age = invalidation_age
self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]} self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]}
self.acceptor = Acceptor(self, addr, ssl) self.client_conflict_resolution = client_conflict_resolution
if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr if addr is not None:
else: self.acceptor = Acceptor(self, addr, ssl)
self.addr = addr if isinstance(addr, tuple) and addr[0]:
self.loop = self.acceptor.loop self.addr = self.acceptor.addr
ZODB.event.notify(Serving(self, address=self.acceptor.addr)) else:
self.addr = addr
self.loop = self.acceptor.loop
ZODB.event.notify(Serving(self, address=self.acceptor.addr))
self.stats = {} self.stats = {}
self.timeouts = {} self.timeouts = {}
for name in self.storages.keys(): for name in self.storages.keys():
...@@ -1383,7 +1333,7 @@ class Serving(ServerEvent): ...@@ -1383,7 +1333,7 @@ class Serving(ServerEvent):
class Closed(ServerEvent): class Closed(ServerEvent):
pass pass
default_cert_authenticate = 'SIGNED' def never_resolve_conflict(oid, committedSerial, oldSerial, newpickle,
def ssl_config(section): committedData=b''):
from .sslconfig import ssl_config raise ConflictError(oid=oid, serials=(committedSerial, oldSerial),
return ssl_config(section, True) data=newpickle)
...@@ -46,7 +46,8 @@ class TransactionBuffer: ...@@ -46,7 +46,8 @@ class TransactionBuffer:
# stored are builtin types -- strings or None. # stored are builtin types -- strings or None.
self.pickler = Pickler(self.file, 1) self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1 self.pickler.fast = 1
self.serials = {} # processed { oid -> serial } self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number}
self.exception = None self.exception = None
def close(self): def close(self):
...@@ -59,12 +60,17 @@ class TransactionBuffer: ...@@ -59,12 +60,17 @@ class TransactionBuffer:
# Estimate per-record cache size # Estimate per-record cache size
self.size = self.size + (data and len(data) or 0) + 31 self.size = self.size + (data and len(data) or 0) + 31
def resolve(self, oid, data):
"""Record client-resolved data
"""
self.store(oid, data)
self.client_resolved[oid] = self.count - 1
def serial(self, oid, serial): def serial(self, oid, serial):
if isinstance(serial, Exception): if isinstance(serial, Exception):
self.exception = serial self.exception = serial # This transaction will never be committed
self.serials[oid] = None elif serial == ResolvedSerial:
else: self.server_resolved.add(oid)
self.serials[oid] = serial
def storeBlob(self, oid, blobfilename): def storeBlob(self, oid, blobfilename):
self.blobs.append((oid, blobfilename)) self.blobs.append((oid, blobfilename))
...@@ -72,7 +78,8 @@ class TransactionBuffer: ...@@ -72,7 +78,8 @@ class TransactionBuffer:
def __iter__(self): def __iter__(self):
self.file.seek(0) self.file.seek(0)
unpickler = Unpickler(self.file) unpickler = Unpickler(self.file)
serials = self.serials server_resolved = self.server_resolved
client_resolved = self.client_resolved
# Gaaaa, this is awkward. There can be entries in serials that # Gaaaa, this is awkward. There can be entries in serials that
# aren't in the buffer, because undo. Entries can be repeated # aren't in the buffer, because undo. Entries can be repeated
...@@ -82,10 +89,11 @@ class TransactionBuffer: ...@@ -82,10 +89,11 @@ class TransactionBuffer:
seen = set() seen = set()
for i in range(self.count): for i in range(self.count):
oid, data = unpickler.load() oid, data = unpickler.load()
seen.add(oid) if client_resolved.get(oid, i) == i:
yield oid, data, serials.get(oid) == ResolvedSerial seen.add(oid)
yield oid, data, oid in server_resolved
# We may have leftover serials because undo # We may have leftover oids because undo
for oid, serial in serials.items(): for oid in server_resolved:
if oid not in seen: if oid not in seen:
yield oid, None, serial == ResolvedSerial yield oid, None, True
from struct import unpack from struct import unpack
import asyncio import asyncio
import logging import logging
import socket
import sys
from .marshal import encoder from .marshal import encoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6
class Protocol(asyncio.Protocol): class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface """asyncio low-level ZEO base interface
""" """
...@@ -41,7 +45,15 @@ class Protocol(asyncio.Protocol): ...@@ -41,7 +45,15 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) logger.info("Connected %s", self)
if sys.version_info < (3, 6):
sock = transport.get_extra_info('socket')
if sock is not None and sock.family in INET_FAMILIES:
# See https://bugs.python.org/issue27456 :(
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
self.transport = transport self.transport = transport
paused = self.paused paused = self.paused
output = self.output output = self.output
append = output.append append = output.append
......
...@@ -17,7 +17,7 @@ class ServerProtocol(base.Protocol): ...@@ -17,7 +17,7 @@ class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface """asyncio low-level ZEO server interface
""" """
protocols = b'Z4', b'Z5' protocols = (b'Z5', )
name = 'server protocol' name = 'server protocol'
methods = set(('register', )) methods = set(('register', ))
...@@ -162,7 +162,7 @@ class Delay: ...@@ -162,7 +162,7 @@ class Delay:
def error(self, exc_info): def error(self, exc_info):
self.sent = 'error' self.sent = 'error'
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info) logger.error("Error raised in delayed method", exc_info=exc_info)
self.protocol.send_error(self.msgid, exc_info[1]) self.protocol.send_error(self.msgid, exc_info[1])
def __repr__(self): def __repr__(self):
...@@ -199,7 +199,6 @@ class MTDelay(Delay): ...@@ -199,7 +199,6 @@ class MTDelay(Delay):
def error(self, exc_info): def error(self, exc_info):
self.ready.wait() self.ready.wait()
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.protocol.call_soon_threadsafe(Delay.error, self, exc_info) self.protocol.call_soon_threadsafe(Delay.error, self, exc_info)
......
...@@ -86,7 +86,7 @@ class Transport: ...@@ -86,7 +86,7 @@ class Transport:
capacity = 1 << 64 capacity = 1 << 64
paused = False paused = False
extra = dict(peername='1.2.3.4', sockname=('127.0.0.1', 4200)) extra = dict(peername='1.2.3.4', sockname=('127.0.0.1', 4200), socket=None)
def __init__(self, protocol): def __init__(self, protocol):
self.data = [] self.data = []
......
...@@ -750,7 +750,7 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -750,7 +750,7 @@ class ServerTests(Base, setupstack.TestCase):
self.target = protocol.zeo_storage self.target = protocol.zeo_storage
if finish: if finish:
self.assertEqual(self.pop(parse=False), best_protocol_version) self.assertEqual(self.pop(parse=False), best_protocol_version)
protocol.data_received(sized(b'Z4')) protocol.data_received(sized(b'Z5'))
return protocol return protocol
message_id = 0 message_id = 0
...@@ -788,9 +788,9 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -788,9 +788,9 @@ class ServerTests(Base, setupstack.TestCase):
self.assertEqual(self.pop(parse=False), best_protocol_version) self.assertEqual(self.pop(parse=False), best_protocol_version)
# The client sends it's protocol: # The client sends it's protocol:
protocol.data_received(sized(b'Z4')) protocol.data_received(sized(b'Z5'))
self.assertEqual(protocol.protocol_version, b'Z4') self.assertEqual(protocol.protocol_version, b'Z5')
protocol.zeo_storage.notify_connected.assert_called_once_with(protocol) protocol.zeo_storage.notify_connected.assert_called_once_with(protocol)
......
...@@ -33,7 +33,7 @@ import time ...@@ -33,7 +33,7 @@ import time
import ZODB.fsIndex import ZODB.fsIndex
import zc.lockfile import zc.lockfile
from ZODB.utils import p64, u64, z64 from ZODB.utils import p64, u64, z64, RLock
import six import six
from ._compat import PYPY from ._compat import PYPY
...@@ -182,6 +182,8 @@ class ClientCache(object): ...@@ -182,6 +182,8 @@ class ClientCache(object):
# currentofs. # currentofs.
self.currentofs = ZEC_HEADER_SIZE self.currentofs = ZEC_HEADER_SIZE
self._lock = RLock()
# self.f is the open file object. # self.f is the open file object.
# When we're not reusing an existing file, self.f is left None # When we're not reusing an existing file, self.f is left None
# here -- the scan() method must be called then to open the file # here -- the scan() method must be called then to open the file
...@@ -239,9 +241,10 @@ class ClientCache(object): ...@@ -239,9 +241,10 @@ class ClientCache(object):
return self return self
def clear(self): def clear(self):
self.f.seek(ZEC_HEADER_SIZE) with self._lock:
self.f.truncate() self.f.seek(ZEC_HEADER_SIZE)
self._initfile(ZEC_HEADER_SIZE) self.f.truncate()
self._initfile(ZEC_HEADER_SIZE)
## ##
# Scan the current contents of the cache file, calling `install` # Scan the current contents of the cache file, calling `install`
...@@ -451,26 +454,28 @@ class ClientCache(object): ...@@ -451,26 +454,28 @@ class ClientCache(object):
# new tid must be strictly greater than our current idea of the most # new tid must be strictly greater than our current idea of the most
# recent tid. # recent tid.
def setLastTid(self, tid): def setLastTid(self, tid):
if (not tid) or (tid == z64): with self._lock:
return if (not tid) or (tid == z64):
if (tid <= self.tid) and self._len: return
if tid == self.tid: if (tid <= self.tid) and self._len:
return # Be a little forgiving if tid == self.tid:
raise ValueError("new last tid (%s) must be greater than " return # Be a little forgiving
"previous one (%s)" raise ValueError("new last tid (%s) must be greater than "
% (u64(tid), u64(self.tid))) "previous one (%s)"
assert isinstance(tid, bytes) and len(tid) == 8, tid % (u64(tid), u64(self.tid)))
self.tid = tid assert isinstance(tid, bytes) and len(tid) == 8, tid
self.f.seek(len(magic)) self.tid = tid
self.f.write(tid) self.f.seek(len(magic))
self.f.flush() self.f.write(tid)
self.f.flush()
## ##
# Return the last transaction seen by the cache. # Return the last transaction seen by the cache.
# @return a transaction id # @return a transaction id
# @defreturn string, or 8 nulls if no transaction is yet known # @defreturn string, or 8 nulls if no transaction is yet known
def getLastTid(self): def getLastTid(self):
return self.tid with self._lock:
return self.tid
## ##
# Return the current data record for oid. # Return the current data record for oid.
...@@ -479,52 +484,54 @@ class ClientCache(object): ...@@ -479,52 +484,54 @@ class ClientCache(object):
# in the cache # in the cache
# @defreturn 3-tuple: (string, string, string) # @defreturn 3-tuple: (string, string, string)
def load(self, oid, before_tid=None): def load(self, oid, before_tid=None):
ofs = self.current.get(oid) with self._lock:
if ofs is None: ofs = self.current.get(oid)
self._trace(0x20, oid) if ofs is None:
return None self._trace(0x20, oid)
self.f.seek(ofs) return None
read = self.f.read
status = read(1)
assert status == b'a', (ofs, self.f.tell(), oid)
size, saved_oid, tid, end_tid, lver, ldata = unpack(
">I8s8s8sHI", read(34))
assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid)
assert end_tid == z64, (ofs, self.f.tell(), oid, tid, end_tid)
assert lver == 0, "Versions aren't supported"
if before_tid and tid >= before_tid:
return None
data = read(ldata)
assert len(data) == ldata, (ofs, self.f.tell(), oid, len(data), ldata)
# WARNING: The following assert changes the file position.
# We must not depend on this below or we'll fail in optimized mode.
assert read(8) == oid, (ofs, self.f.tell(), oid)
self._n_accesses += 1
self._trace(0x22, oid, tid, end_tid, ldata)
ofsofs = self.currentofs - ofs
if ofsofs < 0:
ofsofs += self.maxsize
if (ofsofs > self.rearrange and
self.maxsize > 10*len(data) and
size > 4):
# The record is far back and might get evicted, but it's
# valuable, so move it forward.
# Remove fromn old loc:
del self.current[oid]
self.f.seek(ofs) self.f.seek(ofs)
self.f.write(b'f'+pack(">I", size)) read = self.f.read
status = read(1)
assert status == b'a', (ofs, self.f.tell(), oid)
size, saved_oid, tid, end_tid, lver, ldata = unpack(
">I8s8s8sHI", read(34))
assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid)
assert end_tid == z64, (ofs, self.f.tell(), oid, tid, end_tid)
assert lver == 0, "Versions aren't supported"
if before_tid and tid >= before_tid:
return None
data = read(ldata)
assert len(data) == ldata, (
ofs, self.f.tell(), oid, len(data), ldata)
# WARNING: The following assert changes the file position.
# We must not depend on this below or we'll fail in optimized mode.
assert read(8) == oid, (ofs, self.f.tell(), oid)
self._n_accesses += 1
self._trace(0x22, oid, tid, end_tid, ldata)
# Write to new location: ofsofs = self.currentofs - ofs
self._store(oid, tid, None, data, size) if ofsofs < 0:
ofsofs += self.maxsize
return data, tid if (ofsofs > self.rearrange and
self.maxsize > 10*len(data) and
size > 4):
# The record is far back and might get evicted, but it's
# valuable, so move it forward.
# Remove fromn old loc:
del self.current[oid]
self.f.seek(ofs)
self.f.write(b'f'+pack(">I", size))
# Write to new location:
self._store(oid, tid, None, data, size)
return data, tid
## ##
# Return a non-current revision of oid that was current before tid. # Return a non-current revision of oid that was current before tid.
...@@ -533,54 +540,56 @@ class ClientCache(object): ...@@ -533,54 +540,56 @@ class ClientCache(object):
# @return data record, serial number, start tid, and end tid # @return data record, serial number, start tid, and end tid
# @defreturn 4-tuple: (string, string, string, string) # @defreturn 4-tuple: (string, string, string, string)
def loadBefore(self, oid, before_tid): def loadBefore(self, oid, before_tid):
noncurrent_for_oid = self.noncurrent.get(u64(oid)) with self._lock:
if noncurrent_for_oid is None: noncurrent_for_oid = self.noncurrent.get(u64(oid))
result = self.load(oid, before_tid) if noncurrent_for_oid is None:
if result: result = self.load(oid, before_tid)
return result[0], result[1], None if result:
else: return result[0], result[1], None
self._trace(0x24, oid, "", before_tid) else:
return result self._trace(0x24, oid, "", before_tid)
return result
items = noncurrent_for_oid.items(None, u64(before_tid)-1)
if not items: items = noncurrent_for_oid.items(None, u64(before_tid)-1)
result = self.load(oid, before_tid) if not items:
if result: result = self.load(oid, before_tid)
return result[0], result[1], None if result:
else: return result[0], result[1], None
self._trace(0x24, oid, "", before_tid) else:
return result self._trace(0x24, oid, "", before_tid)
return result
tid, ofs = items[-1] tid, ofs = items[-1]
self.f.seek(ofs) self.f.seek(ofs)
read = self.f.read read = self.f.read
status = read(1) status = read(1)
assert status == b'a', (ofs, self.f.tell(), oid, before_tid) assert status == b'a', (ofs, self.f.tell(), oid, before_tid)
size, saved_oid, saved_tid, end_tid, lver, ldata = unpack( size, saved_oid, saved_tid, end_tid, lver, ldata = unpack(
">I8s8s8sHI", read(34)) ">I8s8s8sHI", read(34))
assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid) assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid)
assert saved_tid == p64(tid), (ofs, self.f.tell(), oid, saved_tid, tid) assert saved_tid == p64(tid), (
assert end_tid != z64, (ofs, self.f.tell(), oid) ofs, self.f.tell(), oid, saved_tid, tid)
assert lver == 0, "Versions aren't supported" assert end_tid != z64, (ofs, self.f.tell(), oid)
data = read(ldata) assert lver == 0, "Versions aren't supported"
assert len(data) == ldata, (ofs, self.f.tell()) data = read(ldata)
assert len(data) == ldata, (ofs, self.f.tell())
# WARNING: The following assert changes the file position.
# We must not depend on this below or we'll fail in optimized mode. # WARNING: The following assert changes the file position.
assert read(8) == oid, (ofs, self.f.tell(), oid) # We must not depend on this below or we'll fail in optimized mode.
assert read(8) == oid, (ofs, self.f.tell(), oid)
if end_tid < before_tid:
result = self.load(oid, before_tid) if end_tid < before_tid:
if result: result = self.load(oid, before_tid)
return result[0], result[1], None if result:
else: return result[0], result[1], None
self._trace(0x24, oid, "", before_tid) else:
return result self._trace(0x24, oid, "", before_tid)
return result
self._n_accesses += 1 self._n_accesses += 1
self._trace(0x26, oid, "", saved_tid) self._trace(0x26, oid, "", saved_tid)
return data, saved_tid, end_tid return data, saved_tid, end_tid
## ##
# Store a new data record in the cache. # Store a new data record in the cache.
...@@ -591,45 +600,48 @@ class ClientCache(object): ...@@ -591,45 +600,48 @@ class ClientCache(object):
# current. # current.
# @param data the actual data # @param data the actual data
def store(self, oid, start_tid, end_tid, data): def store(self, oid, start_tid, end_tid, data):
seek = self.f.seek with self._lock:
if end_tid is None: seek = self.f.seek
ofs = self.current.get(oid) if end_tid is None:
if ofs: ofs = self.current.get(oid)
seek(ofs) if ofs:
read = self.f.read seek(ofs)
status = read(1) read = self.f.read
assert status == b'a', (ofs, self.f.tell(), oid) status = read(1)
size, saved_oid, saved_tid, end_tid = unpack( assert status == b'a', (ofs, self.f.tell(), oid)
">I8s8s8s", read(28)) size, saved_oid, saved_tid, end_tid = unpack(
assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid) ">I8s8s8s", read(28))
assert end_tid == z64, (ofs, self.f.tell(), oid) assert saved_oid == oid, (
if saved_tid == start_tid: ofs, self.f.tell(), oid, saved_oid)
assert end_tid == z64, (ofs, self.f.tell(), oid)
if saved_tid == start_tid:
return
raise ValueError("already have current data for oid")
else:
noncurrent_for_oid = self.noncurrent.get(u64(oid))
if noncurrent_for_oid and (
u64(start_tid) in noncurrent_for_oid):
return return
raise ValueError("already have current data for oid")
else:
noncurrent_for_oid = self.noncurrent.get(u64(oid))
if noncurrent_for_oid and (u64(start_tid) in noncurrent_for_oid):
return
size = allocated_record_overhead + len(data) size = allocated_record_overhead + len(data)
# A number of cache simulation experiments all concluded that the # A number of cache simulation experiments all concluded that the
# 2nd-level ZEO cache got a much higher hit rate if "very large" # 2nd-level ZEO cache got a much higher hit rate if "very large"
# objects simply weren't cached. For now, we ignore the request # objects simply weren't cached. For now, we ignore the request
# only if the entire cache file is too small to hold the object. # only if the entire cache file is too small to hold the object.
if size >= min(max_block_size, self.maxsize - ZEC_HEADER_SIZE): if size >= min(max_block_size, self.maxsize - ZEC_HEADER_SIZE):
return return
self._n_adds += 1 self._n_adds += 1
self._n_added_bytes += size self._n_added_bytes += size
self._len += 1 self._len += 1
self._store(oid, start_tid, end_tid, data, size) self._store(oid, start_tid, end_tid, data, size)
if end_tid: if end_tid:
self._trace(0x54, oid, start_tid, end_tid, dlen=len(data)) self._trace(0x54, oid, start_tid, end_tid, dlen=len(data))
else: else:
self._trace(0x52, oid, start_tid, dlen=len(data)) self._trace(0x52, oid, start_tid, dlen=len(data))
def _store(self, oid, start_tid, end_tid, data, size): def _store(self, oid, start_tid, end_tid, data, size):
# Low-level store used by store and load # Low-level store used by store and load
...@@ -696,35 +708,37 @@ class ClientCache(object): ...@@ -696,35 +708,37 @@ class ClientCache(object):
# - tid the id of the transaction that wrote a new revision of oid, # - tid the id of the transaction that wrote a new revision of oid,
# or None to forget all cached info about oid. # or None to forget all cached info about oid.
def invalidate(self, oid, tid): def invalidate(self, oid, tid):
ofs = self.current.get(oid) with self._lock:
if ofs is None: ofs = self.current.get(oid)
# 0x10 == invalidate (miss) if ofs is None:
self._trace(0x10, oid, tid) # 0x10 == invalidate (miss)
return self._trace(0x10, oid, tid)
return
self.f.seek(ofs)
read = self.f.read
status = read(1)
assert status == b'a', (ofs, self.f.tell(), oid)
size, saved_oid, saved_tid, end_tid = unpack(">I8s8s8s", read(28))
assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid)
assert end_tid == z64, (ofs, self.f.tell(), oid)
del self.current[oid]
if tid is None:
self.f.seek(ofs) self.f.seek(ofs)
self.f.write(b'f'+pack(">I", size)) read = self.f.read
# 0x1E = invalidate (hit, discarding current or non-current) status = read(1)
self._trace(0x1E, oid, tid) assert status == b'a', (ofs, self.f.tell(), oid)
self._len -= 1 size, saved_oid, saved_tid, end_tid = unpack(">I8s8s8s", read(28))
else: assert saved_oid == oid, (ofs, self.f.tell(), oid, saved_oid)
if tid == saved_tid: assert end_tid == z64, (ofs, self.f.tell(), oid)
logger.warning("Ignoring invalidation with same tid as current") del self.current[oid]
return if tid is None:
self.f.seek(ofs+21) self.f.seek(ofs)
self.f.write(tid) self.f.write(b'f'+pack(">I", size))
self._set_noncurrent(oid, saved_tid, ofs) # 0x1E = invalidate (hit, discarding current or non-current)
# 0x1C = invalidate (hit, saving non-current) self._trace(0x1E, oid, tid)
self._trace(0x1C, oid, tid) self._len -= 1
else:
if tid == saved_tid:
logger.warning(
"Ignoring invalidation with same tid as current")
return
self.f.seek(ofs+21)
self.f.write(tid)
self._set_noncurrent(oid, saved_tid, ofs)
# 0x1C = invalidate (hit, saving non-current)
self._trace(0x1C, oid, tid)
## ##
# Generates (oid, serial) oairs for all objects in the # Generates (oid, serial) oairs for all objects in the
......
...@@ -24,8 +24,7 @@ class StaleCache(object): ...@@ -24,8 +24,7 @@ class StaleCache(object):
class IClientCache(zope.interface.Interface): class IClientCache(zope.interface.Interface):
"""Client cache interface. """Client cache interface.
Note that caches need not be thread safe, fpr the most part, Note that caches need to be thread safe.
except for getLastTid, which may be called from multiple threads.
""" """
def close(): def close():
......
...@@ -98,6 +98,9 @@ class ZEOOptionsMixin: ...@@ -98,6 +98,9 @@ class ZEOOptionsMixin:
self.add("address", "zeo.address.address", self.add("address", "zeo.address.address",
required="no server address specified; use -a or -C") required="no server address specified; use -a or -C")
self.add("read_only", "zeo.read_only", default=0) self.add("read_only", "zeo.read_only", default=0)
self.add("client_conflict_resolution",
"zeo.client_conflict_resolution",
default=0)
self.add("invalidation_queue_size", "zeo.invalidation_queue_size", self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
default=100) default=100)
self.add("invalidation_age", "zeo.invalidation_age") self.add("invalidation_age", "zeo.invalidation_age")
...@@ -339,6 +342,7 @@ def create_server(storages, options): ...@@ -339,6 +342,7 @@ def create_server(storages, options):
options.address, options.address,
storages, storages,
read_only = options.read_only, read_only = options.read_only,
client_conflict_resolution=options.client_conflict_resolution,
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout = options.transaction_timeout,
......
...@@ -107,6 +107,14 @@ ...@@ -107,6 +107,14 @@
<metadefault>$INSTANCE/var/ZEO.pid (or $clienthome/ZEO.pid)</metadefault> <metadefault>$INSTANCE/var/ZEO.pid (or $clienthome/ZEO.pid)</metadefault>
</key> </key>
<key name="client-conflict-resolution" datatype="boolean"
required="no" default="false">
<description>
Flag indicating whether the server should return conflict
errors to the client, for resolution there.
</description>
</key>
</sectiontype> </sectiontype>
</component> </component>
...@@ -30,6 +30,8 @@ class DummyDB: ...@@ -30,6 +30,8 @@ class DummyDB:
def invalidate(self, *args, **kwargs): def invalidate(self, *args, **kwargs):
pass pass
transform_record_data = untransform_record_data = lambda self, data: data
class WorkerThread(TestThread): class WorkerThread(TestThread):
# run the entire test in a thread so that the blocking call for # run the entire test in a thread so that the blocking call for
......
...@@ -59,6 +59,9 @@ class DummyDB: ...@@ -59,6 +59,9 @@ class DummyDB:
def invalidateCache(self): def invalidateCache(self):
pass pass
transform_record_data = untransform_record_data = lambda self, data: data
class CommonSetupTearDown(StorageTestBase): class CommonSetupTearDown(StorageTestBase):
"""Common boilerplate""" """Common boilerplate"""
...@@ -1018,90 +1021,6 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -1018,90 +1021,6 @@ class TimeoutTests(CommonSetupTearDown):
# or the server. # or the server.
self.assertRaises(KeyError, storage.load, oid, '') self.assertRaises(KeyError, storage.load, oid, '')
def checkTimeoutProvokingConflicts(self):
self._storage = storage = self.openClientStorage()
# Assert that the zeo cache is empty.
self.assert_(not list(storage._cache.contents()))
# Create the object
oid = storage.new_oid()
obj = MinPO(7)
# We need to successfully commit an object now so we have something to
# conflict about.
t = Transaction()
storage.tpc_begin(t)
revid1a = storage.store(oid, ZERO, zodb_pickle(obj), '', t)
revid1b = storage.tpc_vote(t)
revid1 = handle_serials(oid, revid1a, revid1b)
storage.tpc_finish(t)
# Now do a store, sleeping before the finish so as to cause a timeout.
obj.value = 8
t = Transaction()
old_connection_count = storage.connection_count_for_tests
storage.tpc_begin(t)
revid2a = storage.store(oid, revid1, zodb_pickle(obj), '', t)
revid2b = storage.tpc_vote(t)
revid2 = handle_serials(oid, revid2a, revid2b)
# Now sleep long enough for the storage to time out.
# This used to sleep for 3 seconds, and sometimes (but very rarely)
# failed then. Now we try for a minute. It typically succeeds
# on the second time thru the loop, and, since self.timeout is 1,
# it's typically faster now (2/1.8 ~= 1.11 seconds sleeping instead
# of 3).
deadline = time.time() + 60 # wait up to a minute
while time.time() < deadline:
if (storage.is_connected() and
(storage.connection_count_for_tests == old_connection_count)
):
time.sleep(self.timeout / 1.8)
else:
break
self.assert_(
(not storage.is_connected())
or
(storage.connection_count_for_tests > old_connection_count)
)
storage._wait()
self.assert_(storage.is_connected())
# We expect finish to fail.
self.assertRaises(ClientDisconnected, storage.tpc_finish, t)
storage.tpc_abort(t)
# Now we think we've committed the second transaction, but we really
# haven't. A third one should produce a POSKeyError on the server,
# which manifests as a ConflictError on the client.
obj.value = 9
t = Transaction()
storage.tpc_begin(t)
storage.store(oid, revid2, zodb_pickle(obj), '', t)
self.assertRaises(ConflictError, storage.tpc_vote, t)
# Even aborting won't help.
storage.tpc_abort(t)
self.assertRaises(ZODB.POSException.StorageTransactionError,
storage.tpc_finish, t)
# Try again.
obj.value = 10
t = Transaction()
storage.tpc_begin(t)
storage.store(oid, revid2, zodb_pickle(obj), '', t)
# Even aborting won't help.
self.assertRaises(ConflictError, storage.tpc_vote, t)
# Abort this one and try a transaction that should succeed.
storage.tpc_abort(t)
# Now do a store.
obj.value = 11
t = Transaction()
storage.tpc_begin(t)
revid2a = storage.store(oid, revid1, zodb_pickle(obj), '', t)
revid2b = storage.tpc_vote(t)
revid2 = handle_serials(oid, revid2a, revid2b)
storage.tpc_finish(t)
# Now load the object and verify that it has a value of 11.
data, revid = storage.load(oid, '')
self.assertEqual(zodb_unpickle(data), MinPO(11))
self.assertEqual(revid, revid2)
class MSTThread(threading.Thread): class MSTThread(threading.Thread):
__super_init = threading.Thread.__init__ __super_init = threading.Thread.__init__
......
...@@ -324,8 +324,8 @@ class InvalidationTests: ...@@ -324,8 +324,8 @@ class InvalidationTests:
def checkConcurrentUpdates2Storages_emulated(self): def checkConcurrentUpdates2Storages_emulated(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
storage2 = self.openClientStorage()
db1 = DB(storage1) db1 = DB(storage1)
storage2 = self.openClientStorage()
db2 = DB(storage2) db2 = DB(storage2)
cn = db1.open() cn = db1.open()
...@@ -349,8 +349,8 @@ class InvalidationTests: ...@@ -349,8 +349,8 @@ class InvalidationTests:
def checkConcurrentUpdates2Storages(self): def checkConcurrentUpdates2Storages(self):
self._storage = storage1 = self.openClientStorage() self._storage = storage1 = self.openClientStorage()
storage2 = self.openClientStorage()
db1 = DB(storage1) db1 = DB(storage1)
storage2 = self.openClientStorage()
db2 = DB(storage2) db2 = DB(storage2)
stop = threading.Event() stop = threading.Event()
......
...@@ -33,7 +33,7 @@ logger = logging.getLogger('ZEO.tests.forker') ...@@ -33,7 +33,7 @@ logger = logging.getLogger('ZEO.tests.forker')
class ZEOConfig: class ZEOConfig:
"""Class to generate ZEO configuration file. """ """Class to generate ZEO configuration file. """
def __init__(self, addr): def __init__(self, addr, **options):
if isinstance(addr, str): if isinstance(addr, str):
self.logpath = addr+'.log' self.logpath = addr+'.log'
else: else:
...@@ -42,6 +42,7 @@ class ZEOConfig: ...@@ -42,6 +42,7 @@ class ZEOConfig:
self.address = addr self.address = addr
self.read_only = None self.read_only = None
self.loglevel = 'INFO' self.loglevel = 'INFO'
self.__dict__.update(options)
def dump(self, f): def dump(self, f):
print("<zeo>", file=f) print("<zeo>", file=f)
...@@ -52,7 +53,7 @@ class ZEOConfig: ...@@ -52,7 +53,7 @@ class ZEOConfig:
for name in ( for name in (
'invalidation_queue_size', 'invalidation_age', 'invalidation_queue_size', 'invalidation_age',
'transaction_timeout', 'pid_filename', 'transaction_timeout', 'pid_filename',
'ssl_certificate', 'ssl_key', 'ssl_certificate', 'ssl_key', 'client_conflict_resolution',
): ):
v = getattr(self, name, None) v = getattr(self, name, None)
if v: if v:
...@@ -95,6 +96,10 @@ def runner(config, qin, qout, timeout=None, ...@@ -95,6 +96,10 @@ def runner(config, qin, qout, timeout=None,
import ZEO.asyncio.server import ZEO.asyncio.server
old_protocol = ZEO.asyncio.server.best_protocol_version old_protocol = ZEO.asyncio.server.best_protocol_version
ZEO.asyncio.server.best_protocol_version = protocol ZEO.asyncio.server.best_protocol_version = protocol
old_protocols = ZEO.asyncio.server.ServerProtocol.protocols
ZEO.asyncio.server.ServerProtocol.protocols = tuple(sorted(
set(old_protocols) | set([protocol])
))
try: try:
import ZEO.runzeo, threading import ZEO.runzeo, threading
...@@ -142,8 +147,8 @@ def runner(config, qin, qout, timeout=None, ...@@ -142,8 +147,8 @@ def runner(config, qin, qout, timeout=None,
finally: finally:
if old_protocol: if old_protocol:
ZEO.asyncio.server.best_protocol_version = protocol ZEO.asyncio.server.best_protocol_version = old_protocol
ZEO.asyncio.server.ServerProtocol.protocols = old_protocols
def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None): def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None):
qin.put('stop') qin.put('stop')
...@@ -155,7 +160,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None): ...@@ -155,7 +160,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None):
# The runner thread didn't stop. If it was a process, # The runner thread didn't stop. If it was a process,
# give it some time to exit # give it some time to exit
if hasattr(thread, 'pid') and thread.pid: if hasattr(thread, 'pid') and thread.pid:
os.waitpid(thread.pid) os.waitpid(thread.pid, 0)
else: else:
# Gaaaa, force gc in hopes of maybe getting the unclosed # Gaaaa, force gc in hopes of maybe getting the unclosed
# sockets to get GCed # sockets to get GCed
......
...@@ -5,7 +5,7 @@ A full test of all protocols isn't practical. But we'll do a limited ...@@ -5,7 +5,7 @@ A full test of all protocols isn't practical. But we'll do a limited
test that at least the current and previous protocols are supported in test that at least the current and previous protocols are supported in
both directions. both directions.
Let's start a Z309 server Let's start a Z4 server
>>> storage_conf = ''' >>> storage_conf = '''
... <blobstorage> ... <blobstorage>
...@@ -94,82 +94,85 @@ A current client should be able to connect to a old server: ...@@ -94,82 +94,85 @@ A current client should be able to connect to a old server:
>>> zope.testing.setupstack.rmtree('blobs') >>> zope.testing.setupstack.rmtree('blobs')
>>> zope.testing.setupstack.rmtree('server-blobs') >>> zope.testing.setupstack.rmtree('server-blobs')
And the other way around: #############################################################################
# Note that the ZEO 5.0 server only supports clients that use the Z5 protocol
>>> addr, _ = start_server(storage_conf, dict(invalidation_queue_size=5)) # And the other way around:
Note that we'll have to pull some hijinks: # >>> addr, _ = start_server(storage_conf, dict(invalidation_queue_size=5))
>>> import ZEO.asyncio.client # Note that we'll have to pull some hijinks:
>>> old_protocols = ZEO.asyncio.client.Protocol.protocols
>>> ZEO.asyncio.client.Protocol.protocols = [b'Z4']
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') # >>> import ZEO.asyncio.client
>>> db.storage.protocol_version # >>> old_protocols = ZEO.asyncio.client.Protocol.protocols
b'Z4' # >>> ZEO.asyncio.client.Protocol.protocols = [b'Z4']
>>> wait_connected(db.storage)
>>> conn = db.open()
>>> conn.root().x = 0
>>> transaction.commit()
>>> len(db.history(conn.root()._p_oid, 99))
2
>>> conn.root()['blob1'] = ZODB.blob.Blob() # >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
>>> with conn.root()['blob1'].open('w') as f: # >>> db.storage.protocol_version
... r = f.write(b'blob data 1') # b'Z4'
>>> transaction.commit() # >>> wait_connected(db.storage)
# >>> conn = db.open()
# >>> conn.root().x = 0
# >>> transaction.commit()
# >>> len(db.history(conn.root()._p_oid, 99))
# 2
>>> db2 = ZEO.DB(addr, blob_dir='server-blobs', shared_blob_dir=True) # >>> conn.root()['blob1'] = ZODB.blob.Blob()
>>> wait_connected(db2.storage) # >>> with conn.root()['blob1'].open('w') as f:
>>> conn2 = db2.open() # ... r = f.write(b'blob data 1')
>>> for i in range(5): # >>> transaction.commit()
... conn2.root().x += 1
... transaction.commit()
>>> conn2.root()['blob2'] = ZODB.blob.Blob()
>>> with conn2.root()['blob2'].open('w') as f:
... r = f.write(b'blob data 2')
>>> transaction.commit()
# >>> db2 = ZEO.DB(addr, blob_dir='server-blobs', shared_blob_dir=True)
# >>> wait_connected(db2.storage)
# >>> conn2 = db2.open()
# >>> for i in range(5):
# ... conn2.root().x += 1
# ... transaction.commit()
# >>> conn2.root()['blob2'] = ZODB.blob.Blob()
# >>> with conn2.root()['blob2'].open('w') as f:
# ... r = f.write(b'blob data 2')
# >>> transaction.commit()
>>> @wait_until()
... def x_to_be_5():
... conn.sync()
... return conn.root().x == 5
>>> db.close() # >>> @wait_until()
# ... def x_to_be_5():
# ... conn.sync()
# ... return conn.root().x == 5
>>> for i in range(2): # >>> db.close()
... conn2.root().x += 1
... transaction.commit()
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') # >>> for i in range(2):
>>> wait_connected(db.storage) # ... conn2.root().x += 1
>>> conn = db.open() # ... transaction.commit()
>>> conn.root().x
7
>>> db.close() # >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
# >>> wait_connected(db.storage)
# >>> conn = db.open()
# >>> conn.root().x
# 7
>>> for i in range(10): # >>> db.close()
... conn2.root().x += 1
... transaction.commit()
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') # >>> for i in range(10):
>>> wait_connected(db.storage) # ... conn2.root().x += 1
>>> conn = db.open() # ... transaction.commit()
>>> conn.root().x
17
>>> with conn.root()['blob1'].open() as f: # >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
... f.read() # >>> wait_connected(db.storage)
b'blob data 1' # >>> conn = db.open()
>>> with conn.root()['blob2'].open() as f: # >>> conn.root().x
... f.read() # 17
b'blob data 2'
>>> db2.close() # >>> with conn.root()['blob1'].open() as f:
>>> db.close() # ... f.read()
# b'blob data 1'
# >>> with conn.root()['blob2'].open() as f:
# ... f.read()
# b'blob data 2'
# >>> db2.close()
# >>> db.close()
Undo the hijinks: # Undo the hijinks:
>>> ZEO.asyncio.client.Protocol.protocols = old_protocols # >>> ZEO.asyncio.client.Protocol.protocols = old_protocols
...@@ -52,6 +52,8 @@ class FakeServer: ...@@ -52,6 +52,8 @@ class FakeServer:
def register_connection(*args): def register_connection(*args):
return None, None return None, None
client_conflict_resolution = False
class FakeConnection: class FakeConnection:
protocol_version = b'Z4' protocol_version = b'Z4'
addr = 'test' addr = 'test'
......
...@@ -143,23 +143,9 @@ class MiscZEOTests: ...@@ -143,23 +143,9 @@ class MiscZEOTests:
self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction()) self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction())
storage3.close() storage3.close()
class GenericTests( class GenericTestBase(
# Base class for all ZODB tests # Base class for all ZODB tests
StorageTestBase.StorageTestBase, StorageTestBase.StorageTestBase):
# ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage,
PackableStorage.PackableStorage,
Synchronization.SynchronizedStorage,
MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests,
):
"""Combine tests from various origins in one class."""
shared_blob_dir = False shared_blob_dir = False
blob_cache_dir = None blob_cache_dir = None
...@@ -200,14 +186,23 @@ class GenericTests( ...@@ -200,14 +186,23 @@ class GenericTests(
stop() stop()
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
def runTest(self): class GenericTests(
try: GenericTestBase,
super(GenericTests, self).runTest()
except: # ZODB test mixin classes (in the same order as imported)
self._failed = True BasicStorage.BasicStorage,
raise PackableStorage.PackableStorage,
else: Synchronization.SynchronizedStorage,
self._failed = False MTStorage.MTStorage,
ReadOnlyStorage.ReadOnlyStorage,
# ZEO test mixin classes (in the same order as imported)
CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests,
# Locally defined (see above)
MiscZEOTests,
):
"""Combine tests from various origins in one class.
"""
def open(self, read_only=0): def open(self, read_only=0):
# Needed to support ReadOnlyStorage tests. Ought to be a # Needed to support ReadOnlyStorage tests. Ought to be a
...@@ -394,7 +389,16 @@ class FileStorageClientHexTests(FileStorageHexTests): ...@@ -394,7 +389,16 @@ class FileStorageClientHexTests(FileStorageHexTests):
def _wrap_client(self, client): def _wrap_client(self, client):
return ZODB.tests.hexstorage.HexStorage(client) return ZODB.tests.hexstorage.HexStorage(client)
class ClientConflictResolutionTests(
GenericTestBase,
ConflictResolution.ConflictResolvingStorage,
):
def getConfig(self):
return '<mappingstorage>\n</mappingstorage>\n'
def getZEOConfig(self):
return forker.ZEOConfig(('', 0), client_conflict_resolution=True)
class MappingStorageTests(GenericTests): class MappingStorageTests(GenericTests):
"""ZEO backed by a Mapping storage.""" """ZEO backed by a Mapping storage."""
...@@ -492,6 +496,8 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown): ...@@ -492,6 +496,8 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
self._invalidatedCache += 1 self._invalidatedCache += 1
def invalidate(*a, **k): def invalidate(*a, **k):
pass pass
transform_record_data = untransform_record_data = \
lambda self, data: data
db = DummyDB() db = DummyDB()
storage.registerDB(db) storage.registerDB(db)
...@@ -753,24 +759,23 @@ class StorageServerWrapper: ...@@ -753,24 +759,23 @@ class StorageServerWrapper:
self.server.tpc_begin(id(transaction), '', '', {}, None, ' ') self.server.tpc_begin(id(transaction), '', '', {}, None, ' ')
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
vote_result = self.server.vote(id(transaction)) result = self.server.vote(id(transaction))
assert vote_result is None assert result == self.server.connection.serials[:]
result = self.server.connection.serials[:]
del self.server.connection.serials[:] del self.server.connection.serials[:]
return result return result
def store(self, oid, serial, data, version_ignored, transaction): def store(self, oid, serial, data, version_ignored, transaction):
self.server.storea(oid, serial, data, id(transaction)) self.server.storea(oid, serial, data, id(transaction))
def send_reply(self, *args): # Masquerade as conn def send_reply(self, _, result): # Masquerade as conn
pass self._result = result
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
self.server.tpc_abort(id(transaction)) self.server.tpc_abort(id(transaction))
def tpc_finish(self, transaction, func = lambda: None): def tpc_finish(self, transaction, func = lambda: None):
self.server.tpc_finish(id(transaction)).set_sender(0, self) self.server.tpc_finish(id(transaction)).set_sender(0, self)
return self._result
def multiple_storages_invalidation_queue_is_not_insane(): def multiple_storages_invalidation_queue_is_not_insane():
""" """
...@@ -937,7 +942,7 @@ def tpc_finish_error(): ...@@ -937,7 +942,7 @@ def tpc_finish_error():
buffer, sadly, using implementation details: buffer, sadly, using implementation details:
>>> tbuf = t.data(client) >>> tbuf = t.data(client)
>>> tbuf.serials = None >>> tbuf.client_resolved = None
tpc_finish will fail: tpc_finish will fail:
...@@ -1596,6 +1601,7 @@ def test_suite(): ...@@ -1596,6 +1601,7 @@ def test_suite():
"ClientDisconnected"), "ClientDisconnected"),
)), )),
)) ))
zeo.addTest(unittest.makeSuite(ClientConflictResolutionTests, 'check'))
zeo.layer = ZODB.tests.util.MininalTestLayer('testZeo-misc') zeo.layer = ZODB.tests.util.MininalTestLayer('testZeo-misc')
suite.addTest(zeo) suite.addTest(zeo)
......
...@@ -78,6 +78,8 @@ will conflict. It will be blocked at the vote call. ...@@ -78,6 +78,8 @@ will conflict. It will be blocked at the vote call.
>>> class Sender: >>> class Sender:
... def send_reply(self, id, reply): ... def send_reply(self, id, reply):
... print('reply', id, reply) ... print('reply', id, reply)
... def send_error(self, id, err):
... print('error', id, err)
>>> delay.set_sender(1, Sender()) >>> delay.set_sender(1, Sender())
>>> logger = logging.getLogger('ZEO') >>> logger = logging.getLogger('ZEO')
...@@ -87,13 +89,20 @@ will conflict. It will be blocked at the vote call. ...@@ -87,13 +89,20 @@ will conflict. It will be blocked at the vote call.
Now, when we abort the transaction for the first client. The second Now, when we abort the transaction for the first client. The second
client will be restarted. It will get a conflict error, that is client will be restarted. It will get a conflict error, that is
handled correctly: raised to the client:
>>> zs1.tpc_abort('0') # doctest: +ELLIPSIS >>> zs1.tpc_abort('0') # doctest: +ELLIPSIS
reply 1 None Error raised in delayed method
Traceback (most recent call last):
...
ZODB.POSException.ConflictError: ...
error 1 database conflict error ...
The transaction is aborted by the server:
>>> fs.tpc_transaction() is not None >>> fs.tpc_transaction() is None
True True
>>> zs2.connected >>> zs2.connected
True True
...@@ -116,7 +125,7 @@ And an initial client. ...@@ -116,7 +125,7 @@ And an initial client.
>>> zs1 = ZEO.tests.servertesting.client(server, 1) >>> zs1 = ZEO.tests.servertesting.client(server, 1)
>>> zs1.tpc_begin('0', '', '', {}) >>> zs1.tpc_begin('0', '', '', {})
>>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '0') >>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, b'x', '0')
Intentionally break zs1: Intentionally break zs1:
...@@ -135,7 +144,7 @@ We can start another client and get the storage lock. ...@@ -135,7 +144,7 @@ We can start another client and get the storage lock.
>>> zs1 = ZEO.tests.servertesting.client(server, 1) >>> zs1 = ZEO.tests.servertesting.client(server, 1)
>>> zs1.tpc_begin('1', '', '', {}) >>> zs1.tpc_begin('1', '', '', {})
>>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '1') >>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, b'x', '1')
>>> _ = zs1.vote('1') # doctest: +ELLIPSIS >>> _ = zs1.vote('1') # doctest: +ELLIPSIS
>>> zs1.tpc_finish('1').set_sender(0, zs1.connection) >>> zs1.tpc_finish('1').set_sender(0, zs1.connection)
...@@ -220,7 +229,7 @@ We start a transaction and vote, this leads to getting the lock. ...@@ -220,7 +229,7 @@ We start a transaction and vote, this leads to getting the lock.
ZEO.asyncio.server INFO ZEO.asyncio.server INFO
received handshake b'Z5' received handshake b'Z5'
>>> tid1 = start_trans(zs1) >>> tid1 = start_trans(zs1)
>>> zs1.vote(tid1) # doctest: +ELLIPSIS >>> resolved1 = zs1.vote(tid1) # doctest: +ELLIPSIS
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-1) ('1') lock: transactions waiting: 0 (test-addr-1) ('1') lock: transactions waiting: 0
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
...@@ -477,7 +486,7 @@ ZEOStorage as closed and see if trying to get a lock cleans it up: ...@@ -477,7 +486,7 @@ ZEOStorage as closed and see if trying to get a lock cleans it up:
ZEO.asyncio.server INFO ZEO.asyncio.server INFO
received handshake b'Z5' received handshake b'Z5'
>>> tid1 = start_trans(zs1) >>> tid1 = start_trans(zs1)
>>> zs1.vote(tid1) # doctest: +ELLIPSIS >>> resolved1 = zs1.vote(tid1) # doctest: +ELLIPSIS
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-1) ('1') lock: transactions waiting: 0 (test-addr-1) ('1') lock: transactions waiting: 0
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
...@@ -493,7 +502,7 @@ ZEOStorage as closed and see if trying to get a lock cleans it up: ...@@ -493,7 +502,7 @@ ZEOStorage as closed and see if trying to get a lock cleans it up:
ZEO.asyncio.server INFO ZEO.asyncio.server INFO
received handshake b'Z5' received handshake b'Z5'
>>> tid2 = start_trans(zs2) >>> tid2 = start_trans(zs2)
>>> zs2.vote(tid2) # doctest: +ELLIPSIS >>> resolved2 = zs2.vote(tid2) # doctest: +ELLIPSIS
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-2) ('1') lock: transactions waiting: 0 (test-addr-2) ('1') lock: transactions waiting: 0
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
......
import unittest
import zope.testing.setupstack
from BTrees.Length import Length
from ZODB import serialize
from ZODB.DemoStorage import DemoStorage
from ZODB.utils import p64, z64, maxtid
from ZODB.broken import find_global
import ZEO
from .utils import StorageServer
class Var(object):
def __eq__(self, other):
self.value = other
return True
class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
def test_server_side(self):
# First, verify default conflict resolution.
server = StorageServer(self, DemoStorage())
zs = server.zs
reader = serialize.ObjectReader(
factory=lambda conn, *args: find_global(*args))
writer = serialize.ObjectWriter()
ob = Length(0)
ob._p_oid = z64
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
# Now, a cnflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns the object id, indicating that a conflict was resolved.
self.assertEqual(zs.vote(3), [ob._p_oid])
tid3 = server.unpack_result(zs.tpc_finish(3))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 2)
# Now, we'll create a server that expects the client to
# resolve conflicts:
server = StorageServer(
self, DemoStorage(), client_conflict_resolution=True)
zs = server.zs
# 2 non-conflicting transactions:
zs.tpc_begin(1, '', '', {})
zs.storea(ob._p_oid, z64, writer.serialize(ob), 1)
self.assertEqual(zs.vote(1), [])
tid1 = server.unpack_result(zs.tpc_finish(1))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
ob.change(1)
zs.tpc_begin(2, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 2)
self.assertEqual(zs.vote(2), [])
tid2 = server.unpack_result(zs.tpc_finish(2))
server.assert_calls(self, ('info', {'length': 1, 'size': Var()}))
# Now, a conflicting one:
zs.tpc_begin(3, '', '', {})
zs.storea(ob._p_oid, tid1, writer.serialize(ob), 3)
# Vote returns an object, indicating that a conflict was not resolved.
self.assertEqual(
zs.vote(3),
[dict(oid=ob._p_oid,
serials=(tid2, tid1),
data=writer.serialize(ob),
)],
)
# Now, it's up to the client to resolve the conflict. It can
# do this by making another store call. In this call, we use
# tid2 as the starting tid:
ob.change(1)
zs.storea(ob._p_oid, tid2, writer.serialize(ob), 3)
self.assertEqual(zs.vote(3), [])
tid3 = server.unpack_result(zs.tpc_finish(3))
server.assert_calls(self, ('info', {'size': Var(), 'length': 1}))
p, serial, next_serial = zs.loadBefore(ob._p_oid, maxtid)
self.assertEqual((serial, next_serial), (tid3, None))
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 3)
def test_client_side(self):
# First, traditional:
addr, stop = ZEO.server('data.fs')
db = ZEO.DB(addr)
with db.transaction() as conn:
conn.root.l = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
db.close(); stop()
# Now, do conflict resolution on the client.
addr2, stop = ZEO.server(
storage_conf='<mappingstorage>\n</mappingstorage>\n',
zeo_conf=dict(client_conflict_resolution=True),
)
db = ZEO.DB(addr2)
with db.transaction() as conn:
conn.root.l = Length(0)
conn2 = db.open()
conn2.root.l.change(1)
with db.transaction() as conn:
conn.root.l.change(1)
self.assertEqual(conn2.root.l.value, 1)
conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2)
db.close(); stop()
def test_suite():
return unittest.makeSuite(ClientSideConflictResolutionTests)
"""Testing helpers
"""
import ZEO.StorageServer
from ..asyncio.server import best_protocol_version
class ServerProtocol:
method = ('register', )
def __init__(self, zs,
protocol_version=best_protocol_version,
addr='test-address'):
self.calls = []
self.addr = addr
self.zs = zs
self.protocol_version = protocol_version
zs.notify_connected(self)
closed = False
def close(self):
if not self.closed:
self.closed = True
self.zs.notify_disconnected()
def call_soon_threadsafe(self, func, *args):
func(*args)
def async(self, *args):
self.calls.append(args)
class StorageServer:
"""Create a client interface to a StorageServer.
This is for testing StorageServer. It interacts with the storgr
server through its network interface, but without creating a
network connection.
"""
def __init__(self, test, storage,
protocol_version=best_protocol_version,
**kw):
self.test = test
self.storage_server = ZEO.StorageServer.StorageServer(
None, {'1': storage}, **kw)
self.zs = self.storage_server.create_client_handler()
self.protocol = ServerProtocol(self.zs,
protocol_version=protocol_version)
self.zs.register('1', kw.get('read_only', False))
def assert_calls(self, test, *argss):
if argss:
for args in argss:
test.assertEqual(self.protocol.calls.pop(0), args)
else:
test.assertEqual(self.protocol.calls, ())
def unpack_result(self, result):
"""For methods that return Result objects, unwrap the results
"""
result, callback = result.args
callback()
return result
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment