Commit e1e8558d authored by Jim Fulton's avatar Jim Fulton

Implemented ZEO servers with asyncio.

parent a94d9b50
Changelog Changelog
========= =========
5.0.0 (unreleases)
------------------
This is a major ZEO revision, which replaces the ZEO network protocol
implementation.
New features:
Dropped features:
- The ZEO authentication protocol.
This will be replaced by new authentication mechanims leveraging SSL.
- The ZEO monitor server.
- Full cache verification.
- Client suppprt for servers older than ZODB 3.9
- Server support for clients older than ZEO 4.2.0
4.2.0 (2016-06-15) 4.2.0 (2016-06-15)
------------------ ------------------
......
...@@ -165,7 +165,7 @@ class ClientStorage(object): ...@@ -165,7 +165,7 @@ class ClientStorage(object):
shared_blob_dir shared_blob_dir
Flag whether the blob_dir is a server-shared filesystem Flag whether the blob_dir is a server-shared filesystem
that should be used instead of transferring blob data over that should be used instead of transferring blob data over
zrpc. ZEO protocol.
blob_cache_size blob_cache_size
Maximum size of the ZEO blob cache, in bytes. If not set, then Maximum size of the ZEO blob cache, in bytes. If not set, then
...@@ -479,12 +479,10 @@ class ClientStorage(object): ...@@ -479,12 +479,10 @@ class ClientStorage(object):
return buf return buf
def history(self, oid, size=1, def history(self, oid, size=1):
timeout=None, # for tests
):
"""Storage API: return a sequence of HistoryEntry objects. """Storage API: return a sequence of HistoryEntry objects.
""" """
return self._call('history', oid, size, timeout=timeout) return self._call('history', oid, size)
def record_iternext(self, next=None): def record_iternext(self, next=None):
"""Storage API: get the next database record. """Storage API: get the next database record.
...@@ -762,8 +760,8 @@ class ClientStorage(object): ...@@ -762,8 +760,8 @@ class ClientStorage(object):
txn.set_data(self, TransactionBuffer(self._connection_generation)) txn.set_data(self, TransactionBuffer(self._connection_generation))
if self.protocol_version < b'Z5': # XXX we'd like to allow multiple transactions at a time at some point,
# Earlier protocols only allowed one transaction at a time :( # but for now, due to server limitations, TCBOO.
self._commit_lock.acquire() self._commit_lock.acquire()
self._tbuf = txn.data(self) self._tbuf = txn.data(self)
...@@ -780,7 +778,6 @@ class ClientStorage(object): ...@@ -780,7 +778,6 @@ class ClientStorage(object):
if tbuf is not None: if tbuf is not None:
tbuf.close() tbuf.close()
txn.set_data(self, None) txn.set_data(self, None)
if self.protocol_version < b'Z5':
self._commit_lock.release() self._commit_lock.release()
def lastTransaction(self): def lastTransaction(self):
......
...@@ -19,18 +19,20 @@ file storage or Berkeley storage. ...@@ -19,18 +19,20 @@ file storage or Berkeley storage.
TODO: Need some basic access control-- a declaration of the methods TODO: Need some basic access control-- a declaration of the methods
exported for invocation by the server. exported for invocation by the server.
""" """
import asyncore import asyncio
import codecs import codecs
import itertools import itertools
import logging import logging
import os import os
import socket
import sys import sys
import tempfile import tempfile
import threading import threading
import time import time
import transaction import transaction
import warnings import warnings
import ZEO.zrpc.error import ZEO.acceptor
import ZEO.asyncio.server
import ZODB.blob import ZODB.blob
import ZODB.event import ZODB.event
import ZODB.serialize import ZODB.serialize
...@@ -40,9 +42,8 @@ import six ...@@ -40,9 +42,8 @@ import six
from ZEO._compat import Pickler, Unpickler, PY3, BytesIO from ZEO._compat import Pickler, Unpickler, PY3, BytesIO
from ZEO.Exceptions import AuthError from ZEO.Exceptions import AuthError
from ZEO.monitor import StorageStats, StatsServer from ZEO.monitor import StorageStats
from ZEO.zrpc.connection import ManagedServerConnection, Delay, MTDelay, Result from ZEO.asyncio.server import Delay, MTDelay, Result
from ZEO.zrpc.server import Dispatcher
from ZODB.ConflictResolution import ResolvedSerial from ZODB.ConflictResolution import ResolvedSerial
from ZODB.loglevels import BLATHER from ZODB.loglevels import BLATHER
from ZODB.POSException import StorageError, StorageTransactionError from ZODB.POSException import StorageError, StorageTransactionError
...@@ -62,6 +63,15 @@ def log(message, level=logging.INFO, label='', exc_info=False): ...@@ -62,6 +63,15 @@ def log(message, level=logging.INFO, label='', exc_info=False):
class StorageServerError(StorageError): class StorageServerError(StorageError):
"""Error reported when an unpicklable exception is raised.""" """Error reported when an unpicklable exception is raised."""
registered_methods = set(( 'get_info', 'lastTransaction',
'getInvalidations', 'new_oids', 'pack', 'loadBefore', 'storea',
'checkCurrentSerialInTransaction', 'restorea', 'storeBlobStart',
'storeBlobChunk', 'storeBlobEnd', 'storeBlobShared',
'deleteObject', 'tpc_begin', 'vote', 'tpc_finish', 'tpc_abort',
'history', 'record_iternext', 'sendBlob', 'getTid', 'loadSerial',
'new_oid', 'undoa', 'undoLog', 'undoInfo', 'iterator_start',
'iterator_next', 'iterator_record_start', 'iterator_record_next',
'iterator_gc', 'server_status', 'set_client_label'))
class ZEOStorage: class ZEOStorage:
"""Proxy to underlying storage for a single remote client.""" """Proxy to underlying storage for a single remote client."""
...@@ -70,23 +80,16 @@ class ZEOStorage: ...@@ -70,23 +80,16 @@ class ZEOStorage:
# should override. # should override.
extensions = [] extensions = []
def __init__(self, server, read_only=0, auth_realm=None): connected = connection = stats = storage = storage_id = transaction = None
blob_tempfile = None
log_label = 'unconnected'
locked = False # Don't have storage lock
verifying = store_failed = 0
def __init__(self, server, read_only=0):
self.server = server self.server = server
# timeout and stats will be initialized in register() # timeout and stats will be initialized in register()
self.stats = None
self.connection = None
self.client = None
self.storage = None
self.storage_id = "uninitialized"
self.transaction = None
self.read_only = read_only self.read_only = read_only
self.log_label = 'unconnected'
self.locked = False # Don't have storage lock
self.verifying = 0
self.store_failed = 0
self.authenticated = 0
self.auth_realm = auth_realm
self.blob_tempfile = None
# The authentication protocol may define extra methods. # The authentication protocol may define extra methods.
self._extensions = {} self._extensions = {}
for func in self.extensions: for func in self.extensions:
...@@ -97,26 +100,16 @@ class ZEOStorage: ...@@ -97,26 +100,16 @@ class ZEOStorage:
# transaction iterator. # transaction iterator.
self._txn_iterators_last = {} self._txn_iterators_last = {}
def _finish_auth(self, authenticated):
if not self.auth_realm:
return 1
self.authenticated = authenticated
return authenticated
def set_database(self, database): def set_database(self, database):
self.database = database self.database = database
def notifyConnected(self, conn): def notify_connected(self, conn):
self.connection = conn self.connection = conn
assert conn.peer_protocol_version is not None self.connected = True
if conn.peer_protocol_version < b'Z309': assert conn.protocol_version is not None
self.client = ClientStub308(conn)
conn.register_object(ZEOStorage308Adapter(self))
else:
self.client = ClientStub(conn)
self.log_label = _addr_label(conn.addr) self.log_label = _addr_label(conn.addr)
def notifyDisconnected(self): def notify_disconnected(self):
# When this storage closes, we must ensure that it aborts # When this storage closes, we must ensure that it aborts
# any pending transaction. # any pending transaction.
if self.transaction is not None: if self.transaction is not None:
...@@ -126,7 +119,8 @@ class ZEOStorage: ...@@ -126,7 +119,8 @@ class ZEOStorage:
else: else:
self.log("disconnected") self.log("disconnected")
self.connection = None self.connected = False
self.server.close_conn(self)
def __repr__(self): def __repr__(self):
tid = self.transaction and repr(self.transaction.id) tid = self.transaction and repr(self.transaction.id)
...@@ -185,6 +179,8 @@ class ZEOStorage: ...@@ -185,6 +179,8 @@ class ZEOStorage:
else: else:
raise raise
self.connection.methods = registered_methods
def history(self,tid,size=1): def history(self,tid,size=1):
# This caters for storages which still accept # This caters for storages which still accept
# a version parameter. # a version parameter.
...@@ -212,15 +208,6 @@ class ZEOStorage: ...@@ -212,15 +208,6 @@ class ZEOStorage:
return 0 return 0
return 1 return 1
def getAuthProtocol(self):
"""Return string specifying name of authentication module to use.
The module name should be auth_%s where %s is auth_protocol."""
protocol = self.server.auth_protocol
if not protocol or protocol == 'none':
return None
return protocol
def register(self, storage_id, read_only): def register(self, storage_id, read_only):
"""Select the storage that this client will use """Select the storage that this client will use
...@@ -228,9 +215,6 @@ class ZEOStorage: ...@@ -228,9 +215,6 @@ class ZEOStorage:
For authenticated storages this method will be called by the client For authenticated storages this method will be called by the client
immediately after authentication is finished. immediately after authentication is finished.
""" """
if self.auth_realm and not self.authenticated:
raise AuthError("Client was never authenticated with server!")
if self.storage is not None: if self.storage is not None:
self.log("duplicate register() call") self.log("duplicate register() call")
raise ValueError("duplicate register() call") raise ValueError("duplicate register() call")
...@@ -252,9 +236,8 @@ class ZEOStorage: ...@@ -252,9 +236,8 @@ class ZEOStorage:
def get_info(self): def get_info(self):
storage = self.storage storage = self.storage
supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)() supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)()
and self.connection.peer_protocol_version >= b'Z310') and self.connection.protocol_version >= b'Z310')
# Communicate the backend storage interfaces to the client # Communicate the backend storage interfaces to the client
storage_provides = zope.interface.providedBy(storage) storage_provides = zope.interface.providedBy(storage)
...@@ -295,37 +278,6 @@ class ZEOStorage: ...@@ -295,37 +278,6 @@ class ZEOStorage:
% (len(invlist), u64(invtid))) % (len(invlist), u64(invtid)))
return invtid, invlist return invtid, invlist
def verify(self, oid, tid):
try:
t = self.getTid(oid)
except KeyError:
self.client.invalidateVerify(oid)
else:
if tid != t:
self.client.invalidateVerify(oid)
def zeoVerify(self, oid, s):
if not self.verifying:
self.verifying = 1
self.stats.verifying_clients += 1
try:
os = self.getTid(oid)
except KeyError:
self.client.invalidateVerify((oid, ''))
# It's not clear what we should do now. The KeyError
# could be caused by an object uncreation, in which case
# invalidation is right. It could be an application bug
# that left a dangling reference, in which case it's bad.
else:
if s != os:
self.client.invalidateVerify((oid, ''))
def endZeoVerify(self):
if self.verifying:
self.stats.verifying_clients -= 1
self.verifying = 0
self.client.endVerify()
def pack(self, time, wait=1): def pack(self, time, wait=1):
# Yes, you can pack a read-only server or storage! # Yes, you can pack a read-only server or storage!
if wait: if wait:
...@@ -449,14 +401,16 @@ class ZEOStorage: ...@@ -449,14 +401,16 @@ class ZEOStorage:
return self._try_to_vote() return self._try_to_vote()
def _try_to_vote(self, delay=None): def _try_to_vote(self, delay=None):
if self.connection is None: if not self.connected:
return # We're disconnected return # We're disconnected
if delay is not None and delay.sent: if delay is not None and delay.sent:
# as a consequence of the unlocking strategy, _try_to_vote # as a consequence of the unlocking strategy, _try_to_vote
# may be called multiple times for delayed # may be called multiple times for delayed
# transactions. The first call will mark the delay as # transactions. The first call will mark the delay as
# sent. We should skip if the delay was already sent. # sent. We should skip if the delay was already sent.
return return
self.locked, delay = self.server.lock_storage(self, delay) self.locked, delay = self.server.lock_storage(self, delay)
if self.locked: if self.locked:
try: try:
...@@ -490,7 +444,7 @@ class ZEOStorage: ...@@ -490,7 +444,7 @@ class ZEOStorage:
if serials: if serials:
self.serials.extend(serials) self.serials.extend(serials)
self.client.serialnos(self.serials) self.connection.async('serialnos', self.serials)
except Exception: except Exception:
self.storage.tpc_abort(self.transaction) self.storage.tpc_abort(self.transaction)
...@@ -509,11 +463,10 @@ class ZEOStorage: ...@@ -509,11 +463,10 @@ class ZEOStorage:
return delay return delay
def _unlock_callback(self, delay): def _unlock_callback(self, delay):
connection = self.connection if self.connected:
if connection is None: self.connection.call_soon_threadsafe(self._try_to_vote, delay)
self.server.stop_waiting(self)
else: else:
connection.call_from_thread(self._try_to_vote, delay) self.server.stop_waiting(self)
# The public methods of the ZEO client API do not do the real work. # The public methods of the ZEO client API do not do the real work.
# They defer work until after the storage lock has been acquired. # They defer work until after the storage lock has been acquired.
...@@ -575,7 +528,19 @@ class ZEOStorage: ...@@ -575,7 +528,19 @@ class ZEOStorage:
self.blob_log.append((oid, serial, data, filename)) self.blob_log.append((oid, serial, data, filename))
def sendBlob(self, oid, serial): def sendBlob(self, oid, serial):
self.client.storeBlob(oid, serial, self.storage.loadBlob(oid, serial)) blobfilename = self.storage.loadBlob(oid, serial)
def store():
yield ('receiveBlobStart', (oid, serial))
with open(blobfilename, 'rb') as f:
while 1:
chunk = f.read(59000)
if not chunk:
break
yield ('receiveBlobChunk', (oid, serial, chunk, ))
yield ('receiveBlobStop', (oid, serial))
self.connection.call_async_iter(store())
def undo(*a, **k): def undo(*a, **k):
raise NotImplementedError raise NotImplementedError
...@@ -760,7 +725,18 @@ class ZEOStorage: ...@@ -760,7 +725,18 @@ class ZEOStorage:
def set_client_label(self, label): def set_client_label(self, label):
self.log_label = str(label)+' '+_addr_label(self.connection.addr) self.log_label = str(label)+' '+_addr_label(self.connection.addr)
def ruok(self):
return self.server.ruok()
class StorageServerDB: class StorageServerDB:
"""Adapter from StorageServerDB to ZODB.interfaces.IStorageWrapper
This is used in a ZEO fan-out situation, where a storage server
calls registerDB on a ClientStorage.
Note that this is called from the Client-storage's IO thread, so
always a separate thread from the storge-server connections.
"""
def __init__(self, server, storage_id): def __init__(self, server, storage_id):
self.server = server self.server = server
...@@ -788,21 +764,11 @@ class StorageServer: ...@@ -788,21 +764,11 @@ class StorageServer:
ZEOStorage instance only handles a single storage. ZEOStorage instance only handles a single storage.
""" """
# Classes we instantiate. A subclass might override.
DispatcherClass = ZEO.zrpc.server.Dispatcher
ZEOStorageClass = ZEOStorage
ManagedServerConnectionClass = ManagedServerConnection
def __init__(self, addr, storages, def __init__(self, addr, storages,
read_only=0, read_only=0,
invalidation_queue_size=100, invalidation_queue_size=100,
invalidation_age=None, invalidation_age=None,
transaction_timeout=None, transaction_timeout=None,
monitor_address=None,
auth_protocol=None,
auth_database=None,
auth_realm=None,
): ):
"""StorageServer constructor. """StorageServer constructor.
...@@ -847,29 +813,8 @@ class StorageServer: ...@@ -847,29 +813,8 @@ class StorageServer:
a transaction to commit after acquiring the storage lock. a transaction to commit after acquiring the storage lock.
If the transaction takes too long, the client connection If the transaction takes too long, the client connection
will be closed and the transaction aborted. will be closed and the transaction aborted.
monitor_address -- The address at which the monitor server
should listen. If specified, a monitor server is started.
The monitor server provides server statistics in a simple
text format.
auth_protocol -- The name of the authentication protocol to use.
Examples are "digest" and "srp".
auth_database -- The name of the password database filename.
It should be in a format compatible with the authentication
protocol used; for instance, "sha" and "srp" require different
formats.
Note that to implement an authentication protocol, a server
and client authentication mechanism must be implemented in a
auth_* module, which should be stored inside the "auth"
subdirectory. This module may also define a DatabaseClass
variable that should indicate what database should be used
by the authenticator.
""" """
self.addr = addr
self.storages = storages self.storages = storages
msg = ", ".join( msg = ", ".join(
["%s:%s:%s" % (name, storage.isReadOnly() and "RO" or "RW", ["%s:%s:%s" % (name, storage.isReadOnly() and "RO" or "RW",
...@@ -884,12 +829,7 @@ class StorageServer: ...@@ -884,12 +829,7 @@ class StorageServer:
self._waiting = dict((name, []) for name in storages) self._waiting = dict((name, []) for name in storages)
self.read_only = read_only self.read_only = read_only
self.auth_protocol = auth_protocol
self.auth_database = auth_database
self.auth_realm = auth_realm
self.database = None self.database = None
if auth_protocol:
self._setup_auth(auth_protocol)
# A list, by server, of at most invalidation_queue_size invalidations. # A list, by server, of at most invalidation_queue_size invalidations.
# The list is kept in sorted order with the most recent # The list is kept in sorted order with the most recent
# invalidation at the front. The list never has more than # invalidation at the front. The list never has more than
...@@ -900,19 +840,20 @@ class StorageServer: ...@@ -900,19 +840,20 @@ class StorageServer:
self._setup_invq(name, storage) self._setup_invq(name, storage)
storage.registerDB(StorageServerDB(self, name)) storage.registerDB(StorageServerDB(self, name))
self.invalidation_age = invalidation_age self.invalidation_age = invalidation_age
self.connections = {} self.zeo_storages_by_storage_id = {} # {storage_id -> [ZEOStorage]}
self.socket_map = {} self.acceptor = ZEO.acceptor.Acceptor(addr, self.new_connection)
self.dispatcher = self.DispatcherClass( if isinstance(addr, tuple) and addr[0]:
addr, factory=self.new_connection, map=self.socket_map) self.addr = self.acceptor.addr
if len(self.addr) == 2 and self.addr[1] == 0 and self.addr[0]: else:
self.addr = self.dispatcher.socket.getsockname() self.addr = addr
ZODB.event.notify( self.loop = self.acceptor.loop
Serving(self, address=self.dispatcher.socket.getsockname())) ZODB.event.notify(Serving(self, address=self.acceptor.addr))
self.stats = {} self.stats = {}
self.timeouts = {} self.timeouts = {}
for name in self.storages.keys(): for name in self.storages.keys():
self.connections[name] = [] self.zeo_storages_by_storage_id[name] = []
self.stats[name] = StorageStats(self.connections[name]) self.stats[name] = StorageStats(
self.zeo_storages_by_storage_id[name])
if transaction_timeout is None: if transaction_timeout is None:
# An object with no-op methods # An object with no-op methods
timeout = StubTimeoutThread() timeout = StubTimeoutThread()
...@@ -921,14 +862,6 @@ class StorageServer: ...@@ -921,14 +862,6 @@ class StorageServer:
timeout.setName("TimeoutThread for %s" % name) timeout.setName("TimeoutThread for %s" % name)
timeout.start() timeout.start()
self.timeouts[name] = timeout self.timeouts[name] = timeout
if monitor_address:
warnings.warn(
"The monitor server is deprecated. Use the server_status\n"
"ZEO method instead.",
DeprecationWarning)
self.monitor = StatsServer(monitor_address, self.stats)
else:
self.monitor = None
def _setup_invq(self, name, storage): def _setup_invq(self, name, storage):
lastInvalidations = getattr(storage, 'lastInvalidations', None) lastInvalidations = getattr(storage, 'lastInvalidations', None)
...@@ -944,72 +877,36 @@ class StorageServer: ...@@ -944,72 +877,36 @@ class StorageServer:
self.invq[name] = list(lastInvalidations(self.invq_bound)) self.invq[name] = list(lastInvalidations(self.invq_bound))
self.invq[name].reverse() self.invq[name].reverse()
def _setup_auth(self, protocol):
# Can't be done in global scope, because of cyclic references
from ZEO.auth import get_module
name = self.__class__.__name__
module = get_module(protocol)
if not module:
log("%s: no such an auth protocol: %s" % (name, protocol))
return
storage_class, client, db_class = module
if not storage_class or not issubclass(storage_class, ZEOStorage):
log(("%s: %s isn't a valid protocol, must have a StorageClass" %
(name, protocol)))
self.auth_protocol = None
return
self.ZEOStorageClass = storage_class
log("%s: using auth protocol: %s" % (name, protocol))
# We create a Database instance here for use with the authenticator
# modules. Having one instance allows it to be shared between multiple
# storages, avoiding the need to bloat each with a new authenticator
# Database that would contain the same info, and also avoiding any
# possibly synchronization issues between them.
self.database = db_class(self.auth_database)
if self.database.realm != self.auth_realm:
raise ValueError("password database realm %r "
"does not match storage realm %r"
% (self.database.realm, self.auth_realm))
def new_connection(self, sock, addr): def new_connection(self, sock, addr):
"""Internal: factory to create a new connection. """Internal: factory to create a new connection.
This is called by the Dispatcher class in ZEO.zrpc.server
whenever accept() returns a socket for a new incoming
connection.
""" """
if self.auth_protocol and self.database: logger.debug("new connection %s" % (addr,))
zstorage = self.ZEOStorageClass(self, self.read_only,
auth_realm=self.auth_realm) def run():
zstorage.set_database(self.database) loop = asyncio.new_event_loop()
else: asyncio.set_event_loop(loop)
zstorage = self.ZEOStorageClass(self, self.read_only) ZEO.asyncio.server.new_connection(
loop, addr, sock, ZEOStorage(self, self.read_only))
c = self.ManagedServerConnectionClass(sock, addr, zstorage, self) loop.run_forever()
log("new connection %s: %s" % (addr, repr(c)), logging.DEBUG) loop.close()
return c
thread = threading.Thread(target=run, name='zeo_client_hander')
thread.setDaemon(True)
thread.start()
def register_connection(self, storage_id, conn): def register_connection(self, storage_id, zeo_storage):
"""Internal: register a connection with a particular storage. """Internal: register a ZEOStorage with a particular storage.
This is called by ZEOStorage.register(). This is called by ZEOStorage.register().
The dictionary self.connections maps each storage name to a The dictionary self.zeo_storages_by_storage_id maps each
list of current connections for that storage; this information storage name to a list of current ZEOStorages for that
is needed to handle invalidation. This function updates this storage; this information is needed to handle invalidation.
dictionary. This function updates this dictionary.
Returns the timeout and stats objects for the appropriate storage. Returns the timeout and stats objects for the appropriate storage.
""" """
self.connections[storage_id].append(conn) self.zeo_storages_by_storage_id[storage_id].append(zeo_storage)
return self.stats[storage_id] return self.stats[storage_id]
def _invalidateCache(self, storage_id): def _invalidateCache(self, storage_id):
...@@ -1020,7 +917,7 @@ class StorageServer: ...@@ -1020,7 +917,7 @@ class StorageServer:
and making them reconnect. and making them reconnect.
""" """
# This method can be called from foreign threads. We have to # This method is called from foreign threads. We have to
# worry about interaction with the main thread. # worry about interaction with the main thread.
# 1. We modify self.invq which is read by get_invalidations # 1. We modify self.invq which is read by get_invalidations
...@@ -1045,15 +942,11 @@ class StorageServer: ...@@ -1045,15 +942,11 @@ class StorageServer:
# connections indirectoy by closing them. We don't care about # connections indirectoy by closing them. We don't care about
# later transactions since they will have to validate their # later transactions since they will have to validate their
# caches anyway. # caches anyway.
for p in self.connections[storage_id][:]: for zs in self.zeo_storages_by_storage_id[storage_id][:]:
try: zs.connection.call_soon_threadsafe(zs.connection.close)
p.connection.should_close()
p.connection.trigger.pull_trigger()
except ZEO.zrpc.error.DisconnectedError:
pass
def invalidate(
def invalidate(self, conn, storage_id, tid, invalidated=(), info=None): self, zeo_storage, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients. """Internal: broadcast info and invalidations to clients.
This is called from several ZEOStorage methods. This is called from several ZEOStorage methods.
...@@ -1064,7 +957,7 @@ class StorageServer: ...@@ -1064,7 +957,7 @@ class StorageServer:
- If the invalidated argument is non-empty, it broadcasts - If the invalidated argument is non-empty, it broadcasts
invalidateTransaction() messages to all clients of the given invalidateTransaction() messages to all clients of the given
storage except the current client (the conn argument). storage except the current client (the zeo_storage argument).
- If the invalidated argument is empty and the info argument - If the invalidated argument is empty and the info argument
is a non-empty dictionary, it broadcasts info() messages to is a non-empty dictionary, it broadcasts info() messages to
...@@ -1104,15 +997,17 @@ class StorageServer: ...@@ -1104,15 +997,17 @@ class StorageServer:
if len(invq) >= self.invq_bound: if len(invq) >= self.invq_bound:
invq.pop() invq.pop()
invq.insert(0, (tid, invalidated)) invq.insert(0, (tid, invalidated))
# serialize invalidation message, so we don't have to to
for p in self.connections[storage_id]: # it over and over
try:
if invalidated and p is not conn: for zs in self.zeo_storages_by_storage_id[storage_id]:
p.client.invalidateTransaction(tid, invalidated) connection = zs.connection
if invalidated and zs is not zeo_storage:
connection.call_soon_threadsafe(
connection.async, 'invalidateTransaction', tid, invalidated)
elif info is not None: elif info is not None:
p.client.info(info) connection.call_soon_threadsafe(
except ZEO.zrpc.error.DisconnectedError: connection.async, 'info', info)
pass
def get_invalidations(self, storage_id, tid): def get_invalidations(self, storage_id, tid):
"""Return a tid and list of all objects invalidation since tid. """Return a tid and list of all objects invalidation since tid.
...@@ -1159,13 +1054,6 @@ class StorageServer: ...@@ -1159,13 +1054,6 @@ class StorageServer:
return latest_tid, list(oids) return latest_tid, list(oids)
def loop(self):
try:
asyncore.loop(map=self.socket_map)
except Exception:
if not self.__closed:
raise # Unexpected exc
__thread = None __thread = None
def start_thread(self, daemon=True): def start_thread(self, daemon=True):
self.__thread = thread = threading.Thread(target=self.loop) self.__thread = thread = threading.Thread(target=self.loop)
...@@ -1184,19 +1072,18 @@ class StorageServer: ...@@ -1184,19 +1072,18 @@ class StorageServer:
self.__closed = True self.__closed = True
# Stop accepting connections # Stop accepting connections
self.dispatcher.close() self.acceptor.close()
if self.monitor is not None:
self.monitor.close()
ZODB.event.notify(Closed(self)) ZODB.event.notify(Closed(self))
# Close open client connections # Close open client connections
for sid, connections in self.connections.items(): for sid, zeo_storages in self.zeo_storages_by_storage_id.items():
for conn in connections[:]: for zs in zeo_storages[:]:
try: try:
conn.connection.close() zs.connection.call_soon_threadsafe(
except: zs.connection.close)
pass except Exception:
logger.exception("closing connection %r", zs)
for name, storage in six.iteritems(self.storages): for name, storage in six.iteritems(self.storages):
logger.info("closing storage %r", name) logger.info("closing storage %r", name)
...@@ -1205,14 +1092,14 @@ class StorageServer: ...@@ -1205,14 +1092,14 @@ class StorageServer:
if self.__thread is not None: if self.__thread is not None:
self.__thread.join(join_timeout) self.__thread.join(join_timeout)
def close_conn(self, conn): def close_conn(self, zeo_storage):
"""Internal: remove the given connection from self.connections. """Remove the given zeo_storage from self.zeo_storages_by_storage_id.
This is the inverse of register_connection(). This is the inverse of register_connection().
""" """
for cl in self.connections.values(): for zeo_storages in self.zeo_storages_by_storage_id.values():
if conn.obj in cl: if zeo_storage in zeo_storages:
cl.remove(conn.obj) zeo_storages.remove(zeo_storage)
def lock_storage(self, zeostore, delay): def lock_storage(self, zeostore, delay):
storage_id = zeostore.storage_id storage_id = zeostore.storage_id
...@@ -1226,7 +1113,7 @@ class StorageServer: ...@@ -1226,7 +1113,7 @@ class StorageServer:
assert locked is not zeostore, (storage_id, delay) assert locked is not zeostore, (storage_id, delay)
if locked.connection is None: if not locked.connected:
locked.log("Still locked after disconnected. Unlocking.", locked.log("Still locked after disconnected. Unlocking.",
logging.CRITICAL) logging.CRITICAL)
if locked.transaction: if locked.transaction:
...@@ -1328,6 +1215,7 @@ class StorageServer: ...@@ -1328,6 +1215,7 @@ class StorageServer:
return dict((storage_id, self.server_status(storage_id)) return dict((storage_id, self.server_status(storage_id))
for storage_id in self.storages) for storage_id in self.storages)
def _level_for_waiting(waiting): def _level_for_waiting(waiting):
if len(waiting) > 9: if len(waiting) > 9:
return logging.CRITICAL return logging.CRITICAL
...@@ -1396,7 +1284,8 @@ class TimeoutThread(threading.Thread): ...@@ -1396,7 +1284,8 @@ class TimeoutThread(threading.Thread):
client.log("Transaction timeout after %s seconds" % client.log("Transaction timeout after %s seconds" %
self._timeout, logging.CRITICAL) self._timeout, logging.CRITICAL)
try: try:
client.connection.call_from_thread(client.connection.close) client.connection.call_soon_threadsafe(
client.connection.close)
except: except:
client.log("Timeout failure", logging.CRITICAL, client.log("Timeout failure", logging.CRITICAL,
exc_info=sys.exc_info()) exc_info=sys.exc_info())
...@@ -1442,141 +1331,6 @@ class SlowMethodThread(threading.Thread): ...@@ -1442,141 +1331,6 @@ class SlowMethodThread(threading.Thread):
self.delay.reply(result) self.delay.reply(result)
class ClientStub:
def __init__(self, rpc):
self.rpc = rpc
def beginVerify(self):
self.rpc.callAsync('beginVerify')
def invalidateVerify(self, args):
self.rpc.callAsync('invalidateVerify', args)
def endVerify(self):
self.rpc.callAsync('endVerify')
def invalidateTransaction(self, tid, args):
# Note that this method is *always* called from a different
# thread than self.rpc's async thread. It is the only method
# for which this is true and requires special consideration!
# callAsyncNoSend is important here because:
# - callAsyncNoPoll isn't appropriate because
# the network thread may not wake up for a long time,
# delaying invalidations for too long. (This is demonstrateed
# by a test failure.)
# - callAsync isn't appropriate because (on the server) it tries
# to write to the socket. If self.rpc's network thread also
# tries to write at the ame time, we can run into problems
# because handle_write isn't thread safe.
self.rpc.callAsyncNoSend('invalidateTransaction', tid, args)
def serialnos(self, arg):
self.rpc.callAsyncNoPoll('serialnos', arg)
def info(self, arg):
self.rpc.callAsyncNoPoll('info', arg)
def storeBlob(self, oid, serial, blobfilename):
def store():
yield ('receiveBlobStart', (oid, serial))
f = open(blobfilename, 'rb')
while 1:
chunk = f.read(59000)
if not chunk:
break
yield ('receiveBlobChunk', (oid, serial, chunk, ))
f.close()
yield ('receiveBlobStop', (oid, serial))
self.rpc.callAsyncIterator(store())
class ClientStub308(ClientStub):
def invalidateTransaction(self, tid, args):
ClientStub.invalidateTransaction(
self, tid, [(arg, '') for arg in args])
def invalidateVerify(self, oid):
ClientStub.invalidateVerify(self, (oid, ''))
class ZEOStorage308Adapter:
def __init__(self, storage):
self.storage = storage
def __eq__(self, other):
return self is other or self.storage is other
def getSerial(self, oid):
return self.storage.loadEx(oid)[1] # Z200
def history(self, oid, version, size=1):
if version:
raise ValueError("Versions aren't supported.")
return self.storage.history(oid, size=size)
def getInvalidations(self, tid):
result = self.storage.getInvalidations(tid)
if result is not None:
result = result[0], [(oid, '') for oid in result[1]]
return result
def verify(self, oid, version, tid):
if version:
raise StorageServerError("Versions aren't supported.")
return self.storage.verify(oid, tid)
def loadEx(self, oid, version=''):
if version:
raise StorageServerError("Versions aren't supported.")
data, serial = self.storage.loadEx(oid)
return data, serial, ''
def storea(self, oid, serial, data, version, id):
if version:
raise StorageServerError("Versions aren't supported.")
self.storage.storea(oid, serial, data, id)
def storeBlobEnd(self, oid, serial, data, version, id):
if version:
raise StorageServerError("Versions aren't supported.")
self.storage.storeBlobEnd(oid, serial, data, id)
def storeBlobShared(self, oid, serial, data, filename, version, id):
if version:
raise StorageServerError("Versions aren't supported.")
self.storage.storeBlobShared(oid, serial, data, filename, id)
def getInfo(self):
result = self.storage.getInfo()
result['supportsVersions'] = False
return result
def zeoVerify(self, oid, s, sv=None):
if sv:
raise StorageServerError("Versions aren't supported.")
self.storage.zeoVerify(oid, s)
def modifiedInVersion(self, oid):
return ''
def versions(self):
return ()
def versionEmpty(self, version):
return True
def commitVersion(self, *a, **k):
raise NotImplementedError
abortVersion = commitVersion
def __getattr__(self, name):
return getattr(self.storage, name)
def _addr_label(addr): def _addr_label(addr):
if isinstance(addr, six.binary_type): if isinstance(addr, six.binary_type):
return addr.decode('ascii') return addr.decode('ascii')
...@@ -1639,3 +1393,4 @@ class Serving(ServerEvent): ...@@ -1639,3 +1393,4 @@ class Serving(ServerEvent):
class Closed(ServerEvent): class Closed(ServerEvent):
pass pass
...@@ -26,14 +26,24 @@ def client(*args, **kw): ...@@ -26,14 +26,24 @@ def client(*args, **kw):
return ZEO.ClientStorage.ClientStorage(*args, **kw) return ZEO.ClientStorage.ClientStorage(*args, **kw)
def DB(*args, **kw): def DB(*args, **kw):
s = client(*args, **kw)
try:
import ZODB import ZODB
return ZODB.DB(client(*args, **kw)) return ZODB.DB(s)
except Exception:
s.close()
raise
def connection(*args, **kw): def connection(*args, **kw):
return DB(*args, **kw).open_then_close_db_when_connection_closes() db = DB(*args, **kw)
try:
return db.open_then_close_db_when_connection_closes()
except Exception:
db.close()
ra
def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None, def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
port=None): port=0, **kw):
"""Convenience function to start a server for interactive exploration """Convenience function to start a server for interactive exploration
This fuction starts a ZEO server, given a storage configuration or This fuction starts a ZEO server, given a storage configuration or
...@@ -74,14 +84,7 @@ def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None, ...@@ -74,14 +84,7 @@ def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
import os, ZEO.tests.forker import os, ZEO.tests.forker
if storage_conf is None and path is None: if storage_conf is None and path is None:
storage_conf = '<mappingstorage>\n</mappingstorage>' storage_conf = '<mappingstorage>\n</mappingstorage>'
if port is None and zeo_conf is None:
port = ZEO.tests.forker.get_port()
addr, admin, pid, config = ZEO.tests.forker.start_zeo_server( return ZEO.tests.forker.start_zeo_server(
storage_conf, zeo_conf, port, keep=True, path=path, storage_conf, zeo_conf, port, keep=True, path=path,
blob_dir=blob_dir, suicide=False) blob_dir=blob_dir, suicide=False, threaded=True, **kw)
os.remove(config)
def stop_server():
ZEO.tests.forker.shutdown_zeo_server(admin)
os.waitpid(pid, 0)
return addr, stop_server
...@@ -31,33 +31,31 @@ else: ...@@ -31,33 +31,31 @@ else:
s.close() s.close()
del s del s
from ZEO.zrpc.connection import Connection
from ZEO.zrpc.log import log
import ZEO.zrpc.log
import logging import logging
# Export the main asyncore loop logger = logging.getLogger(__name__)
loop = asyncore.loop
class Dispatcher(asyncore.dispatcher): class Acceptor(asyncore.dispatcher):
"""A server that accepts incoming RPC connections""" """A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__
def __init__(self, addr, factory=Connection, map=None): def __init__(self, addr, factory):
self.__super_init(map=map) self.socket_map = {}
asyncore.dispatcher.__init__(self, map=self.socket_map)
self.addr = addr self.addr = addr
self.factory = factory self.factory = factory
self._open_socket() self._open_socket()
def _open_socket(self): def _open_socket(self):
if type(self.addr) == tuple: addr = self.addr
if self.addr[0] == '' and _has_dualstack:
if type(addr) == tuple:
if addr[0] == '' and _has_dualstack:
# Wildcard listen on all interfaces, both IPv4 and # Wildcard listen on all interfaces, both IPv4 and
# IPv6 if possible # IPv6 if possible
self.create_socket(socket.AF_INET6, socket.SOCK_STREAM) self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt( self.socket.setsockopt(
socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
elif ':' in self.addr[0]: elif ':' in addr[0]:
self.create_socket(socket.AF_INET6, socket.SOCK_STREAM) self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
if _has_dualstack: if _has_dualstack:
# On Linux, IPV6_V6ONLY is off by default. # On Linux, IPV6_V6ONLY is off by default.
...@@ -68,20 +66,28 @@ class Dispatcher(asyncore.dispatcher): ...@@ -68,20 +66,28 @@ class Dispatcher(asyncore.dispatcher):
self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
else: else:
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM) self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.set_reuse_addr() self.set_reuse_addr()
log("listening on %s" % str(self.addr), logging.INFO)
for i in range(25): for i in range(25):
try: try:
self.bind(self.addr) self.bind(addr)
except Exception as exc: except Exception as exc:
log("bind failed %s waiting", i) logger.info("bind on %s failed %s waiting", addr, i)
if i == 24: if i == 24:
raise raise
else: else:
time.sleep(5) time.sleep(5)
except:
logger.exception('binding')
raise
else: else:
break break
if isinstance(addr, tuple) and addr[1] == 0:
self.addr = addr = self.socket.getsockname()
logger.info("listening on %s", str(addr))
self.listen(5) self.listen(5)
def writable(self): def writable(self):
...@@ -94,7 +100,7 @@ class Dispatcher(asyncore.dispatcher): ...@@ -94,7 +100,7 @@ class Dispatcher(asyncore.dispatcher):
try: try:
sock, addr = self.accept() sock, addr = self.accept()
except socket.error as msg: except socket.error as msg:
log("accepted failed: %s" % msg) logger.info("accepted failed: %s", msg)
return return
...@@ -115,9 +121,24 @@ class Dispatcher(asyncore.dispatcher): ...@@ -115,9 +121,24 @@ class Dispatcher(asyncore.dispatcher):
try: try:
c = self.factory(sock, addr) c = self.factory(sock, addr)
except: except Exception:
if sock.fileno() in asyncore.socket_map: if sock.fileno() in asyncore.socket_map:
del asyncore.socket_map[sock.fileno()] del asyncore.socket_map[sock.fileno()]
ZEO.zrpc.log.logger.exception("Error in handle_accept") logger.exception("Error in handle_accept")
else: else:
log("connect from %s: %s" % (repr(addr), c)) logger.info("connect from %s: %s", repr(addr), c)
def loop(self):
try:
asyncore.loop(map=self.socket_map)
except Exception:
if not self.__closed:
raise # Unexpected exc
logger.debug('acceptor %s loop stopped', self.addr)
__closed = False
def close(self):
if not self.__closed:
self.__closed = True
asyncore.dispatcher.close(self)
from struct import unpack
import asyncio
import logging
from .marshal import encoder
logger = logging.getLogger(__name__)
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface
"""
# All of the code in this class runs in a single dedicated
# thread. Thus, we can mostly avoid worrying about interleaved
# operations.
# One place where special care was required was in cache setup on
# connect. See finish connect below.
transport = protocol_version = None
def __init__(self, loop, addr):
self.loop = loop
self.addr = addr
self.input = [] # Input buffer when assembling messages
self.output = [] # Output buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
def __repr__(self):
return self.name
closed = False
def close(self):
if not self.closed:
self.closed = True
if self.transport is not None:
self.transport.close()
def connection_made(self, transport):
logger.info("Connected %s", self)
self.transport = transport
paused = self.paused
output = self.output
append = output.append
writelines = transport.writelines
from struct import pack
def write(message):
if paused:
append(message)
else:
writelines((pack(">I", len(message)), message))
self._write = write
def writeit(data):
# Note, don't worry about combining messages. Iters
# will be used with blobs, in which case, the individual
# messages will be big to begin with.
data = iter(data)
for message in data:
writelines((pack(">I", len(message)), message))
if paused:
append(data)
break
self._writeit = writeit
got = 0
want = 4
getting_size = True
def data_received(self, data):
# Low-level input handler collects data into sized messages.
# Note that the logic below assume that when new data pushes
# us over what we want, we process it in one call until we
# need more, because we assume that excess data is all in the
# last item of self.input. This is why the exception handling
# in the while loop is critical. Without it, an exception
# might cause us to exit before processing all of the data we
# should, when then causes the logic to be broken in
# subsequent calls.
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
try:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [input[-1][-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.message_received(collected)
except Exception:
logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size)
def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
self.encode = encoder()
self.finish_connect(protocol_version)
def call_async(self, method, args):
self._write(self.encode(0, True, method, args))
def call_async_iter(self, it):
self._writeit(self.encode(0, True, method, args)
for method, args in it)
def pause_writing(self):
self.paused.append(1)
def resume_writing(self):
paused = self.paused
del paused[:]
output = self.output
writelines = self.transport.writelines
from struct import pack
while output and not paused:
message = output.pop(0)
if isinstance(message, bytes):
writelines((pack(">I", len(message)), message))
else:
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
output.insert(0, data)
break
def get_peername(self):
return self.transport.get_extra_info('peername')
from pickle import loads, dumps
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZODB.ConflictResolution import ResolvedSerial from ZODB.ConflictResolution import ResolvedSerial
from struct import unpack
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import logging import logging
import random import random
import threading import threading
import traceback
import ZODB.event import ZODB.event
import ZODB.POSException import ZODB.POSException
...@@ -15,17 +12,16 @@ import ZODB.POSException ...@@ -15,17 +12,16 @@ import ZODB.POSException
import ZEO.Exceptions import ZEO.Exceptions
import ZEO.interfaces import ZEO.interfaces
from . import base
from .marshal import decode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Fallback = object() Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests local_random = random.Random() # use separate generator to facilitate tests
class Closed(Exception): class Protocol(base.Protocol):
"""A connection has been closed
"""
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -36,9 +32,7 @@ class Protocol(asyncio.Protocol): ...@@ -36,9 +32,7 @@ class Protocol(asyncio.Protocol):
# One place where special care was required was in cache setup on # One place where special care was required was in cache setup on
# connect. See finish connect below. # connect. See finish connect below.
transport = protocol_version = None protocols = b'Z309', b'Z310', b'Z3101', b'Z4', b'Z5'
protocols = b"Z309", b"Z310", b"Z3101"
def __init__(self, loop, def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1, addr, client, storage_key, read_only, connect_poll=1,
...@@ -51,8 +45,7 @@ class Protocol(asyncio.Protocol): ...@@ -51,8 +45,7 @@ class Protocol(asyncio.Protocol):
cache is a ZEO.interfaces.IClientCache. cache is a ZEO.interfaces.IClientCache.
""" """
self.loop = loop super(Protocol, self).__init__(loop, addr)
self.addr = addr
self.storage_key = storage_key self.storage_key = storage_key
self.read_only = read_only self.read_only = read_only
self.name = "%s(%r, %r, %r)" % ( self.name = "%s(%r, %r, %r)" % (
...@@ -61,19 +54,9 @@ class Protocol(asyncio.Protocol): ...@@ -61,19 +54,9 @@ class Protocol(asyncio.Protocol):
self.connect_poll = connect_poll self.connect_poll = connect_poll
self.heartbeat_interval = heartbeat_interval self.heartbeat_interval = heartbeat_interval
self.futures = {} # { message_id -> future } self.futures = {} # { message_id -> future }
self.input = [] # Buffer when assembling messages
self.output = [] # Buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
self.connect() self.connect()
def __repr__(self):
return self.name
closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -118,35 +101,7 @@ class Protocol(asyncio.Protocol): ...@@ -118,35 +101,7 @@ class Protocol(asyncio.Protocol):
) )
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) super(Protocol, self).connection_made(transport)
self.transport = transport
paused = self.paused
output = self.output
append = output.append
writelines = transport.writelines
from struct import pack
def write(message):
if paused:
append(message)
else:
writelines((pack(">I", len(message)), message))
self._write = write
def writeit(data):
# Note, don't worry about combining messages. Iters
# will be used with blobs, in which case, the individual
# messages will be big to begin with.
data = iter(data)
for message in data:
writelines((pack(">I", len(message)), message))
if paused:
append(data)
break
self._writeit = writeit
self.heartbeat(write=False) self.heartbeat(write=False)
def connection_lost(self, exc): def connection_lost(self, exc):
...@@ -181,6 +136,7 @@ class Protocol(asyncio.Protocol): ...@@ -181,6 +136,7 @@ class Protocol(asyncio.Protocol):
# invalidations. # invalidations.
self.protocol_version = min(protocol_version, self.protocols[-1]) self.protocol_version = min(protocol_version, self.protocols[-1])
if self.protocol_version not in self.protocols: if self.protocol_version not in self.protocols:
self.client.register_failed( self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version)) self, ZEO.Exceptions.ProtocolError(protocol_version))
...@@ -236,59 +192,9 @@ class Protocol(asyncio.Protocol): ...@@ -236,59 +192,9 @@ class Protocol(asyncio.Protocol):
else: else:
self.client.register_failed(self, exc) self.client.register_failed(self, exc)
got = 0
want = 4
getting_size = True
def data_received(self, data):
# Low-level input handler collects data into sized messages.
# Note that the logic below assume that when new data pushes
# us over what we want, we process it in one call until we
# need more, because we assume that excess data is all in the
# last item of self.input. This is why the exception handling
# in the while loop is critical. Without it, an exception
# might cause us to exit before processing all of the data we
# should, when then causes the logic to be broken in
# subsequent calls.
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
try:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [input[-1][-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.message_received(collected)
except Exception:
logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size)
def first_message_received(self, data):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
self.finish_connect(data)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
msgid, async, name, args = loads(data) msgid, async, name, args = decode(data)
if name == '.reply': if name == '.reply':
future = self.futures.pop(msgid) future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and if (isinstance(args, tuple) and len(args) > 1 and
...@@ -315,46 +221,16 @@ class Protocol(asyncio.Protocol): ...@@ -315,46 +221,16 @@ class Protocol(asyncio.Protocol):
else: else:
raise AttributeError(name) raise AttributeError(name)
def call_async(self, method, args):
self._write(dumps((0, True, method, args), 3))
def call_async_iter(self, it):
self._writeit(dumps((0, True, method, args), 3) for method, args in it)
message_id = 0 message_id = 0
def call(self, future, method, args): def call(self, future, method, args):
self.message_id += 1 self.message_id += 1
self.futures[self.message_id] = future self.futures[self.message_id] = future
self._write(dumps((self.message_id, False, method, args), 3)) self._write(self.encode(self.message_id, False, method, args))
return future return future
def promise(self, method, *args): def promise(self, method, *args):
return self.call(Promise(), method, args) return self.call(Promise(), method, args)
def pause_writing(self):
self.paused.append(1)
def resume_writing(self):
paused = self.paused
del paused[:]
output = self.output
writelines = self.transport.writelines
from struct import pack
while output and not paused:
message = output.pop(0)
if isinstance(message, bytes):
writelines((pack(">I", len(message)), message))
else:
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
output.insert(0, data)
break
def get_peername(self):
return self.transport.get_extra_info('peername')
# Methods called by the server. # Methods called by the server.
# WARNING WARNING we can't call methods that call back to us # WARNING WARNING we can't call methods that call back to us
# syncronously, as that would lead to DEADLOCK! # syncronously, as that would lead to DEADLOCK!
...@@ -825,6 +701,7 @@ class ClientThread(ClientRunner): ...@@ -825,6 +701,7 @@ class ClientThread(ClientRunner):
exception = None exception = None
def run(self): def run(self):
loop = None
try: try:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
...@@ -832,15 +709,20 @@ class ClientThread(ClientRunner): ...@@ -832,15 +709,20 @@ class ClientThread(ClientRunner):
self.started.set() self.started.set()
loop.run_forever() loop.run_forever()
except Exception as exc: except Exception as exc:
raise
logger.exception("Client thread") logger.exception("Client thread")
self.exception = exc self.exception = exc
finally: finally:
if not self.closed: if not self.closed:
if self.client.ready:
self.closed = True self.closed = True
try:
if self.client.ready:
self.client.ready = False self.client.ready = False
self.client.client.notify_disconnected() self.client.client.notify_disconnected()
except AttributeError:
pass
logger.critical("Client loop stopped unexpectedly") logger.critical("Client loop stopped unexpectedly")
if loop is not None:
loop.close() loop.close()
logger.debug('Stopping client thread') logger.debug('Stopping client thread')
......
...@@ -11,78 +11,64 @@ ...@@ -11,78 +11,64 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
"""Support for marshaling ZEO messages
Not to be confused with marshaling objects in ZODB.
We currently use pickle. In the future, we may use a
Python-independent format, or possibly a minimal pickle subset.
"""
import logging import logging
from ZEO._compat import Unpickler, Pickler, BytesIO, PY3, PYPY from .._compat import Unpickler, Pickler, BytesIO, PY3, PYPY
from ZEO.zrpc.error import ZRPCError from ..shortrepr import short_repr
from ZEO.zrpc.log import log, short_repr
logger = logging.getLogger(__name__)
def encode(*args): # args: (msgid, flags, name, args)
# (We used to have a global pickler, but that's not thread-safe. :-( ) def encoder():
"""Return a non-thread-safe encoder
# It's not thread safe if, in the couse of pickling, we call the """
# Python interpeter, which releases the GIL.
if PY3 or PYPY:
# Note that args may contain very large binary pickles already; for
# this reason, it's important to use proto 1 (or higher) pickles here
# too. For a long time, this used proto 0 pickles, and that can
# bloat our pickle to 4x the size (due to high-bit and control bytes
# being represented by \xij escapes in proto 0).
# Undocumented: cPickle.Pickler accepts a lone protocol argument;
# pickle.py does not.
if PY3:
# XXX: Py3: Needs optimization.
f = BytesIO() f = BytesIO()
pickler = Pickler(f, 3) getvalue = f.getvalue
seek = f.seek
truncate = f.truncate
pickler = Pickler(f, 3 if PY3 else 1)
pickler.fast = 1 pickler.fast = 1
pickler.dump(args) dump = pickler.dump
res = f.getvalue() def encode(*args):
return res seek(0)
truncate()
dump(args)
return getvalue()
else: else:
pickler = Pickler(1)
pickler.fast = 1
# Only CPython's cPickle supports dumping
# and returning in one operation:
# return pickler.dump(args, 1)
# For PyPy we must return the value; fortunately this
# works the same on CPython and is no more expensive
pickler.dump(args)
return pickler.getvalue()
if PY3:
# XXX: Py3: Needs optimization.
fast_encode = encode
elif PYPY:
# can't use the python-2 branch, need a new pickler
# every time, getvalue() only works once
fast_encode = encode
else:
def fast_encode():
# Only use in cases where you *know* the data contains only basic
# Python objects
pickler = Pickler(1) pickler = Pickler(1)
pickler.fast = 1 pickler.fast = 1
dump = pickler.dump dump = pickler.dump
def fast_encode(*args): def encode(*args):
return dump(args, 1) return dump(args, 2)
return fast_encode
fast_encode = fast_encode() return encode
def encode(*args):
return encoder()(*args)
def decode(msg): def decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global unpickler.find_global = find_global
try: try:
unpickler.find_class = find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except:
log("can't decode message: %s" % short_repr(msg), logger.error("can't decode message: %s" % short_repr(msg))
level=logging.ERROR)
raise raise
def server_decode(msg): def server_decode(msg):
...@@ -90,15 +76,15 @@ def server_decode(msg): ...@@ -90,15 +76,15 @@ def server_decode(msg):
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global unpickler.find_global = server_find_global
try: try:
unpickler.find_class = server_find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = server_find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except:
log("can't decode message: %s" % short_repr(msg), logger.error("can't decode message: %s" % short_repr(msg))
level=logging.ERROR)
raise raise
_globals = globals() _globals = globals()
...@@ -111,12 +97,12 @@ def find_global(module, name): ...@@ -111,12 +97,12 @@ def find_global(module, name):
try: try:
m = __import__(module, _globals, _globals, _silly) m = __import__(module, _globals, _globals, _silly)
except ImportError as msg: except ImportError as msg:
raise ZRPCError("import error %s: %s" % (module, msg)) raise ImportError("import error %s: %s" % (module, msg))
try: try:
r = getattr(m, name) r = getattr(m, name)
except AttributeError: except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name)) raise ImportError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) safe = getattr(r, '__no_side_effects__', 0)
if safe: if safe:
...@@ -126,7 +112,7 @@ def find_global(module, name): ...@@ -126,7 +112,7 @@ def find_global(module, name):
if type(r) == exception_type_type and issubclass(r, Exception): if type(r) == exception_type_type and issubclass(r, Exception):
return r return r
raise ZRPCError("Unsafe global: %s.%s" % (module, name)) raise ImportError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name): def server_find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
...@@ -135,11 +121,11 @@ def server_find_global(module, name): ...@@ -135,11 +121,11 @@ def server_find_global(module, name):
raise ImportError raise ImportError
m = __import__(module, _globals, _globals, _silly) m = __import__(module, _globals, _globals, _silly)
except ImportError as msg: except ImportError as msg:
raise ZRPCError("import error %s: %s" % (module, msg)) raise ImportError("import error %s: %s" % (module, msg))
try: try:
r = getattr(m, name) r = getattr(m, name)
except AttributeError: except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name)) raise ImportError("module %s has no global %s" % (module, name))
return r return r
import asyncio
import json
import logging
import os
import random
import threading
import ZODB.POSException
logger = logging.getLogger(__name__)
from ..shortrepr import short_repr
from . import base
from .marshal import server_decode
class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface
"""
protocols = b'Z4', b'Z5'
name = 'server protocol'
methods = set(('register', ))
unlogged_exception_types = (
ZODB.POSException.POSKeyError,
)
def __init__(self, loop, addr, zeo_storage):
"""Create a server's client interface
"""
super(ServerProtocol, self).__init__(loop, addr)
self.zeo_storage = zeo_storage
closed = False
def close(self):
if not self.closed:
self.closed = True
if self.transport is not None:
self.transport.close()
connected = None # for tests
def connection_made(self, transport):
self.connected = True
super(ServerProtocol, self).connection_made(transport)
self._write(best_protocol_version)
def connection_lost(self, exc):
self.connected = False
if exc:
logger.error("Disconnected %s:%s", exc.__class__.__name__, exc)
self.zeo_storage.notify_disconnected()
self.loop.stop()
def finish_connect(self, protocol_version):
if protocol_version == b'ruok':
self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii"))
self.close()
else:
if protocol_version in self.protocols:
logger.info("received handshake %r" % protocol_version)
self.protocol_version = protocol_version
self.zeo_storage.notify_connected(self)
else:
logger.error("bad handshake %s" % short_repr(protocol_version))
self.close()
def call_soon_threadsafe(self, func, *args):
try:
self.loop.call_soon_threadsafe(func, *args)
except RuntimeError:
if self.connected:
logger.exception("call_soon_threadsafe failed while connected")
def message_received(self, message):
try:
message_id, async, name, args = server_decode(message)
except Exception:
logger.exception("Can't deserialize message")
self.close()
if message_id == -1:
return # keep-alive
if name not in self.methods:
logger.error('Invalid method, %r', name)
self.close()
try:
result = getattr(self.zeo_storage, name)(*args)
except Exception as exc:
if not isinstance(exc, self.unlogged_exception_types):
logger.exception(
"Bad %srequest, %r", 'async ' if async else '', name)
if async:
return self.close() # No way to recover/cry for help
else:
return self.send_error(message_id, exc)
if not async:
self.send_reply(message_id, result)
def send_reply(self, message_id, result, send_error=False):
try:
result = self.encode(message_id, 0, '.reply', result)
except Exception:
if isinstance(result, Delay):
result.set_sender(message_id, self)
return
else:
logger.exception("Unpicklable response %r", result)
if not send_error:
self.send_error(
message_id,
ValueError("Couldn't pickle response"),
True)
self._write(result)
def send_reply_threadsafe(self, message_id, result):
self.loop.call_soon_threadsafe(self.reply, message_id, result)
def send_error(self, message_id, exc, send_error=False):
"""Abstracting here so we can make this cleaner in the future
"""
self.send_reply(message_id, (exc.__class__, exc), send_error)
def async(self, method, *args):
self.call_async(method, args)
best_protocol_version = os.environ.get(
'ZEO_SERVER_PROTOCOL',
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage):
protocol = ServerProtocol(loop, addr, zeo_storage)
cr = loop.create_connection((lambda : protocol), sock=socket)
asyncio.async(cr, loop=loop)
class Delay:
"""Used to delay response to client for synchronous calls.
When a synchronous call is made and the original handler returns
without handling the call, it returns a Delay object that prevents
the mainloop from sending a response.
"""
msgid = protocol = sent = None
def set_sender(self, msgid, protocol):
self.msgid = msgid
self.protocol = protocol
def reply(self, obj):
self.sent = 'reply'
self.protocol.send_reply(self.msgid, obj)
def error(self, exc_info):
self.sent = 'error'
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.protocol.send_error(self.msgid, exc_info[1])
def __repr__(self):
return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self),
self.msgid, self.protocol, self.sent)
def __reduce__(self):
raise TypeError("Can't pickle delays.")
class Result(Delay):
def __init__(self, *args):
self.args = args
def set_sender(self, msgid, protocol):
reply, callback = self.args
protocol.send_reply(msgid, reply)
callback()
class MTDelay(Delay):
def __init__(self):
self.ready = threading.Event()
def set_sender(self, *args):
Delay.set_sender(self, *args)
self.ready.set()
def reply(self, obj):
self.ready.wait()
self.protocol.call_soon_threadsafe(
self.protocol.send_reply, self.msgid, obj)
def error(self, exc_info):
self.ready.wait()
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.protocol.call_soon_threadsafe(Delay.error, self, exc_info)
...@@ -30,13 +30,18 @@ class Loop: ...@@ -30,13 +30,18 @@ class Loop:
if not future.cancelled(): if not future.cancelled():
future.set_exception(ConnectionRefusedError()) future.set_exception(ConnectionRefusedError())
def create_connection(self, protocol_factory, host, port): def create_connection(
self, protocol_factory, host=None, port=None, sock=None
):
future = asyncio.Future(loop=self) future = asyncio.Future(loop=self)
if sock is None:
addr = host, port addr = host, port
if addr in self.addrs: if addr in self.addrs:
self._connect(future, protocol_factory) self._connect(future, protocol_factory)
else: else:
self.connecting[addr] = future, protocol_factory self.connecting[addr] = future, protocol_factory
else:
self._connect(future, protocol_factory)
return future return future
...@@ -61,6 +66,14 @@ class Loop: ...@@ -61,6 +66,14 @@ class Loop:
def call_exception_handler(self, context): def call_exception_handler(self, context):
self.exceptions.append(context) self.exceptions.append(context)
closed = False
def close(self):
self.closed = True
stopped = False
def stop(self):
self.stopped = True
class Handle: class Handle:
cancelled = False cancelled = False
......
...@@ -7,17 +7,62 @@ import asyncio ...@@ -7,17 +7,62 @@ import asyncio
import collections import collections
import logging import logging
import pdb import pdb
import pickle
import struct import struct
import unittest import unittest
import ZEO.Exceptions
from ..Exceptions import ClientDisconnected, ProtocolError
from ..ClientStorage import m64
from .testing import Loop from .testing import Loop
from .client import ClientRunner, Fallback from .client import ClientRunner, Fallback
from ..Exceptions import ClientDisconnected from .server import new_connection, best_protocol_version
from ..ClientStorage import m64 from .marshal import encoder, decode
class Base(object):
def setUp(self):
super(Base, self).setUp()
self.encode = encoder()
def unsized(self, data, unpickle=False):
result = []
while data:
size, message, *data = data
self.assertEqual(struct.unpack(">I", size)[0], len(message))
if unpickle:
message = decode(message)
result.append(message)
if len(result) == 1:
result = result[0]
return result
def parse(self, data):
return self.unsized(data, True)
target = None
def send(self, method, *args, **kw):
target = kw.pop('target', self.target)
called = kw.pop('called', True)
no_output = kw.pop('no_output', True)
self.assertFalse(kw)
self.loop.protocol.data_received(
sized(self.encode(0, True, method, args)))
if target is not None:
target = getattr(target, method)
if called:
target.assert_called_with(*args)
target.reset_mock()
else:
self.assertFalse(target.called)
if no_output:
self.assertFalse(self.loop.transport.pop())
def pop(self, count=None, parse=True):
return self.unsized(self.loop.transport.pop(count), parse)
class AsyncTests(setupstack.TestCase, ClientRunner): class ClientTests(Base, setupstack.TestCase, ClientRunner):
def start(self, def start(self,
addrs=(('127.0.0.1', 8200), ), loop_addrs=None, addrs=(('127.0.0.1', 8200), ), loop_addrs=None,
...@@ -28,6 +73,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -28,6 +73,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# object and a cache. # object and a cache.
wrapper = mock.Mock() wrapper = mock.Mock()
self.target = wrapper
cache = MemoryCache() cache = MemoryCache()
self.set_options(addrs, wrapper, cache, 'TEST', read_only, timeout=1) self.set_options(addrs, wrapper, cache, 'TEST', read_only, timeout=1)
...@@ -39,42 +85,35 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -39,42 +85,35 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
protocol = loop.protocol protocol = loop.protocol
transport = loop.transport transport = loop.transport
def send(meth, *args):
loop.protocol.data_received(
sized(pickle.dumps((0, True, meth, args), 3)))
def respond(message_id, result):
loop.protocol.data_received(
sized(pickle.dumps((message_id, False, '.reply', result), 3)))
if finish_start: if finish_start:
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.pop(2, False), b'Z3101')
parse = self.parse self.assertEqual(self.pop(),
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) self.respond(1, None)
respond(2, 'a'*8) self.respond(2, 'a'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
respond(3, dict(length=42)) self.respond(3, dict(length=42))
return (wrapper, cache, self.loop, self.client, protocol, transport)
return (wrapper, cache, self.loop, self.client, protocol, transport, def respond(self, message_id, result):
send, respond) self.loop.protocol.data_received(
sized(self.encode(message_id, False, '.reply', result)))
def wait_for_result(self, future, timeout): def wait_for_result(self, future, timeout):
return future return future
def testBasics(self): def testClientBasics(self):
# Here, we'll go through the basic usage of the asyncio ZEO # Here, we'll go through the basic usage of the asyncio ZEO
# network client. The client is responsible for the core # network client. The client is responsible for the core
# functionality of a ZEO client storage. The client storage # functionality of a ZEO client storage. The client storage
# is largely just a wrapper around the asyncio client. # is largely just a wrapper around the asyncio client.
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start()
self.start())
self.assertFalse(wrapper.notify_disconnected.called) self.assertFalse(wrapper.notify_disconnected.called)
# The client isn't connected until the server sends it some data. # The client isn't connected until the server sends it some data.
...@@ -87,9 +126,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -87,9 +126,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# The client sends back a handshake, and registers the # The client sends back a handshake, and registers the
# storage, and requests the last transaction. # storage, and requests the last transaction.
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.pop(2, False), b'Z5')
parse = self.parse self.assertEqual(self.pop(),
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
...@@ -119,37 +157,36 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -119,37 +157,36 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Let's respond to those first 2 calls: # Let's respond to those first 2 calls:
respond(1, None) self.respond(1, None)
respond(2, 'a'*8) self.respond(2, 'a'*8)
# After verification, the client requests info: # After verification, the client requests info:
self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
respond(3, dict(length=42)) self.respond(3, dict(length=42))
# Now we're connected, the cache was initialized, and the # Now we're connected, the cache was initialized, and the
# queued message has been sent: # queued message has been sent:
self.assert_(client.connected.done()) self.assert_(client.connected.done())
self.assertEqual(cache.getLastTid(), 'a'*8) self.assertEqual(cache.getLastTid(), 'a'*8)
self.assertEqual(parse(transport.pop()), (4, False, 'foo', (1, 2))) self.assertEqual(self.pop(), (4, False, 'foo', (1, 2)))
# The wrapper object (ClientStorage) has been notified: # The wrapper object (ClientStorage) has been notified:
wrapper.notify_connected.assert_called_with(client, {'length': 42}) wrapper.notify_connected.assert_called_with(client, {'length': 42})
respond(4, 42) self.respond(4, 42)
self.assertEqual(f1.result(), 42) self.assertEqual(f1.result(), 42)
# Now we can make async calls: # Now we can make async calls:
f2 = self.async('bar', 3, 4) f2 = self.async('bar', 3, 4)
self.assert_(f2.done() and f2.exception() is None) self.assert_(f2.done() and f2.exception() is None)
self.assertEqual(parse(transport.pop()), (0, True, 'bar', (3, 4))) self.assertEqual(self.pop(), (0, True, 'bar', (3, 4)))
# Loading objects gets special handling to leverage the cache. # Loading objects gets special handling to leverage the cache.
loaded = self.load_before(b'1'*8, m64) loaded = self.load_before(b'1'*8, m64)
# The data wasn't in the cache, so we make a server call: # The data wasn't in the cache, so we make a server call:
self.assertEqual(parse(transport.pop()), self.assertEqual(self.pop(), (5, False, 'loadBefore', (b'1'*8, m64)))
(5, False, 'loadBefore', (b'1'*8, m64))) self.respond(5, (b'data', b'a'*8, None))
respond(5, (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None)) self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
# If we make another request, it will be satisfied from the cache: # If we make another request, it will be satisfied from the cache:
...@@ -158,15 +195,12 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -158,15 +195,12 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data) self.assertFalse(transport.data)
# Let's send an invalidation: # Let's send an invalidation:
send('invalidateTransaction', b'b'*8, [b'1'*8]) self.send('invalidateTransaction', b'b'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'b'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
# Now, if we try to load current again, we'll make a server request. # Now, if we try to load current again, we'll make a server request.
loaded = self.load_before(b'1'*8, m64) loaded = self.load_before(b'1'*8, m64)
self.assertEqual(parse(transport.pop()), self.assertEqual(self.pop(), (6, False, 'loadBefore', (b'1'*8, m64)))
(6, False, 'loadBefore', (b'1'*8, m64))) self.respond(6, (b'data2', b'b'*8, None))
respond(6, (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None)) self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
# Loading non-current data may also be satisfied from cache # Loading non-current data may also be satisfied from cache
...@@ -178,9 +212,9 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -178,9 +212,9 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8) loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(parse(transport.pop()), self.assertEqual(self.pop(),
(7, False, 'loadBefore', (b'1'*8, b'_'*8))) (7, False, 'loadBefore', (b'1'*8, b'_'*8)))
respond(7, (b'data0', b'^'*8, b'_'*8)) self.respond(7, (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8)) self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache # When committing transactions, we need to update the cache
...@@ -202,9 +236,9 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -202,9 +236,9 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
cache.load(b'2'*8) or cache.load(b'2'*8) or
cache.load(b'4'*8)) cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8)) self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(parse(transport.pop()), self.assertEqual(self.pop(),
(8, False, 'tpc_finish', (b'd'*8,))) (8, False, 'tpc_finish', (b'd'*8,)))
respond(8, b'e'*8) self.respond(8, b'e'*8)
self.assertEqual(committed.result(), b'e'*8) self.assertEqual(committed.result(), b'e'*8)
self.assertEqual(cache.load(b'1'*8), None) self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8)) self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8))
...@@ -216,8 +250,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -216,8 +250,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, m64) loaded = self.load_before(b'1'*8, m64)
f1 = self.call('foo', 1, 2) f1 = self.call('foo', 1, 2)
self.assertFalse(loaded.done() or f1.done()) self.assertFalse(loaded.done() or f1.done())
self.assertEqual(parse(transport.pop()), self.assertEqual(self.pop(), [(9, False, 'loadBefore', (b'1'*8, m64)),
[(9, False, 'loadBefore', (b'1'*8, m64)),
(10, False, 'foo', (1, 2))], (10, False, 'foo', (1, 2))],
) )
exc = TypeError(43) exc = TypeError(43)
...@@ -246,15 +279,15 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -246,15 +279,15 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# protocol: # protocol:
protocol.data_received(sized(b'Z310')) protocol.data_received(sized(b'Z310'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z310') self.assertEqual(self.unsized(transport.pop(2)), b'Z310')
self.assertEqual(parse(transport.pop()), self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
self.assertFalse(wrapper.notify_connected.called) self.assertFalse(wrapper.notify_connected.called)
respond(1, None) self.respond(1, None)
respond(2, b'e'*8) self.respond(2, b'e'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
respond(3, dict(length=42)) self.respond(3, dict(length=42))
# Because the server tid matches the cache tid, we're done connecting # Because the server tid matches the cache tid, we're done connecting
wrapper.notify_connected.assert_called_with(client, {'length': 42}) wrapper.notify_connected.assert_called_with(client, {'length': 42})
...@@ -274,8 +307,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -274,8 +307,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(loop.transport, transport) self.assertEqual(loop.transport, transport)
def test_cache_behind(self): def test_cache_behind(self):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start()
self.start())
cache.setLastTid(b'a'*8) cache.setLastTid(b'a'*8)
cache.store(b'4'*8, b'a'*8, None, '4 data') cache.store(b'4'*8, b'a'*8, None, '4 data')
...@@ -284,22 +316,20 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -284,22 +316,20 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) self.respond(1, None)
respond(2, b'e'*8) self.respond(2, b'e'*8)
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, )))
(3, False, 'getInvalidations', (b'a'*8, ))) self.respond(3, (b'e'*8, [b'4'*8]))
respond(3, (b'e'*8, [b'4'*8]))
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(), (4, False, 'get_info', ()))
(4, False, 'get_info', ())) self.respond(4, dict(length=42))
respond(4, dict(length=42))
# Now that verification is done, we're done connecting # Now that verification is done, we're done connecting
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
...@@ -315,8 +345,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -315,8 +345,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(wrapper.invalidateCache.called) self.assertFalse(wrapper.invalidateCache.called)
def test_cache_way_behind(self): def test_cache_way_behind(self):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start()
self.start())
cache.setLastTid(b'a'*8) cache.setLastTid(b'a'*8)
cache.store(b'4'*8, b'a'*8, None, '4 data') cache.store(b'4'*8, b'a'*8, None, '4 data')
...@@ -325,24 +354,22 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -325,24 +354,22 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) self.respond(1, None)
respond(2, b'e'*8) self.respond(2, b'e'*8)
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, )))
(3, False, 'getInvalidations', (b'a'*8, )))
# We respond None, indicating that we're too far out of date: # We respond None, indicating that we're too far out of date:
respond(3, None) self.respond(3, None)
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(), (4, False, 'get_info', ()))
(4, False, 'get_info', ())) self.respond(4, dict(length=42))
respond(4, dict(length=42))
# Now that verification is done, we're done connecting # Now that verification is done, we're done connecting
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
...@@ -355,8 +382,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -355,8 +382,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
def test_multiple_addresses(self): def test_multiple_addresses(self):
# We can pass multiple addresses to client constructor # We can pass multiple addresses to client constructor
addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)] addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)]
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start(
self.start(addrs, ())) addrs, ())
# We haven't connected yet # We haven't connected yet
self.assert_(protocol is None and transport is None) self.assert_(protocol is None and transport is None)
...@@ -381,7 +408,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -381,7 +408,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
respond(1, None) self.respond(1, None)
# Now, when the first connection fails, it won't be retried, # Now, when the first connection fails, it won't be retried,
# because we're already connected. # because we're already connected.
...@@ -394,19 +421,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -394,19 +421,17 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
def test_bad_server_tid(self): def test_bad_server_tid(self):
# If in verification we get a server_tid behing the cache's, make sure # If in verification we get a server_tid behing the cache's, make sure
# we retry the connection later. # we retry the connection later.
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start()
self.start())
cache.store(b'4'*8, b'a'*8, None, '4 data') cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.setLastTid('b'*8) cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
parse = self.parse self.assertEqual(self.pop(),
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) self.respond(1, None)
respond(2, 'a'*8) self.respond(2, 'a'*8)
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
self.assert_(8 < delay < 10) self.assert_(8 < delay < 10)
...@@ -418,21 +443,21 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -418,21 +443,21 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(parse(transport.pop()), self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) self.respond(1, None)
respond(2, 'b'*8) self.respond(2, 'b'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
respond(3, dict(length=42)) self.respond(3, dict(length=42))
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
self.assert_(client.ready) self.assert_(client.ready)
def test_readonly_fallback(self): def test_readonly_fallback(self):
addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)] addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)]
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start(
self.start(addrs, (), read_only=Fallback)) addrs, (), read_only=Fallback)
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
...@@ -442,20 +467,20 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -442,20 +467,20 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
# We see that the client tried a writable connection: # We see that the client tried a writable connection:
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
(1, False, 'register', ('TEST', False))) (1, False, 'register', ('TEST', False)))
# We respond with a read-only exception: # We respond with a read-only exception:
respond(1, (ReadOnlyError, ReadOnlyError())) self.respond(1, (ReadOnlyError, ReadOnlyError()))
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
# The client tries for a read-only connection: # The client tries for a read-only connection:
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
[(2, False, 'register', ('TEST', True)), [(2, False, 'register', ('TEST', True)),
(3, False, 'lastTransaction', ()), (3, False, 'lastTransaction', ()),
]) ])
# We respond with successfully: # We respond with successfully:
respond(2, None) self.respond(2, None)
respond(3, 'b'*8) self.respond(3, 'b'*8)
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
# At this point, the client is ready and using the protocol, # At this point, the client is ready and using the protocol,
...@@ -466,9 +491,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -466,9 +491,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
connected = client.connected connected = client.connected
# The client asks for info, and we respond: # The client asks for info, and we respond:
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(), (4, False, 'get_info', ()))
(4, False, 'get_info', ())) self.respond(4, dict(length=42))
respond(4, dict(length=42))
self.assert_(connected.done()) self.assert_(connected.done())
...@@ -481,7 +505,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -481,7 +505,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
# We respond and the writable connection succeeds: # We respond and the writable connection succeeds:
respond(1, None) self.respond(1, None)
self.assertFalse(self.is_read_only()) self.assertFalse(self.is_read_only())
# at this point, a lastTransaction request is emitted: # at this point, a lastTransaction request is emitted:
...@@ -501,28 +525,25 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -501,28 +525,25 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(protocol.read_only, False) self.assertEqual(protocol.read_only, False)
# Now, we finish verification # Now, we finish verification
respond(2, 'b'*8) self.respond(2, 'b'*8)
respond(3, dict(length=42)) self.respond(3, dict(length=42))
self.assert_(client.ready) self.assert_(client.ready)
self.assert_(client.connected.done()) self.assert_(client.connected.done())
def test_invalidations_while_verifying(self): def test_invalidations_while_verifying(self):
# While we're verifying, invalidations are ignored # While we're verifying, invalidations are ignored
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start()
self.start())
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) self.respond(1, None)
send('invalidateTransaction', b'b'*8, [b'1'*8]) self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False)
self.assertFalse(wrapper.invalidateTransaction.called) self.respond(2, b'a'*8)
respond(2, b'a'*8) self.send('invalidateTransaction', b'c'*8, [b'1'*8], no_output=False)
send('invalidateTransaction', b'c'*8, [b'1'*8]) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
wrapper.invalidateTransaction.assert_called_with(b'c'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
# We'll disconnect: # We'll disconnect:
protocol.connection_lost(Exception("lost")) protocol.connection_lost(Exception("lost"))
...@@ -535,17 +556,15 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -535,17 +556,15 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
[(1, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) self.respond(1, None)
send('invalidateTransaction', b'd'*8, [b'1'*8]) self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False)
self.assertFalse(wrapper.invalidateTransaction.called) self.respond(2, b'c'*8)
respond(2, b'c'*8) self.send('invalidateTransaction', b'e'*8, [b'1'*8], no_output=False)
send('invalidateTransaction', b'e'*8, [b'1'*8]) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
wrapper.invalidateTransaction.assert_called_with(b'e'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
def test_flow_control(self): def test_flow_control(self):
# When sending a lot of data (blobs), we don't want to fill up # When sending a lot of data (blobs), we don't want to fill up
...@@ -553,8 +572,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -553,8 +572,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# seems a bit complicated. We'd rather pass an iterator that's # seems a bit complicated. We'd rather pass an iterator that's
# consumed as we can. # consumed as we can.
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start(
self.start(finish_start=True)) finish_start=True)
# Give the transport a small capacity: # Give the transport a small capacity:
transport.capacity = 2 transport.capacity = 2
...@@ -564,52 +583,45 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -564,52 +583,45 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.async('splat') self.async('splat')
# The first 2 were sent, but the remaining were queued. # The first 2 were sent, but the remaining were queued.
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
[(0, True, 'foo', ()), (0, True, 'bar', ())]) [(0, True, 'foo', ()), (0, True, 'bar', ())])
# But popping them allowed sending to resume: # But popping them allowed sending to resume:
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(),
[(0, True, 'baz', ()), (0, True, 'splat', ())]) [(0, True, 'baz', ()), (0, True, 'splat', ())])
# This is especially handy with iterators: # This is especially handy with iterators:
self.async_iter((name, ()) for name in 'abcde') self.async_iter((name, ()) for name in 'abcde')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(), [(0, True, 'a', ()), (0, True, 'b', ())])
[(0, True, 'a', ()), (0, True, 'b', ())]) self.assertEqual(self.pop(), [(0, True, 'c', ()), (0, True, 'd', ())])
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.pop(), (0, True, 'e', ()))
[(0, True, 'c', ()), (0, True, 'd', ())]) self.assertEqual(self.pop(), [])
self.assertEqual(self.parse(transport.pop()),
(0, True, 'e', ()))
self.assertEqual(self.parse(transport.pop()),
[])
def test_bad_protocol(self): def test_bad_protocol(self):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start()
self.start())
with mock.patch("ZEO.asyncio.client.logger.error") as error: with mock.patch("ZEO.asyncio.client.logger.error") as error:
self.assertFalse(error.called) self.assertFalse(error.called)
protocol.data_received(sized(b'Z200')) protocol.data_received(sized(b'Z200'))
self.assert_(isinstance(error.call_args[0][1], self.assert_(isinstance(error.call_args[0][1], ProtocolError))
ZEO.Exceptions.ProtocolError))
def test_get_peername(self): def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start(
self.start(finish_start=True)) finish_start=True)
self.assertEqual(client.get_peername(), '1.2.3.4') self.assertEqual(client.get_peername(), '1.2.3.4')
def test_call_async_from_same_thread(self): def test_call_async_from_same_thread(self):
# There are a few (1?) cases where we call into client storage # There are a few (1?) cases where we call into client storage
# where it needs to call back asyncronously. Because we're # where it needs to call back asyncronously. Because we're
# calling from the same thread, we don't need to use a futurte. # calling from the same thread, we don't need to use a futurte.
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start(
self.start(finish_start=True)) finish_start=True)
client.call_async_from_same_thread('foo', 1) client.call_async_from_same_thread('foo', 1)
self.assertEqual(self.parse(transport.pop()), (0, True, 'foo', (1, ))) self.assertEqual(self.pop(), (0, True, 'foo', (1, )))
def test_ClientDisconnected_on_call_timeout(self): def test_ClientDisconnected_on_call_timeout(self):
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start()
self.start())
self.wait_for_result = super().wait_for_result self.wait_for_result = super().wait_for_result
self.assertRaises(ClientDisconnected, self.call, 'foo') self.assertRaises(ClientDisconnected, self.call, 'foo')
client.ready = False client.ready = False
...@@ -620,34 +632,35 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -620,34 +632,35 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# that caused it to fail badly if errors were raised while # that caused it to fail badly if errors were raised while
# handling data. # handling data.
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport =self.start(
self.start(finish_start=True)) finish_start=True)
wrapper.receiveBlobStart.side_effect = ValueError('test') wrapper.receiveBlobStart.side_effect = ValueError('test')
chunk = 'x' * 99999 chunk = 'x' * 99999
try: try:
loop.protocol.data_received(b''.join((
sized(pickle.dumps(
(0, True, 'receiveBlobStart', ('oid', 'serial')), 3)),
sized(pickle.dumps(
(0, True, 'receiveBlobChunk',
('oid', 'serial', chunk)), 3)),
)))
except ValueError:
pass
loop.protocol.data_received( loop.protocol.data_received(
sized(pickle.dumps( sized(
(0, True, 'receiveBlobStop', ('oid', 'serial')), 3)), self.encode(0, True, 'receiveBlobStart', ('oid', 'serial'))
) +
sized(
self.encode(
0, True, 'receiveBlobChunk', ('oid', 'serial', chunk))
)
) )
except ValueError:
pass
loop.protocol.data_received(sized(
self.encode(0, True, 'receiveBlobStop', ('oid', 'serial'))
))
wrapper.receiveBlobChunk.assert_called_with('oid', 'serial', chunk) wrapper.receiveBlobChunk.assert_called_with('oid', 'serial', chunk)
wrapper.receiveBlobStop.assert_called_with('oid', 'serial') wrapper.receiveBlobStop.assert_called_with('oid', 'serial')
def test_heartbeat(self): def test_heartbeat(self):
# Protocols run heartbeats on a configurable (sort of) # Protocols run heartbeats on a configurable (sort of)
# heartbeat interval, which defaults to every 60 seconds. # heartbeat interval, which defaults to every 60 seconds.
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport = self.start(
self.start(finish_start=True)) finish_start=True)
delay, func, args, handle = loop.later.pop() delay, func, args, handle = loop.later.pop()
self.assertEqual( self.assertEqual(
...@@ -658,7 +671,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -658,7 +671,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# The heartbeat function sends heartbeat data and reschedules itself. # The heartbeat function sends heartbeat data and reschedules itself.
func() func()
self.assertEqual(self.parse(transport.pop()), (-1, 0, '.reply', None)) self.assertEqual(self.pop(), (-1, 0, '.reply', None))
self.assertTrue(protocol.heartbeat_handle != handle) self.assertTrue(protocol.heartbeat_handle != handle)
delay, func, args, handle = loop.later.pop() delay, func, args, handle = loop.later.pop()
...@@ -672,27 +685,6 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -672,27 +685,6 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
protocol.connection_lost(None) protocol.connection_lost(None)
self.assertTrue(handle.cancelled) self.assertTrue(handle.cancelled)
def unsized(self, data, unpickle=False):
result = []
while data:
size, message, *data = data
self.assertEqual(struct.unpack(">I", size)[0], len(message))
if unpickle:
message = pickle.loads(message)
result.append(message)
if len(result) == 1:
result = result[0]
return result
def parse(self, data):
return self.unsized(data, True)
def response(*data):
return sized(pickle.dumps(data, 3))
def sized(message):
return struct.pack(">I", len(message)) + message
class MemoryCache: class MemoryCache:
...@@ -745,6 +737,106 @@ class MemoryCache: ...@@ -745,6 +737,106 @@ class MemoryCache:
def setLastTid(self, tid): def setLastTid(self, tid):
self.last_tid = tid self.last_tid = tid
class ServerTests(Base, setupstack.TestCase):
# The server side of things is pretty simple compared to the
# client, because it's the clien't job to make and keep
# connections. Servers are pretty passive.
def connect(self, finish=False):
protocol = server_protocol()
self.loop = protocol.loop
self.target = protocol.zeo_storage
if finish:
self.assertEqual(self.pop(parse=False), best_protocol_version)
protocol.data_received(sized(b'Z4'))
return protocol
message_id = 0
target = None
def call(self, meth, *args, **kw):
if kw:
expect = kw.pop('expect', self)
target = kw.pop('target', self.target)
self.assertFalse(kw)
if target is not None:
target = getattr(target, meth)
if expect is not self:
target.return_value = expect
self.message_id += 1
self.loop.protocol.data_received(
sized(self.encode(self.message_id, False, meth, args)))
if target is not None:
target.assert_called_once_with(*args)
target.reset_mock()
if expect is not self:
self.assertEqual(self.pop(),
(self.message_id, False, '.reply', expect))
def testServerBasics(self):
# A simple listening thread accepts connections. It creats
# asyncio connections by calling ZEO.asyncio.new_connection:
protocol = self.connect()
self.assertFalse(protocol.zeo_storage.notify_connected.called)
# The server sends it's protocol.
self.assertEqual(self.pop(parse=False), best_protocol_version)
# The client sends it's protocol:
protocol.data_received(sized(b'Z4'))
self.assertEqual(protocol.protocol_version, b'Z4')
protocol.zeo_storage.notify_connected.assert_called_once_with(protocol)
# The client registers:
self.call('register', False, expect=None)
# It does other things, like, send hearbeats:
protocol.data_received(sized(b'(J\xff\xff\xff\xffK\x00U\x06.replyNt.'))
# The client can make async calls:
self.send('register')
# Let's close the connection
self.assertFalse(protocol.zeo_storage.notify_disconnected.called)
protocol.connection_lost(None)
protocol.zeo_storage.notify_disconnected.assert_called_once_with()
def test_invalid_methods(self):
protocol = self.connect(True)
protocol.zeo_storage.notify_connected.assert_called_once_with(protocol)
# If we try to call a methid that isn't in the protocol's
# white list, it will disconnect:
self.assertFalse(protocol.loop.transport.closed)
self.call('foo', target=None)
self.assertTrue(protocol.loop.transport.closed)
def server_protocol(zeo_storage=None,
protocol_version=None,
addr=('1.2.3.4', '42'),
):
if zeo_storage is None:
zeo_storage = mock.Mock()
loop = Loop()
sock = () # anything not None
new_connection(loop, addr, sock, zeo_storage)
if protocol_version:
loop.protocol.data_received(sized(protocol_version))
return loop.protocol
def response(*data):
return sized(self.encode(*data))
def sized(message):
return struct.pack(">I", len(message)) + message
class Logging: class Logging:
def __init__(self, level=logging.ERROR): def __init__(self, level=logging.ERROR):
...@@ -762,5 +854,6 @@ class Logging: ...@@ -762,5 +854,6 @@ class Logging:
def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(AsyncTests)) suite.addTest(unittest.makeSuite(ClientTests))
suite.addTest(unittest.makeSuite(ServerTests))
return suite return suite
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
_auth_modules = {}
def get_module(name):
if name == 'sha':
from auth_sha import StorageClass, SHAClient, Database
return StorageClass, SHAClient, Database
elif name == 'digest':
from .auth_digest import StorageClass, DigestClient, DigestDatabase
return StorageClass, DigestClient, DigestDatabase
else:
return _auth_modules.get(name)
def register_module(name, storage_class, client, db):
if name in _auth_modules:
raise TypeError("%s is already registred" % name)
_auth_modules[name] = storage_class, client, db
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Digest authentication for ZEO
This authentication mechanism follows the design of HTTP digest
authentication (RFC 2069). It is a simple challenge-response protocol
that does not send passwords in the clear, but does not offer strong
security. The RFC discusses many of the limitations of this kind of
protocol.
Guard the password database as if it contained plaintext passwords.
It stores the hash of a username and password. This does not expose
the plaintext password, but it is sensitive nonetheless. An attacker
with the hash can impersonate the real user. This is a limitation of
the simple digest scheme.
HTTP is a stateless protocol, and ZEO is a stateful protocol. The
security requirements are quite different as a result. The HTTP
protocol uses a nonce as a challenge. The ZEO protocol requires a
separate session key that is used for message authentication. We
generate a second nonce for this purpose; the hash of nonce and
user/realm/password is used as the session key.
TODO: I'm not sure if this is a sound approach; SRP would be preferred.
"""
import os
import random
import struct
import time
from ZEO.auth.base import Database, Client
from ZEO.StorageServer import ZEOStorage
from ZEO.Exceptions import AuthError
from ZEO.hash import sha1
def get_random_bytes(n=8):
try:
b = os.urandom(n)
except NotImplementedError:
L = [chr(random.randint(0, 255)) for i in range(n)]
b = b"".join(L)
return b
def hexdigest(s):
return sha1(s.encode()).hexdigest()
class DigestDatabase(Database):
def __init__(self, filename, realm=None):
Database.__init__(self, filename, realm)
# Initialize a key used to build the nonce for a challenge.
# We need one key for the lifetime of the server, so it
# is convenient to store in on the database.
self.noncekey = get_random_bytes(8)
def _store_password(self, username, password):
dig = hexdigest("%s:%s:%s" % (username, self.realm, password))
self._users[username] = dig
def session_key(h_up, nonce):
# The hash itself is a bit too short to be a session key.
# HMAC wants a 64-byte key. We don't want to use h_up
# directly because it would never change over time. Instead
# use the hash plus part of h_up.
return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() +
h_up.encode('utf-8')[:44])
class StorageClass(ZEOStorage):
def set_database(self, database):
assert isinstance(database, DigestDatabase)
self.database = database
self.noncekey = database.noncekey
def _get_time(self):
# Return a string representing the current time.
t = int(time.time())
return struct.pack("i", t)
def _get_nonce(self):
# RFC 2069 recommends a nonce of the form
# H(client-IP ":" time-stamp ":" private-key)
dig = sha1()
dig.update(str(self.connection.addr).encode('latin-1'))
dig.update(self._get_time())
dig.update(self.noncekey)
return dig.hexdigest()
def auth_get_challenge(self):
"""Return realm, challenge, and nonce."""
self._challenge = self._get_nonce()
self._key_nonce = self._get_nonce()
return self.auth_realm, self._challenge, self._key_nonce
def auth_response(self, resp):
# verify client response
user, challenge, response = resp
# Since zrpc is a stateful protocol, we just store the nonce
# we sent to the client. It will need to generate a new
# nonce for a new connection anyway.
if self._challenge != challenge:
raise ValueError("invalid challenge")
# lookup user in database
h_up = self.database.get_password(user)
# regeneration resp from user, password, and nonce
check = hexdigest("%s:%s" % (h_up, challenge))
if check == response:
self.connection.setSessionKey(session_key(h_up, self._key_nonce))
return self._finish_auth(check == response)
extensions = [auth_get_challenge, auth_response]
class DigestClient(Client):
extensions = ["auth_get_challenge", "auth_response"]
def start(self, username, realm, password):
_realm, challenge, nonce = self.stub.auth_get_challenge()
if _realm != realm:
raise AuthError("expected realm %r, got realm %r"
% (_realm, realm))
h_up = hexdigest("%s:%s:%s" % (username, realm, password))
resp_dig = hexdigest("%s:%s" % (h_up, challenge))
result = self.stub.auth_response((username, challenge, resp_dig))
if result:
return session_key(h_up, nonce)
else:
return None
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Base classes for defining an authentication protocol.
Database -- abstract base class for password database
Client -- abstract base class for authentication client
"""
from __future__ import print_function
from __future__ import print_function
import os
from ZEO.hash import sha1
class Client:
# Subclass should override to list the names of methods that
# will be called on the server.
extensions = []
def __init__(self, stub):
self.stub = stub
for m in self.extensions:
setattr(self.stub, m, self.stub.extensionMethod(m))
def sort(L):
"""Sort a list in-place and return it."""
L.sort()
return L
class Database:
"""Abstracts a password database.
This class is used both in the authentication process (via
get_password()) and by client scripts that manage the password
database file.
The password file is a simple, colon-separated text file mapping
usernames to password hashes. The hashes are SHA hex digests
produced from the password string.
"""
realm = None
def __init__(self, filename, realm=None):
"""Creates a new Database
filename: a string containing the full pathname of
the password database file. Must be readable by the user
running ZEO. Must be writeable by any client script that
accesses the database.
realm: the realm name (a string)
"""
self._users = {}
self.filename = filename
self.load()
if realm:
if self.realm and self.realm != realm:
raise ValueError("Specified realm %r differs from database "
"realm %r" % (realm or '', self.realm))
else:
self.realm = realm
def save(self, fd=None):
filename = self.filename
needs_closed = False
if not fd:
fd = open(filename, 'w')
needs_closed = True
try:
if self.realm:
print("realm", self.realm, file=fd)
for username in sorted(self._users.keys()):
print("%s: %s" % (username, self._users[username]), file=fd)
finally:
if needs_closed:
fd.close()
def load(self):
filename = self.filename
if not filename:
return
if not os.path.exists(filename):
return
with open(filename) as fd:
L = fd.readlines()
if not L:
return
if L[0].startswith("realm "):
line = L.pop(0).strip()
self.realm = line[len("realm "):]
for line in L:
username, hash = line.strip().split(":", 1)
self._users[username] = hash.strip()
def _store_password(self, username, password):
self._users[username] = self.hash(password)
def get_password(self, username):
"""Returns password hash for specified username.
Callers must check for LookupError, which is raised in
the case of a non-existent user specified."""
if username not in self._users:
raise LookupError("No such user: %s" % username)
return self._users[username]
def hash(self, s):
return sha1(s.encode()).hexdigest()
def add_user(self, username, password):
if username in self._users:
raise LookupError("User %s already exists" % username)
self._store_password(username, password)
def del_user(self, username):
if username not in self._users:
raise LookupError("No such user: %s" % username)
del self._users[username]
def change_password(self, username, password):
if username not in self._users:
raise LookupError("No such user: %s" % username)
self._store_password(username, password)
"""HMAC (Keyed-Hashing for Message Authentication) Python module.
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC:
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
"""Create a new HMAC object.
key: key for the keyed hash object.
msg: Initial input for the hash, if provided.
digestmod: A module supporting PEP 247. Defaults to the md5 module.
"""
if digestmod is None:
import md5
digestmod = md5
self.digestmod = digestmod
self.outer = digestmod.new()
self.inner = digestmod.new()
self.digest_size = digestmod.digest_size
blocksize = 64
ipad = "\x36" * blocksize
opad = "\x5C" * blocksize
if len(key) > blocksize:
key = digestmod.new(key).digest()
key = key + chr(0) * (blocksize - len(key))
self.outer.update(_strxor(key, opad))
self.inner.update(_strxor(key, ipad))
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
"""
self.inner.update(msg)
def copy(self):
"""Return a separate copy of this hashing object.
An update to this copy won't affect the original object.
"""
other = HMAC("")
other.digestmod = self.digestmod
other.inner = self.inner.copy()
other.outer = self.outer.copy()
return other
def digest(self):
"""Return the hash value of this hashing object.
This returns a string containing 8-bit data. The object is
not altered in any way by this function; you can continue
updating the object after calling this function.
"""
h = self.outer.copy()
h.update(self.inner.digest())
return h.digest()
def hexdigest(self):
"""Like digest(), but returns a string of hexadecimal digits instead.
"""
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
msg: if available, will immediately be hashed into the object's starting
state.
You can now feed arbitrary strings into the object using its update()
method, and can ask for the hash value at any time by calling its digest()
method.
"""
return HMAC(key, msg, digestmod)
...@@ -55,22 +55,6 @@ ...@@ -55,22 +55,6 @@
</description> </description>
</key> </key>
<key name="monitor-address" datatype="socket-binding-address"
required="no">
<description>
The address at which the monitor server should listen. If
specified, a monitor server is started. The monitor server
provides server statistics in a simple text format. This can
be in the form 'host:port' to signify a TCP/IP connection or a
pathname string to signify a Unix domain socket connection (at
least one '/' is required). A hostname may be a DNS name or a
dotted IP address. If the hostname is omitted, the platform's
default behavior is used when binding the listening socket (''
is passed to socket.bind() as the hostname portion of the
address).
</description>
</key>
<key name="transaction-timeout" datatype="integer" <key name="transaction-timeout" datatype="integer"
required="no"> required="no">
<description> <description>
...@@ -81,28 +65,6 @@ ...@@ -81,28 +65,6 @@
</description> </description>
</key> </key>
<key name="authentication-protocol" required="no">
<description>
The name of the protocol used for authentication. The
only protocol provided with ZEO is "digest," but extensions
may provide other protocols.
</description>
</key>
<key name="authentication-database" required="no">
<description>
The path of the database containing authentication credentials.
</description>
</key>
<key name="authentication-realm" required="no">
<description>
The authentication realm of the server. Some authentication
schemes use a realm to identify the logical set of usernames
that are accepted by this server.
</description>
</key>
<key name="pid-filename" datatype="existing-dirpath" <key name="pid-filename" datatype="existing-dirpath"
required="no"> required="no">
<description> <description>
......
...@@ -24,7 +24,8 @@ class StaleCache(object): ...@@ -24,7 +24,8 @@ class StaleCache(object):
class IClientCache(zope.interface.Interface): class IClientCache(zope.interface.Interface):
"""Client cache interface. """Client cache interface.
Note that caches need not be thread safe. Note that caches need not be thread safe, fpr the most part,
except for getLastTid, which may be called from multiple threads.
""" """
def close(): def close():
...@@ -73,6 +74,9 @@ class IClientCache(zope.interface.Interface): ...@@ -73,6 +74,9 @@ class IClientCache(zope.interface.Interface):
"""Get the last tid seen by the cache """Get the last tid seen by the cache
This is the cached last tid we've seen from the server. This is the cached last tid we've seen from the server.
This method may be called from multiple threads. (It's assumed
to be trivial.)
""" """
def setLastTid(tid): def setLastTid(tid):
......
...@@ -59,7 +59,6 @@ class StorageStats: ...@@ -59,7 +59,6 @@ class StorageStats:
self.commits = 0 self.commits = 0
self.aborts = 0 self.aborts = 0
self.active_txns = 0 self.active_txns = 0
self.verifying_clients = 0
self.lock_time = None self.lock_time = None
self.conflicts = 0 self.conflicts = 0
self.conflicts_resolved = 0 self.conflicts_resolved = 0
...@@ -114,79 +113,3 @@ class StorageStats: ...@@ -114,79 +113,3 @@ class StorageStats:
print("Stores:", self.stores, file=f) print("Stores:", self.stores, file=f)
print("Conflicts:", self.conflicts, file=f) print("Conflicts:", self.conflicts, file=f)
print("Conflicts resolved:", self.conflicts_resolved, file=f) print("Conflicts resolved:", self.conflicts_resolved, file=f)
class StatsClient(asyncore.dispatcher):
def __init__(self, sock, addr):
asyncore.dispatcher.__init__(self, sock)
self.buf = []
self.closed = 0
def close(self):
self.closed = 1
# The socket is closed after all the data is written.
# See handle_write().
def write(self, s):
self.buf.append(s)
def writable(self):
return len(self.buf)
def readable(self):
return 0
def handle_write(self):
s = "".join(self.buf)
self.buf = []
n = self.socket.send(s.encode('ascii'))
if n < len(s):
self.buf.append(s[:n])
if self.closed and not self.buf:
asyncore.dispatcher.close(self)
class StatsServer(asyncore.dispatcher):
StatsConnectionClass = StatsClient
def __init__(self, addr, stats):
asyncore.dispatcher.__init__(self)
self.addr = addr
self.stats = stats
if type(self.addr) == tuple:
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
else:
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.set_reuse_addr()
logger = logging.getLogger('ZEO.monitor')
logger.info("listening on %s", repr(self.addr))
self.bind(self.addr)
self.listen(5)
def writable(self):
return 0
def readable(self):
return 1
def handle_accept(self):
try:
sock, addr = self.accept()
except socket.error:
return
f = self.StatsConnectionClass(sock, addr)
self.dump(f)
f.close()
def dump(self, f):
print("ZEO monitor server version %s" % zeo_version, file=f)
print(time.ctime(), file=f)
print(file=f)
L = sorted(self.stats.keys())
for k in L:
stats = self.stats[k]
print("Storage:", k, file=f)
stats.dump(f)
print(file=f)
...@@ -10,133 +10,212 @@ An object history is a sequence of object revisions. Each revision has ...@@ -10,133 +10,212 @@ An object history is a sequence of object revisions. Each revision has
a tid, which is essentially a time stamp. a tid, which is essentially a time stamp.
We load objects using either ``load``, which returns the current We load objects using either ``load``, which returns the current
object. or loadBefore, which returns the object before a specific time/tid. object. or ``loadBefore``, which returns the object before a specific time/tid.
When we cache revisions, we record the tid and the next/end tid, which When we cache revisions, we record the tid and the next/end tid, which
may be None. The end tid is important for choosing a revision for may be None. The end tid is important for choosing a revision for
loadBefore, as well as for determining whether a cached value is ``loadBefore``, as well as for determining whether a cached value is
current, for load. current, for ``load``.
Because the client and server are multi-threaded, the client may see Because the client and server are multi-threaded, the client may see
data out of order. Let's consider some scenarios. In these data out of order. Let's consider some scenarios. In these
scenarios, we'll consider a single object with revisions t1, t2, etc. scenarios
We consider loading pretty generically, as bath load and loadBefore
are similar in that they may have data about current revisions.
Scenarios Scenarios
========= =========
S1 When considering ordering scenarioes, we'll consider 2 different
Client sees load results before earlier invalidations client behaviors, traditional (T) and loadBefore (B).
- server commits t1 The *traditional* behaviors is that used in ZODB 4. It uses the storage
``load(oid)`` method to load objects if it hasn't seen an invalidation
for the object. If it has seen an invalidation, it uses
``loadBefore(oid, START)``, where ``START`` is the transaction time of
the first invalidation it's seen. If it hasn't seen an invalidation
*for an object*, it uses ``load(oid)`` and then checks again for an
invalidation. If it sees an invalidation, then it retries using
``loadBefore``. This approach **assumes that invalidations for a tid
are returned before loads for a tid**.
- server commits t2 The *loadBefore* behavior, used in ZODB5, always determines
transaction start time, ``START`` at the beginning of a transaction by
calling the storage's ``sync`` method and then querying the storage's
``lastTransaction`` method (and adding 1). It loads objects
exclusively using ``loadBefore(oid, START)``.
- client makes load request, server loads t2 Scenario 1, Invalidations seen after loads for transaction
----------------------------------------------------------
- client gets load result for t2 This scenario could occur because the commits are for a different
client, and a hypothetical; server doesn't block loads while
committing, or sends invalidations in a way that might delay them (but
not send them out of order).
- client gets invalidation for t1, client should ignore T1
- client gets invalidation for t2, client should ignore - client starts a transaction
This scenario could occur because the commits are for a different - client load(O1) gets O1-T1
client, and a hypothetical; server doesn't block loads while
committing. This situation is pretty easy to deal with, as we just
ignore invalidations for earlier revisions.
Note that invalidations will never come out of order from the server. - client load(O2)
S2 - Server commits O2-T2
Client sees load results before finish results (for another client thread)
- Client commits, server commits t1 - Server loads (O2-T2)
- Client commits, server commits t2 - Client gets O2-T2, updates the client cache, and completes load
- Client makes load request, server reads t2. - Client sees invalidation for O2-T2. If the
client is smart, it doesn't update the cache.
- Client receives t2 in load result. The transaction now has inconsistent data, because it should have
loaded whatever O2 was before T2. Because the invalidation came
in after O2 was loaded, the load was unaffected.
- Client receives t1 in tpc_finish result, doesn't invalidate anything B1
- Client receives t2 in tpc_finish result, doesn't invalidate anything - client starts a transaction. Sets START to T1+1
This scenario is equivalent to S1. - client loadBefore(O1, T1+1) gets O1-T1, T1, None
S3 - client loadBefore(O2, T1+1)
Client sees invalidations before load results.
- Client loads, storage reads t1. - Server commits O2-T2
- server commits t2 - Server loadBefore(O2, T1+1) -> O2-T0-T2
- Client receives invalidation for t2. (assuming that the revision of O2 before T2 was T0)
- Client receives load result for t1. - Client gets O2-T0-T2, updates cache.
This scenario is worrisome because the data that needs to be - Client sees invalidation for O2-T2. No update to the cache is
invalidated isn't present when the invalidation arrives. necessary.
S4 In this scenario, loadBefore prevents reading incorrect data.
Client sees commit results before load results.
- Client loads, storage reads t1. A variation on this scenario is that client sees invalidations
tpc_finish in another thread after loads for the same transaction.
- Client commits, storage commits t2. Scenario 2, Client sees invalidations for later transaction before load result
------------------------------------------------------------------------------
- Client receives t2 in tpc_finish result. T2
- Client receives load result for t1. - client starts a transaction
This scenario is equivalent to S3. - client load(O1) gets O1-T1
Implementation notes - client load(O2)
===================
First, it's worth noting that the server sends data to the client in - Server loads (O2-T0)
correct order with respect to loads and invalidations (or tpc_finish
results). This is a consequence of the fact that invalidations are
sent in a callback called when the storage lock is held, blocking
loadd while committing, and the fact that client requests, for a
particular client, are handled by a single thread on the server.
Invalidations are sent from different threads that clients. Outgoing - Server commits O2-T2
data is queued, however, using Python lists, which are protected by
the GIL. This means that the serialization provided though storage - Client sees invalidation for O2-T2. O2 isn't in the cache, so
locks is preserved by the way that server outputs are queued. nothing to do.
- Client gets O2-T0, updates the client cache, and completes load
The cache is now incorrect. It has O2-T0-None, meaning it thinks
O2-T0 is current.
The transaction is OK, because it got a consistent value for O2.
B2
- client starts a transaction. Sets START to T1+1
- client loadBefore(O1, T1+1) gets O1-T1, T1, None
- client loadBefore(O2, T1+1)
- Server loadBefore(O2, T1+1) -> O2-T0-None
- Server commits O2-T2
- Client sees invalidation for O2-T2. O2 isn't in the cache, so
nothing to do.
- Client gets O2-T0-None, and completes load
ZEO 4 doesn't cache loadBefore results with no ending transaction.
Assume ZEO 5 updates the client cache.
For ZEO 5, the cache is now incorrect. It has O2-T0-None, meaning
it thinks O2-T0 is current.
The transaction is OK, because it got a consistent value for O2.
In this case, ``loadBefore`` didn't prevent an invalid cache value.
Scenario 3, client sees invalidation after lastTransaction result
------------------------------------------------------------------
(This doesn't effect the traditional behavior.)
B3
- The client cache has a last tid of T1.
- ZODB calls sync() then calls lastTransaction. Is so configured,
ZEO calls lastTransaction on the server. This is mainly to make a
round trip to get in-flight invalidations. We don't necessarily
need to use the value. In fact, in protocol 5, we could just add a
sync method that just makes a round trip, but does nothing else.
- Server commits O1-T2, O2-T2.
- Server reads and returns T2. (It doesn't mater what it returns
- client sets START to T1+1, because lastTransaction is based on
what's in the cache, which is based on invalidations.
- Client loadBefore(O1, T2+1), finds O1-T1-None in cache and uses
it.
- Client gets invalidation for O1-T2. Updates cache to O1-T1-T2.
- Client loadBefore(O2, T1+1), gets O2-T1-None
This is OK, as long as the client doesn't do anything with the
lastTransaction result in ``sync``.
Implementation notes
===================
ZEO 4 ZEO 4
----- -----
In ZEO 4, invalidations and loads are handled by separate The ZEO 4 server sends data to the client in correct order with
respect to loads and invalidations (or tpc_finish results). This is a
consequence of the fact that invalidations are sent in a callback
called when the storage lock is held, blocking loads while committing,
and, fact that client requests, for a particular client, are
handled by a single thread on the server, and that all output for a
client goes through a thread-safe queue.
Invalidations are sent from different threads than clients. Outgoing
data is queued, however, using Python lists, which are protected by
the GIL. This means that the serialization provided though storage
locks is preserved by the way that server outputs are queued. **The
queueing mechanism is in part a consequence of the way asyncore, used
by ZEO4, works.
In ZEO 4 clients, invalidations and loads are handled by separate
threads. This means that even though data arive in order, they may not threads. This means that even though data arive in order, they may not
be processed in order, be processed in order,
S1 T1
The existing servers mitigate this by blocking loads while The existing servers mitigate this by blocking loads while
committing. On the client, this is still a potential issue because loads committing. On the client, this is still a potential issue because loads
and invalidations are handled by separate threads. and invalidations are handled by separate threads, however, locks are
used on the client to assure that invalidations are processed before
blocked loads complete.
The client cache is conservative because it always forgets current data in T2
memory when it sees an invalidation data for an object. Existing storage servers serialize commits (and thus sending of
invalidations) and loads. As with scenario T1, threading on the
The client gets this scenario wrong, in an edge case, because it
checks for invalidations matching the current tid, but not
invalidations before the current tid. If the thread handling
invalidations was slow enough for this scenario to occur, then the
cache would end up with an end tid < a starting tid. This is
probably very unlikely.
S2
The existing clients prevent this by serializing commits with each
other (only one at a time on the client) and with loads.
S3
Existing storages serialize commits (and thus sending of
invalidations) and loads. As with scenario S1, threading on the
client can cause load results and invalidations to be processed out client can cause load results and invalidations to be processed out
of order. To mitigate this, the client uses a load lock to track of order. To mitigate this, the client uses a load lock to track
when loads are invalidated while in flight and doesn't save to the when loads are invalidated while in flight and doesn't save to the
...@@ -145,9 +224,10 @@ S3 ...@@ -145,9 +224,10 @@ S3
to the cache unnecessarily, if the invalidation is for a revision to the cache unnecessarily, if the invalidation is for a revision
before the one that was loaded. before the one that was loaded.
S4 B2
As with S2, clients mitigate this by preventing simultaneous loads Here, we avoid incorrect returned values and incorrect cache at the
and commits. cost of caching nothing. For this reason, a future ZEO 4 revision
will require ZODB 4 or earlier.
ZEO 5 ZEO 5
----- -----
...@@ -156,34 +236,39 @@ In ZEO(/ZODB) 5, we want to get more concurrency, both on the client, ...@@ -156,34 +236,39 @@ In ZEO(/ZODB) 5, we want to get more concurrency, both on the client,
and on the server. On the client, cache invalidations and loads are and on the server. On the client, cache invalidations and loads are
done by the same thread, which makes things a bit simpler. This let's done by the same thread, which makes things a bit simpler. This let's
us get rid of the client load lock and prevents the scenarios above us get rid of the client load lock and prevents the scenarios above
with existing servers and storages. with existing servers.
On the client, we'd like to stop serializing loads and commits. We'd On the client, we'd like to stop serializing loads and commits. We'd
like commits (tpc_finish calls) to in flight with loads (and with like commits (tpc_finish calls) to be in flight with loads (and with
other commits). In the current protocol, tpc_finish, load and other commits). In the current protocol, tpc_finish, load and
loadBefore are all synchronous calls that are handled by a single loadBefore are all synchronous calls that are handled by a single
thread on the server, so these calls end up being serialized on the thread on the server, so these calls end up being serialized on the
server. server anyway.
If we ever allowed multiple threads to service client requests, then The server-side hndling of invalidations is a bit tricker in ZEO 5
we'd need to consider scenario S4, but this isn't an issue now (or for because there isn't a thread-safe queue of outgoing messages in ZEO 5
the foreseeable future). as there was in ZEO 4. The natural approach in ZEO 5 would be to use
asyncio's ``call_soon_threadsafe`` to send invalidations in a client's
thread. This could easily cause invalidations to be sent after loads.
As shown above, this isn't a problem for ZODB 5, at least assuming
that invalidations arrive in order. This would be a problem for
ZODB 4.
Note that this approach can't cause invalidations to be sent early,
because they could only be sent by the thread that's busy loading, so
scenario 2 wouldn't happen.
To mitigate T1, we could create a thread-safe server-side message
queue that's used when sending results. Unfortunately, this puts us
back in the position of having to wake up the event loop again (via
``call_soon_threadsafe``). Maybe that's OK.
The main server opportunity is allowing commits for separate oids to The main server opportunity is allowing commits for separate oids to
happen concurrently. This wouldn't effect the invalidation/load happen concurrently. This wouldn't effect the invalidation/load
ordering though, assuming we continued to block loading an oid while ordering though.
it was being committed in tpc_finish.
It would be nice not to block loads while making tpc_finish calls, but
We could also allow loads to proceed while invalidations are being storages do this anyway now, so there's nothing to be done about it
queued for an object. Queuing invalidations is pretty fast though. It's now. Storage locking requirements aren't well specified, and probably
not clear that this would be much of a win. This probably isn't worth should be rethought in light of ZODB5/loadBefore.
fooling with for now. If we did want to relax this, we could, on the
client, track invalidations for outstanding load requests and adjust
how we wrote data to the cache accordingly. Again, we won't bother in
the short term.
So, for now, we can rely on the server sending clients
properly-ordered loads and invalidations. Also, because invalidations
and loads will be performed by a single thread on the client, we can
count on the ordering being preserved on the client.
...@@ -22,7 +22,6 @@ Options: ...@@ -22,7 +22,6 @@ Options:
-f/--filename FILENAME -- filename for FileStorage -f/--filename FILENAME -- filename for FileStorage
-t/--timeout TIMEOUT -- transaction timeout in seconds (default no timeout) -t/--timeout TIMEOUT -- transaction timeout in seconds (default no timeout)
-h/--help -- print this usage message and exit -h/--help -- print this usage message and exit
-m/--monitor ADDRESS -- address of monitor server ([HOST:]PORT or PATH)
--pid-file PATH -- relative path to output file containing this process's pid; --pid-file PATH -- relative path to output file containing this process's pid;
default $(INSTANCE_HOME)/var/ZEO.pid but only if envar default $(INSTANCE_HOME)/var/ZEO.pid but only if envar
INSTANCE_HOME is defined INSTANCE_HOME is defined
...@@ -72,9 +71,6 @@ class ZEOOptionsMixin: ...@@ -72,9 +71,6 @@ class ZEOOptionsMixin:
def handle_address(self, arg): def handle_address(self, arg):
self.family, self.address = parse_binding_address(arg) self.family, self.address = parse_binding_address(arg)
def handle_monitor_address(self, arg):
self.monitor_family, self.monitor_address = parse_binding_address(arg)
def handle_filename(self, arg): def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*! from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig: class FSConfig:
...@@ -107,14 +103,6 @@ class ZEOOptionsMixin: ...@@ -107,14 +103,6 @@ class ZEOOptionsMixin:
self.add("invalidation_age", "zeo.invalidation_age") self.add("invalidation_age", "zeo.invalidation_age")
self.add("transaction_timeout", "zeo.transaction_timeout", self.add("transaction_timeout", "zeo.transaction_timeout",
"t:", "timeout=", float) "t:", "timeout=", float)
self.add("monitor_address", "zeo.monitor_address.address",
"m:", "monitor=", self.handle_monitor_address)
self.add('auth_protocol', 'zeo.authentication_protocol',
None, 'auth-protocol=', default=None)
self.add('auth_database', 'zeo.authentication_database',
None, 'auth-database=')
self.add('auth_realm', 'zeo.authentication_realm',
None, 'auth-realm=')
self.add('pid_file', 'zeo.pid_filename', self.add('pid_file', 'zeo.pid_filename',
None, 'pid-file=') None, 'pid-file=')
...@@ -184,6 +172,7 @@ class ZEOServer: ...@@ -184,6 +172,7 @@ class ZEOServer:
self.options.address[1] is None): self.options.address[1] is None):
self.options.address = self.options.address[0], 0 self.options.address = self.options.address[0], 0
return return
if self.can_connect(self.options.family, self.options.address): if self.can_connect(self.options.family, self.options.address):
self.options.usage("address %s already in use" % self.options.usage("address %s already in use" %
repr(self.options.address)) repr(self.options.address))
...@@ -352,10 +341,6 @@ def create_server(storages, options): ...@@ -352,10 +341,6 @@ def create_server(storages, options):
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout = options.transaction_timeout,
monitor_address = options.monitor_address,
auth_protocol = options.auth_protocol,
auth_database = options.auth_database,
auth_realm = options.auth_realm,
) )
...@@ -393,5 +378,11 @@ def main(args=None): ...@@ -393,5 +378,11 @@ def main(args=None):
s = ZEOServer(options) s = ZEOServer(options)
s.main() s.main()
def run(args):
options = ZEOOptions()
options.realize(args)
s = ZEOServer(options)
s.run()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -500,7 +500,8 @@ def days(f): ...@@ -500,7 +500,8 @@ def days(f):
minute(f, 10, detail=0) minute(f, 10, detail=0)
new_connection_idre = re.compile(r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):") new_connection_idre = re.compile(
r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):")
def verify(f): def verify(f):
f, = f f, = f
......
...@@ -11,27 +11,6 @@ ...@@ -11,27 +11,6 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
import os
import threading
import logging
from ZODB.loglevels import BLATHER
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
logger = logging.getLogger('ZEO.zrpc')
_label = "%s" % os.getpid()
def new_label():
global _label
_label = str(os.getpid())
def log(message, level=BLATHER, label=None, exc_info=False):
label = label or _label
if LOG_THREAD_ID:
label = label + ':' + threading.currentThread().getName()
logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info)
REPR_LIMIT = 60 REPR_LIMIT = 60
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# #
############################################################################## ##############################################################################
import concurrent.futures import concurrent.futures
import contextlib
import os import os
import time import time
import socket import socket
...@@ -21,7 +22,7 @@ import logging ...@@ -21,7 +22,7 @@ import logging
from ZEO.ClientStorage import ClientStorage from ZEO.ClientStorage import ClientStorage
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZEO.zrpc.marshal import encode from ZEO.asyncio.marshal import encode
from ZEO.tests import forker from ZEO.tests import forker
from ZODB.DB import DB from ZODB.DB import DB
...@@ -79,40 +80,22 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -79,40 +80,22 @@ class CommonSetupTearDown(StorageTestBase):
logging.info("setUp() %s", self.id()) logging.info("setUp() %s", self.id())
self.file = 'storage_conf' self.file = 'storage_conf'
self.addr = [] self.addr = []
self._pids = []
self._servers = [] self._servers = []
self.conf_paths = []
self.caches = [] self.caches = []
self._newAddr() self._newAddr()
self.startServer() self.startServer()
# self._old_log_level = logging.getLogger().getEffectiveLevel()
# logging.getLogger().setLevel(logging.WARNING)
# self._log_handler = logging.StreamHandler()
# logging.getLogger().addHandler(self._log_handler)
def tearDown(self): def tearDown(self):
"""Try to cause the tests to halt""" """Try to cause the tests to halt"""
# logging.getLogger().setLevel(self._old_log_level)
# logging.getLogger().removeHandler(self._log_handler)
# logging.info("tearDown() %s" % self.id())
for p in self.conf_paths:
os.remove(p)
if getattr(self, '_storage', None) is not None: if getattr(self, '_storage', None) is not None:
self._storage.close() self._storage.close()
if hasattr(self._storage, 'cleanup'): if hasattr(self._storage, 'cleanup'):
logging.debug("cleanup storage %s" % logging.debug("cleanup storage %s" %
self._storage.__name__) self._storage.__name__)
self._storage.cleanup() self._storage.cleanup()
for adminaddr in self._servers: for stop in self._servers:
if adminaddr is not None: stop()
forker.shutdown_zeo_server(adminaddr)
for pid in self._pids:
try:
os.waitpid(pid, 0)
except OSError:
pass # The subprocess module may already have waited
for c in self.caches: for c in self.caches:
for i in 0, 1: for i in 0, 1:
...@@ -183,7 +166,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -183,7 +166,7 @@ class CommonSetupTearDown(StorageTestBase):
return zconf return zconf
def startServer(self, create=1, index=0, read_only=0, ro_svr=0, keep=None, def startServer(self, create=1, index=0, read_only=0, ro_svr=0, keep=None,
path=None): path=None, **kw):
addr = self.addr[index] addr = self.addr[index]
logging.info("startServer(create=%d, index=%d, read_only=%d) @ %s" % logging.info("startServer(create=%d, index=%d, read_only=%d) @ %s" %
(create, index, read_only, addr)) (create, index, read_only, addr))
...@@ -193,19 +176,17 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -193,19 +176,17 @@ class CommonSetupTearDown(StorageTestBase):
zconf = self.getServerConfig(addr, ro_svr) zconf = self.getServerConfig(addr, ro_svr)
if keep is None: if keep is None:
keep = self.keep keep = self.keep
zeoport, adminaddr, pid, path = forker.start_zeo_server( zeoport, stop = forker.start_zeo_server(
sconf, zconf, addr[1], keep) sconf, zconf, addr[1], keep, **kw)
self.conf_paths.append(path) self._servers.append(stop)
self._pids.append(pid)
self._servers.append(adminaddr)
def shutdownServer(self, index=0): def shutdownServer(self, index=0):
logging.info("shutdownServer(index=%d) @ %s" % logging.info("shutdownServer(index=%d) @ %s" %
(index, self._servers[index])) (index, self._servers[index]))
adminaddr = self._servers[index] stop = self._servers[index]
if adminaddr is not None: if stop is not None:
forker.shutdown_zeo_server(adminaddr) stop()
self._servers[index] = None self._servers[index] = lambda : None
def pollUp(self, timeout=30.0, storage=None): def pollUp(self, timeout=30.0, storage=None):
if storage is None: if storage is None:
...@@ -310,6 +291,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -310,6 +291,7 @@ class ConnectionTests(CommonSetupTearDown):
# object is not in the cache. # object is not in the cache.
self.shutdownServer() self.shutdownServer()
self._storage = self.openClientStorage('test', 1000, wait=0) self._storage = self.openClientStorage('test', 1000, wait=0)
with short_timeout(self):
self.assertRaises(ClientDisconnected, self.assertRaises(ClientDisconnected,
self._storage.load, b'fredwash', '') self._storage.load, b'fredwash', '')
self._storage.close() self._storage.close()
...@@ -377,6 +359,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -377,6 +359,7 @@ class ConnectionTests(CommonSetupTearDown):
self.assertEqual(expected2, self._storage.load(oid2, '')) self.assertEqual(expected2, self._storage.load(oid2, ''))
# But oid1 should have been purged, so that trying to load it will # But oid1 should have been purged, so that trying to load it will
# try to fetch it from the (non-existent) ZEO server. # try to fetch it from the (non-existent) ZEO server.
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._storage.load, oid1, '') self.assertRaises(ClientDisconnected, self._storage.load, oid1, '')
self._storage.close() self._storage.close()
...@@ -569,13 +552,17 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -569,13 +552,17 @@ class ConnectionTests(CommonSetupTearDown):
self._storage = self.openClientStorage() self._storage = self.openClientStorage()
self._dostore() self._dostore()
self.shutdownServer() self.shutdownServer()
self.assertRaises(ClientDisconnected, self._storage.load, b'\0'*8, '') with short_timeout(self):
self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '')
self.startServer() self.startServer()
# No matter how long we wait, the client won't reconnect: # No matter how long we wait, the client won't reconnect:
time.sleep(2) time.sleep(2)
self.assertRaises(ClientDisconnected, self._storage.load, b'\0'*8, '') with short_timeout(self):
self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '')
class InvqTests(CommonSetupTearDown): class InvqTests(CommonSetupTearDown):
invq = 3 invq = 3
...@@ -701,6 +688,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -701,6 +688,7 @@ class ReconnectionTests(CommonSetupTearDown):
# Poll until the client disconnects # Poll until the client disconnects
self.pollDown() self.pollDown()
# Stores should fail now # Stores should fail now
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._dostore) self.assertRaises(ClientDisconnected, self._dostore)
# Restart the server # Restart the server
...@@ -750,6 +738,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -750,6 +738,7 @@ class ReconnectionTests(CommonSetupTearDown):
# Poll until the client disconnects # Poll until the client disconnects
self.pollDown() self.pollDown()
# Stores should fail now # Stores should fail now
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._dostore) self.assertRaises(ClientDisconnected, self._dostore)
# Restart the server # Restart the server
...@@ -780,8 +769,8 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -780,8 +769,8 @@ class ReconnectionTests(CommonSetupTearDown):
self.pollDown() self.pollDown()
# Accesses should fail now # Accesses should fail now
self.assertRaises(ClientDisconnected, self._storage.history, ZERO, with short_timeout(self):
timeout=1) self.assertRaises(ClientDisconnected, self._storage.history, ZERO)
# Restart the server, this time read-write # Restart the server, this time read-write
self.startServer(create=0, keep=0) self.startServer(create=0, keep=0)
...@@ -881,6 +870,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -881,6 +870,7 @@ class ReconnectionTests(CommonSetupTearDown):
data = zodb_pickle(MinPO(oid)) data = zodb_pickle(MinPO(oid))
self._storage.store(oid, None, data, '', txn) self._storage.store(oid, None, data, '', txn)
self.shutdownServer() self.shutdownServer()
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._storage.tpc_vote, txn) self.assertRaises(ClientDisconnected, self._storage.tpc_vote, txn)
self.startServer(create=0) self.startServer(create=0)
self._storage.tpc_abort(txn) self._storage.tpc_abort(txn)
...@@ -967,11 +957,12 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -967,11 +957,12 @@ class TimeoutTests(CommonSetupTearDown):
timeout = 1 timeout = 1
def checkTimeout(self): def checkTimeout(self):
storage = self.openClientStorage() self._storage = storage = self.openClientStorage()
txn = Transaction() txn = Transaction()
storage.tpc_begin(txn) storage.tpc_begin(txn)
storage.tpc_vote(txn) storage.tpc_vote(txn)
time.sleep(2) time.sleep(2)
with short_timeout(self):
self.assertRaises(ClientDisconnected, storage.tpc_finish, txn) self.assertRaises(ClientDisconnected, storage.tpc_finish, txn)
# Make sure it's logged as CRITICAL # Make sure it's logged as CRITICAL
...@@ -1188,6 +1179,14 @@ class MSTThread(threading.Thread): ...@@ -1188,6 +1179,14 @@ class MSTThread(threading.Thread):
except: except:
pass pass
@contextlib.contextmanager
def short_timeout(self):
old = self._storage._server.timeout
self._storage._server.timeout = 1
yield
self._storage._server.timeout = old
# Run IPv6 tests if V6 sockets are supported # Run IPv6 tests if V6 sockets are supported
try: try:
socket.socket(socket.AF_INET6, socket.SOCK_STREAM) socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
......
...@@ -20,7 +20,7 @@ verification is no longer supported. ...@@ -20,7 +20,7 @@ verification is no longer supported.
Here's an example that shows that this is actually what happens. Here's an example that shows that this is actually what happens.
Start a server, create a cient to it and commit some data Start a server, create a client to it and commit some data
>>> addr, admin = start_server(keep=1) >>> addr, admin = start_server(keep=1)
>>> import ZEO, transaction >>> import ZEO, transaction
...@@ -57,6 +57,7 @@ logging and event data: ...@@ -57,6 +57,7 @@ logging and event data:
... 'ZEO', level=logging.ERROR) ... 'ZEO', level=logging.ERROR)
>>> events = [] >>> events = []
>>> def event_handler(e): >>> def event_handler(e):
... if hasattr(e, 'storage'):
... events.append(( ... events.append((
... len(e.storage._server.client.cache), str(handler), e.__class__.__name__)) ... len(e.storage._server.client.cache), str(handler), e.__class__.__name__))
...@@ -70,7 +71,8 @@ is generated before the cache is dropped or the message is logged. ...@@ -70,7 +71,8 @@ is generated before the cache is dropped or the message is logged.
Now, we'll restart the server on the original address: Now, we'll restart the server on the original address:
>>> _, admin = start_server(zeo_conf=dict(invalidation_queue_size=1), >>> _, admin = start_server(zeo_conf=dict(invalidation_queue_size=1),
... addr=addr, keep=1) ... addr=addr, keep=1, threaded=True)
>>> wait_connected(db.storage) >>> wait_connected(db.storage)
Now, let's verify our assertions above: Now, let's verify our assertions above:
......
...@@ -42,11 +42,7 @@ class ZEOConfig: ...@@ -42,11 +42,7 @@ class ZEOConfig:
self.read_only = None self.read_only = None
self.invalidation_queue_size = None self.invalidation_queue_size = None
self.invalidation_age = None self.invalidation_age = None
self.monitor_address = None
self.transaction_timeout = None self.transaction_timeout = None
self.authentication_protocol = None
self.authentication_database = None
self.authentication_realm = None
self.loglevel = 'INFO' self.loglevel = 'INFO'
def dump(self, f): def dump(self, f):
...@@ -55,19 +51,12 @@ class ZEOConfig: ...@@ -55,19 +51,12 @@ class ZEOConfig:
if self.read_only is not None: if self.read_only is not None:
print("read-only", self.read_only and "true" or "false", file=f) print("read-only", self.read_only and "true" or "false", file=f)
if self.invalidation_queue_size is not None: if self.invalidation_queue_size is not None:
print("invalidation-queue-size", self.invalidation_queue_size, file=f) print("invalidation-queue-size",
self.invalidation_queue_size, file=f)
if self.invalidation_age is not None: if self.invalidation_age is not None:
print("invalidation-age", self.invalidation_age, file=f) print("invalidation-age", self.invalidation_age, file=f)
if self.monitor_address is not None:
print("monitor-address %s:%s" % self.monitor_address, file=f)
if self.transaction_timeout is not None: if self.transaction_timeout is not None:
print("transaction-timeout", self.transaction_timeout, file=f) print("transaction-timeout", self.transaction_timeout, file=f)
if self.authentication_protocol is not None:
print("authentication-protocol", self.authentication_protocol, file=f)
if self.authentication_database is not None:
print("authentication-database", self.authentication_database, file=f)
if self.authentication_realm is not None:
print("authentication-realm", self.authentication_realm, file=f)
print("</zeo>", file=f) print("</zeo>", file=f)
print(""" print("""
...@@ -93,10 +82,81 @@ def encode_format(fmt): ...@@ -93,10 +82,81 @@ def encode_format(fmt):
fmt = fmt.replace(*xform) fmt = fmt.replace(*xform)
return fmt return fmt
def runner(config, qin, qout, timeout=None,
join_timeout=9, debug=False, name=None,
keep=False, protocol=None):
if debug:
debug_logging()
old_protocol = None
if protocol:
import ZEO.asyncio.server
old_protocol = ZEO.asyncio.server.best_protocol_version
ZEO.asyncio.server.best_protocol_version = protocol
try:
import ZEO.runzeo, threading
from six.moves.queue import Empty
options = ZEO.runzeo.ZEOOptions()
options.realize(['-C', config])
server = ZEO.runzeo.ZEOServer(options)
server.open_storages()
server.clear_socket()
server.create_server()
logger.debug('SERVER CREATED')
qout.put(server.server.acceptor.addr)
logger.debug('ADDRESS SENT')
thread = threading.Thread(
target=server.server.loop,
name = None if name is None else name + '-server',
)
thread.setDaemon(True)
thread.start()
try:
qin.get(timeout=timeout)
except Empty:
pass
server.server.close()
thread.join(join_timeout)
if not keep:
# Try to cleanup storage files
for storage in server.server.storages.values():
try:
storage.cleanup()
except AttributeError:
pass
qout.put('stopped')
if hasattr(qout, 'close'):
qout.close()
qout.join_thread()
except Exception:
logger.exception("In server thread")
finally:
if old_protocol:
ZEO.asyncio.server.best_protocol_version = protocol
def stop_runner(thread, config, qin, qout, stop_timeout=9, pid=None):
qin.put('stop')
if hasattr(qin, 'close'):
qin.close()
qin.join_thread()
qout.get(timeout=stop_timeout)
thread.join(stop_timeout)
os.remove(config)
def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
path='Data.fs', protocol=None, blob_dir=None, path='Data.fs', protocol=None, blob_dir=None,
suicide=True, debug=False): suicide=True, debug=False,
threaded=False, start_timeout=150, name=None,
):
"""Start a ZEO server in a separate process. """Start a ZEO server in a separate process.
Takes two positional arguments a string containing the storage conf Takes two positional arguments a string containing the storage conf
...@@ -118,7 +178,6 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, ...@@ -118,7 +178,6 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
if isinstance(port, int): if isinstance(port, int):
addr = 'localhost', port addr = 'localhost', port
adminaddr = 'localhost', port+1
else: else:
addr = port addr = port
adminaddr = port+'-test' adminaddr = port+'-test'
...@@ -136,59 +195,29 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, ...@@ -136,59 +195,29 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
fp.write(storage_conf) fp.write(storage_conf)
fp.close() fp.close()
# Find the zeoserver script if threaded:
import ZEO.tests.zeoserver from threading import Thread
script = ZEO.tests.zeoserver.__file__ from six.moves.queue import Queue
if script.endswith('.pyc'):
script = script[:-1]
# Create a list of arguments, which we'll tuplify below
qa = _quote_arg
args = [qa(sys.executable), qa(script), '-C', qa(tmpfile)]
if keep:
args.append("-k")
if debug:
args.append("-d")
if not suicide:
args.append("-S")
if protocol:
args.extend(["-v", protocol])
d = os.environ.copy()
d['PYTHONPATH'] = os.pathsep.join(sys.path)
if sys.platform.startswith('win'):
pid = os.spawnve(os.P_NOWAIT, sys.executable, tuple(args), d)
else:
pid = subprocess.Popen(args, env=d, close_fds=True).pid
# We need to wait until the server starts, but not forever. 150
# seconds is a somewhat arbitrary upper bound, but probably helps
# in an address already in use situation.
for i in range(1500):
time.sleep(0.1)
try:
if isinstance(adminaddr, str) and not os.path.exists(adminaddr):
continue
logger.debug('connect %s', i)
if isinstance(adminaddr, str):
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else: else:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) from multiprocessing import Process as Thread
s.connect(adminaddr) from multiprocessing import Queue
ack = s.recv(1024)
s.close() qin = Queue()
logging.debug('acked: %s' % ack) qout = Queue()
break thread = Thread(
except socket.error as e: target=runner,
if e.args[0] not in (errno.ECONNREFUSED, errno.ECONNRESET): args=[tmpfile, qin, qout, 999 if suicide else None],
raise kwargs=dict(debug=debug, name=name, protocol=protocol, keep=keep),
s.close() name = None if name is None else name + '-server-runner',
else: )
logging.debug('boo hoo') thread.daemon = True
raise RuntimeError("Failed to start server") thread.start()
return addr, adminaddr, pid, tmpfile addr = qout.get(timeout=start_timeout)
def stop(stop_timeout=9):
stop_runner(thread, tmpfile, qin, qout, stop_timeout)
return addr, stop
if sys.platform[:3].lower() == "win": if sys.platform[:3].lower() == "win":
def _quote_arg(s): def _quote_arg(s):
...@@ -197,40 +226,8 @@ else: ...@@ -197,40 +226,8 @@ else:
def _quote_arg(s): def _quote_arg(s):
return s return s
def shutdown_zeo_server(stop):
def shutdown_zeo_server(adminaddr): stop()
# Do this in a loop to guard against the possibility that the
# client failed to connect to the adminaddr earlier. That really
# only requires two iterations, but do a third for pure
# superstition.
for i in range(3):
if isinstance(adminaddr, str):
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(.3)
try:
s.connect(adminaddr)
except socket.timeout:
# On FreeBSD 5.3 the connection just timed out
if i > 0:
break
raise
except socket.error as e:
if (e.args[0] == errno.ECONNREFUSED
or
# MAC OS X uses EINVAL when connecting to a port
# that isn't being listened on.
(sys.platform == 'darwin' and e.args[0] == errno.EINVAL)
) and i > 0:
break
raise
try:
ack = s.recv(1024)
except socket.error as e:
ack = 'no ack received'
logger.debug('shutdown_zeo_server(): acked: %s' % ack)
s.close()
def get_port(test=None): def get_port(test=None):
"""Return a port that is not in use. """Return a port that is not in use.
...@@ -311,11 +308,11 @@ def can_connect(port): ...@@ -311,11 +308,11 @@ def can_connect(port):
def setUp(test): def setUp(test):
ZODB.tests.util.setUp(test) ZODB.tests.util.setUp(test)
servers = {} servers = []
def start_server(storage_conf=None, zeo_conf=None, port=None, keep=False, def start_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
addr=None, path='Data.fs', protocol=None, blob_dir=None, addr=None, path='Data.fs', protocol=None, blob_dir=None,
suicide=True, debug=False): suicide=True, debug=False, **kw):
"""Start a ZEO server. """Start a ZEO server.
Return the server and admin addresses. Return the server and admin addresses.
...@@ -327,12 +324,19 @@ def setUp(test): ...@@ -327,12 +324,19 @@ def setUp(test):
port = addr[1] port = addr[1]
elif addr is not None: elif addr is not None:
raise TypeError("Can't specify port and addr") raise TypeError("Can't specify port and addr")
addr, adminaddr, pid, config_path = start_zeo_server( addr, stop = start_zeo_server(
storage_conf, zeo_conf, port, keep, path, protocol, blob_dir, storage_conf=storage_conf,
suicide, debug) zeo_conf=zeo_conf,
os.remove(config_path) port=port,
servers[adminaddr] = pid keep=keep,
return addr, adminaddr path=path,
protocol=protocol,
blob_dir=blob_dir,
suicide=suicide,
debug=debug,
**kw)
servers.append(stop)
return addr, stop
test.globs['start_server'] = start_server test.globs['start_server'] = start_server
...@@ -341,16 +345,15 @@ def setUp(test): ...@@ -341,16 +345,15 @@ def setUp(test):
test.globs['get_port'] = get_port test.globs['get_port'] = get_port
def stop_server(adminaddr): def stop_server(stop):
pid = servers.pop(adminaddr) stop()
shutdown_zeo_server(adminaddr) servers.remove(stop)
os.waitpid(pid, 0)
test.globs['stop_server'] = stop_server test.globs['stop_server'] = stop_server
def cleanup_servers(): def cleanup_servers():
for adminaddr in list(servers): for stop in list(servers):
stop_server(adminaddr) stop()
zope.testing.setupstack.register(test, cleanup_servers) zope.testing.setupstack.register(test, cleanup_servers)
...@@ -400,3 +403,4 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG): ...@@ -400,3 +403,4 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG):
logger.setLevel(logging.NOTSET) logger.setLevel(logging.NOTSET)
return stop return stop
...@@ -5,7 +5,7 @@ A full test of all protocols isn't practical. But we'll do a limited ...@@ -5,7 +5,7 @@ A full test of all protocols isn't practical. But we'll do a limited
test that at least the current and previous protocols are supported in test that at least the current and previous protocols are supported in
both directions. both directions.
Let's start a Z308 server Let's start a Z309 server
>>> storage_conf = ''' >>> storage_conf = '''
... <blobstorage> ... <blobstorage>
...@@ -16,8 +16,8 @@ Let's start a Z308 server ...@@ -16,8 +16,8 @@ Let's start a Z308 server
... </blobstorage> ... </blobstorage>
... ''' ... '''
>>> addr, admin = start_server( >>> addr, stop = start_server(
... storage_conf, dict(invalidation_queue_size=5), protocol=b'Z309') ... storage_conf, dict(invalidation_queue_size=5), protocol=b'Z4')
A current client should be able to connect to a old server: A current client should be able to connect to a old server:
...@@ -25,7 +25,7 @@ A current client should be able to connect to a old server: ...@@ -25,7 +25,7 @@ A current client should be able to connect to a old server:
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
>>> wait_connected(db.storage) >>> wait_connected(db.storage)
>>> db.storage.protocol_version >>> db.storage.protocol_version
b'Z309' b'Z4'
>>> conn = db.open() >>> conn = db.open()
>>> conn.root().x = 0 >>> conn.root().x = 0
...@@ -87,7 +87,7 @@ A current client should be able to connect to a old server: ...@@ -87,7 +87,7 @@ A current client should be able to connect to a old server:
>>> db2.close() >>> db2.close()
>>> db.close() >>> db.close()
>>> stop_server(admin) >>> stop_server(stop)
>>> import os, zope.testing.setupstack >>> import os, zope.testing.setupstack
>>> os.remove('client-1.zec') >>> os.remove('client-1.zec')
...@@ -102,11 +102,11 @@ Note that we'll have to pull some hijinks: ...@@ -102,11 +102,11 @@ Note that we'll have to pull some hijinks:
>>> import ZEO.asyncio.client >>> import ZEO.asyncio.client
>>> old_protocols = ZEO.asyncio.client.Protocol.protocols >>> old_protocols = ZEO.asyncio.client.Protocol.protocols
>>> ZEO.asyncio.client.Protocol.protocols = [b'Z309'] >>> ZEO.asyncio.client.Protocol.protocols = [b'Z4']
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
>>> db.storage.protocol_version >>> db.storage.protocol_version
b'Z309' b'Z4'
>>> wait_connected(db.storage) >>> wait_connected(db.storage)
>>> conn = db.open() >>> conn = db.open()
>>> conn.root().x = 0 >>> conn.root().x = 0
......
...@@ -30,9 +30,8 @@ from __future__ import print_function ...@@ -30,9 +30,8 @@ from __future__ import print_function
# Here, we'll try to provide some testing infrastructure to isolate # Here, we'll try to provide some testing infrastructure to isolate
# servers from the network. # servers from the network.
import ZEO.asyncio.tests
import ZEO.StorageServer import ZEO.StorageServer
import ZEO.zrpc.connection
import ZEO.zrpc.error
import ZODB.MappingStorage import ZODB.MappingStorage
class StorageServer(ZEO.StorageServer.StorageServer): class StorageServer(ZEO.StorageServer.StorageServer):
...@@ -42,44 +41,10 @@ class StorageServer(ZEO.StorageServer.StorageServer): ...@@ -42,44 +41,10 @@ class StorageServer(ZEO.StorageServer.StorageServer):
storages = {'1': ZODB.MappingStorage.MappingStorage()} storages = {'1': ZODB.MappingStorage.MappingStorage()}
ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw) ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw)
def client(server, name='client'):
class DispatcherClass:
__init__ = lambda *a, **kw: None
class socket:
getsockname = staticmethod(lambda : 'socket')
class Connection:
peer_protocol_version = ZEO.zrpc.connection.Connection.current_protocol
connected = True
def __init__(self, name='connection', addr=''):
name = str(name)
self.name = name
self.addr = addr or 'test-addr-'+name
def close(self):
print(self.name, 'closed')
self.connected = False
def poll(self):
if not self.connected:
raise ZEO.zrpc.error.DisconnectedError()
def callAsync(self, meth, *args):
print(self.name, 'callAsync', meth, repr(args))
callAsyncNoPoll = callAsync
def call_from_thread(self, *args):
if args:
args[0](*args[1:])
def send_reply(self, *args):
pass
def client(server, name='client', addr=''):
zs = ZEO.StorageServer.ZEOStorage(server) zs = ZEO.StorageServer.ZEOStorage(server)
zs.notifyConnected(Connection(name, addr)) protocol = ZEO.asyncio.tests.server_protocol(
zs, protocol_version=b'Z5', addr='test-addr-%s' % name)
zs.notify_connected(protocol)
zs.register('1', 0) zs.register('1', 0)
return zs return zs
...@@ -28,8 +28,6 @@ else: ...@@ -28,8 +28,6 @@ else:
import doctest import doctest
import unittest import unittest
import ZEO.tests.forker import ZEO.tests.forker
import ZEO.tests.testMonitor
import ZEO.zrpc.connection
import ZODB.tests.util import ZODB.tests.util
class FileStorageConfig: class FileStorageConfig:
...@@ -90,41 +88,6 @@ class MappingStorageTimeoutTests( ...@@ -90,41 +88,6 @@ class MappingStorageTimeoutTests(
): ):
pass pass
class MonitorTests(ZEO.tests.testMonitor.MonitorTests):
def check_connection_management(self):
# Open and close a few connections, making sure that
# the resulting number of clients is 0.
s1 = self.openClientStorage()
s2 = self.openClientStorage()
s3 = self.openClientStorage()
stats = self.parse(self.get_monitor_output())[1]
self.assertEqual(stats.clients, 3)
s1.close()
s3.close()
s2.close()
ZEO.tests.forker.wait_until(
"Number of clients shown in monitor drops to 0",
lambda :
self.parse(self.get_monitor_output())[1].clients == 0
)
def check_connection_management_with_old_client(self):
# Check that connection management works even when using an
# older protcool that requires a connection adapter.
test_protocol = b"Z303"
current_protocol = ZEO.zrpc.connection.Connection.current_protocol
ZEO.zrpc.connection.Connection.current_protocol = test_protocol
ZEO.zrpc.connection.Connection.servers_we_can_talk_to.append(
test_protocol)
try:
self.check_connection_management()
finally:
ZEO.zrpc.connection.Connection.current_protocol = current_protocol
ZEO.zrpc.connection.Connection.servers_we_can_talk_to.pop()
test_classes = [FileStorageConnectionTests, test_classes = [FileStorageConnectionTests,
FileStorageReconnectionTests, FileStorageReconnectionTests,
...@@ -132,7 +95,6 @@ test_classes = [FileStorageConnectionTests, ...@@ -132,7 +95,6 @@ test_classes = [FileStorageConnectionTests,
FileStorageTimeoutTests, FileStorageTimeoutTests,
MappingStorageConnectionTests, MappingStorageConnectionTests,
MappingStorageTimeoutTests, MappingStorageTimeoutTests,
MonitorTests,
] ]
def invalidations_while_connecting(): def invalidations_while_connecting():
......
...@@ -52,6 +52,10 @@ class FakeServer: ...@@ -52,6 +52,10 @@ class FakeServer:
def register_connection(*args): def register_connection(*args):
return None, None return None, None
class FakeConnection:
protocol_version = b'Z4'
addr = 'test'
def test_server_record_iternext(): def test_server_record_iternext():
""" """
...@@ -61,6 +65,7 @@ underlying storage. ...@@ -61,6 +65,7 @@ underlying storage.
>>> import ZEO.StorageServer >>> import ZEO.StorageServer
>>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False) >>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False)
>>> zeo.notify_connected(FakeConnection())
>>> zeo.register('1', False) >>> zeo.register('1', False)
>>> next = None >>> next = None
...@@ -80,6 +85,7 @@ The storage info also reflects the fact that record_iternext is supported. ...@@ -80,6 +85,7 @@ The storage info also reflects the fact that record_iternext is supported.
True True
>>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False) >>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False)
>>> zeo.notify_connected(FakeConnection())
>>> zeo.register('2', False) >>> zeo.register('2', False)
>>> zeo.get_info()['supports_record_iternext'] >>> zeo.get_info()['supports_record_iternext']
...@@ -129,41 +135,6 @@ Now we'll have our way with it's private _server attr: ...@@ -129,41 +135,6 @@ Now we'll have our way with it's private _server attr:
""" """
def history_to_version_compatible_storage():
"""
Some storages work under ZODB <= 3.8 and ZODB >= 3.9.
This means they have a history method that accepts a version parameter:
>>> class VersionCompatibleStorage(FakeStorageBase):
... def history(self,oid,version='',size=1):
... return oid,version,size
A ZEOStorage such as the following should support this type of storage:
>>> class OurFakeServer(FakeServer):
... storages = {'1':VersionCompatibleStorage()}
>>> import ZEO.StorageServer
>>> zeo = ZEO.StorageServer.ZEOStorage(OurFakeServer(), False)
>>> zeo.register('1', False)
The ZEOStorage should sort out the following call such that the storage gets
the correct parameters and so should return the parameters it was called with:
>>> zeo.history('oid',99)
('oid', '', 99)
The same problem occurs when a Z308 client connects to a Z309 server,
but different code is executed:
>>> from ZEO.StorageServer import ZEOStorage308Adapter
>>> zeo = ZEOStorage308Adapter(VersionCompatibleStorage())
The history method should still return the parameters it was called with:
>>> zeo.history('oid','',99)
('oid', '', 99)
"""
def test_suite(): def test_suite():
return doctest.DocTestSuite() return doctest.DocTestSuite()
......
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Test that the monitor produce sensible results.
$Id$
"""
import socket
import unittest
from ZEO.tests.ConnectionTests import CommonSetupTearDown
from ZEO.monitor import StorageStats
class MonitorTests(CommonSetupTearDown):
monitor = 1
def get_monitor_output(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(('localhost', 42000))
L = []
while 1:
buf = s.recv(8192)
if buf:
L.append(buf)
else:
break
s.close()
return b"".join(L).decode('ascii')
def parse(self, s):
# Return a list of StorageStats, one for each storage.
lines = s.split("\n")
self.assert_(lines[0].startswith("ZEO monitor server"))
# lines[1] is a date
# Break up rest of lines into sections starting with Storage:
# and ending with a blank line.
sections = []
cur = None
for line in lines[2:]:
if line.startswith("Storage:"):
cur = [line]
elif line:
cur.append(line)
else:
if cur is not None:
sections.append(cur)
cur = None
assert cur is None # bug in the test code if this fails
d = {}
for sect in sections:
hdr = sect[0]
key, value = hdr.split(":")
storage = int(value)
s = d[storage] = StorageStats()
s.parse("\n".join(sect[1:]))
return d
def getConfig(self, path, create, read_only):
return """<mappingstorage 1/>"""
def testMonitor(self):
# Just open a client to know that the server is up and running
# TODO: should put this in setUp.
self.storage = self.openClientStorage()
s = self.get_monitor_output()
self.storage.close()
self.assert_(s.find("monitor") != -1)
d = self.parse(s)
stats = d[1]
self.assertEqual(stats.clients, 1)
self.assertEqual(stats.commits, 0)
def test_suite():
return unittest.makeSuite(MonitorTests)
...@@ -16,11 +16,10 @@ from __future__ import print_function ...@@ -16,11 +16,10 @@ from __future__ import print_function
import multiprocessing import multiprocessing
import re import re
from ZEO.ClientStorage import ClientStorage from ZEO.ClientStorage import ClientStorage, m64
from ZEO.tests.forker import get_port from ZEO.tests.forker import get_port
from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests
from ZEO.tests import IterationTests from ZEO.tests import IterationTests
from ZEO.zrpc.error import DisconnectedError
from ZEO._compat import PY3 from ZEO._compat import PY3
from ZODB.tests import StorageTestBase, BasicStorage, \ from ZODB.tests import StorageTestBase, BasicStorage, \
TransactionalUndoStorage, \ TransactionalUndoStorage, \
...@@ -48,7 +47,6 @@ import transaction ...@@ -48,7 +47,6 @@ import transaction
import unittest import unittest
import ZEO.StorageServer import ZEO.StorageServer
import ZEO.tests.ConnectionTests import ZEO.tests.ConnectionTests
import ZEO.zrpc.connection
import ZODB import ZODB
import ZODB.blob import ZODB.blob
import ZODB.tests.hexstorage import ZODB.tests.hexstorage
...@@ -168,11 +166,9 @@ class GenericTests( ...@@ -168,11 +166,9 @@ class GenericTests(
logger.info("setUp() %s", self.id()) logger.info("setUp() %s", self.id())
port = get_port(self) port = get_port(self)
zconf = forker.ZEOConfig(('', port)) zconf = forker.ZEOConfig(('', port))
zport, adminaddr, pid, path = forker.start_zeo_server(self.getConfig(), zport, stop = forker.start_zeo_server(self.getConfig(),
zconf, port) zconf, port)
self._pids = [pid] self._servers = [stop]
self._servers = [adminaddr]
self._conf_path = path
if not self.blob_cache_dir: if not self.blob_cache_dir:
# This is the blob cache for ClientStorage # This is the blob cache for ClientStorage
self.blob_cache_dir = tempfile.mkdtemp( self.blob_cache_dir = tempfile.mkdtemp(
...@@ -190,12 +186,8 @@ class GenericTests( ...@@ -190,12 +186,8 @@ class GenericTests(
def tearDown(self): def tearDown(self):
self._storage.close() self._storage.close()
for server in self._servers: for stop in self._servers:
forker.shutdown_zeo_server(server) stop()
if hasattr(os, 'waitpid'):
# Not in Windows Python until 2.3
for pid in self._pids:
os.waitpid(pid, 0)
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
def runTest(self): def runTest(self):
...@@ -278,10 +270,9 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase, ...@@ -278,10 +270,9 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase,
def _new_storage(self): def _new_storage(self):
port = get_port(self) port = get_port(self)
zconf = forker.ZEOConfig(('', port)) zconf = forker.ZEOConfig(('', port))
zport, adminaddr, pid, path = forker.start_zeo_server(self.getConfig(), zport, stop = forker.start_zeo_server(self.getConfig(),
zconf, port) zconf, port)
self._pids.append(pid) self._servers.append(stop)
self._servers.append(adminaddr)
blob_cache_dir = tempfile.mkdtemp(dir='.') blob_cache_dir = tempfile.mkdtemp(dir='.')
...@@ -294,7 +285,6 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase, ...@@ -294,7 +285,6 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase,
def setUp(self): def setUp(self):
StorageTestBase.StorageTestBase.setUp(self) StorageTestBase.StorageTestBase.setUp(self)
self._pids = []
self._servers = [] self._servers = []
self._storage = self._new_storage() self._storage = self._new_storage()
...@@ -304,12 +294,8 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase, ...@@ -304,12 +294,8 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase,
self._storage.close() self._storage.close()
self._dst.close() self._dst.close()
for server in self._servers: for stop in self._servers:
forker.shutdown_zeo_server(server) stop()
if hasattr(os, 'waitpid'):
# Not in Windows Python until 2.3
for pid in self._pids:
os.waitpid(pid, 0)
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
def new_dest(self): def new_dest(self):
...@@ -708,27 +694,23 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests): ...@@ -708,27 +694,23 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests):
class FauxConn: class FauxConn:
addr = 'x' addr = 'x'
peer_protocol_version = ZEO.zrpc.connection.Connection.current_protocol protocol_version = ZEO.asyncio.server.best_protocol_version
peer_protocol_version = protocol_version
class StorageServerClientWrapper: serials = []
def async(self, method, *args):
if method == 'serialnos':
self.serials.extend(args[0])
def __init__(self): call_soon_threadsafe = async
self.serials = []
def serialnos(self, serials):
self.serials.extend(serials)
def info(self, info):
pass
class StorageServerWrapper: class StorageServerWrapper:
def __init__(self, server, storage_id): def __init__(self, server, storage_id):
self.storage_id = storage_id self.storage_id = storage_id
self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only) self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only)
self.server.notifyConnected(FauxConn()) self.server.notify_connected(FauxConn())
self.server.register(storage_id, False) self.server.register(storage_id, False)
self.server.client = StorageServerClientWrapper()
def sortKey(self): def sortKey(self):
return self.storage_id return self.storage_id
...@@ -751,8 +733,8 @@ class StorageServerWrapper: ...@@ -751,8 +733,8 @@ class StorageServerWrapper:
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
vote_result = self.server.vote(id(transaction)) vote_result = self.server.vote(id(transaction))
assert vote_result is None assert vote_result is None
result = self.server.client.serials[:] result = self.server.connection.serials[:]
del self.server.client.serials[:] del self.server.connection.serials[:]
return result return result
def store(self, oid, serial, data, version_ignored, transaction): def store(self, oid, serial, data, version_ignored, transaction):
...@@ -838,7 +820,7 @@ Now we'll open a storage server on the data, simulating a restart: ...@@ -838,7 +820,7 @@ Now we'll open a storage server on the data, simulating a restart:
>>> fs = FileStorage('t.fs') >>> fs = FileStorage('t.fs')
>>> sv = StorageServer(('', get_port()), dict(fs=fs)) >>> sv = StorageServer(('', get_port()), dict(fs=fs))
>>> s = ZEOStorage(sv, sv.read_only) >>> s = ZEOStorage(sv, sv.read_only)
>>> s.notifyConnected(FauxConn()) >>> s.notify_connected(FauxConn())
>>> s.register('fs', False) >>> s.register('fs', False)
If we ask for the last transaction, we should get the last transaction If we ask for the last transaction, we should get the last transaction
...@@ -848,7 +830,7 @@ we saved: ...@@ -848,7 +830,7 @@ we saved:
True True
If a storage implements the method lastInvalidations, as FileStorage If a storage implements the method lastInvalidations, as FileStorage
does, then the stroage server will populate its invalidation data does, then the storage server will populate its invalidation data
structure using lastTransactions. structure using lastTransactions.
...@@ -1085,7 +1067,7 @@ def runzeo_without_configfile(): ...@@ -1085,7 +1067,7 @@ def runzeo_without_configfile():
------ ------
--T INFO ZEO.StorageServer StorageServer created RW with storages 1RWt --T INFO ZEO.StorageServer StorageServer created RW with storages 1RWt
------ ------
--T INFO ZEO.zrpc () listening on ... --T INFO ZEO.acceptor listening on ...
------ ------
--T INFO ZEO.StorageServer closing storage '1' --T INFO ZEO.StorageServer closing storage '1'
testing exit immediately testing exit immediately
...@@ -1150,7 +1132,6 @@ def test_server_status(): ...@@ -1150,7 +1132,6 @@ def test_server_status():
'start': 'Tue May 4 10:55:20 2010', 'start': 'Tue May 4 10:55:20 2010',
'stores': 1, 'stores': 1,
'timeout-thread-is-alive': True, 'timeout-thread-is-alive': True,
'verifying_clients': 0,
'waiting': 0} 'waiting': 0}
>>> db.close() >>> db.close()
...@@ -1169,7 +1150,8 @@ def test_ruok(): ...@@ -1169,7 +1150,8 @@ def test_ruok():
>>> _ = writer.write(struct.pack(">I", 4)+b"ruok") >>> _ = writer.write(struct.pack(">I", 4)+b"ruok")
>>> writer.close() >>> writer.close()
>>> proto = s.recv(struct.unpack(">I", s.recv(4))[0]) >>> proto = s.recv(struct.unpack(">I", s.recv(4))[0])
>>> data = json.loads(s.recv(struct.unpack(">I", s.recv(4))[0]).decode("ascii")) >>> data = json.loads(
... s.recv(struct.unpack(">I", s.recv(4))[0]).decode("ascii"))
>>> pprint.pprint(data['1']) >>> pprint.pprint(data['1'])
{u'aborts': 0, {u'aborts': 0,
u'active_txns': 0, u'active_txns': 0,
...@@ -1183,7 +1165,6 @@ def test_ruok(): ...@@ -1183,7 +1165,6 @@ def test_ruok():
u'start': u'Sun Jan 4 09:37:03 2015', u'start': u'Sun Jan 4 09:37:03 2015',
u'stores': 1, u'stores': 1,
u'timeout-thread-is-alive': True, u'timeout-thread-is-alive': True,
u'verifying_clients': 0,
u'waiting': 0} u'waiting': 0}
>>> db.close(); s.close() >>> db.close(); s.close()
""" """
...@@ -1410,7 +1391,7 @@ Now we'll try to use the connection, mainly to wait for everything to ...@@ -1410,7 +1391,7 @@ Now we'll try to use the connection, mainly to wait for everything to
get processed. Before we fixed this by making tpc_finish a synchronous get processed. Before we fixed this by making tpc_finish a synchronous
call to the server. we'd get some sort of error here. call to the server. we'd get some sort of error here.
>>> _ = client._call('loadEx', b'\0'*8) >>> _ = client._call('loadBefore', b'\0'*8, m64)
>>> c.close() >>> c.close()
...@@ -1519,7 +1500,7 @@ class ServerManagingClientStorage(ClientStorage): ...@@ -1519,7 +1500,7 @@ class ServerManagingClientStorage(ClientStorage):
server_blob_dir = 'server-'+blob_dir server_blob_dir = 'server-'+blob_dir
self.globs = {} self.globs = {}
port = forker.get_port2(self) port = forker.get_port2(self)
addr, admin, pid, config = forker.start_zeo_server( addr, stop = forker.start_zeo_server(
""" """
<blobstorage> <blobstorage>
blob-dir %s blob-dir %s
...@@ -1531,10 +1512,7 @@ class ServerManagingClientStorage(ClientStorage): ...@@ -1531,10 +1512,7 @@ class ServerManagingClientStorage(ClientStorage):
""" % (server_blob_dir, name+'.fs', extrafsoptions), """ % (server_blob_dir, name+'.fs', extrafsoptions),
port=port, port=port,
) )
os.remove(config) zope.testing.setupstack.register(self, stop)
zope.testing.setupstack.register(self, os.waitpid, pid, 0)
zope.testing.setupstack.register(
self, forker.shutdown_zeo_server, admin)
if shared: if shared:
ClientStorage.__init__(self, addr, blob_dir=blob_dir, ClientStorage.__init__(self, addr, blob_dir=blob_dir,
shared_blob_dir=True) shared_blob_dir=True)
......
...@@ -33,7 +33,7 @@ def proper_handling_of_blob_conflicts(): ...@@ -33,7 +33,7 @@ def proper_handling_of_blob_conflicts():
Conflict errors weren't properly handled when storing blobs, the Conflict errors weren't properly handled when storing blobs, the
result being that the storage was left in a transaction. result being that the storage was left in a transaction.
We originally saw this when restarting a block transaction, although We originally saw this when restarting a blob transaction, although
it doesn't really matter. it doesn't really matter.
Set up the storage with some initial blob data. Set up the storage with some initial blob data.
...@@ -44,7 +44,7 @@ Set up the storage with some initial blob data. ...@@ -44,7 +44,7 @@ Set up the storage with some initial blob data.
>>> conn.root.b = ZODB.blob.Blob(b'x') >>> conn.root.b = ZODB.blob.Blob(b'x')
>>> transaction.commit() >>> transaction.commit()
Get the iod and first serial. We'll use the serial later to provide Get the oid and first serial. We'll use the serial later to provide
out-of-date data. out-of-date data.
>>> oid = conn.root.b._p_oid >>> oid = conn.root.b._p_oid
...@@ -60,22 +60,15 @@ Create the server: ...@@ -60,22 +60,15 @@ Create the server:
And an initial client. And an initial client.
>>> zs1 = ZEO.StorageServer.ZEOStorage(server) >>> zs1 = ZEO.tests.servertesting.client(server, 1)
>>> conn1 = ZEO.tests.servertesting.Connection(1)
>>> zs1.notifyConnected(conn1)
>>> zs1.register('1', 0)
>>> zs1.tpc_begin('0', '', '', {}) >>> zs1.tpc_begin('0', '', '', {})
>>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, b'x', '0') >>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, b'x', '0')
>>> _ = zs1.vote('0') # doctest: +ELLIPSIS >>> _ = zs1.vote('0') # doctest: +ELLIPSIS
1 callAsync serialnos ...
In a second client, we'll try to commit using the old serial. This In a second client, we'll try to commit using the old serial. This
will conflict. It will be blocked at the vote call. will conflict. It will be blocked at the vote call.
>>> zs2 = ZEO.StorageServer.ZEOStorage(server) >>> zs2 = ZEO.tests.servertesting.client(server, 2)
>>> conn2 = ZEO.tests.servertesting.Connection(2)
>>> zs2.notifyConnected(conn2)
>>> zs2.register('1', 0)
>>> zs2.tpc_begin('1', '', '', {}) >>> zs2.tpc_begin('1', '', '', {})
>>> zs2.storeBlobStart() >>> zs2.storeBlobStart()
>>> zs2.storeBlobChunk(b'z') >>> zs2.storeBlobChunk(b'z')
...@@ -97,12 +90,11 @@ client will be restarted. It will get a conflict error, that is ...@@ -97,12 +90,11 @@ client will be restarted. It will get a conflict error, that is
handled correctly: handled correctly:
>>> zs1.tpc_abort('0') # doctest: +ELLIPSIS >>> zs1.tpc_abort('0') # doctest: +ELLIPSIS
2 callAsync serialnos ...
reply 1 None reply 1 None
>>> fs.tpc_transaction() is not None >>> fs.tpc_transaction() is not None
True True
>>> conn2.connected >>> zs2.connected
True True
>>> logger.setLevel(logging.NOTSET) >>> logger.setLevel(logging.NOTSET)
...@@ -122,10 +114,7 @@ storage isn't left in tpc. ...@@ -122,10 +114,7 @@ storage isn't left in tpc.
And an initial client. And an initial client.
>>> zs1 = ZEO.StorageServer.ZEOStorage(server) >>> zs1 = ZEO.tests.servertesting.client(server, 1)
>>> conn1 = ZEO.tests.servertesting.Connection(1)
>>> zs1.notifyConnected(conn1)
>>> zs1.register('1', 0)
>>> zs1.tpc_begin('0', '', '', {}) >>> zs1.tpc_begin('0', '', '', {})
>>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '0') >>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '0')
...@@ -144,16 +133,12 @@ We're not in a transaction: ...@@ -144,16 +133,12 @@ We're not in a transaction:
We can start another client and get the storage lock. We can start another client and get the storage lock.
>>> zs1 = ZEO.StorageServer.ZEOStorage(server) >>> zs1 = ZEO.tests.servertesting.client(server, 1)
>>> conn1 = ZEO.tests.servertesting.Connection(1)
>>> zs1.notifyConnected(conn1)
>>> zs1.register('1', 0)
>>> zs1.tpc_begin('1', '', '', {}) >>> zs1.tpc_begin('1', '', '', {})
>>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '1') >>> zs1.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '1')
>>> _ = zs1.vote('1') # doctest: +ELLIPSIS >>> _ = zs1.vote('1') # doctest: +ELLIPSIS
1 callAsync serialnos ...
>>> zs1.tpc_finish('1').set_sender(0, conn1) >>> zs1.tpc_finish('1').set_sender(0, zs1.connection)
>>> fs.close() >>> fs.close()
""" """
...@@ -173,10 +158,7 @@ So, we arrange to get an error in vote: ...@@ -173,10 +158,7 @@ So, we arrange to get an error in vote:
>>> server = ZEO.tests.servertesting.StorageServer( >>> server = ZEO.tests.servertesting.StorageServer(
... 'x', {'1': MappingStorage()}) ... 'x', {'1': MappingStorage()})
>>> zs = ZEO.StorageServer.ZEOStorage(server) >>> zs = ZEO.tests.servertesting.client(server, 1)
>>> conn = ZEO.tests.servertesting.Connection(1)
>>> zs.notifyConnected(conn)
>>> zs.register('1', 0)
>>> zs.tpc_begin('0', '', '', {}) >>> zs.tpc_begin('0', '', '', {})
>>> zs.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '0') >>> zs.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '0')
>>> zs.vote('0') >>> zs.vote('0')
...@@ -195,7 +177,6 @@ Of course, if vote suceeds, the lock will be held: ...@@ -195,7 +177,6 @@ Of course, if vote suceeds, the lock will be held:
>>> zs.tpc_begin('1', '', '', {}) >>> zs.tpc_begin('1', '', '', {})
>>> zs.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '1') >>> zs.storea(ZODB.utils.p64(99), ZODB.utils.z64, 'x', '1')
>>> _ = zs.vote('1') # doctest: +ELLIPSIS >>> _ = zs.vote('1') # doctest: +ELLIPSIS
1 callAsync serialnos ...
>>> '1' in server._commit_locks >>> '1' in server._commit_locks
True True
...@@ -234,18 +215,25 @@ quit working in Python 3.4: ...@@ -234,18 +215,25 @@ quit working in Python 3.4:
We start a transaction and vote, this leads to getting the lock. We start a transaction and vote, this leads to getting the lock.
>>> zs1 = ZEO.tests.servertesting.client(server, '1') >>> zs1 = ZEO.tests.servertesting.client(server, '1')
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
>>> tid1 = start_trans(zs1) >>> tid1 = start_trans(zs1)
>>> zs1.vote(tid1) # doctest: +ELLIPSIS >>> zs1.vote(tid1) # doctest: +ELLIPSIS
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-1) ('1') lock: transactions waiting: 0 (test-addr-1) ('1') lock: transactions waiting: 0
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
(test-addr-1) Preparing to commit transaction: 1 objects, ... bytes (test-addr-1) Preparing to commit transaction: 1 objects, 108 bytes
1 callAsync serialnos ...
If another client tried to vote, it's lock request will be queued and If another client tried to vote, it's lock request will be queued and
a delay will be returned: a delay will be returned:
>>> zs2 = ZEO.tests.servertesting.client(server, '2') >>> zs2 = ZEO.tests.servertesting.client(server, '2')
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
>>> tid2 = start_trans(zs2) >>> tid2 = start_trans(zs2)
>>> delay = zs2.vote(tid2) >>> delay = zs2.vote(tid2)
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
...@@ -262,7 +250,6 @@ When we end the first transaction, the queued vote gets the lock. ...@@ -262,7 +250,6 @@ When we end the first transaction, the queued vote gets the lock.
(test-addr-2) ('1') lock: transactions waiting: 0 (test-addr-2) ('1') lock: transactions waiting: 0
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
(test-addr-2) Preparing to commit transaction: 1 objects, ... bytes (test-addr-2) Preparing to commit transaction: 1 objects, ... bytes
2 callAsync serialnos ...
Let's try again with the first client. The vote will be queued: Let's try again with the first client. The vote will be queued:
...@@ -306,29 +293,65 @@ increased, so does the logging level: ...@@ -306,29 +293,65 @@ increased, so does the logging level:
... tid = start_trans(client) ... tid = start_trans(client)
... delay = client.vote(tid) ... delay = client.vote(tid)
... clients.append(client) ... clients.append(client)
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-10) ('1') queue lock: transactions waiting: 2 (test-addr-10) ('1') queue lock: transactions waiting: 2
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-11) ('1') queue lock: transactions waiting: 3 (test-addr-11) ('1') queue lock: transactions waiting: 3
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer WARNING ZEO.StorageServer WARNING
(test-addr-12) ('1') queue lock: transactions waiting: 4 (test-addr-12) ('1') queue lock: transactions waiting: 4
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer WARNING ZEO.StorageServer WARNING
(test-addr-13) ('1') queue lock: transactions waiting: 5 (test-addr-13) ('1') queue lock: transactions waiting: 5
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer WARNING ZEO.StorageServer WARNING
(test-addr-14) ('1') queue lock: transactions waiting: 6 (test-addr-14) ('1') queue lock: transactions waiting: 6
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer WARNING ZEO.StorageServer WARNING
(test-addr-15) ('1') queue lock: transactions waiting: 7 (test-addr-15) ('1') queue lock: transactions waiting: 7
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer WARNING ZEO.StorageServer WARNING
(test-addr-16) ('1') queue lock: transactions waiting: 8 (test-addr-16) ('1') queue lock: transactions waiting: 8
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer WARNING ZEO.StorageServer WARNING
(test-addr-17) ('1') queue lock: transactions waiting: 9 (test-addr-17) ('1') queue lock: transactions waiting: 9
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
ZEO.StorageServer CRITICAL ZEO.StorageServer CRITICAL
(test-addr-18) ('1') queue lock: transactions waiting: 10 (test-addr-18) ('1') queue lock: transactions waiting: 10
If a client with the transaction lock disconnects, it will abort and If a client with the transaction lock disconnects, it will abort and
release the lock and one of the waiting clients will get the lock. release the lock and one of the waiting clients will get the lock.
>>> zs2.notifyDisconnected() # doctest: +ELLIPSIS >>> zs2.notify_disconnected() # doctest: +ELLIPSIS
ZEO.StorageServer INFO ZEO.StorageServer INFO
(test-addr-2) disconnected during locked transaction (test-addr-2) disconnected during locked transaction
ZEO.StorageServer CRITICAL ZEO.StorageServer CRITICAL
...@@ -337,7 +360,6 @@ release the lock and one of the waiting clients will get the lock. ...@@ -337,7 +360,6 @@ release the lock and one of the waiting clients will get the lock.
(test-addr-1) ('1') lock: transactions waiting: 9 (test-addr-1) ('1') lock: transactions waiting: 9
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
(test-addr-1) Preparing to commit transaction: 1 objects, ... bytes (test-addr-1) Preparing to commit transaction: 1 objects, ... bytes
1 callAsync serialnos ...
(In practice, waiting clients won't necessarily get the lock in order.) (In practice, waiting clients won't necessarily get the lock in order.)
...@@ -350,23 +372,19 @@ statistics using the server_status method: ...@@ -350,23 +372,19 @@ statistics using the server_status method:
'commits': 0, 'commits': 0,
'conflicts': 0, 'conflicts': 0,
'conflicts_resolved': 0, 'conflicts_resolved': 0,
'connections': 11, 'connections': 10,
'last-transaction': '0000000000000000', 'last-transaction': '0000000000000000',
'loads': 0, 'loads': 0,
'lock_time': 1272653598.693882, 'lock_time': 1272653598.693882,
'start': 'Fri Apr 30 14:53:18 2010', 'start': 'Fri Apr 30 14:53:18 2010',
'stores': 13, 'stores': 13,
'timeout-thread-is-alive': 'stub', 'timeout-thread-is-alive': 'stub',
'verifying_clients': 0,
'waiting': 9} 'waiting': 9}
(Note that the connections count above is off by 1 due to the way the
test infrastructure works.)
If clients disconnect while waiting, they will be dequeued: If clients disconnect while waiting, they will be dequeued:
>>> for client in clients: >>> for client in clients:
... client.notifyDisconnected() ... client.notify_disconnected()
ZEO.StorageServer INFO ZEO.StorageServer INFO
(test-addr-10) disconnected during unlocked transaction (test-addr-10) disconnected during unlocked transaction
ZEO.StorageServer WARNING ZEO.StorageServer WARNING
...@@ -454,28 +472,33 @@ Now, we'll start a transaction, get the lock and then mark the ...@@ -454,28 +472,33 @@ Now, we'll start a transaction, get the lock and then mark the
ZEOStorage as closed and see if trying to get a lock cleans it up: ZEOStorage as closed and see if trying to get a lock cleans it up:
>>> zs1 = ZEO.tests.servertesting.client(server, '1') >>> zs1 = ZEO.tests.servertesting.client(server, '1')
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
>>> tid1 = start_trans(zs1) >>> tid1 = start_trans(zs1)
>>> zs1.vote(tid1) # doctest: +ELLIPSIS >>> zs1.vote(tid1) # doctest: +ELLIPSIS
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-1) ('1') lock: transactions waiting: 0 (test-addr-1) ('1') lock: transactions waiting: 0
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
(test-addr-1) Preparing to commit transaction: 1 objects, ... bytes (test-addr-1) Preparing to commit transaction: 1 objects, ... bytes
1 callAsync serialnos ...
>>> zs1.connection = None >>> zs1.connection.connection_lost(None)
ZEO.StorageServer INFO
(test-addr-1) disconnected during locked transaction
>>> zs2 = ZEO.tests.servertesting.client(server, '2') >>> zs2 = ZEO.tests.servertesting.client(server, '2')
ZEO.asyncio.base INFO
Connected server protocol
ZEO.asyncio.server INFO
received handshake b'Z5'
>>> tid2 = start_trans(zs2) >>> tid2 = start_trans(zs2)
>>> zs2.vote(tid2) # doctest: +ELLIPSIS >>> zs2.vote(tid2) # doctest: +ELLIPSIS
ZEO.StorageServer CRITICAL
(test-addr-1) Still locked after disconnected. Unlocking.
ZEO.StorageServer DEBUG ZEO.StorageServer DEBUG
(test-addr-2) ('1') lock: transactions waiting: 0 (test-addr-2) ('1') lock: transactions waiting: 0
ZEO.StorageServer BLATHER ZEO.StorageServer BLATHER
(test-addr-2) Preparing to commit transaction: 1 objects, ... bytes (test-addr-2) Preparing to commit transaction: 1 objects, ... bytes
2 callAsync serialnos ...
>>> zs1.txnlog.close()
>>> zs2.tpc_abort(tid2) >>> zs2.tpc_abort(tid2)
>>> logging.getLogger('ZEO').setLevel(logging.NOTSET) >>> logging.getLogger('ZEO').setLevel(logging.NOTSET)
......
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Helper file used to launch a ZEO server cross platform"""
import asyncore
import errno
import getopt
import logging
import os
import signal
import socket
import sys
import threading
import time
import ZEO.runzeo
import ZEO.zrpc.connection
def cleanup(storage):
# FileStorage and the Berkeley storages have this method, which deletes
# all files and directories used by the storage. This prevents @-files
# from clogging up /tmp
try:
storage.cleanup()
except AttributeError:
pass
logger = logging.getLogger('ZEO.tests.zeoserver')
def log(label, msg, *args):
message = "(%s) %s" % (label, msg)
logger.debug(message, *args)
class ZEOTestServer(asyncore.dispatcher):
"""A server for killing the whole process at the end of a test.
The first time we connect to this server, we write an ack character down
the socket. The other end should block on a recv() of the socket so it
can guarantee the server has started up before continuing on.
The second connect to the port immediately exits the process, via
os._exit(), without writing data on the socket. It does close and clean
up the storage first. The other end will get the empty string from its
recv() which will be enough to tell it that the server has exited.
I think this should prevent us from ever getting a legitimate addr-in-use
error.
"""
__super_init = asyncore.dispatcher.__init__
def __init__(self, addr, server, keep):
self.__super_init()
self._server = server
self._sockets = [self]
self._keep = keep
# Count down to zero, the number of connects
self._count = 1
self._label ='%d @ %s' % (os.getpid(), addr)
if isinstance(addr, str):
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
# Some ZEO tests attempt a quick start of the server using the same
# port so we have to set the reuse flag.
self.set_reuse_addr()
try:
self.bind(addr)
except:
# We really want to see these exceptions
import traceback
traceback.print_exc()
raise
self.listen(5)
self.log('bound and listening')
def log(self, msg, *args):
log(self._label, msg, *args)
def handle_accept(self):
sock, addr = self.accept()
self.log('in handle_accept()')
# When we're done with everything, close the storage. Do not write
# the ack character until the storage is finished closing.
if self._count <= 0:
self.log('closing the storage')
self._server.close()
if not self._keep:
for storage in self._server.storages.values():
cleanup(storage)
self.log('exiting')
# Close all the other sockets so that we don't have to wait
# for os._exit() to get to it before starting the next
# server process.
for s in self._sockets:
s.close()
# Now explicitly close the socket returned from accept(),
# since it didn't go through the wrapper.
sock.close()
os._exit(0)
self.log('continuing')
sock.send(b'X')
self._count -= 1
def register_socket(self, sock):
# Register a socket to be closed when server shutsdown.
self._sockets.append(sock)
class Suicide(threading.Thread):
def __init__(self, addr):
threading.Thread.__init__(self)
self._adminaddr = addr
def run(self):
# If this process doesn't exit in 330 seconds, commit suicide.
# The client threads in the ConcurrentUpdate tests will run for
# as long as 300 seconds. Set this timeout to 330 to minimize
# chance that the server gives up before the clients.
time.sleep(999)
log(str(os.getpid()), "suicide thread invoking shutdown")
# If the server hasn't shut down yet, the client may not be
# able to connect to it. If so, try to kill the process to
# force it to shutdown.
if hasattr(os, "kill"):
os.kill(pid, signal.SIGTERM)
time.sleep(5)
os.kill(pid, signal.SIGKILL)
else:
from ZEO.tests.forker import shutdown_zeo_server
# Nott: If the -k option was given to zeoserver, then the
# process will go away but the temp files won't get
# cleaned up.
shutdown_zeo_server(self._adminaddr)
def main():
global pid
pid = os.getpid()
label = str(pid)
log(label, "starting")
# We don't do much sanity checking of the arguments, since if we get it
# wrong, it's a bug in the test suite.
keep = 0
configfile = None
suicide = True
# Parse the arguments and let getopt.error percolate
opts, args = getopt.getopt(sys.argv[1:], 'dkSC:v:')
for opt, arg in opts:
if opt == '-k':
keep = 1
if opt == '-d':
ZEO.zrpc.connection.debug_zrpc = True
elif opt == '-C':
configfile = arg
elif opt == '-S':
suicide = False
elif opt == '-v':
ZEO.zrpc.connection.Connection.current_protocol = arg.encode(
'ascii')
zo = ZEO.runzeo.ZEOOptions()
zo.realize(["-C", configfile])
addr = zo.address
if isinstance(addr, tuple):
test_addr = addr[0], addr[1]+1
else:
test_addr = addr + '-test'
log(label, 'creating the storage server')
mon_addr = None
if zo.monitor_address:
mon_addr = zo.monitor_address
storages = dict((s.name or '1', s.open()) for s in zo.storages)
server = ZEO.runzeo.create_server(storages, zo)
try:
log(label, 'creating the test server, keep: %s', keep)
t = ZEOTestServer(test_addr, server, keep)
except socket.error as e:
if e[0] != errno.EADDRINUSE:
raise
log(label, 'addr in use, closing and exiting')
for storage in storages.values():
storage.close()
cleanup(storage)
sys.exit(2)
t.register_socket(server.dispatcher)
if suicide:
# Create daemon suicide thread
d = Suicide(test_addr)
d.setDaemon(1)
d.start()
# Loop for socket events
log(label, 'entering asyncore loop')
server.start_thread()
asyncore.loop()
if __name__ == '__main__':
import warnings
warnings.simplefilter('ignore')
main()
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
# zrpc is a package with the following modules
# client -- manages connection creation to remote server
# connection -- object dispatcher
# log -- logging helper
# error -- exceptions raised by zrpc
# marshal -- internal, handles basic protocol issues
# server -- manages incoming connections from remote clients
# smac -- sized message async connections
# trigger -- medusa's trigger
# zrpc is not an advertised subpackage of ZEO; its interfaces are internal
# This file is a slightly modified copy of Python 2.3's Lib/hmac.py.
# This file is under the Python Software Foundation (PSF) license.
"""HMAC (Keyed-Hashing for Message Authentication) Python module.
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC:
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
"""Create a new HMAC object.
key: key for the keyed hash object.
msg: Initial input for the hash, if provided.
digestmod: A module supporting PEP 247. Defaults to the md5 module.
"""
if digestmod is None:
import md5
digestmod = md5
self.digestmod = digestmod
self.outer = digestmod.new()
self.inner = digestmod.new()
# Python 2.1 and 2.2 differ about the correct spelling
try:
self.digest_size = digestmod.digestsize
except AttributeError:
self.digest_size = digestmod.digest_size
blocksize = 64
ipad = "\x36" * blocksize
opad = "\x5C" * blocksize
if len(key) > blocksize:
key = digestmod.new(key).digest()
key = key + chr(0) * (blocksize - len(key))
self.outer.update(_strxor(key, opad))
self.inner.update(_strxor(key, ipad))
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
"""
self.inner.update(msg)
def copy(self):
"""Return a separate copy of this hashing object.
An update to this copy won't affect the original object.
"""
other = HMAC("")
other.digestmod = self.digestmod
other.inner = self.inner.copy()
other.outer = self.outer.copy()
return other
def digest(self):
"""Return the hash value of this hashing object.
This returns a string containing 8-bit data. The object is
not altered in any way by this function; you can continue
updating the object after calling this function.
"""
h = self.outer.copy()
h.update(self.inner.digest())
return h.digest()
def hexdigest(self):
"""Like digest(), but returns a string of hexadecimal digits instead.
"""
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
msg: if available, will immediately be hashed into the object's starting
state.
You can now feed arbitrary strings into the object using its update()
method, and can ask for the hash value at any time by calling its digest()
method.
"""
return HMAC(key, msg, digestmod)
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import asyncore
import errno
import json
import sys
import threading
import logging
import ZEO.zrpc.marshal
import ZEO.zrpc.trigger
from ZEO.zrpc import smac
from ZEO.zrpc.error import ZRPCError, DisconnectedError
from ZEO.zrpc.log import short_repr, log
from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException
REPLY = ".reply" # message name used for replies
exception_type_type = type(Exception)
debug_zrpc = False
class Delay:
"""Used to delay response to client for synchronous calls.
When a synchronous call is made and the original handler returns
without handling the call, it returns a Delay object that prevents
the mainloop from sending a response.
"""
msgid = conn = sent = None
def set_sender(self, msgid, conn):
self.msgid = msgid
self.conn = conn
def reply(self, obj):
self.sent = 'reply'
self.conn.send_reply(self.msgid, obj)
def error(self, exc_info):
self.sent = 'error'
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.conn.return_error(self.msgid, *exc_info[:2])
def __repr__(self):
return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self), self.msgid, self.conn, self.sent)
class Result(Delay):
def __init__(self, *args):
self.args = args
def set_sender(self, msgid, conn):
reply, callback = self.args
conn.send_reply(msgid, reply, False)
callback()
class MTDelay(Delay):
def __init__(self):
self.ready = threading.Event()
def set_sender(self, *args):
Delay.set_sender(self, *args)
self.ready.set()
def reply(self, obj):
self.ready.wait()
self.conn.call_from_thread(self.conn.send_reply, self.msgid, obj)
def error(self, exc_info):
self.ready.wait()
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.conn.call_from_thread(Delay.error, self, exc_info)
# PROTOCOL NEGOTIATION
#
# The code implementing protocol version 2.0.0 (which is deployed
# in the field and cannot be changed) *only* talks to peers that
# send a handshake indicating protocol version 2.0.0. In that
# version, both the client and the server immediately send out
# their protocol handshake when a connection is established,
# without waiting for their peer, and disconnect when a different
# handshake is receive.
#
# The new protocol uses this to enable new clients to talk to
# 2.0.0 servers. In the new protocol:
#
# The server sends its protocol handshake to the client at once.
#
# The client waits until it receives the server's protocol handshake
# before sending its own handshake. The client sends the lower of its
# own protocol version and the server protocol version, allowing it to
# talk to servers using later protocol versions (2.0.2 and higher) as
# well: the effective protocol used will be the lower of the client
# and server protocol. However, this changed in ZODB 3.3.1 (and
# should have changed in ZODB 3.3) because an older server doesn't
# support MVCC methods required by 3.3 clients.
#
# [Ugly details: In order to treat the first received message (protocol
# handshake) differently than all later messages, both client and server
# start by patching their message_input() method to refer to their
# recv_handshake() method instead. In addition, the client has to arrange
# to queue (delay) outgoing messages until it receives the server's
# handshake, so that the first message the client sends to the server is
# the client's handshake. This multiply-special treatment of the first
# message is delicate, and several asyncore and thread subtleties were
# handled unsafely before ZODB 3.2.6.
# ]
#
# The ZEO modules ClientStorage and ServerStub have backwards
# compatibility code for dealing with the previous version of the
# protocol. The client accepts the old version of some messages,
# and will not send new messages when talking to an old server.
#
# As long as the client hasn't sent its handshake, it can't send
# anything else; output messages are queued during this time.
# (Output can happen because the connection testing machinery can
# start sending requests before the handshake is received.)
#
# UPGRADING FROM ZEO 2.0.0 TO NEWER VERSIONS:
#
# Because a new client can talk to an old server, but not vice
# versa, all clients should be upgraded before upgrading any
# servers. Protocol upgrades beyond 2.0.1 will not have this
# restriction, because clients using protocol 2.0.1 or later can
# talk to both older and newer servers.
#
# No compatibility with protocol version 1 is provided.
# Connection is abstract (it must be derived from). ManagedServerConnection
# and ManagedClientConnection are the concrete subclasses. They need to
# supply a handshake() method appropriate for their role in protocol
# negotiation.
class Connection(smac.SizedMessageAsyncConnection, object):
"""Dispatcher for RPC on object on both sides of socket.
The connection supports synchronous calls, which expect a return,
and asynchronous calls, which do not.
It uses the Marshaller class to handle encoding and decoding of
method calls and arguments. Marshaller uses pickle to encode
arbitrary Python objects. The code here doesn't ever see the wire
format.
A Connection is designed for use in a multithreaded application,
where a synchronous call must block until a response is ready.
A socket connection between a client and a server allows either
side to invoke methods on the other side. The processes on each
end of the socket use a Connection object to manage communication.
The Connection deals with decoded RPC messages. They are
represented as four-tuples containing: msgid, flags, method name,
and a tuple of method arguments.
The msgid starts at zero and is incremented by one each time a
method call message is sent. Each side of the connection has a
separate msgid state.
When one side of the connection (the client) calls a method, it
sends a message with a new msgid. The other side (the server),
replies with a message that has the same msgid, the string
".reply" (the global variable REPLY) as the method name, and the
actual return value in the args position. Note that each side of
the Connection can initiate a call, in which case it will be the
client for that particular call.
The protocol also supports asynchronous calls. The client does
not wait for a return value for an asynchronous call.
If a method call raises an Exception, the exception is propagated
back to the client via the REPLY message. The client side will
raise any exception it receives instead of returning the value to
the caller.
"""
__super_init = smac.SizedMessageAsyncConnection.__init__
__super_close = smac.SizedMessageAsyncConnection.close
__super_setSessionKey = smac.SizedMessageAsyncConnection.setSessionKey
# Protocol history:
#
# Z200 -- Original ZEO 2.0 protocol
#
# Z201 -- Added invalidateTransaction() to client.
# Renamed several client methods.
# Added several sever methods:
# lastTransaction()
# getAuthProtocol() and scheme-specific authentication methods
# getExtensionMethods().
# getInvalidations().
#
# Z303 -- named after the ZODB release 3.3
# Added methods for MVCC:
# loadBefore()
# A Z303 client cannot talk to a Z201 server, because the latter
# doesn't support MVCC. A Z201 client can talk to a Z303 server,
# but because (at least) the type of the root object changed
# from ZODB.PersistentMapping to persistent.mapping, the older
# client can't actually make progress if a Z303 client created,
# or ever modified, the root.
#
# Z308 -- named after the ZODB release 3.8
# Added blob-support server methods:
# sendBlob
# storeBlobStart
# storeBlobChunk
# storeBlobEnd
# storeBlobShared
# Added blob-support client methods:
# receiveBlobStart
# receiveBlobChunk
# receiveBlobStop
#
# Z309 -- named after the ZODB release 3.9
# New server methods:
# restorea, iterator_start, iterator_next,
# iterator_record_start, iterator_record_next,
# iterator_gc
#
# Z310 -- named after the ZODB release 3.10
# New server methods:
# undoa
# Doesn't support undo for older clients.
# Undone oid info returned by vote.
#
# Z3101 -- checkCurrentSerialInTransaction
#
# Z4 -- checkCurrentSerialInTransaction
# No-longer call load.
# Protocol variables:
# Our preferred protocol.
current_protocol = b"Z4"
# If we're a client, an exhaustive list of the server protocols we
# can accept.
servers_we_can_talk_to = [b"Z308", b"Z309", b"Z310", b"Z3101",
current_protocol]
# If we're a server, an exhaustive list of the client protocols we
# can accept.
clients_we_can_talk_to = [
b"Z200", b"Z201", b"Z303", b"Z308", b"Z309", b"Z310", b"Z3101",
current_protocol]
# This is pretty excruciating. Details:
#
# 3.3 server 3.2 client
# server sends Z303 to client
# client computes min(Z303, Z201) == Z201 as the protocol to use
# client sends Z201 to server
# OK, because Z201 is in the server's clients_we_can_talk_to
#
# 3.2 server 3.3 client
# server sends Z201 to client
# client computes min(Z303, Z201) == Z201 as the protocol to use
# Z201 isn't in the client's servers_we_can_talk_to, so client
# raises exception
#
# 3.3 server 3.3 client
# server sends Z303 to client
# client computes min(Z303, Z303) == Z303 as the protocol to use
# Z303 is in the client's servers_we_can_talk_to, so client
# sends Z303 to server
# OK, because Z303 is in the server's clients_we_can_talk_to
# Exception types that should not be logged:
unlogged_exception_types = ()
# Client constructor passes b'C' for tag, server constructor b'S'. This
# is used in log messages, and to determine whether we can speak with
# our peer.
def __init__(self, sock, addr, obj, tag, map=None):
self.obj = None
self.decode = ZEO.zrpc.marshal.decode
self.encode = ZEO.zrpc.marshal.encode
self.fast_encode = ZEO.zrpc.marshal.fast_encode
self.closed = False
self.peer_protocol_version = None # set in recv_handshake()
assert tag in b"CS"
self.tag = tag
self.logger = logging.getLogger('ZEO.zrpc.Connection(%r)' % tag)
if isinstance(addr, tuple):
self.log_label = "(%s:%d) " % addr
else:
self.log_label = "(%s) " % addr
# Supply our own socket map, so that we don't get registered with
# the asyncore socket map just yet. The initial protocol messages
# are treated very specially, and we dare not get invoked by asyncore
# before that special-case setup is complete. Some of that setup
# occurs near the end of this constructor, and the rest is done by
# a concrete subclass's handshake() method. Unfortunately, because
# we ultimately derive from asyncore.dispatcher, it's not possible
# to invoke the superclass constructor without asyncore stuffing
# us into _some_ socket map.
ourmap = {}
self.__super_init(sock, addr, map=ourmap)
# The singleton dict is used in synchronous mode when a method
# needs to call into asyncore to try to force some I/O to occur.
# The singleton dict is a socket map containing only this object.
self._singleton = {self._fileno: self}
# waiting_for_reply is used internally to indicate whether
# a call is in progress. setting a session key is deferred
# until after the call returns.
self.waiting_for_reply = False
self.delay_sesskey = None
self.register_object(obj)
# The first message we see is a protocol handshake. message_input()
# is temporarily replaced by recv_handshake() to treat that message
# specially. revc_handshake() does "del self.message_input", which
# uncovers the normal message_input() method thereafter.
self.message_input = self.recv_handshake
# Server and client need to do different things for protocol
# negotiation, and handshake() is implemented differently in each.
self.handshake()
# Now it's safe to register with asyncore's socket map; it was not
# safe before message_input was replaced, or before handshake() was
# invoked.
# Obscure: in Python 2.4, the base asyncore.dispatcher class grew
# a ._map attribute, which is used instead of asyncore's global
# socket map when ._map isn't None. Because we passed `ourmap` to
# the base class constructor above, in 2.4 asyncore believes we want
# to use `ourmap` instead of the global socket map -- but we don't.
# So we have to replace our ._map with the global socket map, and
# update the global socket map with `ourmap`. Replacing our ._map
# isn't necessary before Python 2.4, but doesn't hurt then (it just
# gives us an unused attribute in 2.3); updating the global socket
# map is necessary regardless of Python version.
if map is None:
map = asyncore.socket_map
self._map = map
map.update(ourmap)
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.addr)
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__
def log(self, message, level=BLATHER, exc_info=False):
self.logger.log(level, self.log_label + message, exc_info=exc_info)
def close(self):
self.mgr.close_conn(self)
if self.closed:
return
self._singleton.clear()
self.closed = True
self.__super_close()
self.trigger.pull_trigger()
def register_object(self, obj):
"""Register obj as the true object to invoke methods on."""
self.obj = obj
# Subclass must implement. handshake() is called by the constructor,
# near its end, but before self is added to asyncore's socket map.
# When a connection is created the first message sent is a 4-byte
# protocol version. This allows the protocol to evolve over time, and
# lets servers handle clients using multiple versions of the protocol.
# In general, the server's handshake() just needs to send the server's
# preferred protocol; the client's also needs to queue (delay) outgoing
# messages until it sees the handshake from the server.
def handshake(self):
raise NotImplementedError
# Replaces message_input() for the first message received. Records the
# protocol sent by the peer in `peer_protocol_version`, restores the
# normal message_input() method, and raises an exception if the peer's
# protocol is unacceptable. That's all the server needs to do. The
# client needs to do additional work in response to the server's
# handshake, and extends this method.
def recv_handshake(self, proto):
# Extended by ManagedClientConnection.
del self.message_input # uncover normal-case message_input()
self.peer_protocol_version = proto
if self.tag == b'C':
good_protos = self.servers_we_can_talk_to
else:
assert self.tag == b'S'
good_protos = self.clients_we_can_talk_to
if proto in good_protos:
self.log("received handshake %r" % proto, level=logging.INFO)
else:
self.log("bad handshake %s" % short_repr(proto),
level=logging.ERROR)
raise ZRPCError("bad handshake %r" % proto)
def message_input(self, message):
"""Decode an incoming message and dispatch it"""
# If something goes wrong during decoding, the marshaller
# will raise an exception. The exception will ultimately
# result in asycnore calling handle_error(), which will
# close the connection.
msgid, async, name, args = self.decode(message)
if debug_zrpc:
self.log("recv msg: %s, %s, %s, %s" % (msgid, async, name,
short_repr(args)),
level=TRACE)
if name == 'loadEx':
# Special case and inline the heck out of load case:
try:
ret = self.obj.loadEx(*args)
except (SystemExit, KeyboardInterrupt):
raise
except Exception as msg:
if not isinstance(msg, self.unlogged_exception_types):
self.log("%s() raised exception: %s" % (name, msg),
logging.ERROR, exc_info=True)
self.return_error(msgid, *sys.exc_info()[:2])
else:
try:
self.message_output(self.fast_encode(msgid, 0, REPLY, ret))
self.poll()
except:
# Fall back to normal version for better error handling
self.send_reply(msgid, ret)
elif name == REPLY:
assert not async
self.handle_reply(msgid, args)
else:
self.handle_request(msgid, async, name, args)
def handle_request(self, msgid, async, name, args):
obj = self.obj
if name.startswith('_') or not hasattr(obj, name):
if obj is None:
if debug_zrpc:
self.log("no object calling %s%s"
% (name, short_repr(args)),
level=logging.DEBUG)
return
msg = "Invalid method name: %s on %s" % (name, repr(obj))
raise ZRPCError(msg)
if debug_zrpc:
self.log("calling %s%s" % (name, short_repr(args)),
level=logging.DEBUG)
meth = getattr(obj, name)
try:
self.waiting_for_reply = True
try:
ret = meth(*args)
finally:
self.waiting_for_reply = False
except (SystemExit, KeyboardInterrupt):
raise
except Exception as msg:
if not isinstance(msg, self.unlogged_exception_types):
self.log("%s() raised exception: %s" % (name, msg),
logging.ERROR, exc_info=True)
error = sys.exc_info()[:2]
if async:
self.log("Asynchronous call raised exception: %s" % self,
level=logging.ERROR, exc_info=True)
else:
self.return_error(msgid, *error)
return
if async:
if ret is not None:
raise ZRPCError("async method %s returned value %s" %
(name, short_repr(ret)))
else:
if debug_zrpc:
self.log("%s returns %s" % (name, short_repr(ret)),
logging.DEBUG)
if isinstance(ret, Delay):
ret.set_sender(msgid, self)
else:
self.send_reply(msgid, ret, not self.delay_sesskey)
if self.delay_sesskey:
self.__super_setSessionKey(self.delay_sesskey)
self.delay_sesskey = None
def return_error(self, msgid, err_type, err_value):
# Note that, ideally, this should be defined soley for
# servers, but a test arranges to get it called on
# a client. Too much trouble to fix it now. :/
if not isinstance(err_value, Exception):
err_value = err_type, err_value
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, (err_type, err_value))
except: # see above
try:
r = short_repr(err_value)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle error %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
self.poll()
def handle_error(self):
if sys.exc_info()[0] == SystemExit:
raise sys.exc_info()
self.log("Error caught in asyncore",
level=logging.ERROR, exc_info=True)
self.close()
def setSessionKey(self, key):
if self.waiting_for_reply:
self.delay_sesskey = key
else:
self.__super_setSessionKey(key)
def send_call(self, method, args, async=False):
# send a message and return its msgid
if async:
msgid = 0
else:
msgid = self._new_msgid()
if debug_zrpc:
self.log("send msg: %d, %d, %s, ..." % (msgid, async, method),
level=TRACE)
buf = self.encode(msgid, async, method, args)
self.message_output(buf)
return msgid
def callAsync(self, method, *args):
if self.closed:
raise DisconnectedError()
self.send_call(method, args, 1)
self.poll()
def callAsyncNoPoll(self, method, *args):
# Like CallAsync but doesn't poll. This exists so that we can
# send invalidations atomically to all clients without
# allowing any client to sneak in a load request.
if self.closed:
raise DisconnectedError()
self.send_call(method, args, 1)
def callAsyncNoSend(self, method, *args):
# Like CallAsync but doesn't poll. This exists so that we can
# send invalidations atomically to all clients without
# allowing any client to sneak in a load request.
if self.closed:
raise DisconnectedError()
self.send_call(method, args, 1)
self.call_from_thread()
def callAsyncIterator(self, iterator):
"""Queue a sequence of calls using an iterator
The calls will not be interleaved with other calls from the same
client.
"""
self.message_output(self.encode(0, 1, method, args)
for method, args in iterator)
def handle_reply(self, msgid, ret):
assert msgid == -1 and ret is None
def poll(self):
"""Invoke asyncore mainloop to get pending message out."""
if debug_zrpc:
self.log("poll()", level=TRACE)
self.trigger.pull_trigger()
# import cProfile, time
class ManagedServerConnection(Connection):
"""Server-side Connection subclass."""
# Exception types that should not be logged:
unlogged_exception_types = (ZODB.POSException.POSKeyError, )
def __init__(self, sock, addr, obj, mgr):
self.mgr = mgr
map = {}
Connection.__init__(self, sock, addr, obj, b'S', map=map)
self.decode = ZEO.zrpc.marshal.server_decode
self.trigger = ZEO.zrpc.trigger.trigger(map)
self.call_from_thread = self.trigger.pull_trigger
t = threading.Thread(target=server_loop, args=(map,))
t.setName("ManagedServerConnection thread")
t.setDaemon(True)
t.start()
# self.profile = cProfile.Profile()
# def message_input(self, message):
# self.profile.enable()
# try:
# Connection.message_input(self, message)
# finally:
# self.profile.disable()
def handshake(self):
# Send the server's preferred protocol to the client.
self.message_output(self.current_protocol)
def recv_handshake(self, proto):
if proto == b'ruok':
self.message_output(json.dumps(self.mgr.ruok()).encode("ascii"))
self.poll()
Connection.close(self)
else:
Connection.recv_handshake(self, proto)
self.obj.notifyConnected(self)
def close(self):
self.obj.notifyDisconnected()
Connection.close(self)
# self.profile.dump_stats(str(time.time())+'.stats')
def send_reply(self, msgid, ret, immediately=True):
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, ret)
except: # see above
try:
r = short_repr(ret)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
if immediately:
self.poll()
poll = smac.SizedMessageAsyncConnection.handle_write
def server_loop(map):
while len(map) > 1:
try:
asyncore.poll(30.0, map)
except Exception as v:
if v.args[0] != errno.EBADF:
raise
for o in tuple(map.values()):
o.close()
class ManagedClientConnection(Connection):
"""Client-side Connection subclass."""
__super_init = Connection.__init__
base_message_output = Connection.message_output
def __init__(self, sock, addr, mgr):
self.mgr = mgr
# We can't use the base smac's message_output directly because the
# client needs to queue outgoing messages until it's seen the
# initial protocol handshake from the server. So we have our own
# message_ouput() method, and support for initial queueing. This is
# a delicate design, requiring an output mutex to be wholly
# thread-safe.
# Caution: we must set this up before calling the base class
# constructor, because the latter registers us with asyncore;
# we need to guarantee that we'll queue outgoing messages before
# asyncore learns about us.
self.output_lock = threading.Lock()
self.queue_output = True
self.queued_messages = []
# msgid_lock guards access to msgid
self.msgid = 0
self.msgid_lock = threading.Lock()
# replies_cond is used to block when a synchronous call is
# waiting for a response
self.replies_cond = threading.Condition()
self.replies = {}
self.__super_init(sock, addr, None, tag=b'C', map=mgr.map)
self.trigger = mgr.trigger
self.call_from_thread = self.trigger.pull_trigger
self.call_from_thread()
def close(self):
Connection.close(self)
self.replies_cond.acquire()
self.replies_cond.notifyAll()
self.replies_cond.release()
# Our message_ouput() queues messages until recv_handshake() gets the
# protocol handshake from the server.
def message_output(self, message):
self.output_lock.acquire()
try:
if self.queue_output:
self.queued_messages.append(message)
else:
assert not self.queued_messages
self.base_message_output(message)
finally:
self.output_lock.release()
def handshake(self):
# The client waits to see the server's handshake. Outgoing messages
# are queued for the duration. The client will send its own
# handshake after the server's handshake is seen, in recv_handshake()
# below. It will then send any messages queued while waiting.
assert self.queue_output # the constructor already set this
def recv_handshake(self, proto):
# The protocol to use is the older of our and the server's preferred
# protocols.
proto = min(proto, self.current_protocol)
# Restore the normal message_input method, and raise an exception
# if the protocol version is too old.
Connection.recv_handshake(self, proto)
# Tell the server the protocol in use, then send any messages that
# were queued while waiting to hear the server's protocol, and stop
# queueing messages.
self.output_lock.acquire()
try:
self.base_message_output(proto)
for message in self.queued_messages:
self.base_message_output(message)
self.queued_messages = []
self.queue_output = False
finally:
self.output_lock.release()
def _new_msgid(self):
self.msgid_lock.acquire()
try:
msgid = self.msgid
self.msgid = self.msgid + 1
return msgid
finally:
self.msgid_lock.release()
def call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply."""
if debug_zrpc:
self.log("wait(%d)" % msgid, level=TRACE)
self.trigger.pull_trigger()
self.replies_cond.acquire()
try:
while 1:
if self.closed:
raise DisconnectedError()
reply = self.replies.get(msgid, self)
if reply is not self:
del self.replies[msgid]
if debug_zrpc:
self.log("wait(%d): reply=%s" %
(msgid, short_repr(reply)), level=TRACE)
return reply
self.replies_cond.wait()
finally:
self.replies_cond.release()
# For testing purposes, it is useful to begin a synchronous call
# but not block waiting for its response.
def _deferred_call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
self.trigger.pull_trigger()
return msgid
def _deferred_wait(self, msgid):
r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def handle_reply(self, msgid, args):
if debug_zrpc:
self.log("recv reply: %s, %s"
% (msgid, short_repr(args)), level=TRACE)
self.replies_cond.acquire()
try:
self.replies[msgid] = args
self.replies_cond.notifyAll()
finally:
self.replies_cond.release()
def send_reply(self, msgid, ret):
# Whimper. Used to send heartbeat
assert msgid == -1 and ret is None
self.message_output(b'(J\xff\xff\xff\xffK\x00U\x06.replyNt.')
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
from ZODB import POSException
from ZEO.Exceptions import ClientDisconnected
class ZRPCError(POSException.StorageError):
pass
class DisconnectedError(ZRPCError, ClientDisconnected):
"""The database storage is disconnected from the storage server.
The error occurred because a problem in the low-level RPC connection,
or because the connection was closed.
"""
# This subclass is raised when zrpc catches the error.
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Sized Message Async Connections.
This class extends the basic asyncore layer with a record-marking
layer. The message_output() method accepts an arbitrary sized string
as its argument. It sends over the wire the length of the string
encoded using struct.pack('>I') and the string itself. The receiver
passes the original string to message_input().
This layer also supports an optional message authentication code
(MAC). If a session key is present, it uses HMAC-SHA-1 to generate a
20-byte MAC. If a MAC is present, the high-order bit of the length
is set to 1 and the MAC immediately follows the length.
"""
import asyncore
import errno
import six
try:
import hmac
except ImportError:
from . import _hmac as hmac
import socket
import struct
import threading
from ZEO.zrpc.log import log
from ZEO.zrpc.error import DisconnectedError
import ZEO.hash
# Use the dictionary to make sure we get the minimum number of errno
# entries. We expect that EWOULDBLOCK == EAGAIN on most systems --
# or that only one is actually used.
tmp_dict = {errno.EWOULDBLOCK: 0,
errno.EAGAIN: 0,
errno.EINTR: 0,
}
expected_socket_read_errors = tuple(tmp_dict.keys())
tmp_dict = {errno.EAGAIN: 0,
errno.EWOULDBLOCK: 0,
errno.ENOBUFS: 0,
errno.EINTR: 0,
}
expected_socket_write_errors = tuple(tmp_dict.keys())
del tmp_dict
# We chose 60000 as the socket limit by looking at the largest strings
# that we could pass to send() without blocking.
SEND_SIZE = 60000
MAC_BIT = 0x80000000
_close_marker = object()
class SizedMessageAsyncConnection(asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close
__closed = True # Marker indicating that we're closed
socket = None # to outwit Sam's getattr
def __init__(self, sock, addr, map=None):
self.addr = addr
# __input_lock protects __inp, __input_len, __state, __msg_size
self.__input_lock = threading.Lock()
self.__inp = None # None, a single String, or a list
self.__input_len = 0
# Instance variables __state, __msg_size and __has_mac work together:
# when __state == 0:
# __msg_size == 4, and the next thing read is a message size;
# __has_mac is set according to the MAC_BIT in the header
# when __state == 1:
# __msg_size is variable, and the next thing read is a message.
# __has_mac indicates if we're in MAC mode or not (and
# therefore, if we need to check the mac header)
# The next thing read is always of length __msg_size.
# The state alternates between 0 and 1.
self.__state = 0
self.__has_mac = 0
self.__msg_size = 4
self.__output_messages = []
self.__output = []
self.__closed = False
# Each side of the connection sends and receives messages. A
# MAC is generated for each message and depends on each
# previous MAC; the state of the MAC generator depends on the
# history of operations it has performed. So the MACs must be
# generated in the same order they are verified.
# Each side is guaranteed to receive messages in the order
# they are sent, but there is no ordering constraint between
# message sends and receives. If the two sides are A and B
# and message An indicates the nth message sent by A, then
# A1 A2 B1 B2 and A1 B1 B2 A2 are both legitimate total
# orderings of the messages.
# As a result, there must be seperate MAC generators for each
# side of the connection. If not, the generator state would
# be different after A1 A2 B1 B2 than it would be after
# A1 B1 B2 A2; if the generator state was different, the MAC
# could not be verified.
self.__hmac_send = None
self.__hmac_recv = None
self.__super_init(sock, map)
# asyncore overwrites addr with the getpeername result
# restore our value
self.addr = addr
def setSessionKey(self, sesskey):
log("set session key %r" % sesskey)
# Low-level construction is now delayed until data are sent.
# This is to allow use of iterators that generate messages
# only when we're ready to do I/O so that we can effeciently
# transmit large files. Because we delay messages, we also
# have to delay setting the session key to retain proper
# ordering.
# The low-level output queue supports strings, a special close
# marker, and iterators. It doesn't support callbacks. We
# can create a allback by providing an iterator that doesn't
# yield anything.
# The hack fucntion below is a callback in iterator's
# clothing. :) It never yields anything, but is a generator
# and thus iterator, because it contains a yield statement.
def hack():
self.__hmac_send = hmac.HMAC(sesskey, digestmod=ZEO.hash)
self.__hmac_recv = hmac.HMAC(sesskey, digestmod=ZEO.hash)
if False:
yield b''
self.message_output(hack())
def get_addr(self):
return self.addr
# TODO: avoid expensive getattr calls? Can't remember exactly what
# this comment was supposed to mean, but it has something to do
# with the way asyncore uses getattr and uses if sock:
def __nonzero__(self):
return 1
def handle_read(self):
self.__input_lock.acquire()
try:
# Use a single __inp buffer and integer indexes to make this fast.
try:
d = self.recv(8192)
except socket.error as err:
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_read_errors:
return
raise
if not d:
return
input_len = self.__input_len + len(d)
msg_size = self.__msg_size
state = self.__state
has_mac = self.__has_mac
inp = self.__inp
if msg_size > input_len:
if inp is None:
self.__inp = d
elif isinstance(self.__inp, six.binary_type):
self.__inp = [self.__inp, d]
else:
self.__inp.append(d)
self.__input_len = input_len
return # keep waiting for more input
# load all previous input and d into single string inp
if isinstance(inp, six.binary_type):
inp = inp + d
elif inp is None:
inp = d
else:
inp.append(d)
inp = b"".join(inp)
offset = 0
while (offset + msg_size) <= input_len:
msg = inp[offset:offset + msg_size]
offset = offset + msg_size
if not state:
msg_size = struct.unpack(">I", msg)[0]
has_mac = msg_size & MAC_BIT
if has_mac:
msg_size ^= MAC_BIT
msg_size += 20
elif self.__hmac_send:
raise ValueError("Received message without MAC")
state = 1
else:
msg_size = 4
state = 0
# Obscure: We call message_input() with __input_lock
# held!!! And message_input() may end up calling
# message_output(), which has its own lock. But
# message_output() cannot call message_input(), so
# the locking order is always consistent, which
# prevents deadlock. Also, message_input() may
# take a long time, because it can cause an
# incoming call to be handled. During all this
# time, the __input_lock is held. That's a good
# thing, because it serializes incoming calls.
if has_mac:
mac = msg[:20]
msg = msg[20:]
if self.__hmac_recv:
self.__hmac_recv.update(msg)
_mac = self.__hmac_recv.digest()
if mac != _mac:
raise ValueError("MAC failed: %r != %r"
% (_mac, mac))
else:
log("Received MAC but no session key set")
elif self.__hmac_send:
raise ValueError("Received message without MAC")
self.message_input(msg)
self.__state = state
self.__has_mac = has_mac
self.__msg_size = msg_size
self.__inp = inp[offset:]
self.__input_len = input_len - offset
finally:
self.__input_lock.release()
def readable(self):
return True
def writable(self):
return bool(self.__output_messages or self.__output)
def should_close(self):
self.__output_messages.append(_close_marker)
def handle_write(self):
output = self.__output
messages = self.__output_messages
while output or messages:
# Process queued messages until we have enough output
size = sum((len(s) for s in output))
while (size <= SEND_SIZE) and messages:
message = messages[0]
if isinstance(message, six.binary_type):
size += self.__message_output(messages.pop(0), output)
elif isinstance(message, six.text_type):
# XXX This can silently lead to data loss and client hangs
# if asserts aren't enabled. Encountered this under Python3
# and 'ruok' protocol
assert False, "Got a unicode message: %s" % repr(message)
elif message is _close_marker:
del messages[:]
del output[:]
return self.close()
else:
try:
message = six.advance_iterator(message)
except StopIteration:
messages.pop(0)
else:
assert(isinstance(message, six.binary_type))
size += self.__message_output(message, output)
v = b"".join(output)
del output[:]
try:
n = self.send(v)
except socket.error as err:
# Fix for https://bugs.launchpad.net/zodb/+bug/182833
# ensure the above mentioned "output" invariant
output.insert(0, v)
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_write_errors:
break # we couldn't write anything
raise
if n < len(v):
output.append(v[n:])
break # we can't write any more
def handle_close(self):
self.close()
def message_output(self, message):
if self.__closed:
raise DisconnectedError(
"This action is temporarily unavailable.<p>")
self.__output_messages.append(message)
def __message_output(self, message, output):
# do two separate appends to avoid copying the message string
size = 4
if self.__hmac_send:
output.append(struct.pack(">I", len(message) | MAC_BIT))
self.__hmac_send.update(message)
output.append(self.__hmac_send.digest())
size += 20
else:
output.append(struct.pack(">I", len(message)))
if len(message) <= SEND_SIZE:
output.append(message)
else:
for i in range(0, len(message), SEND_SIZE):
output.append(message[i:i+SEND_SIZE])
return size + len(message)
def close(self):
if not self.__closed:
self.__closed = True
self.__super_close()
from __future__ import print_function
##############################################################################
#
# Copyright (c) 2001-2005 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
from __future__ import with_statement
import asyncore
import os
import socket
import errno
from ZODB.utils import positive_id
from ZEO._compat import thread, get_ident
# Original comments follow; they're hard to follow in the context of
# ZEO's use of triggers. TODO: rewrite from a ZEO perspective.
# Wake up a call to select() running in the main thread.
#
# This is useful in a context where you are using Medusa's I/O
# subsystem to deliver data, but the data is generated by another
# thread. Normally, if Medusa is in the middle of a call to
# select(), new output data generated by another thread will have
# to sit until the call to select() either times out or returns.
# If the trigger is 'pulled' by another thread, it should immediately
# generate a READ event on the trigger object, which will force the
# select() invocation to return.
#
# A common use for this facility: letting Medusa manage I/O for a
# large number of connections; but routing each request through a
# thread chosen from a fixed-size thread pool. When a thread is
# acquired, a transaction is performed, but output data is
# accumulated into buffers that will be emptied more efficiently
# by Medusa. [picture a server that can process database queries
# rapidly, but doesn't want to tie up threads waiting to send data
# to low-bandwidth connections]
#
# The other major feature provided by this class is the ability to
# move work back into the main thread: if you call pull_trigger()
# with a thunk argument, when select() wakes up and receives the
# event it will call your thunk from within that thread. The main
# purpose of this is to remove the need to wrap thread locks around
# Medusa's data structures, which normally do not need them. [To see
# why this is true, imagine this scenario: A thread tries to push some
# new data onto a channel's outgoing data queue at the same time that
# the main thread is trying to remove some]
class _triggerbase(object):
"""OS-independent base class for OS-dependent trigger class."""
kind = None # subclass must set to "pipe" or "loopback"; used by repr
def __init__(self):
self._closed = False
# `lock` protects the `thunks` list from being traversed and
# appended to simultaneously.
self.lock = thread.allocate_lock()
# List of no-argument callbacks to invoke when the trigger is
# pulled. These run in the thread running the asyncore mainloop,
# regardless of which thread pulls the trigger.
self.thunks = []
def readable(self):
return 1
def writable(self):
return 0
def handle_connect(self):
pass
def handle_close(self):
self.close()
# Override the asyncore close() method, because it doesn't know about
# (so can't close) all the gimmicks we have open. Subclass must
# supply a _close() method to do platform-specific closing work. _close()
# will be called iff we're not already closed.
def close(self):
if not self._closed:
self._closed = True
self.del_channel()
self._close() # subclass does OS-specific stuff
def _close(self): # see close() above; subclass must supply
raise NotImplementedError
def pull_trigger(self, *thunk):
if thunk:
with self.lock:
self.thunks.append(thunk)
try:
self._physical_pull()
except Exception:
if not self._closed:
raise
# Subclass must supply _physical_pull, which does whatever the OS
# needs to do to provoke the "write" end of the trigger.
def _physical_pull(self):
raise NotImplementedError
def handle_read(self):
try:
self.recv(8192)
except socket.error:
return
while 1:
with self.lock:
if self.thunks:
thunk = self.thunks.pop(0)
else:
return
try:
thunk[0](*thunk[1:])
except:
nil, t, v, tbinfo = asyncore.compact_traceback()
print(('exception in trigger thunk:'
' (%s:%s %s)' % (t, v, tbinfo)))
def __repr__(self):
return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self))
if os.name == 'posix':
class trigger(_triggerbase, asyncore.file_dispatcher):
kind = "pipe"
def __init__(self, map=None):
_triggerbase.__init__(self)
r, self.trigger = os.pipe()
asyncore.file_dispatcher.__init__(self, r, map)
if self.socket.fd != r:
# Starting in Python 2.6, the descriptor passed to
# file_dispatcher gets duped and assigned to
# self.socket.fd. This breals the instantiation semantics and
# is a bug imo. I dount it will get fixed, but maybe
# it will. Who knows. For that reason, we test for the
# fd changing rather than just checking the Python version.
os.close(r)
def _close(self):
os.close(self.trigger)
asyncore.file_dispatcher.close(self)
def _physical_pull(self):
os.write(self.trigger, b'x')
else:
# Windows version; uses just sockets, because a pipe isn't select'able
# on Windows.
class BindError(Exception):
pass
class trigger(_triggerbase, asyncore.dispatcher):
kind = "loopback"
def __init__(self, map=None):
_triggerbase.__init__(self)
# Get a pair of connected sockets. The trigger is the 'w'
# end of the pair, which is connected to 'r'. 'r' is put
# in the asyncore socket map. "pulling the trigger" then
# means writing something on w, which will wake up r.
w = socket.socket()
# Disable buffering -- pulling the trigger sends 1 byte,
# and we want that sent immediately, to wake up asyncore's
# select() ASAP.
w.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1)
try:
w.connect(connect_address)
break # success
except socket.error as detail:
if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
w.close()
raise BindError("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
r, addr = a.accept() # r becomes asyncore's (self.)socket
a.close()
self.trigger = w
asyncore.dispatcher.__init__(self, r, map)
def _close(self):
# self.socket is r, and self.trigger is w, from __init__
self.socket.close()
self.trigger.close()
def _physical_pull(self):
self.trigger.send('x')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment