Commit 6a945252 authored by Jim Fulton's avatar Jim Fulton

Refactored the zrpc implementation to:

- Most server methods now return data to clients more quickly by writing to
  client sockets immediately, rather than waiting for the asyncore
  select loop to get around to it.

- More clearly define client and server responsibilities. Machinery
  needed for just clients or just servers has been moved to the
  corresponding connection subclasses.

- Degeneralized "flags" argument to many methods. There's just one
  async flag.
parent 365774c9
......@@ -1340,10 +1340,10 @@ class ClientStub:
self.rpc.callAsyncNoPoll('invalidateTransaction', tid, args)
def serialnos(self, arg):
self.rpc.callAsync('serialnos', arg)
self.rpc.callAsyncNoPoll('serialnos', arg)
def info(self, arg):
self.rpc.callAsync('info', arg)
self.rpc.callAsyncNoPoll('info', arg)
def storeBlob(self, oid, serial, blobfilename):
......
......@@ -56,3 +56,5 @@ class Connection:
def callAsync(self, meth, *args):
print self.name, 'callAsync', meth, repr(args)
callAsyncNoPoll = callAsync
......@@ -78,9 +78,10 @@ will conflict. It will be blocked at the vote call.
>>> zs2.storeBlobEnd(oid, serial, data, '1')
>>> delay = zs2.vote('1')
>>> def send_reply(id, reply):
>>> class Sender:
... def send_reply(self, id, reply):
... print 'reply', id, reply
>>> delay.set_sender(1, send_reply, None)
>>> delay.set_sender(1, Sender())
>>> logger = logging.getLogger('ZEO')
>>> handler = logging.StreamHandler(sys.stdout)
......
......@@ -30,7 +30,6 @@ from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException
REPLY = ".reply" # message name used for replies
ASYNC = 1
exception_type_type = type(Exception)
......@@ -180,34 +179,33 @@ class Delay:
the mainloop from sending a response.
"""
def set_sender(self, msgid, send_reply, return_error):
def set_sender(self, msgid, conn):
self.msgid = msgid
self.send_reply = send_reply
self.return_error = return_error
self.conn = conn
def reply(self, obj):
self.send_reply(self.msgid, obj)
self.conn.send_reply(self.msgid, obj)
def error(self, exc_info):
log("Error raised in delayed method", logging.ERROR, exc_info=True)
self.return_error(self.msgid, 0, *exc_info[:2])
self.conn.return_error(self.msgid, *exc_info[:2])
class MTDelay(Delay):
def __init__(self):
self.ready = threading.Event()
def set_sender(self, msgid, send_reply, return_error):
Delay.set_sender(self, msgid, send_reply, return_error)
def set_sender(self, *args):
Delay.set_sender(self, *args)
self.ready.set()
def reply(self, obj):
self.ready.wait()
Delay.reply(self, obj)
self.conn.call_from_thread(self.conn.send_reply, self.msgid, obj)
def error(self, exc_info):
self.ready.wait()
Delay.error(self, exc_info)
self.conn.call_from_thread(Delay.error, self, exc_info)
# PROTOCOL NEGOTIATION
#
......@@ -304,9 +302,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
client for that particular call.
The protocol also supports asynchronous calls. The client does
not wait for a return value for an asynchronous call. The only
defined flag is ASYNC. If a method call message has the ASYNC
flag set, the server will raise an exception.
not wait for a return value for an asynchronous call.
If a method call raises an Exception, the exception is propagated
back to the client via the REPLY message. The client side will
......@@ -428,15 +424,6 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# The singleton dict is a socket map containing only this object.
self._singleton = {self._fileno: self}
# msgid_lock guards access to msgid
self.msgid = 0
self.msgid_lock = threading.Lock()
# replies_cond is used to block when a synchronous call is
# waiting for a response
self.replies_cond = threading.Condition()
self.replies = {}
# waiting_for_reply is used internally to indicate whether
# a call is in progress. setting a session key is deferred
# until after the call returns.
......@@ -488,9 +475,6 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.closed = True
self.__super_close()
self.trigger.pull_trigger()
self.replies_cond.acquire()
self.replies_cond.notifyAll()
self.replies_cond.release()
def register_object(self, obj):
"""Register obj as the true object to invoke methods on."""
......@@ -537,29 +521,19 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# will raise an exception. The exception will ultimately
# result in asycnore calling handle_error(), which will
# close the connection.
msgid, flags, name, args = self.marshal.decode(message)
msgid, async, name, args = self.marshal.decode(message)
if debug_zrpc:
self.log("recv msg: %s, %s, %s, %s" % (msgid, flags, name,
self.log("recv msg: %s, %s, %s, %s" % (msgid, async, name,
short_repr(args)),
level=TRACE)
if name == REPLY:
self.handle_reply(msgid, flags, args)
assert not async
self.handle_reply(msgid, args)
else:
self.handle_request(msgid, flags, name, args)
self.handle_request(msgid, async, name, args)
def handle_reply(self, msgid, flags, args):
if debug_zrpc:
self.log("recv reply: %s, %s, %s"
% (msgid, flags, short_repr(args)), level=TRACE)
self.replies_cond.acquire()
try:
self.replies[msgid] = flags, args
self.replies_cond.notifyAll()
finally:
self.replies_cond.release()
def handle_request(self, msgid, flags, name, args):
def handle_request(self, msgid, async, name, args):
obj = self.obj
if name.startswith('_') or not hasattr(obj, name):
......@@ -590,9 +564,14 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.log("%s() raised exception: %s" % (name, msg),
logging.ERROR, exc_info=True)
error = sys.exc_info()[:2]
return self.return_error(msgid, flags, *error)
if async:
self.log("Asynchronous call raised exception: %s" % self,
level=logging.ERROR, exc_info=True)
else:
self.return_error(msgid, *error)
return
if flags & ASYNC:
if async:
if ret is not None:
raise ZRPCError("async method %s returned value %s" %
(name, short_repr(ret)))
......@@ -601,43 +580,19 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.log("%s returns %s" % (name, short_repr(ret)),
logging.DEBUG)
if isinstance(ret, Delay):
ret.set_sender(msgid, self.send_reply, self.return_error)
ret.set_sender(msgid, self)
else:
self.send_reply(msgid, ret)
self.send_reply(msgid, ret, not self.delay_sesskey)
if self.delay_sesskey:
self.__super_setSessionKey(self.delay_sesskey)
self.delay_sesskey = None
def handle_error(self):
if sys.exc_info()[0] == SystemExit:
raise sys.exc_info()
self.log("Error caught in asyncore",
level=logging.ERROR, exc_info=True)
self.close()
def send_reply(self, msgid, ret):
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.marshal.encode(msgid, 0, REPLY, ret)
except: # see above
try:
r = short_repr(ret)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
self.poll()
def return_error(self, msgid, err_type, err_value):
# Note that, ideally, this should be defined soley for
# servers, but a test arranges to get it called on
# a client. Too much trouble to fix it now. :/
def return_error(self, msgid, flags, err_type, err_value):
if flags & ASYNC:
self.log("Asynchronous call raised exception: %s" % self,
level=logging.ERROR, exc_info=True)
return
if not isinstance(err_value, Exception):
err_value = err_type, err_value
......@@ -657,79 +612,37 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.message_output(msg)
self.poll()
def handle_error(self):
if sys.exc_info()[0] == SystemExit:
raise sys.exc_info()
self.log("Error caught in asyncore",
level=logging.ERROR, exc_info=True)
self.close()
def setSessionKey(self, key):
if self.waiting_for_reply:
self.delay_sesskey = key
else:
self.__super_setSessionKey(key)
# The next two public methods (call and callAsync) are used by
# clients to invoke methods on remote objects
def __new_msgid(self):
self.msgid_lock.acquire()
try:
msgid = self.msgid
self.msgid = self.msgid + 1
return msgid
finally:
self.msgid_lock.release()
def __call_message(self, method, args, flags):
# compute a message and return it
msgid = self.__new_msgid()
if debug_zrpc:
self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
level=TRACE)
return self.marshal.encode(msgid, flags, method, args)
def send_call(self, method, args, flags):
def send_call(self, method, args, async=False):
# send a message and return its msgid
msgid = self.__new_msgid()
if async:
msgid = 0
else:
msgid = self._new_msgid()
if debug_zrpc:
self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
self.log("send msg: %d, %d, %s, ..." % (msgid, async, method),
level=TRACE)
buf = self.marshal.encode(msgid, flags, method, args)
buf = self.marshal.encode(msgid, async, method, args)
self.message_output(buf)
return msgid
def call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args, 0)
r_flags, r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
# For testing purposes, it is useful to begin a synchronous call
# but not block waiting for its response.
def _deferred_call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args, 0)
self.trigger.pull_trigger()
return msgid
def _deferred_wait(self, msgid):
r_flags, r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def callAsync(self, method, *args):
if self.closed:
raise DisconnectedError()
self.send_call(method, args, ASYNC)
self.send_call(method, args, 1)
self.poll()
def callAsyncNoPoll(self, method, *args):
......@@ -738,7 +651,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# allowing any client to sneak in a load request.
if self.closed:
raise DisconnectedError()
self.send_call(method, args, ASYNC)
self.send_call(method, args, 1)
def callAsyncIterator(self, iterator):
"""Queue a sequence of calls using an iterator
......@@ -746,46 +659,11 @@ class Connection(smac.SizedMessageAsyncConnection, object):
The calls will not be interleaved with other calls from the same
client.
"""
self.message_output(self.__outputIterator(iterator))
def __outputIterator(self, iterator):
for method, args in iterator:
yield self.__call_message(method, args, ASYNC)
def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply."""
if debug_zrpc:
self.log("wait(%d)" % msgid, level=TRACE)
self.trigger.pull_trigger()
# Delay used when we call asyncore.poll() directly.
# Start with a 1 msec delay, double until 1 sec.
delay = 0.001
self.replies_cond.acquire()
try:
while 1:
if self.closed:
raise DisconnectedError()
reply = self.replies.get(msgid)
if reply is not None:
del self.replies[msgid]
if debug_zrpc:
self.log("wait(%d): reply=%s" %
(msgid, short_repr(reply)), level=TRACE)
return reply
self.replies_cond.wait()
finally:
self.replies_cond.release()
self.message_output(self.marshal.encode(0, 1, method, args)
for method, args in iterator)
def flush(self):
"""Invoke poll() until the output buffer is empty."""
if debug_zrpc:
self.log("flush")
while self.writable():
self.poll()
def handle_reply(self, msgid, ret):
assert msgid == -1 and ret is None
def poll(self):
"""Invoke asyncore mainloop to get pending message out."""
......@@ -794,7 +672,6 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.trigger.pull_trigger()
class ManagedServerConnection(Connection):
"""Server-side Connection subclass."""
......@@ -803,6 +680,7 @@ class ManagedServerConnection(Connection):
# Servers use a shared server trigger that uses the asyncore socket map
trigger = trigger()
call_from_thread = trigger.pull_trigger
def __init__(self, sock, addr, obj, mgr):
self.mgr = mgr
......@@ -821,13 +699,33 @@ class ManagedServerConnection(Connection):
self.obj.notifyDisconnected()
Connection.close(self)
def send_reply(self, msgid, ret, immediately=True):
# encode() can pass on a wide variety of exceptions from cPickle.
# While a bare `except` is generally poor practice, in this case
# it's acceptable -- we really do want to catch every exception
# cPickle may raise.
try:
msg = self.marshal.encode(msgid, 0, REPLY, ret)
except: # see above
try:
r = short_repr(ret)
except:
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
self.message_output(msg)
if immediately:
self.poll()
poll = smac.SizedMessageAsyncConnection.handle_write
class ManagedClientConnection(Connection):
"""Client-side Connection subclass."""
__super_init = Connection.__init__
__super_close = Connection.close
base_message_output = Connection.message_output
trigger = client_trigger
call_from_thread = trigger.pull_trigger
def __init__(self, sock, addr, mgr):
self.mgr = mgr
......@@ -846,9 +744,24 @@ class ManagedClientConnection(Connection):
self.queue_output = True
self.queued_messages = []
# msgid_lock guards access to msgid
self.msgid = 0
self.msgid_lock = threading.Lock()
# replies_cond is used to block when a synchronous call is
# waiting for a response
self.replies_cond = threading.Condition()
self.replies = {}
self.__super_init(sock, addr, None, tag='C', map=client_map)
client_trigger.pull_trigger()
def close(self):
Connection.close(self)
self.replies_cond.acquire()
self.replies_cond.notifyAll()
self.replies_cond.release()
# Our message_ouput() queues messages until recv_handshake() gets the
# protocol handshake from the server.
def message_output(self, message):
......@@ -890,3 +803,88 @@ class ManagedClientConnection(Connection):
self.queue_output = False
finally:
self.output_lock.release()
def _new_msgid(self):
self.msgid_lock.acquire()
try:
msgid = self.msgid
self.msgid = self.msgid + 1
return msgid
finally:
self.msgid_lock.release()
def call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def wait(self, msgid):
"""Invoke asyncore mainloop and wait for reply."""
if debug_zrpc:
self.log("wait(%d)" % msgid, level=TRACE)
self.trigger.pull_trigger()
# Delay used when we call asyncore.poll() directly.
# Start with a 1 msec delay, double until 1 sec.
delay = 0.001
self.replies_cond.acquire()
try:
while 1:
if self.closed:
raise DisconnectedError()
reply = self.replies.get(msgid, self)
if reply is not self:
del self.replies[msgid]
if debug_zrpc:
self.log("wait(%d): reply=%s" %
(msgid, short_repr(reply)), level=TRACE)
return reply
self.replies_cond.wait()
finally:
self.replies_cond.release()
# For testing purposes, it is useful to begin a synchronous call
# but not block waiting for its response.
def _deferred_call(self, method, *args):
if self.closed:
raise DisconnectedError()
msgid = self.send_call(method, args)
self.trigger.pull_trigger()
return msgid
def _deferred_wait(self, msgid):
r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
inst = r_args[1]
raise inst # error raised by server
else:
return r_args
def handle_reply(self, msgid, args):
if debug_zrpc:
self.log("recv reply: %s, %s"
% (msgid, short_repr(args)), level=TRACE)
self.replies_cond.acquire()
try:
self.replies[msgid] = args
self.replies_cond.notifyAll()
finally:
self.replies_cond.release()
def send_reply(self, msgid, ret):
# Whimper. Used to send heartbeat
assert msgid == -1 and ret is None
self.message_output('(J\xff\xff\xff\xffK\x00U\x06.replyNt.')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment