Commit e1e8558d authored by Jim Fulton's avatar Jim Fulton

Implemented ZEO servers with asyncio.

parent a94d9b50
Changelog Changelog
========= =========
5.0.0 (unreleases)
------------------
This is a major ZEO revision, which replaces the ZEO network protocol
implementation.
New features:
Dropped features:
- The ZEO authentication protocol.
This will be replaced by new authentication mechanims leveraging SSL.
- The ZEO monitor server.
- Full cache verification.
- Client suppprt for servers older than ZODB 3.9
- Server support for clients older than ZEO 4.2.0
4.2.0 (2016-06-15) 4.2.0 (2016-06-15)
------------------ ------------------
......
...@@ -165,7 +165,7 @@ class ClientStorage(object): ...@@ -165,7 +165,7 @@ class ClientStorage(object):
shared_blob_dir shared_blob_dir
Flag whether the blob_dir is a server-shared filesystem Flag whether the blob_dir is a server-shared filesystem
that should be used instead of transferring blob data over that should be used instead of transferring blob data over
zrpc. ZEO protocol.
blob_cache_size blob_cache_size
Maximum size of the ZEO blob cache, in bytes. If not set, then Maximum size of the ZEO blob cache, in bytes. If not set, then
...@@ -479,12 +479,10 @@ class ClientStorage(object): ...@@ -479,12 +479,10 @@ class ClientStorage(object):
return buf return buf
def history(self, oid, size=1, def history(self, oid, size=1):
timeout=None, # for tests
):
"""Storage API: return a sequence of HistoryEntry objects. """Storage API: return a sequence of HistoryEntry objects.
""" """
return self._call('history', oid, size, timeout=timeout) return self._call('history', oid, size)
def record_iternext(self, next=None): def record_iternext(self, next=None):
"""Storage API: get the next database record. """Storage API: get the next database record.
...@@ -762,8 +760,8 @@ class ClientStorage(object): ...@@ -762,8 +760,8 @@ class ClientStorage(object):
txn.set_data(self, TransactionBuffer(self._connection_generation)) txn.set_data(self, TransactionBuffer(self._connection_generation))
if self.protocol_version < b'Z5': # XXX we'd like to allow multiple transactions at a time at some point,
# Earlier protocols only allowed one transaction at a time :( # but for now, due to server limitations, TCBOO.
self._commit_lock.acquire() self._commit_lock.acquire()
self._tbuf = txn.data(self) self._tbuf = txn.data(self)
...@@ -780,7 +778,6 @@ class ClientStorage(object): ...@@ -780,7 +778,6 @@ class ClientStorage(object):
if tbuf is not None: if tbuf is not None:
tbuf.close() tbuf.close()
txn.set_data(self, None) txn.set_data(self, None)
if self.protocol_version < b'Z5':
self._commit_lock.release() self._commit_lock.release()
def lastTransaction(self): def lastTransaction(self):
......
This diff is collapsed.
...@@ -26,14 +26,24 @@ def client(*args, **kw): ...@@ -26,14 +26,24 @@ def client(*args, **kw):
return ZEO.ClientStorage.ClientStorage(*args, **kw) return ZEO.ClientStorage.ClientStorage(*args, **kw)
def DB(*args, **kw): def DB(*args, **kw):
s = client(*args, **kw)
try:
import ZODB import ZODB
return ZODB.DB(client(*args, **kw)) return ZODB.DB(s)
except Exception:
s.close()
raise
def connection(*args, **kw): def connection(*args, **kw):
return DB(*args, **kw).open_then_close_db_when_connection_closes() db = DB(*args, **kw)
try:
return db.open_then_close_db_when_connection_closes()
except Exception:
db.close()
ra
def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None, def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
port=None): port=0, **kw):
"""Convenience function to start a server for interactive exploration """Convenience function to start a server for interactive exploration
This fuction starts a ZEO server, given a storage configuration or This fuction starts a ZEO server, given a storage configuration or
...@@ -74,14 +84,7 @@ def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None, ...@@ -74,14 +84,7 @@ def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
import os, ZEO.tests.forker import os, ZEO.tests.forker
if storage_conf is None and path is None: if storage_conf is None and path is None:
storage_conf = '<mappingstorage>\n</mappingstorage>' storage_conf = '<mappingstorage>\n</mappingstorage>'
if port is None and zeo_conf is None:
port = ZEO.tests.forker.get_port()
addr, admin, pid, config = ZEO.tests.forker.start_zeo_server( return ZEO.tests.forker.start_zeo_server(
storage_conf, zeo_conf, port, keep=True, path=path, storage_conf, zeo_conf, port, keep=True, path=path,
blob_dir=blob_dir, suicide=False) blob_dir=blob_dir, suicide=False, threaded=True, **kw)
os.remove(config)
def stop_server():
ZEO.tests.forker.shutdown_zeo_server(admin)
os.waitpid(pid, 0)
return addr, stop_server
...@@ -31,33 +31,31 @@ else: ...@@ -31,33 +31,31 @@ else:
s.close() s.close()
del s del s
from ZEO.zrpc.connection import Connection
from ZEO.zrpc.log import log
import ZEO.zrpc.log
import logging import logging
# Export the main asyncore loop logger = logging.getLogger(__name__)
loop = asyncore.loop
class Dispatcher(asyncore.dispatcher): class Acceptor(asyncore.dispatcher):
"""A server that accepts incoming RPC connections""" """A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__
def __init__(self, addr, factory=Connection, map=None): def __init__(self, addr, factory):
self.__super_init(map=map) self.socket_map = {}
asyncore.dispatcher.__init__(self, map=self.socket_map)
self.addr = addr self.addr = addr
self.factory = factory self.factory = factory
self._open_socket() self._open_socket()
def _open_socket(self): def _open_socket(self):
if type(self.addr) == tuple: addr = self.addr
if self.addr[0] == '' and _has_dualstack:
if type(addr) == tuple:
if addr[0] == '' and _has_dualstack:
# Wildcard listen on all interfaces, both IPv4 and # Wildcard listen on all interfaces, both IPv4 and
# IPv6 if possible # IPv6 if possible
self.create_socket(socket.AF_INET6, socket.SOCK_STREAM) self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt( self.socket.setsockopt(
socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
elif ':' in self.addr[0]: elif ':' in addr[0]:
self.create_socket(socket.AF_INET6, socket.SOCK_STREAM) self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
if _has_dualstack: if _has_dualstack:
# On Linux, IPV6_V6ONLY is off by default. # On Linux, IPV6_V6ONLY is off by default.
...@@ -68,20 +66,28 @@ class Dispatcher(asyncore.dispatcher): ...@@ -68,20 +66,28 @@ class Dispatcher(asyncore.dispatcher):
self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
else: else:
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM) self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.set_reuse_addr() self.set_reuse_addr()
log("listening on %s" % str(self.addr), logging.INFO)
for i in range(25): for i in range(25):
try: try:
self.bind(self.addr) self.bind(addr)
except Exception as exc: except Exception as exc:
log("bind failed %s waiting", i) logger.info("bind on %s failed %s waiting", addr, i)
if i == 24: if i == 24:
raise raise
else: else:
time.sleep(5) time.sleep(5)
except:
logger.exception('binding')
raise
else: else:
break break
if isinstance(addr, tuple) and addr[1] == 0:
self.addr = addr = self.socket.getsockname()
logger.info("listening on %s", str(addr))
self.listen(5) self.listen(5)
def writable(self): def writable(self):
...@@ -94,7 +100,7 @@ class Dispatcher(asyncore.dispatcher): ...@@ -94,7 +100,7 @@ class Dispatcher(asyncore.dispatcher):
try: try:
sock, addr = self.accept() sock, addr = self.accept()
except socket.error as msg: except socket.error as msg:
log("accepted failed: %s" % msg) logger.info("accepted failed: %s", msg)
return return
...@@ -115,9 +121,24 @@ class Dispatcher(asyncore.dispatcher): ...@@ -115,9 +121,24 @@ class Dispatcher(asyncore.dispatcher):
try: try:
c = self.factory(sock, addr) c = self.factory(sock, addr)
except: except Exception:
if sock.fileno() in asyncore.socket_map: if sock.fileno() in asyncore.socket_map:
del asyncore.socket_map[sock.fileno()] del asyncore.socket_map[sock.fileno()]
ZEO.zrpc.log.logger.exception("Error in handle_accept") logger.exception("Error in handle_accept")
else: else:
log("connect from %s: %s" % (repr(addr), c)) logger.info("connect from %s: %s", repr(addr), c)
def loop(self):
try:
asyncore.loop(map=self.socket_map)
except Exception:
if not self.__closed:
raise # Unexpected exc
logger.debug('acceptor %s loop stopped', self.addr)
__closed = False
def close(self):
if not self.__closed:
self.__closed = True
asyncore.dispatcher.close(self)
from struct import unpack
import asyncio
import logging
from .marshal import encoder
logger = logging.getLogger(__name__)
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface
"""
# All of the code in this class runs in a single dedicated
# thread. Thus, we can mostly avoid worrying about interleaved
# operations.
# One place where special care was required was in cache setup on
# connect. See finish connect below.
transport = protocol_version = None
def __init__(self, loop, addr):
self.loop = loop
self.addr = addr
self.input = [] # Input buffer when assembling messages
self.output = [] # Output buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
def __repr__(self):
return self.name
closed = False
def close(self):
if not self.closed:
self.closed = True
if self.transport is not None:
self.transport.close()
def connection_made(self, transport):
logger.info("Connected %s", self)
self.transport = transport
paused = self.paused
output = self.output
append = output.append
writelines = transport.writelines
from struct import pack
def write(message):
if paused:
append(message)
else:
writelines((pack(">I", len(message)), message))
self._write = write
def writeit(data):
# Note, don't worry about combining messages. Iters
# will be used with blobs, in which case, the individual
# messages will be big to begin with.
data = iter(data)
for message in data:
writelines((pack(">I", len(message)), message))
if paused:
append(data)
break
self._writeit = writeit
got = 0
want = 4
getting_size = True
def data_received(self, data):
# Low-level input handler collects data into sized messages.
# Note that the logic below assume that when new data pushes
# us over what we want, we process it in one call until we
# need more, because we assume that excess data is all in the
# last item of self.input. This is why the exception handling
# in the while loop is critical. Without it, an exception
# might cause us to exit before processing all of the data we
# should, when then causes the logic to be broken in
# subsequent calls.
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
try:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [input[-1][-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.message_received(collected)
except Exception:
logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size)
def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
self.encode = encoder()
self.finish_connect(protocol_version)
def call_async(self, method, args):
self._write(self.encode(0, True, method, args))
def call_async_iter(self, it):
self._writeit(self.encode(0, True, method, args)
for method, args in it)
def pause_writing(self):
self.paused.append(1)
def resume_writing(self):
paused = self.paused
del paused[:]
output = self.output
writelines = self.transport.writelines
from struct import pack
while output and not paused:
message = output.pop(0)
if isinstance(message, bytes):
writelines((pack(">I", len(message)), message))
else:
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
output.insert(0, data)
break
def get_peername(self):
return self.transport.get_extra_info('peername')
from pickle import loads, dumps
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZODB.ConflictResolution import ResolvedSerial from ZODB.ConflictResolution import ResolvedSerial
from struct import unpack
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import logging import logging
import random import random
import threading import threading
import traceback
import ZODB.event import ZODB.event
import ZODB.POSException import ZODB.POSException
...@@ -15,17 +12,16 @@ import ZODB.POSException ...@@ -15,17 +12,16 @@ import ZODB.POSException
import ZEO.Exceptions import ZEO.Exceptions
import ZEO.interfaces import ZEO.interfaces
from . import base
from .marshal import decode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Fallback = object() Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests local_random = random.Random() # use separate generator to facilitate tests
class Closed(Exception): class Protocol(base.Protocol):
"""A connection has been closed
"""
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -36,9 +32,7 @@ class Protocol(asyncio.Protocol): ...@@ -36,9 +32,7 @@ class Protocol(asyncio.Protocol):
# One place where special care was required was in cache setup on # One place where special care was required was in cache setup on
# connect. See finish connect below. # connect. See finish connect below.
transport = protocol_version = None protocols = b'Z309', b'Z310', b'Z3101', b'Z4', b'Z5'
protocols = b"Z309", b"Z310", b"Z3101"
def __init__(self, loop, def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1, addr, client, storage_key, read_only, connect_poll=1,
...@@ -51,8 +45,7 @@ class Protocol(asyncio.Protocol): ...@@ -51,8 +45,7 @@ class Protocol(asyncio.Protocol):
cache is a ZEO.interfaces.IClientCache. cache is a ZEO.interfaces.IClientCache.
""" """
self.loop = loop super(Protocol, self).__init__(loop, addr)
self.addr = addr
self.storage_key = storage_key self.storage_key = storage_key
self.read_only = read_only self.read_only = read_only
self.name = "%s(%r, %r, %r)" % ( self.name = "%s(%r, %r, %r)" % (
...@@ -61,19 +54,9 @@ class Protocol(asyncio.Protocol): ...@@ -61,19 +54,9 @@ class Protocol(asyncio.Protocol):
self.connect_poll = connect_poll self.connect_poll = connect_poll
self.heartbeat_interval = heartbeat_interval self.heartbeat_interval = heartbeat_interval
self.futures = {} # { message_id -> future } self.futures = {} # { message_id -> future }
self.input = [] # Buffer when assembling messages
self.output = [] # Buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
self.connect() self.connect()
def __repr__(self):
return self.name
closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -118,35 +101,7 @@ class Protocol(asyncio.Protocol): ...@@ -118,35 +101,7 @@ class Protocol(asyncio.Protocol):
) )
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) super(Protocol, self).connection_made(transport)
self.transport = transport
paused = self.paused
output = self.output
append = output.append
writelines = transport.writelines
from struct import pack
def write(message):
if paused:
append(message)
else:
writelines((pack(">I", len(message)), message))
self._write = write
def writeit(data):
# Note, don't worry about combining messages. Iters
# will be used with blobs, in which case, the individual
# messages will be big to begin with.
data = iter(data)
for message in data:
writelines((pack(">I", len(message)), message))
if paused:
append(data)
break
self._writeit = writeit
self.heartbeat(write=False) self.heartbeat(write=False)
def connection_lost(self, exc): def connection_lost(self, exc):
...@@ -181,6 +136,7 @@ class Protocol(asyncio.Protocol): ...@@ -181,6 +136,7 @@ class Protocol(asyncio.Protocol):
# invalidations. # invalidations.
self.protocol_version = min(protocol_version, self.protocols[-1]) self.protocol_version = min(protocol_version, self.protocols[-1])
if self.protocol_version not in self.protocols: if self.protocol_version not in self.protocols:
self.client.register_failed( self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version)) self, ZEO.Exceptions.ProtocolError(protocol_version))
...@@ -236,59 +192,9 @@ class Protocol(asyncio.Protocol): ...@@ -236,59 +192,9 @@ class Protocol(asyncio.Protocol):
else: else:
self.client.register_failed(self, exc) self.client.register_failed(self, exc)
got = 0
want = 4
getting_size = True
def data_received(self, data):
# Low-level input handler collects data into sized messages.
# Note that the logic below assume that when new data pushes
# us over what we want, we process it in one call until we
# need more, because we assume that excess data is all in the
# last item of self.input. This is why the exception handling
# in the while loop is critical. Without it, an exception
# might cause us to exit before processing all of the data we
# should, when then causes the logic to be broken in
# subsequent calls.
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
try:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [input[-1][-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.message_received(collected)
except Exception:
logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size)
def first_message_received(self, data):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
self.finish_connect(data)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
msgid, async, name, args = loads(data) msgid, async, name, args = decode(data)
if name == '.reply': if name == '.reply':
future = self.futures.pop(msgid) future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and if (isinstance(args, tuple) and len(args) > 1 and
...@@ -315,46 +221,16 @@ class Protocol(asyncio.Protocol): ...@@ -315,46 +221,16 @@ class Protocol(asyncio.Protocol):
else: else:
raise AttributeError(name) raise AttributeError(name)
def call_async(self, method, args):
self._write(dumps((0, True, method, args), 3))
def call_async_iter(self, it):
self._writeit(dumps((0, True, method, args), 3) for method, args in it)
message_id = 0 message_id = 0
def call(self, future, method, args): def call(self, future, method, args):
self.message_id += 1 self.message_id += 1
self.futures[self.message_id] = future self.futures[self.message_id] = future
self._write(dumps((self.message_id, False, method, args), 3)) self._write(self.encode(self.message_id, False, method, args))
return future return future
def promise(self, method, *args): def promise(self, method, *args):
return self.call(Promise(), method, args) return self.call(Promise(), method, args)
def pause_writing(self):
self.paused.append(1)
def resume_writing(self):
paused = self.paused
del paused[:]
output = self.output
writelines = self.transport.writelines
from struct import pack
while output and not paused:
message = output.pop(0)
if isinstance(message, bytes):
writelines((pack(">I", len(message)), message))
else:
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
output.insert(0, data)
break
def get_peername(self):
return self.transport.get_extra_info('peername')
# Methods called by the server. # Methods called by the server.
# WARNING WARNING we can't call methods that call back to us # WARNING WARNING we can't call methods that call back to us
# syncronously, as that would lead to DEADLOCK! # syncronously, as that would lead to DEADLOCK!
...@@ -825,6 +701,7 @@ class ClientThread(ClientRunner): ...@@ -825,6 +701,7 @@ class ClientThread(ClientRunner):
exception = None exception = None
def run(self): def run(self):
loop = None
try: try:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
...@@ -832,15 +709,20 @@ class ClientThread(ClientRunner): ...@@ -832,15 +709,20 @@ class ClientThread(ClientRunner):
self.started.set() self.started.set()
loop.run_forever() loop.run_forever()
except Exception as exc: except Exception as exc:
raise
logger.exception("Client thread") logger.exception("Client thread")
self.exception = exc self.exception = exc
finally: finally:
if not self.closed: if not self.closed:
if self.client.ready:
self.closed = True self.closed = True
try:
if self.client.ready:
self.client.ready = False self.client.ready = False
self.client.client.notify_disconnected() self.client.client.notify_disconnected()
except AttributeError:
pass
logger.critical("Client loop stopped unexpectedly") logger.critical("Client loop stopped unexpectedly")
if loop is not None:
loop.close() loop.close()
logger.debug('Stopping client thread') logger.debug('Stopping client thread')
......
...@@ -11,78 +11,64 @@ ...@@ -11,78 +11,64 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
"""Support for marshaling ZEO messages
Not to be confused with marshaling objects in ZODB.
We currently use pickle. In the future, we may use a
Python-independent format, or possibly a minimal pickle subset.
"""
import logging import logging
from ZEO._compat import Unpickler, Pickler, BytesIO, PY3, PYPY from .._compat import Unpickler, Pickler, BytesIO, PY3, PYPY
from ZEO.zrpc.error import ZRPCError from ..shortrepr import short_repr
from ZEO.zrpc.log import log, short_repr
logger = logging.getLogger(__name__)
def encode(*args): # args: (msgid, flags, name, args)
# (We used to have a global pickler, but that's not thread-safe. :-( ) def encoder():
"""Return a non-thread-safe encoder
# It's not thread safe if, in the couse of pickling, we call the """
# Python interpeter, which releases the GIL.
if PY3 or PYPY:
# Note that args may contain very large binary pickles already; for
# this reason, it's important to use proto 1 (or higher) pickles here
# too. For a long time, this used proto 0 pickles, and that can
# bloat our pickle to 4x the size (due to high-bit and control bytes
# being represented by \xij escapes in proto 0).
# Undocumented: cPickle.Pickler accepts a lone protocol argument;
# pickle.py does not.
if PY3:
# XXX: Py3: Needs optimization.
f = BytesIO() f = BytesIO()
pickler = Pickler(f, 3) getvalue = f.getvalue
seek = f.seek
truncate = f.truncate
pickler = Pickler(f, 3 if PY3 else 1)
pickler.fast = 1 pickler.fast = 1
pickler.dump(args) dump = pickler.dump
res = f.getvalue() def encode(*args):
return res seek(0)
truncate()
dump(args)
return getvalue()
else: else:
pickler = Pickler(1)
pickler.fast = 1
# Only CPython's cPickle supports dumping
# and returning in one operation:
# return pickler.dump(args, 1)
# For PyPy we must return the value; fortunately this
# works the same on CPython and is no more expensive
pickler.dump(args)
return pickler.getvalue()
if PY3:
# XXX: Py3: Needs optimization.
fast_encode = encode
elif PYPY:
# can't use the python-2 branch, need a new pickler
# every time, getvalue() only works once
fast_encode = encode
else:
def fast_encode():
# Only use in cases where you *know* the data contains only basic
# Python objects
pickler = Pickler(1) pickler = Pickler(1)
pickler.fast = 1 pickler.fast = 1
dump = pickler.dump dump = pickler.dump
def fast_encode(*args): def encode(*args):
return dump(args, 1) return dump(args, 2)
return fast_encode
fast_encode = fast_encode() return encode
def encode(*args):
return encoder()(*args)
def decode(msg): def decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global unpickler.find_global = find_global
try: try:
unpickler.find_class = find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except:
log("can't decode message: %s" % short_repr(msg), logger.error("can't decode message: %s" % short_repr(msg))
level=logging.ERROR)
raise raise
def server_decode(msg): def server_decode(msg):
...@@ -90,15 +76,15 @@ def server_decode(msg): ...@@ -90,15 +76,15 @@ def server_decode(msg):
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global unpickler.find_global = server_find_global
try: try:
unpickler.find_class = server_find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = server_find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except:
log("can't decode message: %s" % short_repr(msg), logger.error("can't decode message: %s" % short_repr(msg))
level=logging.ERROR)
raise raise
_globals = globals() _globals = globals()
...@@ -111,12 +97,12 @@ def find_global(module, name): ...@@ -111,12 +97,12 @@ def find_global(module, name):
try: try:
m = __import__(module, _globals, _globals, _silly) m = __import__(module, _globals, _globals, _silly)
except ImportError as msg: except ImportError as msg:
raise ZRPCError("import error %s: %s" % (module, msg)) raise ImportError("import error %s: %s" % (module, msg))
try: try:
r = getattr(m, name) r = getattr(m, name)
except AttributeError: except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name)) raise ImportError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) safe = getattr(r, '__no_side_effects__', 0)
if safe: if safe:
...@@ -126,7 +112,7 @@ def find_global(module, name): ...@@ -126,7 +112,7 @@ def find_global(module, name):
if type(r) == exception_type_type and issubclass(r, Exception): if type(r) == exception_type_type and issubclass(r, Exception):
return r return r
raise ZRPCError("Unsafe global: %s.%s" % (module, name)) raise ImportError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name): def server_find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
...@@ -135,11 +121,11 @@ def server_find_global(module, name): ...@@ -135,11 +121,11 @@ def server_find_global(module, name):
raise ImportError raise ImportError
m = __import__(module, _globals, _globals, _silly) m = __import__(module, _globals, _globals, _silly)
except ImportError as msg: except ImportError as msg:
raise ZRPCError("import error %s: %s" % (module, msg)) raise ImportError("import error %s: %s" % (module, msg))
try: try:
r = getattr(m, name) r = getattr(m, name)
except AttributeError: except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name)) raise ImportError("module %s has no global %s" % (module, name))
return r return r
import asyncio
import json
import logging
import os
import random
import threading
import ZODB.POSException
logger = logging.getLogger(__name__)
from ..shortrepr import short_repr
from . import base
from .marshal import server_decode
class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface
"""
protocols = b'Z4', b'Z5'
name = 'server protocol'
methods = set(('register', ))
unlogged_exception_types = (
ZODB.POSException.POSKeyError,
)
def __init__(self, loop, addr, zeo_storage):
"""Create a server's client interface
"""
super(ServerProtocol, self).__init__(loop, addr)
self.zeo_storage = zeo_storage
closed = False
def close(self):
if not self.closed:
self.closed = True
if self.transport is not None:
self.transport.close()
connected = None # for tests
def connection_made(self, transport):
self.connected = True
super(ServerProtocol, self).connection_made(transport)
self._write(best_protocol_version)
def connection_lost(self, exc):
self.connected = False
if exc:
logger.error("Disconnected %s:%s", exc.__class__.__name__, exc)
self.zeo_storage.notify_disconnected()
self.loop.stop()
def finish_connect(self, protocol_version):
if protocol_version == b'ruok':
self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii"))
self.close()
else:
if protocol_version in self.protocols:
logger.info("received handshake %r" % protocol_version)
self.protocol_version = protocol_version
self.zeo_storage.notify_connected(self)
else:
logger.error("bad handshake %s" % short_repr(protocol_version))
self.close()
def call_soon_threadsafe(self, func, *args):
try:
self.loop.call_soon_threadsafe(func, *args)
except RuntimeError:
if self.connected:
logger.exception("call_soon_threadsafe failed while connected")
def message_received(self, message):
try:
message_id, async, name, args = server_decode(message)
except Exception:
logger.exception("Can't deserialize message")
self.close()
if message_id == -1:
return # keep-alive
if name not in self.methods:
logger.error('Invalid method, %r', name)
self.close()
try:
result = getattr(self.zeo_storage, name)(*args)
except Exception as exc:
if not isinstance(exc, self.unlogged_exception_types):
logger.exception(
"Bad %srequest, %r", 'async ' if async else '', name)
if async:
return self.close() # No way to recover/cry for help
else:
return self.send_error(message_id, exc)
if not async:
self.send_reply(message_id, result)
def send_reply(self, message_id, result, send_error=False):
try:
result = self.encode(message_id, 0, '.reply', result)
except Exception:
if isinstance(result, Delay):
result.set_sender(message_id, self)
return
else:
logger.exception("Unpicklable response %r", result)
if not send_error:
self.send_error(
message_id,
ValueError("Couldn't pickle response"),
True)
self._write(result)
def send_reply_threadsafe(self, message_id, result):
self.loop.call_soon_threadsafe(self.reply, message_id, result)
def send_error(self, message_id, exc, send_error=False):
"""Abstracting here so we can make this cleaner in the future
"""
self.send_reply(message_id, (exc.__class__, exc), send_error)
def async(self, method, *args):
self.call_async(method, args)
best_protocol_version = os.environ.get(
'ZEO_SERVER_PROTOCOL',
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage):
protocol = ServerProtocol(loop, addr, zeo_storage)
cr = loop.create_connection((lambda : protocol), sock=socket)
asyncio.async(cr, loop=loop)
class Delay:
"""Used to delay response to client for synchronous calls.
When a synchronous call is made and the original handler returns
without handling the call, it returns a Delay object that prevents
the mainloop from sending a response.
"""
msgid = protocol = sent = None
def set_sender(self, msgid, protocol):
self.msgid = msgid
self.protocol = protocol
def reply(self, obj):
self.sent = 'reply'
self.protocol.send_reply(self.msgid, obj)
def error(self, exc_info):
self.sent = 'error'
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.protocol.send_error(self.msgid, exc_info[1])
def __repr__(self):
return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self),
self.msgid, self.protocol, self.sent)
def __reduce__(self):
raise TypeError("Can't pickle delays.")
class Result(Delay):
def __init__(self, *args):
self.args = args
def set_sender(self, msgid, protocol):
reply, callback = self.args
protocol.send_reply(msgid, reply)
callback()
class MTDelay(Delay):
def __init__(self):
self.ready = threading.Event()
def set_sender(self, *args):
Delay.set_sender(self, *args)
self.ready.set()
def reply(self, obj):
self.ready.wait()
self.protocol.call_soon_threadsafe(
self.protocol.send_reply, self.msgid, obj)
def error(self, exc_info):
self.ready.wait()
log("Error raised in delayed method", logging.ERROR, exc_info=exc_info)
self.protocol.call_soon_threadsafe(Delay.error, self, exc_info)
...@@ -30,13 +30,18 @@ class Loop: ...@@ -30,13 +30,18 @@ class Loop:
if not future.cancelled(): if not future.cancelled():
future.set_exception(ConnectionRefusedError()) future.set_exception(ConnectionRefusedError())
def create_connection(self, protocol_factory, host, port): def create_connection(
self, protocol_factory, host=None, port=None, sock=None
):
future = asyncio.Future(loop=self) future = asyncio.Future(loop=self)
if sock is None:
addr = host, port addr = host, port
if addr in self.addrs: if addr in self.addrs:
self._connect(future, protocol_factory) self._connect(future, protocol_factory)
else: else:
self.connecting[addr] = future, protocol_factory self.connecting[addr] = future, protocol_factory
else:
self._connect(future, protocol_factory)
return future return future
...@@ -61,6 +66,14 @@ class Loop: ...@@ -61,6 +66,14 @@ class Loop:
def call_exception_handler(self, context): def call_exception_handler(self, context):
self.exceptions.append(context) self.exceptions.append(context)
closed = False
def close(self):
self.closed = True
stopped = False
def stop(self):
self.stopped = True
class Handle: class Handle:
cancelled = False cancelled = False
......
This diff is collapsed.
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
_auth_modules = {}
def get_module(name):
if name == 'sha':
from auth_sha import StorageClass, SHAClient, Database
return StorageClass, SHAClient, Database
elif name == 'digest':
from .auth_digest import StorageClass, DigestClient, DigestDatabase
return StorageClass, DigestClient, DigestDatabase
else:
return _auth_modules.get(name)
def register_module(name, storage_class, client, db):
if name in _auth_modules:
raise TypeError("%s is already registred" % name)
_auth_modules[name] = storage_class, client, db
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Digest authentication for ZEO
This authentication mechanism follows the design of HTTP digest
authentication (RFC 2069). It is a simple challenge-response protocol
that does not send passwords in the clear, but does not offer strong
security. The RFC discusses many of the limitations of this kind of
protocol.
Guard the password database as if it contained plaintext passwords.
It stores the hash of a username and password. This does not expose
the plaintext password, but it is sensitive nonetheless. An attacker
with the hash can impersonate the real user. This is a limitation of
the simple digest scheme.
HTTP is a stateless protocol, and ZEO is a stateful protocol. The
security requirements are quite different as a result. The HTTP
protocol uses a nonce as a challenge. The ZEO protocol requires a
separate session key that is used for message authentication. We
generate a second nonce for this purpose; the hash of nonce and
user/realm/password is used as the session key.
TODO: I'm not sure if this is a sound approach; SRP would be preferred.
"""
import os
import random
import struct
import time
from ZEO.auth.base import Database, Client
from ZEO.StorageServer import ZEOStorage
from ZEO.Exceptions import AuthError
from ZEO.hash import sha1
def get_random_bytes(n=8):
try:
b = os.urandom(n)
except NotImplementedError:
L = [chr(random.randint(0, 255)) for i in range(n)]
b = b"".join(L)
return b
def hexdigest(s):
return sha1(s.encode()).hexdigest()
class DigestDatabase(Database):
def __init__(self, filename, realm=None):
Database.__init__(self, filename, realm)
# Initialize a key used to build the nonce for a challenge.
# We need one key for the lifetime of the server, so it
# is convenient to store in on the database.
self.noncekey = get_random_bytes(8)
def _store_password(self, username, password):
dig = hexdigest("%s:%s:%s" % (username, self.realm, password))
self._users[username] = dig
def session_key(h_up, nonce):
# The hash itself is a bit too short to be a session key.
# HMAC wants a 64-byte key. We don't want to use h_up
# directly because it would never change over time. Instead
# use the hash plus part of h_up.
return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() +
h_up.encode('utf-8')[:44])
class StorageClass(ZEOStorage):
def set_database(self, database):
assert isinstance(database, DigestDatabase)
self.database = database
self.noncekey = database.noncekey
def _get_time(self):
# Return a string representing the current time.
t = int(time.time())
return struct.pack("i", t)
def _get_nonce(self):
# RFC 2069 recommends a nonce of the form
# H(client-IP ":" time-stamp ":" private-key)
dig = sha1()
dig.update(str(self.connection.addr).encode('latin-1'))
dig.update(self._get_time())
dig.update(self.noncekey)
return dig.hexdigest()
def auth_get_challenge(self):
"""Return realm, challenge, and nonce."""
self._challenge = self._get_nonce()
self._key_nonce = self._get_nonce()
return self.auth_realm, self._challenge, self._key_nonce
def auth_response(self, resp):
# verify client response
user, challenge, response = resp
# Since zrpc is a stateful protocol, we just store the nonce
# we sent to the client. It will need to generate a new
# nonce for a new connection anyway.
if self._challenge != challenge:
raise ValueError("invalid challenge")
# lookup user in database
h_up = self.database.get_password(user)
# regeneration resp from user, password, and nonce
check = hexdigest("%s:%s" % (h_up, challenge))
if check == response:
self.connection.setSessionKey(session_key(h_up, self._key_nonce))
return self._finish_auth(check == response)
extensions = [auth_get_challenge, auth_response]
class DigestClient(Client):
extensions = ["auth_get_challenge", "auth_response"]
def start(self, username, realm, password):
_realm, challenge, nonce = self.stub.auth_get_challenge()
if _realm != realm:
raise AuthError("expected realm %r, got realm %r"
% (_realm, realm))
h_up = hexdigest("%s:%s:%s" % (username, realm, password))
resp_dig = hexdigest("%s:%s" % (h_up, challenge))
result = self.stub.auth_response((username, challenge, resp_dig))
if result:
return session_key(h_up, nonce)
else:
return None
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Base classes for defining an authentication protocol.
Database -- abstract base class for password database
Client -- abstract base class for authentication client
"""
from __future__ import print_function
from __future__ import print_function
import os
from ZEO.hash import sha1
class Client:
# Subclass should override to list the names of methods that
# will be called on the server.
extensions = []
def __init__(self, stub):
self.stub = stub
for m in self.extensions:
setattr(self.stub, m, self.stub.extensionMethod(m))
def sort(L):
"""Sort a list in-place and return it."""
L.sort()
return L
class Database:
"""Abstracts a password database.
This class is used both in the authentication process (via
get_password()) and by client scripts that manage the password
database file.
The password file is a simple, colon-separated text file mapping
usernames to password hashes. The hashes are SHA hex digests
produced from the password string.
"""
realm = None
def __init__(self, filename, realm=None):
"""Creates a new Database
filename: a string containing the full pathname of
the password database file. Must be readable by the user
running ZEO. Must be writeable by any client script that
accesses the database.
realm: the realm name (a string)
"""
self._users = {}
self.filename = filename
self.load()
if realm:
if self.realm and self.realm != realm:
raise ValueError("Specified realm %r differs from database "
"realm %r" % (realm or '', self.realm))
else:
self.realm = realm
def save(self, fd=None):
filename = self.filename
needs_closed = False
if not fd:
fd = open(filename, 'w')
needs_closed = True
try:
if self.realm:
print("realm", self.realm, file=fd)
for username in sorted(self._users.keys()):
print("%s: %s" % (username, self._users[username]), file=fd)
finally:
if needs_closed:
fd.close()
def load(self):
filename = self.filename
if not filename:
return
if not os.path.exists(filename):
return
with open(filename) as fd:
L = fd.readlines()
if not L:
return
if L[0].startswith("realm "):
line = L.pop(0).strip()
self.realm = line[len("realm "):]
for line in L:
username, hash = line.strip().split(":", 1)
self._users[username] = hash.strip()
def _store_password(self, username, password):
self._users[username] = self.hash(password)
def get_password(self, username):
"""Returns password hash for specified username.
Callers must check for LookupError, which is raised in
the case of a non-existent user specified."""
if username not in self._users:
raise LookupError("No such user: %s" % username)
return self._users[username]
def hash(self, s):
return sha1(s.encode()).hexdigest()
def add_user(self, username, password):
if username in self._users:
raise LookupError("User %s already exists" % username)
self._store_password(username, password)
def del_user(self, username):
if username not in self._users:
raise LookupError("No such user: %s" % username)
del self._users[username]
def change_password(self, username, password):
if username not in self._users:
raise LookupError("No such user: %s" % username)
self._store_password(username, password)
"""HMAC (Keyed-Hashing for Message Authentication) Python module.
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC:
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
"""Create a new HMAC object.
key: key for the keyed hash object.
msg: Initial input for the hash, if provided.
digestmod: A module supporting PEP 247. Defaults to the md5 module.
"""
if digestmod is None:
import md5
digestmod = md5
self.digestmod = digestmod
self.outer = digestmod.new()
self.inner = digestmod.new()
self.digest_size = digestmod.digest_size
blocksize = 64
ipad = "\x36" * blocksize
opad = "\x5C" * blocksize
if len(key) > blocksize:
key = digestmod.new(key).digest()
key = key + chr(0) * (blocksize - len(key))
self.outer.update(_strxor(key, opad))
self.inner.update(_strxor(key, ipad))
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
"""
self.inner.update(msg)
def copy(self):
"""Return a separate copy of this hashing object.
An update to this copy won't affect the original object.
"""
other = HMAC("")
other.digestmod = self.digestmod
other.inner = self.inner.copy()
other.outer = self.outer.copy()
return other
def digest(self):
"""Return the hash value of this hashing object.
This returns a string containing 8-bit data. The object is
not altered in any way by this function; you can continue
updating the object after calling this function.
"""
h = self.outer.copy()
h.update(self.inner.digest())
return h.digest()
def hexdigest(self):
"""Like digest(), but returns a string of hexadecimal digits instead.
"""
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
msg: if available, will immediately be hashed into the object's starting
state.
You can now feed arbitrary strings into the object using its update()
method, and can ask for the hash value at any time by calling its digest()
method.
"""
return HMAC(key, msg, digestmod)
...@@ -55,22 +55,6 @@ ...@@ -55,22 +55,6 @@
</description> </description>
</key> </key>
<key name="monitor-address" datatype="socket-binding-address"
required="no">
<description>
The address at which the monitor server should listen. If
specified, a monitor server is started. The monitor server
provides server statistics in a simple text format. This can
be in the form 'host:port' to signify a TCP/IP connection or a
pathname string to signify a Unix domain socket connection (at
least one '/' is required). A hostname may be a DNS name or a
dotted IP address. If the hostname is omitted, the platform's
default behavior is used when binding the listening socket (''
is passed to socket.bind() as the hostname portion of the
address).
</description>
</key>
<key name="transaction-timeout" datatype="integer" <key name="transaction-timeout" datatype="integer"
required="no"> required="no">
<description> <description>
...@@ -81,28 +65,6 @@ ...@@ -81,28 +65,6 @@
</description> </description>
</key> </key>
<key name="authentication-protocol" required="no">
<description>
The name of the protocol used for authentication. The
only protocol provided with ZEO is "digest," but extensions
may provide other protocols.
</description>
</key>
<key name="authentication-database" required="no">
<description>
The path of the database containing authentication credentials.
</description>
</key>
<key name="authentication-realm" required="no">
<description>
The authentication realm of the server. Some authentication
schemes use a realm to identify the logical set of usernames
that are accepted by this server.
</description>
</key>
<key name="pid-filename" datatype="existing-dirpath" <key name="pid-filename" datatype="existing-dirpath"
required="no"> required="no">
<description> <description>
......
...@@ -24,7 +24,8 @@ class StaleCache(object): ...@@ -24,7 +24,8 @@ class StaleCache(object):
class IClientCache(zope.interface.Interface): class IClientCache(zope.interface.Interface):
"""Client cache interface. """Client cache interface.
Note that caches need not be thread safe. Note that caches need not be thread safe, fpr the most part,
except for getLastTid, which may be called from multiple threads.
""" """
def close(): def close():
...@@ -73,6 +74,9 @@ class IClientCache(zope.interface.Interface): ...@@ -73,6 +74,9 @@ class IClientCache(zope.interface.Interface):
"""Get the last tid seen by the cache """Get the last tid seen by the cache
This is the cached last tid we've seen from the server. This is the cached last tid we've seen from the server.
This method may be called from multiple threads. (It's assumed
to be trivial.)
""" """
def setLastTid(tid): def setLastTid(tid):
......
...@@ -59,7 +59,6 @@ class StorageStats: ...@@ -59,7 +59,6 @@ class StorageStats:
self.commits = 0 self.commits = 0
self.aborts = 0 self.aborts = 0
self.active_txns = 0 self.active_txns = 0
self.verifying_clients = 0
self.lock_time = None self.lock_time = None
self.conflicts = 0 self.conflicts = 0
self.conflicts_resolved = 0 self.conflicts_resolved = 0
...@@ -114,79 +113,3 @@ class StorageStats: ...@@ -114,79 +113,3 @@ class StorageStats:
print("Stores:", self.stores, file=f) print("Stores:", self.stores, file=f)
print("Conflicts:", self.conflicts, file=f) print("Conflicts:", self.conflicts, file=f)
print("Conflicts resolved:", self.conflicts_resolved, file=f) print("Conflicts resolved:", self.conflicts_resolved, file=f)
class StatsClient(asyncore.dispatcher):
def __init__(self, sock, addr):
asyncore.dispatcher.__init__(self, sock)
self.buf = []
self.closed = 0
def close(self):
self.closed = 1
# The socket is closed after all the data is written.
# See handle_write().
def write(self, s):
self.buf.append(s)
def writable(self):
return len(self.buf)
def readable(self):
return 0
def handle_write(self):
s = "".join(self.buf)
self.buf = []
n = self.socket.send(s.encode('ascii'))
if n < len(s):
self.buf.append(s[:n])
if self.closed and not self.buf:
asyncore.dispatcher.close(self)
class StatsServer(asyncore.dispatcher):
StatsConnectionClass = StatsClient
def __init__(self, addr, stats):
asyncore.dispatcher.__init__(self)
self.addr = addr
self.stats = stats
if type(self.addr) == tuple:
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
else:
self.create_socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.set_reuse_addr()
logger = logging.getLogger('ZEO.monitor')
logger.info("listening on %s", repr(self.addr))
self.bind(self.addr)
self.listen(5)
def writable(self):
return 0
def readable(self):
return 1
def handle_accept(self):
try:
sock, addr = self.accept()
except socket.error:
return
f = self.StatsConnectionClass(sock, addr)
self.dump(f)
f.close()
def dump(self, f):
print("ZEO monitor server version %s" % zeo_version, file=f)
print(time.ctime(), file=f)
print(file=f)
L = sorted(self.stats.keys())
for k in L:
stats = self.stats[k]
print("Storage:", k, file=f)
stats.dump(f)
print(file=f)
This diff is collapsed.
...@@ -22,7 +22,6 @@ Options: ...@@ -22,7 +22,6 @@ Options:
-f/--filename FILENAME -- filename for FileStorage -f/--filename FILENAME -- filename for FileStorage
-t/--timeout TIMEOUT -- transaction timeout in seconds (default no timeout) -t/--timeout TIMEOUT -- transaction timeout in seconds (default no timeout)
-h/--help -- print this usage message and exit -h/--help -- print this usage message and exit
-m/--monitor ADDRESS -- address of monitor server ([HOST:]PORT or PATH)
--pid-file PATH -- relative path to output file containing this process's pid; --pid-file PATH -- relative path to output file containing this process's pid;
default $(INSTANCE_HOME)/var/ZEO.pid but only if envar default $(INSTANCE_HOME)/var/ZEO.pid but only if envar
INSTANCE_HOME is defined INSTANCE_HOME is defined
...@@ -72,9 +71,6 @@ class ZEOOptionsMixin: ...@@ -72,9 +71,6 @@ class ZEOOptionsMixin:
def handle_address(self, arg): def handle_address(self, arg):
self.family, self.address = parse_binding_address(arg) self.family, self.address = parse_binding_address(arg)
def handle_monitor_address(self, arg):
self.monitor_family, self.monitor_address = parse_binding_address(arg)
def handle_filename(self, arg): def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*! from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig: class FSConfig:
...@@ -107,14 +103,6 @@ class ZEOOptionsMixin: ...@@ -107,14 +103,6 @@ class ZEOOptionsMixin:
self.add("invalidation_age", "zeo.invalidation_age") self.add("invalidation_age", "zeo.invalidation_age")
self.add("transaction_timeout", "zeo.transaction_timeout", self.add("transaction_timeout", "zeo.transaction_timeout",
"t:", "timeout=", float) "t:", "timeout=", float)
self.add("monitor_address", "zeo.monitor_address.address",
"m:", "monitor=", self.handle_monitor_address)
self.add('auth_protocol', 'zeo.authentication_protocol',
None, 'auth-protocol=', default=None)
self.add('auth_database', 'zeo.authentication_database',
None, 'auth-database=')
self.add('auth_realm', 'zeo.authentication_realm',
None, 'auth-realm=')
self.add('pid_file', 'zeo.pid_filename', self.add('pid_file', 'zeo.pid_filename',
None, 'pid-file=') None, 'pid-file=')
...@@ -184,6 +172,7 @@ class ZEOServer: ...@@ -184,6 +172,7 @@ class ZEOServer:
self.options.address[1] is None): self.options.address[1] is None):
self.options.address = self.options.address[0], 0 self.options.address = self.options.address[0], 0
return return
if self.can_connect(self.options.family, self.options.address): if self.can_connect(self.options.family, self.options.address):
self.options.usage("address %s already in use" % self.options.usage("address %s already in use" %
repr(self.options.address)) repr(self.options.address))
...@@ -352,10 +341,6 @@ def create_server(storages, options): ...@@ -352,10 +341,6 @@ def create_server(storages, options):
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout = options.transaction_timeout,
monitor_address = options.monitor_address,
auth_protocol = options.auth_protocol,
auth_database = options.auth_database,
auth_realm = options.auth_realm,
) )
...@@ -393,5 +378,11 @@ def main(args=None): ...@@ -393,5 +378,11 @@ def main(args=None):
s = ZEOServer(options) s = ZEOServer(options)
s.main() s.main()
def run(args):
options = ZEOOptions()
options.realize(args)
s = ZEOServer(options)
s.run()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -500,7 +500,8 @@ def days(f): ...@@ -500,7 +500,8 @@ def days(f):
minute(f, 10, detail=0) minute(f, 10, detail=0)
new_connection_idre = re.compile(r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):") new_connection_idre = re.compile(
r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):")
def verify(f): def verify(f):
f, = f f, = f
......
...@@ -11,27 +11,6 @@ ...@@ -11,27 +11,6 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
import os
import threading
import logging
from ZODB.loglevels import BLATHER
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
logger = logging.getLogger('ZEO.zrpc')
_label = "%s" % os.getpid()
def new_label():
global _label
_label = str(os.getpid())
def log(message, level=BLATHER, label=None, exc_info=False):
label = label or _label
if LOG_THREAD_ID:
label = label + ':' + threading.currentThread().getName()
logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info)
REPR_LIMIT = 60 REPR_LIMIT = 60
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# #
############################################################################## ##############################################################################
import concurrent.futures import concurrent.futures
import contextlib
import os import os
import time import time
import socket import socket
...@@ -21,7 +22,7 @@ import logging ...@@ -21,7 +22,7 @@ import logging
from ZEO.ClientStorage import ClientStorage from ZEO.ClientStorage import ClientStorage
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZEO.zrpc.marshal import encode from ZEO.asyncio.marshal import encode
from ZEO.tests import forker from ZEO.tests import forker
from ZODB.DB import DB from ZODB.DB import DB
...@@ -79,40 +80,22 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -79,40 +80,22 @@ class CommonSetupTearDown(StorageTestBase):
logging.info("setUp() %s", self.id()) logging.info("setUp() %s", self.id())
self.file = 'storage_conf' self.file = 'storage_conf'
self.addr = [] self.addr = []
self._pids = []
self._servers = [] self._servers = []
self.conf_paths = []
self.caches = [] self.caches = []
self._newAddr() self._newAddr()
self.startServer() self.startServer()
# self._old_log_level = logging.getLogger().getEffectiveLevel()
# logging.getLogger().setLevel(logging.WARNING)
# self._log_handler = logging.StreamHandler()
# logging.getLogger().addHandler(self._log_handler)
def tearDown(self): def tearDown(self):
"""Try to cause the tests to halt""" """Try to cause the tests to halt"""
# logging.getLogger().setLevel(self._old_log_level)
# logging.getLogger().removeHandler(self._log_handler)
# logging.info("tearDown() %s" % self.id())
for p in self.conf_paths:
os.remove(p)
if getattr(self, '_storage', None) is not None: if getattr(self, '_storage', None) is not None:
self._storage.close() self._storage.close()
if hasattr(self._storage, 'cleanup'): if hasattr(self._storage, 'cleanup'):
logging.debug("cleanup storage %s" % logging.debug("cleanup storage %s" %
self._storage.__name__) self._storage.__name__)
self._storage.cleanup() self._storage.cleanup()
for adminaddr in self._servers: for stop in self._servers:
if adminaddr is not None: stop()
forker.shutdown_zeo_server(adminaddr)
for pid in self._pids:
try:
os.waitpid(pid, 0)
except OSError:
pass # The subprocess module may already have waited
for c in self.caches: for c in self.caches:
for i in 0, 1: for i in 0, 1:
...@@ -183,7 +166,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -183,7 +166,7 @@ class CommonSetupTearDown(StorageTestBase):
return zconf return zconf
def startServer(self, create=1, index=0, read_only=0, ro_svr=0, keep=None, def startServer(self, create=1, index=0, read_only=0, ro_svr=0, keep=None,
path=None): path=None, **kw):
addr = self.addr[index] addr = self.addr[index]
logging.info("startServer(create=%d, index=%d, read_only=%d) @ %s" % logging.info("startServer(create=%d, index=%d, read_only=%d) @ %s" %
(create, index, read_only, addr)) (create, index, read_only, addr))
...@@ -193,19 +176,17 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -193,19 +176,17 @@ class CommonSetupTearDown(StorageTestBase):
zconf = self.getServerConfig(addr, ro_svr) zconf = self.getServerConfig(addr, ro_svr)
if keep is None: if keep is None:
keep = self.keep keep = self.keep
zeoport, adminaddr, pid, path = forker.start_zeo_server( zeoport, stop = forker.start_zeo_server(
sconf, zconf, addr[1], keep) sconf, zconf, addr[1], keep, **kw)
self.conf_paths.append(path) self._servers.append(stop)
self._pids.append(pid)
self._servers.append(adminaddr)
def shutdownServer(self, index=0): def shutdownServer(self, index=0):
logging.info("shutdownServer(index=%d) @ %s" % logging.info("shutdownServer(index=%d) @ %s" %
(index, self._servers[index])) (index, self._servers[index]))
adminaddr = self._servers[index] stop = self._servers[index]
if adminaddr is not None: if stop is not None:
forker.shutdown_zeo_server(adminaddr) stop()
self._servers[index] = None self._servers[index] = lambda : None
def pollUp(self, timeout=30.0, storage=None): def pollUp(self, timeout=30.0, storage=None):
if storage is None: if storage is None:
...@@ -310,6 +291,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -310,6 +291,7 @@ class ConnectionTests(CommonSetupTearDown):
# object is not in the cache. # object is not in the cache.
self.shutdownServer() self.shutdownServer()
self._storage = self.openClientStorage('test', 1000, wait=0) self._storage = self.openClientStorage('test', 1000, wait=0)
with short_timeout(self):
self.assertRaises(ClientDisconnected, self.assertRaises(ClientDisconnected,
self._storage.load, b'fredwash', '') self._storage.load, b'fredwash', '')
self._storage.close() self._storage.close()
...@@ -377,6 +359,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -377,6 +359,7 @@ class ConnectionTests(CommonSetupTearDown):
self.assertEqual(expected2, self._storage.load(oid2, '')) self.assertEqual(expected2, self._storage.load(oid2, ''))
# But oid1 should have been purged, so that trying to load it will # But oid1 should have been purged, so that trying to load it will
# try to fetch it from the (non-existent) ZEO server. # try to fetch it from the (non-existent) ZEO server.
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._storage.load, oid1, '') self.assertRaises(ClientDisconnected, self._storage.load, oid1, '')
self._storage.close() self._storage.close()
...@@ -569,13 +552,17 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -569,13 +552,17 @@ class ConnectionTests(CommonSetupTearDown):
self._storage = self.openClientStorage() self._storage = self.openClientStorage()
self._dostore() self._dostore()
self.shutdownServer() self.shutdownServer()
self.assertRaises(ClientDisconnected, self._storage.load, b'\0'*8, '') with short_timeout(self):
self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '')
self.startServer() self.startServer()
# No matter how long we wait, the client won't reconnect: # No matter how long we wait, the client won't reconnect:
time.sleep(2) time.sleep(2)
self.assertRaises(ClientDisconnected, self._storage.load, b'\0'*8, '') with short_timeout(self):
self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '')
class InvqTests(CommonSetupTearDown): class InvqTests(CommonSetupTearDown):
invq = 3 invq = 3
...@@ -701,6 +688,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -701,6 +688,7 @@ class ReconnectionTests(CommonSetupTearDown):
# Poll until the client disconnects # Poll until the client disconnects
self.pollDown() self.pollDown()
# Stores should fail now # Stores should fail now
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._dostore) self.assertRaises(ClientDisconnected, self._dostore)
# Restart the server # Restart the server
...@@ -750,6 +738,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -750,6 +738,7 @@ class ReconnectionTests(CommonSetupTearDown):
# Poll until the client disconnects # Poll until the client disconnects
self.pollDown() self.pollDown()
# Stores should fail now # Stores should fail now
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._dostore) self.assertRaises(ClientDisconnected, self._dostore)
# Restart the server # Restart the server
...@@ -780,8 +769,8 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -780,8 +769,8 @@ class ReconnectionTests(CommonSetupTearDown):
self.pollDown() self.pollDown()
# Accesses should fail now # Accesses should fail now
self.assertRaises(ClientDisconnected, self._storage.history, ZERO, with short_timeout(self):
timeout=1) self.assertRaises(ClientDisconnected, self._storage.history, ZERO)
# Restart the server, this time read-write # Restart the server, this time read-write
self.startServer(create=0, keep=0) self.startServer(create=0, keep=0)
...@@ -881,6 +870,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -881,6 +870,7 @@ class ReconnectionTests(CommonSetupTearDown):
data = zodb_pickle(MinPO(oid)) data = zodb_pickle(MinPO(oid))
self._storage.store(oid, None, data, '', txn) self._storage.store(oid, None, data, '', txn)
self.shutdownServer() self.shutdownServer()
with short_timeout(self):
self.assertRaises(ClientDisconnected, self._storage.tpc_vote, txn) self.assertRaises(ClientDisconnected, self._storage.tpc_vote, txn)
self.startServer(create=0) self.startServer(create=0)
self._storage.tpc_abort(txn) self._storage.tpc_abort(txn)
...@@ -967,11 +957,12 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -967,11 +957,12 @@ class TimeoutTests(CommonSetupTearDown):
timeout = 1 timeout = 1
def checkTimeout(self): def checkTimeout(self):
storage = self.openClientStorage() self._storage = storage = self.openClientStorage()
txn = Transaction() txn = Transaction()
storage.tpc_begin(txn) storage.tpc_begin(txn)
storage.tpc_vote(txn) storage.tpc_vote(txn)
time.sleep(2) time.sleep(2)
with short_timeout(self):
self.assertRaises(ClientDisconnected, storage.tpc_finish, txn) self.assertRaises(ClientDisconnected, storage.tpc_finish, txn)
# Make sure it's logged as CRITICAL # Make sure it's logged as CRITICAL
...@@ -1188,6 +1179,14 @@ class MSTThread(threading.Thread): ...@@ -1188,6 +1179,14 @@ class MSTThread(threading.Thread):
except: except:
pass pass
@contextlib.contextmanager
def short_timeout(self):
old = self._storage._server.timeout
self._storage._server.timeout = 1
yield
self._storage._server.timeout = old
# Run IPv6 tests if V6 sockets are supported # Run IPv6 tests if V6 sockets are supported
try: try:
socket.socket(socket.AF_INET6, socket.SOCK_STREAM) socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
......
...@@ -20,7 +20,7 @@ verification is no longer supported. ...@@ -20,7 +20,7 @@ verification is no longer supported.
Here's an example that shows that this is actually what happens. Here's an example that shows that this is actually what happens.
Start a server, create a cient to it and commit some data Start a server, create a client to it and commit some data
>>> addr, admin = start_server(keep=1) >>> addr, admin = start_server(keep=1)
>>> import ZEO, transaction >>> import ZEO, transaction
...@@ -57,6 +57,7 @@ logging and event data: ...@@ -57,6 +57,7 @@ logging and event data:
... 'ZEO', level=logging.ERROR) ... 'ZEO', level=logging.ERROR)
>>> events = [] >>> events = []
>>> def event_handler(e): >>> def event_handler(e):
... if hasattr(e, 'storage'):
... events.append(( ... events.append((
... len(e.storage._server.client.cache), str(handler), e.__class__.__name__)) ... len(e.storage._server.client.cache), str(handler), e.__class__.__name__))
...@@ -70,7 +71,8 @@ is generated before the cache is dropped or the message is logged. ...@@ -70,7 +71,8 @@ is generated before the cache is dropped or the message is logged.
Now, we'll restart the server on the original address: Now, we'll restart the server on the original address:
>>> _, admin = start_server(zeo_conf=dict(invalidation_queue_size=1), >>> _, admin = start_server(zeo_conf=dict(invalidation_queue_size=1),
... addr=addr, keep=1) ... addr=addr, keep=1, threaded=True)
>>> wait_connected(db.storage) >>> wait_connected(db.storage)
Now, let's verify our assertions above: Now, let's verify our assertions above:
......
This diff is collapsed.
...@@ -5,7 +5,7 @@ A full test of all protocols isn't practical. But we'll do a limited ...@@ -5,7 +5,7 @@ A full test of all protocols isn't practical. But we'll do a limited
test that at least the current and previous protocols are supported in test that at least the current and previous protocols are supported in
both directions. both directions.
Let's start a Z308 server Let's start a Z309 server
>>> storage_conf = ''' >>> storage_conf = '''
... <blobstorage> ... <blobstorage>
...@@ -16,8 +16,8 @@ Let's start a Z308 server ...@@ -16,8 +16,8 @@ Let's start a Z308 server
... </blobstorage> ... </blobstorage>
... ''' ... '''
>>> addr, admin = start_server( >>> addr, stop = start_server(
... storage_conf, dict(invalidation_queue_size=5), protocol=b'Z309') ... storage_conf, dict(invalidation_queue_size=5), protocol=b'Z4')
A current client should be able to connect to a old server: A current client should be able to connect to a old server:
...@@ -25,7 +25,7 @@ A current client should be able to connect to a old server: ...@@ -25,7 +25,7 @@ A current client should be able to connect to a old server:
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
>>> wait_connected(db.storage) >>> wait_connected(db.storage)
>>> db.storage.protocol_version >>> db.storage.protocol_version
b'Z309' b'Z4'
>>> conn = db.open() >>> conn = db.open()
>>> conn.root().x = 0 >>> conn.root().x = 0
...@@ -87,7 +87,7 @@ A current client should be able to connect to a old server: ...@@ -87,7 +87,7 @@ A current client should be able to connect to a old server:
>>> db2.close() >>> db2.close()
>>> db.close() >>> db.close()
>>> stop_server(admin) >>> stop_server(stop)
>>> import os, zope.testing.setupstack >>> import os, zope.testing.setupstack
>>> os.remove('client-1.zec') >>> os.remove('client-1.zec')
...@@ -102,11 +102,11 @@ Note that we'll have to pull some hijinks: ...@@ -102,11 +102,11 @@ Note that we'll have to pull some hijinks:
>>> import ZEO.asyncio.client >>> import ZEO.asyncio.client
>>> old_protocols = ZEO.asyncio.client.Protocol.protocols >>> old_protocols = ZEO.asyncio.client.Protocol.protocols
>>> ZEO.asyncio.client.Protocol.protocols = [b'Z309'] >>> ZEO.asyncio.client.Protocol.protocols = [b'Z4']
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
>>> db.storage.protocol_version >>> db.storage.protocol_version
b'Z309' b'Z4'
>>> wait_connected(db.storage) >>> wait_connected(db.storage)
>>> conn = db.open() >>> conn = db.open()
>>> conn.root().x = 0 >>> conn.root().x = 0
......
...@@ -30,9 +30,8 @@ from __future__ import print_function ...@@ -30,9 +30,8 @@ from __future__ import print_function
# Here, we'll try to provide some testing infrastructure to isolate # Here, we'll try to provide some testing infrastructure to isolate
# servers from the network. # servers from the network.
import ZEO.asyncio.tests
import ZEO.StorageServer import ZEO.StorageServer
import ZEO.zrpc.connection
import ZEO.zrpc.error
import ZODB.MappingStorage import ZODB.MappingStorage
class StorageServer(ZEO.StorageServer.StorageServer): class StorageServer(ZEO.StorageServer.StorageServer):
...@@ -42,44 +41,10 @@ class StorageServer(ZEO.StorageServer.StorageServer): ...@@ -42,44 +41,10 @@ class StorageServer(ZEO.StorageServer.StorageServer):
storages = {'1': ZODB.MappingStorage.MappingStorage()} storages = {'1': ZODB.MappingStorage.MappingStorage()}
ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw) ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw)
def client(server, name='client'):
class DispatcherClass:
__init__ = lambda *a, **kw: None
class socket:
getsockname = staticmethod(lambda : 'socket')
class Connection:
peer_protocol_version = ZEO.zrpc.connection.Connection.current_protocol
connected = True
def __init__(self, name='connection', addr=''):
name = str(name)
self.name = name
self.addr = addr or 'test-addr-'+name
def close(self):
print(self.name, 'closed')
self.connected = False
def poll(self):
if not self.connected:
raise ZEO.zrpc.error.DisconnectedError()
def callAsync(self, meth, *args):
print(self.name, 'callAsync', meth, repr(args))
callAsyncNoPoll = callAsync
def call_from_thread(self, *args):
if args:
args[0](*args[1:])
def send_reply(self, *args):
pass
def client(server, name='client', addr=''):
zs = ZEO.StorageServer.ZEOStorage(server) zs = ZEO.StorageServer.ZEOStorage(server)
zs.notifyConnected(Connection(name, addr)) protocol = ZEO.asyncio.tests.server_protocol(
zs, protocol_version=b'Z5', addr='test-addr-%s' % name)
zs.notify_connected(protocol)
zs.register('1', 0) zs.register('1', 0)
return zs return zs
...@@ -28,8 +28,6 @@ else: ...@@ -28,8 +28,6 @@ else:
import doctest import doctest
import unittest import unittest
import ZEO.tests.forker import ZEO.tests.forker
import ZEO.tests.testMonitor
import ZEO.zrpc.connection
import ZODB.tests.util import ZODB.tests.util
class FileStorageConfig: class FileStorageConfig:
...@@ -90,41 +88,6 @@ class MappingStorageTimeoutTests( ...@@ -90,41 +88,6 @@ class MappingStorageTimeoutTests(
): ):
pass pass
class MonitorTests(ZEO.tests.testMonitor.MonitorTests):
def check_connection_management(self):
# Open and close a few connections, making sure that
# the resulting number of clients is 0.
s1 = self.openClientStorage()
s2 = self.openClientStorage()
s3 = self.openClientStorage()
stats = self.parse(self.get_monitor_output())[1]
self.assertEqual(stats.clients, 3)
s1.close()
s3.close()
s2.close()
ZEO.tests.forker.wait_until(
"Number of clients shown in monitor drops to 0",
lambda :
self.parse(self.get_monitor_output())[1].clients == 0
)
def check_connection_management_with_old_client(self):
# Check that connection management works even when using an
# older protcool that requires a connection adapter.
test_protocol = b"Z303"
current_protocol = ZEO.zrpc.connection.Connection.current_protocol
ZEO.zrpc.connection.Connection.current_protocol = test_protocol
ZEO.zrpc.connection.Connection.servers_we_can_talk_to.append(
test_protocol)
try:
self.check_connection_management()
finally:
ZEO.zrpc.connection.Connection.current_protocol = current_protocol
ZEO.zrpc.connection.Connection.servers_we_can_talk_to.pop()
test_classes = [FileStorageConnectionTests, test_classes = [FileStorageConnectionTests,
FileStorageReconnectionTests, FileStorageReconnectionTests,
...@@ -132,7 +95,6 @@ test_classes = [FileStorageConnectionTests, ...@@ -132,7 +95,6 @@ test_classes = [FileStorageConnectionTests,
FileStorageTimeoutTests, FileStorageTimeoutTests,
MappingStorageConnectionTests, MappingStorageConnectionTests,
MappingStorageTimeoutTests, MappingStorageTimeoutTests,
MonitorTests,
] ]
def invalidations_while_connecting(): def invalidations_while_connecting():
......
...@@ -52,6 +52,10 @@ class FakeServer: ...@@ -52,6 +52,10 @@ class FakeServer:
def register_connection(*args): def register_connection(*args):
return None, None return None, None
class FakeConnection:
protocol_version = b'Z4'
addr = 'test'
def test_server_record_iternext(): def test_server_record_iternext():
""" """
...@@ -61,6 +65,7 @@ underlying storage. ...@@ -61,6 +65,7 @@ underlying storage.
>>> import ZEO.StorageServer >>> import ZEO.StorageServer
>>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False) >>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False)
>>> zeo.notify_connected(FakeConnection())
>>> zeo.register('1', False) >>> zeo.register('1', False)
>>> next = None >>> next = None
...@@ -80,6 +85,7 @@ The storage info also reflects the fact that record_iternext is supported. ...@@ -80,6 +85,7 @@ The storage info also reflects the fact that record_iternext is supported.
True True
>>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False) >>> zeo = ZEO.StorageServer.ZEOStorage(FakeServer(), False)
>>> zeo.notify_connected(FakeConnection())
>>> zeo.register('2', False) >>> zeo.register('2', False)
>>> zeo.get_info()['supports_record_iternext'] >>> zeo.get_info()['supports_record_iternext']
...@@ -129,41 +135,6 @@ Now we'll have our way with it's private _server attr: ...@@ -129,41 +135,6 @@ Now we'll have our way with it's private _server attr:
""" """
def history_to_version_compatible_storage():
"""
Some storages work under ZODB <= 3.8 and ZODB >= 3.9.
This means they have a history method that accepts a version parameter:
>>> class VersionCompatibleStorage(FakeStorageBase):
... def history(self,oid,version='',size=1):
... return oid,version,size
A ZEOStorage such as the following should support this type of storage:
>>> class OurFakeServer(FakeServer):
... storages = {'1':VersionCompatibleStorage()}
>>> import ZEO.StorageServer
>>> zeo = ZEO.StorageServer.ZEOStorage(OurFakeServer(), False)
>>> zeo.register('1', False)
The ZEOStorage should sort out the following call such that the storage gets
the correct parameters and so should return the parameters it was called with:
>>> zeo.history('oid',99)
('oid', '', 99)
The same problem occurs when a Z308 client connects to a Z309 server,
but different code is executed:
>>> from ZEO.StorageServer import ZEOStorage308Adapter
>>> zeo = ZEOStorage308Adapter(VersionCompatibleStorage())
The history method should still return the parameters it was called with:
>>> zeo.history('oid','',99)
('oid', '', 99)
"""
def test_suite(): def test_suite():
return doctest.DocTestSuite() return doctest.DocTestSuite()
......
##############################################################################
#
# Copyright (c) 2003 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Test that the monitor produce sensible results.
$Id$
"""
import socket
import unittest
from ZEO.tests.ConnectionTests import CommonSetupTearDown
from ZEO.monitor import StorageStats
class MonitorTests(CommonSetupTearDown):
monitor = 1
def get_monitor_output(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(('localhost', 42000))
L = []
while 1:
buf = s.recv(8192)
if buf:
L.append(buf)
else:
break
s.close()
return b"".join(L).decode('ascii')
def parse(self, s):
# Return a list of StorageStats, one for each storage.
lines = s.split("\n")
self.assert_(lines[0].startswith("ZEO monitor server"))
# lines[1] is a date
# Break up rest of lines into sections starting with Storage:
# and ending with a blank line.
sections = []
cur = None
for line in lines[2:]:
if line.startswith("Storage:"):
cur = [line]
elif line:
cur.append(line)
else:
if cur is not None:
sections.append(cur)
cur = None
assert cur is None # bug in the test code if this fails
d = {}
for sect in sections:
hdr = sect[0]
key, value = hdr.split(":")
storage = int(value)
s = d[storage] = StorageStats()
s.parse("\n".join(sect[1:]))
return d
def getConfig(self, path, create, read_only):
return """<mappingstorage 1/>"""
def testMonitor(self):
# Just open a client to know that the server is up and running
# TODO: should put this in setUp.
self.storage = self.openClientStorage()
s = self.get_monitor_output()
self.storage.close()
self.assert_(s.find("monitor") != -1)
d = self.parse(s)
stats = d[1]
self.assertEqual(stats.clients, 1)
self.assertEqual(stats.commits, 0)
def test_suite():
return unittest.makeSuite(MonitorTests)
...@@ -16,11 +16,10 @@ from __future__ import print_function ...@@ -16,11 +16,10 @@ from __future__ import print_function
import multiprocessing import multiprocessing
import re import re
from ZEO.ClientStorage import ClientStorage from ZEO.ClientStorage import ClientStorage, m64
from ZEO.tests.forker import get_port from ZEO.tests.forker import get_port
from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests
from ZEO.tests import IterationTests from ZEO.tests import IterationTests
from ZEO.zrpc.error import DisconnectedError
from ZEO._compat import PY3 from ZEO._compat import PY3
from ZODB.tests import StorageTestBase, BasicStorage, \ from ZODB.tests import StorageTestBase, BasicStorage, \
TransactionalUndoStorage, \ TransactionalUndoStorage, \
...@@ -48,7 +47,6 @@ import transaction ...@@ -48,7 +47,6 @@ import transaction
import unittest import unittest
import ZEO.StorageServer import ZEO.StorageServer
import ZEO.tests.ConnectionTests import ZEO.tests.ConnectionTests
import ZEO.zrpc.connection
import ZODB import ZODB
import ZODB.blob import ZODB.blob
import ZODB.tests.hexstorage import ZODB.tests.hexstorage
...@@ -168,11 +166,9 @@ class GenericTests( ...@@ -168,11 +166,9 @@ class GenericTests(
logger.info("setUp() %s", self.id()) logger.info("setUp() %s", self.id())
port = get_port(self) port = get_port(self)
zconf = forker.ZEOConfig(('', port)) zconf = forker.ZEOConfig(('', port))
zport, adminaddr, pid, path = forker.start_zeo_server(self.getConfig(), zport, stop = forker.start_zeo_server(self.getConfig(),
zconf, port) zconf, port)
self._pids = [pid] self._servers = [stop]
self._servers = [adminaddr]
self._conf_path = path
if not self.blob_cache_dir: if not self.blob_cache_dir:
# This is the blob cache for ClientStorage # This is the blob cache for ClientStorage
self.blob_cache_dir = tempfile.mkdtemp( self.blob_cache_dir = tempfile.mkdtemp(
...@@ -190,12 +186,8 @@ class GenericTests( ...@@ -190,12 +186,8 @@ class GenericTests(
def tearDown(self): def tearDown(self):
self._storage.close() self._storage.close()
for server in self._servers: for stop in self._servers:
forker.shutdown_zeo_server(server) stop()
if hasattr(os, 'waitpid'):
# Not in Windows Python until 2.3
for pid in self._pids:
os.waitpid(pid, 0)
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
def runTest(self): def runTest(self):
...@@ -278,10 +270,9 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase, ...@@ -278,10 +270,9 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase,
def _new_storage(self): def _new_storage(self):
port = get_port(self) port = get_port(self)
zconf = forker.ZEOConfig(('', port)) zconf = forker.ZEOConfig(('', port))
zport, adminaddr, pid, path = forker.start_zeo_server(self.getConfig(), zport, stop = forker.start_zeo_server(self.getConfig(),
zconf, port) zconf, port)
self._pids.append(pid) self._servers.append(stop)
self._servers.append(adminaddr)
blob_cache_dir = tempfile.mkdtemp(dir='.') blob_cache_dir = tempfile.mkdtemp(dir='.')
...@@ -294,7 +285,6 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase, ...@@ -294,7 +285,6 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase,
def setUp(self): def setUp(self):
StorageTestBase.StorageTestBase.setUp(self) StorageTestBase.StorageTestBase.setUp(self)
self._pids = []
self._servers = [] self._servers = []
self._storage = self._new_storage() self._storage = self._new_storage()
...@@ -304,12 +294,8 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase, ...@@ -304,12 +294,8 @@ class FileStorageRecoveryTests(StorageTestBase.StorageTestBase,
self._storage.close() self._storage.close()
self._dst.close() self._dst.close()
for server in self._servers: for stop in self._servers:
forker.shutdown_zeo_server(server) stop()
if hasattr(os, 'waitpid'):
# Not in Windows Python until 2.3
for pid in self._pids:
os.waitpid(pid, 0)
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
def new_dest(self): def new_dest(self):
...@@ -708,27 +694,23 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests): ...@@ -708,27 +694,23 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests):
class FauxConn: class FauxConn:
addr = 'x' addr = 'x'
peer_protocol_version = ZEO.zrpc.connection.Connection.current_protocol protocol_version = ZEO.asyncio.server.best_protocol_version
peer_protocol_version = protocol_version
class StorageServerClientWrapper: serials = []
def async(self, method, *args):
if method == 'serialnos':
self.serials.extend(args[0])
def __init__(self): call_soon_threadsafe = async
self.serials = []
def serialnos(self, serials):
self.serials.extend(serials)
def info(self, info):
pass
class StorageServerWrapper: class StorageServerWrapper:
def __init__(self, server, storage_id): def __init__(self, server, storage_id):
self.storage_id = storage_id self.storage_id = storage_id
self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only) self.server = ZEO.StorageServer.ZEOStorage(server, server.read_only)
self.server.notifyConnected(FauxConn()) self.server.notify_connected(FauxConn())
self.server.register(storage_id, False) self.server.register(storage_id, False)
self.server.client = StorageServerClientWrapper()
def sortKey(self): def sortKey(self):
return self.storage_id return self.storage_id
...@@ -751,8 +733,8 @@ class StorageServerWrapper: ...@@ -751,8 +733,8 @@ class StorageServerWrapper:
def tpc_vote(self, transaction): def tpc_vote(self, transaction):
vote_result = self.server.vote(id(transaction)) vote_result = self.server.vote(id(transaction))
assert vote_result is None assert vote_result is None
result = self.server.client.serials[:] result = self.server.connection.serials[:]
del self.server.client.serials[:] del self.server.connection.serials[:]
return result return result
def store(self, oid, serial, data, version_ignored, transaction): def store(self, oid, serial, data, version_ignored, transaction):
...@@ -838,7 +820,7 @@ Now we'll open a storage server on the data, simulating a restart: ...@@ -838,7 +820,7 @@ Now we'll open a storage server on the data, simulating a restart:
>>> fs = FileStorage('t.fs') >>> fs = FileStorage('t.fs')
>>> sv = StorageServer(('', get_port()), dict(fs=fs)) >>> sv = StorageServer(('', get_port()), dict(fs=fs))
>>> s = ZEOStorage(sv, sv.read_only) >>> s = ZEOStorage(sv, sv.read_only)
>>> s.notifyConnected(FauxConn()) >>> s.notify_connected(FauxConn())
>>> s.register('fs', False) >>> s.register('fs', False)
If we ask for the last transaction, we should get the last transaction If we ask for the last transaction, we should get the last transaction
...@@ -848,7 +830,7 @@ we saved: ...@@ -848,7 +830,7 @@ we saved:
True True
If a storage implements the method lastInvalidations, as FileStorage If a storage implements the method lastInvalidations, as FileStorage
does, then the stroage server will populate its invalidation data does, then the storage server will populate its invalidation data
structure using lastTransactions. structure using lastTransactions.
...@@ -1085,7 +1067,7 @@ def runzeo_without_configfile(): ...@@ -1085,7 +1067,7 @@ def runzeo_without_configfile():
------ ------
--T INFO ZEO.StorageServer StorageServer created RW with storages 1RWt --T INFO ZEO.StorageServer StorageServer created RW with storages 1RWt
------ ------
--T INFO ZEO.zrpc () listening on ... --T INFO ZEO.acceptor listening on ...
------ ------
--T INFO ZEO.StorageServer closing storage '1' --T INFO ZEO.StorageServer closing storage '1'
testing exit immediately testing exit immediately
...@@ -1150,7 +1132,6 @@ def test_server_status(): ...@@ -1150,7 +1132,6 @@ def test_server_status():
'start': 'Tue May 4 10:55:20 2010', 'start': 'Tue May 4 10:55:20 2010',
'stores': 1, 'stores': 1,
'timeout-thread-is-alive': True, 'timeout-thread-is-alive': True,
'verifying_clients': 0,
'waiting': 0} 'waiting': 0}
>>> db.close() >>> db.close()
...@@ -1169,7 +1150,8 @@ def test_ruok(): ...@@ -1169,7 +1150,8 @@ def test_ruok():
>>> _ = writer.write(struct.pack(">I", 4)+b"ruok") >>> _ = writer.write(struct.pack(">I", 4)+b"ruok")
>>> writer.close() >>> writer.close()
>>> proto = s.recv(struct.unpack(">I", s.recv(4))[0]) >>> proto = s.recv(struct.unpack(">I", s.recv(4))[0])
>>> data = json.loads(s.recv(struct.unpack(">I", s.recv(4))[0]).decode("ascii")) >>> data = json.loads(
... s.recv(struct.unpack(">I", s.recv(4))[0]).decode("ascii"))
>>> pprint.pprint(data['1']) >>> pprint.pprint(data['1'])
{u'aborts': 0, {u'aborts': 0,
u'active_txns': 0, u'active_txns': 0,
...@@ -1183,7 +1165,6 @@ def test_ruok(): ...@@ -1183,7 +1165,6 @@ def test_ruok():
u'start': u'Sun Jan 4 09:37:03 2015', u'start': u'Sun Jan 4 09:37:03 2015',
u'stores': 1, u'stores': 1,
u'timeout-thread-is-alive': True, u'timeout-thread-is-alive': True,
u'verifying_clients': 0,
u'waiting': 0} u'waiting': 0}
>>> db.close(); s.close() >>> db.close(); s.close()
""" """
...@@ -1410,7 +1391,7 @@ Now we'll try to use the connection, mainly to wait for everything to ...@@ -1410,7 +1391,7 @@ Now we'll try to use the connection, mainly to wait for everything to
get processed. Before we fixed this by making tpc_finish a synchronous get processed. Before we fixed this by making tpc_finish a synchronous
call to the server. we'd get some sort of error here. call to the server. we'd get some sort of error here.
>>> _ = client._call('loadEx', b'\0'*8) >>> _ = client._call('loadBefore', b'\0'*8, m64)
>>> c.close() >>> c.close()
...@@ -1519,7 +1500,7 @@ class ServerManagingClientStorage(ClientStorage): ...@@ -1519,7 +1500,7 @@ class ServerManagingClientStorage(ClientStorage):
server_blob_dir = 'server-'+blob_dir server_blob_dir = 'server-'+blob_dir
self.globs = {} self.globs = {}
port = forker.get_port2(self) port = forker.get_port2(self)
addr, admin, pid, config = forker.start_zeo_server( addr, stop = forker.start_zeo_server(
""" """
<blobstorage> <blobstorage>
blob-dir %s blob-dir %s
...@@ -1531,10 +1512,7 @@ class ServerManagingClientStorage(ClientStorage): ...@@ -1531,10 +1512,7 @@ class ServerManagingClientStorage(ClientStorage):
""" % (server_blob_dir, name+'.fs', extrafsoptions), """ % (server_blob_dir, name+'.fs', extrafsoptions),
port=port, port=port,
) )
os.remove(config) zope.testing.setupstack.register(self, stop)
zope.testing.setupstack.register(self, os.waitpid, pid, 0)
zope.testing.setupstack.register(
self, forker.shutdown_zeo_server, admin)
if shared: if shared:
ClientStorage.__init__(self, addr, blob_dir=blob_dir, ClientStorage.__init__(self, addr, blob_dir=blob_dir,
shared_blob_dir=True) shared_blob_dir=True)
......
This diff is collapsed.
This diff is collapsed.
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
# zrpc is a package with the following modules
# client -- manages connection creation to remote server
# connection -- object dispatcher
# log -- logging helper
# error -- exceptions raised by zrpc
# marshal -- internal, handles basic protocol issues
# server -- manages incoming connections from remote clients
# smac -- sized message async connections
# trigger -- medusa's trigger
# zrpc is not an advertised subpackage of ZEO; its interfaces are internal
This diff is collapsed.
This diff is collapsed.
##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
from ZODB import POSException
from ZEO.Exceptions import ClientDisconnected
class ZRPCError(POSException.StorageError):
pass
class DisconnectedError(ZRPCError, ClientDisconnected):
"""The database storage is disconnected from the storage server.
The error occurred because a problem in the low-level RPC connection,
or because the connection was closed.
"""
# This subclass is raised when zrpc catches the error.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment