Commit 70ed5e54 authored by Jim Fulton's avatar Jim Fulton

Dealt with some serialization issues

- Need to handle exception instances embedded within others.

  I dealt with this in msgpack using a "default" option (essentially a
  msgpack/json form of reduce).

  For pickle, we're still creating instance pickles in this case. :/

- Use a python-msgpack option to produce tuples rather than
  lists.  The ZEO protocol uses tuples far more often than lists.

  This really mostly or entirely affects tests.

  Removed workarounds for some test code that expected tuples and
  added some for test code that expects lists. :)
parent 28f7a924
...@@ -26,14 +26,17 @@ from ..shortrepr import short_repr ...@@ -26,14 +26,17 @@ from ..shortrepr import short_repr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def encoder(protocol): def encoder(protocol, server=False):
"""Return a non-thread-safe encoder """Return a non-thread-safe encoder
""" """
if protocol[:1] == b'M': if protocol[:1] == b'M':
from msgpack import packb from msgpack import packb
default = server_default if server else None
def encode(*args): def encode(*args):
return packb(args, use_bin_type=True) return packb(
args, use_bin_type=True, default=default)
return encode return encode
else: else:
assert protocol[:1] == b'Z' assert protocol[:1] == b'Z'
...@@ -69,7 +72,7 @@ def decoder(protocol): ...@@ -69,7 +72,7 @@ def decoder(protocol):
from msgpack import unpackb from msgpack import unpackb
def msgpack_decode(data): def msgpack_decode(data):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
return unpackb(data, encoding='utf-8') return unpackb(data, encoding='utf-8', use_list=False)
return msgpack_decode return msgpack_decode
else: else:
assert protocol[:1] == b'Z' assert protocol[:1] == b'Z'
...@@ -113,6 +116,17 @@ def pickle_server_decode(msg): ...@@ -113,6 +116,17 @@ def pickle_server_decode(msg):
logger.error("can't decode message: %s" % short_repr(msg)) logger.error("can't decode message: %s" % short_repr(msg))
raise raise
def server_default(obj):
if isinstance(obj, Exception):
return reduce_exception(obj)
else:
return obj
def reduce_exception(exc):
class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__)
return class_, exc.__dict__ or exc.args
_globals = globals() _globals = globals()
_silly = ('__doc__',) _silly = ('__doc__',)
......
...@@ -11,7 +11,7 @@ from ..shortrepr import short_repr ...@@ -11,7 +11,7 @@ from ..shortrepr import short_repr
from . import base from . import base
from .compat import asyncio, new_event_loop from .compat import asyncio, new_event_loop
from .marshal import server_decoder, encoder from .marshal import server_decoder, encoder, reduce_exception
class ServerProtocol(base.Protocol): class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface """asyncio low-level ZEO server interface
...@@ -70,7 +70,7 @@ class ServerProtocol(base.Protocol): ...@@ -70,7 +70,7 @@ class ServerProtocol(base.Protocol):
logger.info("received handshake %r" % logger.info("received handshake %r" %
str(protocol_version.decode('ascii'))) str(protocol_version.decode('ascii')))
self.protocol_version = protocol_version self.protocol_version = protocol_version
self.encode = encoder(protocol_version) self.encode = encoder(protocol_version, True)
self.decode = server_decoder(protocol_version) self.decode = server_decoder(protocol_version)
self.zeo_storage.notify_connected(self) self.zeo_storage.notify_connected(self)
else: else:
...@@ -135,10 +135,7 @@ class ServerProtocol(base.Protocol): ...@@ -135,10 +135,7 @@ class ServerProtocol(base.Protocol):
def send_error(self, message_id, exc, send_error=False): def send_error(self, message_id, exc, send_error=False):
"""Abstracting here so we can make this cleaner in the future """Abstracting here so we can make this cleaner in the future
""" """
class_ = exc.__class__ self.send_reply(message_id, reduce_exception(exc), send_error, 2)
class_ = "%s.%s" % (class_.__module__, class_.__name__)
args = class_, exc.__dict__ or exc.args
self.send_reply(message_id, args, send_error, 2)
def async(self, method, *args): def async(self, method, *args):
self.call_async(method, args) self.call_async(method, args)
......
...@@ -26,6 +26,7 @@ from .marshal import encoder, decoder ...@@ -26,6 +26,7 @@ from .marshal import encoder, decoder
class Base(object): class Base(object):
enc = b'Z' enc = b'Z'
seq_type = list
def setUp(self): def setUp(self):
super(Base, self).setUp() super(Base, self).setUp()
...@@ -39,11 +40,7 @@ class Base(object): ...@@ -39,11 +40,7 @@ class Base(object):
data = data[2:] data = data[2:]
self.assertEqual(struct.unpack(">I", size)[0], len(message)) self.assertEqual(struct.unpack(">I", size)[0], len(message))
if unpickle: if unpickle:
message = tuple(self.decode(message)) message = self.decode(message)
if isinstance(message[-1], list):
message = message[:-1] + (tuple(message[-1]),)
if isinstance(message[0], list):
message = (tuple(message[-1]),) + message[1:]
result.append(message) result.append(message)
if len(result) == 1: if len(result) == 1:
...@@ -205,7 +202,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -205,7 +202,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertEqual(self.pop(), (5, False, 'loadBefore', (b'1'*8, maxtid))) self.assertEqual(self.pop(), (5, False, 'loadBefore', (b'1'*8, maxtid)))
# Note load_before uses the oid as the message id. # Note load_before uses the oid as the message id.
self.respond(5, (b'data', b'a'*8, None)) self.respond(5, (b'data', b'a'*8, None))
self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, None)) self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
# If we make another request, it will be satisfied from the cache: # If we make another request, it will be satisfied from the cache:
loaded = self.load_before(b'1'*8, maxtid) loaded = self.load_before(b'1'*8, maxtid)
...@@ -213,7 +210,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -213,7 +210,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data) self.assertFalse(transport.data)
# Let's send an invalidation: # Let's send an invalidation:
self.send('invalidateTransaction', b'b'*8, [b'1'*8]) self.send('invalidateTransaction', b'b'*8, self.seq_type([b'1'*8]))
# Now, if we try to load current again, we'll make a server request. # Now, if we try to load current again, we'll make a server request.
loaded = self.load_before(b'1'*8, maxtid) loaded = self.load_before(b'1'*8, maxtid)
...@@ -224,21 +221,21 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -224,21 +221,21 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertEqual(self.pop(), (6, False, 'loadBefore', (b'1'*8, maxtid))) self.assertEqual(self.pop(), (6, False, 'loadBefore', (b'1'*8, maxtid)))
self.respond(6, (b'data2', b'b'*8, None)) self.respond(6, (b'data2', b'b'*8, None))
self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None)) self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(tuple(loaded2.result()), (b'data2', b'b'*8, None)) self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
# Loading non-current data may also be satisfied from cache # Loading non-current data may also be satisfied from cache
loaded = self.load_before(b'1'*8, b'b'*8) loaded = self.load_before(b'1'*8, b'b'*8)
self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, b'b'*8)) self.assertEqual(loaded.result(), (b'data', b'a'*8, b'b'*8))
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'c'*8) loaded = self.load_before(b'1'*8, b'c'*8)
self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None)) self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8) loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(self.pop(), (7, False, 'loadBefore', (b'1'*8, b'_'*8))) self.assertEqual(self.pop(), (7, False, 'loadBefore', (b'1'*8, b'_'*8)))
self.respond(7, (b'data0', b'^'*8, b'_'*8)) self.respond(7, (b'data0', b'^'*8, b'_'*8))
self.assertEqual(tuple(loaded.result()), (b'data0', b'^'*8, b'_'*8)) self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache # When committing transactions, we need to update the cache
# with committed data. To do this, we pass a (oid, data, resolved) # with committed data. To do this, we pass a (oid, data, resolved)
...@@ -549,7 +546,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -549,7 +546,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.pop(4) self.pop(4)
self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False)
self.respond(2, b'a'*8) self.respond(2, b'a'*8)
self.send('invalidateTransaction', b'c'*8, [b'1'*8], no_output=False) self.send('invalidateTransaction', b'c'*8, self.seq_type([b'1'*8]),
no_output=False)
self.assertEqual(self.pop(), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
# We'll disconnect: # We'll disconnect:
...@@ -567,7 +565,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -567,7 +565,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.pop(4) self.pop(4)
self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False)
self.respond(2, b'c'*8) self.respond(2, b'c'*8)
self.send('invalidateTransaction', b'e'*8, [b'1'*8], no_output=False) self.send('invalidateTransaction', b'e'*8, self.seq_type([b'1'*8]),
no_output=False)
self.assertEqual(self.pop(), (3, False, 'get_info', ())) self.assertEqual(self.pop(), (3, False, 'get_info', ()))
def test_flow_control(self): def test_flow_control(self):
...@@ -691,6 +690,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -691,6 +690,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
class MsgpackClientTests(ClientTests): class MsgpackClientTests(ClientTests):
enc = b'M' enc = b'M'
seq_type = tuple
class MemoryCache(object): class MemoryCache(object):
...@@ -830,6 +830,7 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -830,6 +830,7 @@ class ServerTests(Base, setupstack.TestCase):
class MsgpackServerTests(ServerTests): class MsgpackServerTests(ServerTests):
enc = b'M' enc = b'M'
seq_type = tuple
def server_protocol(msgpack, def server_protocol(msgpack,
zeo_storage=None, zeo_storage=None,
......
...@@ -24,8 +24,8 @@ A current client should be able to connect to a old server: ...@@ -24,8 +24,8 @@ A current client should be able to connect to a old server:
>>> import ZEO, ZODB.blob, transaction >>> import ZEO, ZODB.blob, transaction
>>> db = ZEO.DB(addr, client='client', blob_dir='blobs') >>> db = ZEO.DB(addr, client='client', blob_dir='blobs')
>>> wait_connected(db.storage) >>> wait_connected(db.storage)
>>> str(db.storage.protocol_version.decode('ascii')) >>> str(db.storage.protocol_version.decode('ascii'))[1:]
'Z4' '4'
>>> conn = db.open() >>> conn = db.open()
>>> conn.root().x = 0 >>> conn.root().x = 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment