Commit ccd47c48 authored by Julien Muchembled's avatar Julien Muchembled Committed by Levin Zimmermann

protocol: switch to msgpack for packet serialization

Not only for performance reasons (at least 3% faster) but also because of
several ugly things in the way packets were defined:
- packet field names, which are only documentary; for roots fields,
  they even just duplicate the packet names
- a lot of repetitions for packet names, and even confusion between the name
  of the packet definition and the name of the actual notify/request packet
- the need to implement field types for anything, like PByte to support new
  compression formats, since PBoolean is not enough

neo/lib/protocol.py is now much smaller.
parent 1ad088c8
...@@ -13,6 +13,13 @@ ...@@ -13,6 +13,13 @@
############################################################################## ##############################################################################
def patch(): def patch():
# For msgpack & Py2/ZODB5.
try:
from zodbpickle import binary
binary._pack = bytes.__str__
except ImportError:
pass
from hashlib import md5 from hashlib import md5
from ZODB.Connection import Connection from ZODB.Connection import Connection
......
...@@ -181,7 +181,7 @@ class Application(ThreadedApplication): ...@@ -181,7 +181,7 @@ class Application(ThreadedApplication):
with self._connecting_to_master_node: with self._connecting_to_master_node:
result = self.master_conn result = self.master_conn
if result is None: if result is None:
self.new_oid_list = () self.new_oids = ()
result = self.master_conn = self._connectToPrimaryNode() result = self.master_conn = self._connectToPrimaryNode()
return result return result
...@@ -305,15 +305,19 @@ class Application(ThreadedApplication): ...@@ -305,15 +305,19 @@ class Application(ThreadedApplication):
"""Get a new OID.""" """Get a new OID."""
self._oid_lock_acquire() self._oid_lock_acquire()
try: try:
if not self.new_oid_list: for oid in self.new_oids:
break
else:
# Get new oid list from master node # Get new oid list from master node
# we manage a list of oid here to prevent # we manage a list of oid here to prevent
# from asking too many time new oid one by one # from asking too many time new oid one by one
# from master node # from master node
self._askPrimary(Packets.AskNewOIDs(100)) self._askPrimary(Packets.AskNewOIDs(100))
if not self.new_oid_list: for oid in self.new_oids:
break
else:
raise NEOStorageError('new_oid failed') raise NEOStorageError('new_oid failed')
self.last_oid = oid = self.new_oid_list.pop() self.last_oid = oid
return oid return oid
finally: finally:
self._oid_lock_release() self._oid_lock_release()
...@@ -612,7 +616,7 @@ class Application(ThreadedApplication): ...@@ -612,7 +616,7 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode. # user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB. # BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, txn_context.cache_dict) str(transaction.description), ext, list(txn_context.cache_dict))
queue = txn_context.queue queue = txn_context.queue
conn_dict = txn_context.conn_dict conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -697,7 +701,7 @@ class Application(ThreadedApplication): ...@@ -697,7 +701,7 @@ class Application(ThreadedApplication):
else: else:
try: try:
notify(Packets.AbortTransaction(txn_context.ttid, notify(Packets.AbortTransaction(txn_context.ttid,
txn_context.conn_dict)) list(txn_context.conn_dict)))
except ConnectionClosed: except ConnectionClosed:
pass pass
# No need to flush queue, as it will be destroyed on return, # No need to flush queue, as it will be destroyed on return,
...@@ -731,7 +735,8 @@ class Application(ThreadedApplication): ...@@ -731,7 +735,8 @@ class Application(ThreadedApplication):
for oid in checked_list: for oid in checked_list:
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list) p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
......
...@@ -163,8 +163,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -163,8 +163,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
self.app.setHandlerData(ttid) self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
oid_list.reverse() self.app.new_oids = iter(oid_list)
self.app.new_oid_list = oid_list
def incompleteTransaction(self, conn, message): def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be" raise NEOStorageError("storage nodes for which vote failed can not be"
......
...@@ -26,7 +26,7 @@ from .exception import NEOStorageError ...@@ -26,7 +26,7 @@ from .exception import NEOStorageError
class _WakeupPacket(object): class _WakeupPacket(object):
handler_method_name = 'pong' handler_method_name = 'pong'
decode = tuple _args = ()
getId = int getId = int
class Transaction(object): class Transaction(object):
......
...@@ -16,12 +16,19 @@ ...@@ -16,12 +16,19 @@
from functools import wraps from functools import wraps
from time import time from time import time
import msgpack
from msgpack.exceptions import UnpackValueError
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \
from .util import dummy_read_buffer, ReadBuffer Unpacker
@apply
class dummy_read_buffer(msgpack.Unpacker):
def feed(self, _):
pass
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
...@@ -292,7 +299,7 @@ class ListeningConnection(BaseConnection): ...@@ -292,7 +299,7 @@ class ListeningConnection(BaseConnection):
# message. # message.
else: else:
conn._connected() conn._connected()
self.em.addWriter(conn) # for ENCODED_VERSION self.em.addWriter(conn) # for HANDSHAKE_PACKET
def getAddress(self): def getAddress(self):
return self.connector.getAddress() return self.connector.getAddress()
...@@ -311,12 +318,12 @@ class Connection(BaseConnection): ...@@ -311,12 +318,12 @@ class Connection(BaseConnection):
client = False client = False
server = False server = False
peer_id = None peer_id = None
_parser_state = None _total_unpacked = 0
_timeout = None _timeout = None
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer() self.read_buf = Unpacker()
# NOTE cur_id will be set in Server|Client to maintain `cur_id % 2 == const` invariant # NOTE cur_id will be set in Server|Client to maintain `cur_id % 2 == const` invariant
#self.cur_id = 0 #self.cur_id = 0
self.aborted = False self.aborted = False
...@@ -429,41 +436,38 @@ class Connection(BaseConnection): ...@@ -429,41 +436,38 @@ class Connection(BaseConnection):
self._closure() self._closure()
def _parse(self): def _parse(self):
read = self.read_buf.read from .protocol import HANDSHAKE_PACKET, MAGIC_SIZE, Packets
version = read(4) read_buf = self.read_buf
if version is None: handshake = read_buf.read_bytes(len(HANDSHAKE_PACKET))
if handshake != HANDSHAKE_PACKET:
if HANDSHAKE_PACKET.startswith(handshake): # unlikely so tested last
# Not enough data and there's no API to know it in advance.
# Put it back.
read_buf.feed(handshake)
return return
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE, if HANDSHAKE_PACKET.startswith(handshake[:MAGIC_SIZE]):
PACKET_HEADER_FORMAT, Packets)
if version != ENCODED_VERSION:
logging.warning('Protocol version mismatch with %r', self) logging.warning('Protocol version mismatch with %r', self)
else:
logging.debug('Rejecting non-NEO %r', self)
raise ConnectorException raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size read_next = read_buf.next
unpack = PACKET_HEADER_FORMAT.unpack read_pos = read_buf.tell
def parse(): def parse():
state = self._parser_state try:
if state is None: msg_id, msg_type, args = read_next()
header = read(header_size) except StopIteration:
if header is None:
return return
msg_id, msg_type, msg_len = unpack(header) except UnpackValueError as e:
raise PacketMalformedError(str(e))
try: try:
packet_klass = Packets[msg_type] packet_klass = Packets[msg_type]
except KeyError: except KeyError:
raise PacketMalformedError('Unknown packet type') raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE: pos = read_pos()
raise PacketMalformedError('message too big (%d)' % msg_len) packet = packet_klass(*args)
else: packet.setId(msg_id)
msg_id, packet_klass, msg_len = state packet.size = pos - self._total_unpacked
data = read(msg_len) self._total_unpacked = pos
if data is None:
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet return packet
self._parse = parse self._parse = parse
return parse() return parse()
...@@ -517,7 +521,7 @@ class Connection(BaseConnection): ...@@ -517,7 +521,7 @@ class Connection(BaseConnection):
def close(self): def close(self):
if self.connector is None: if self.connector is None:
assert self._on_close is None assert self._on_close is None
assert not self.read_buf assert not self.read_buf.read_bytes(1)
assert not self.isPending() assert not self.isPending()
return return
# process the network events with the last registered handler to # process the network events with the last registered handler to
...@@ -528,7 +532,7 @@ class Connection(BaseConnection): ...@@ -528,7 +532,7 @@ class Connection(BaseConnection):
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
self.read_buf.clear() self.read_buf = dummy_read_buffer
try: try:
if self.connecting: if self.connecting:
handler.connectionFailed(self) handler.connectionFailed(self)
......
...@@ -19,7 +19,7 @@ import ssl ...@@ -19,7 +19,7 @@ import ssl
import errno import errno
from time import time from time import time
from . import logging from . import logging
from .protocol import ENCODED_VERSION from .protocol import HANDSHAKE_PACKET
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
...@@ -74,14 +74,13 @@ class SocketConnector(object): ...@@ -74,14 +74,13 @@ class SocketConnector(object):
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION] self.queued = [HANDSHAKE_PACKET]
self.queue_size = len(ENCODED_VERSION) self.queue_size = len(HANDSHAKE_PACKET)
return self return self
def queue(self, data): def queue(self, data):
was_empty = not self.queued was_empty = not self.queued
self.queued += data self.queued.append(data)
for data in data:
self.queue_size += len(data) self.queue_size += len(data)
return was_empty return was_empty
...@@ -172,7 +171,7 @@ class SocketConnector(object): ...@@ -172,7 +171,7 @@ class SocketConnector(object):
except socket.error, e: except socket.error, e:
self._error('recv', e) self._error('recv', e)
if data: if data:
read_buf.append(data) read_buf.feed(data)
return return
self._error('recv') self._error('recv')
...@@ -283,7 +282,7 @@ class _SSL: ...@@ -283,7 +282,7 @@ class _SSL:
# non-ragged EOF (peer properly closed its side of connection) # non-ragged EOF (peer properly closed its side of connection)
self._error('recv', None) self._error('recv', None)
return return
read_buf.append(data) read_buf.feed(data)
except ssl.SSLWantReadError: except ssl.SSLWantReadError:
pass pass
except socket.error, e: except socket.error, e:
......
...@@ -23,7 +23,7 @@ NOBODY = [] ...@@ -23,7 +23,7 @@ NOBODY = []
class _ConnectionClosed(object): class _ConnectionClosed(object):
handler_method_name = 'connectionClosed' handler_method_name = 'connectionClosed'
decode = tuple _args = ()
class getId(object): class getId(object):
def __eq__(self, other): def __eq__(self, other):
......
...@@ -71,7 +71,7 @@ class EventHandler(object): ...@@ -71,7 +71,7 @@ class EventHandler(object):
method = getattr(self, packet.handler_method_name) method = getattr(self, packet.handler_method_name)
except AttributeError: except AttributeError:
raise UnexpectedPacketError('no handler found') raise UnexpectedPacketError('no handler found')
args = packet.decode() or () args = packet._args
method(conn, *args, **kw) method(conn, *args, **kw)
except DelayEvent, e: except DelayEvent, e:
assert not kw, kw assert not kw, kw
...@@ -79,9 +79,6 @@ class EventHandler(object): ...@@ -79,9 +79,6 @@ class EventHandler(object):
except UnexpectedPacketError, e: except UnexpectedPacketError, e:
if not conn.isClosed(): if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args) self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e)
conn.close()
except NotReadyError, message: except NotReadyError, message:
if not conn.isClosed(): if not conn.isClosed():
if not message.args: if not message.args:
......
...@@ -154,7 +154,8 @@ class NEOLogger(Logger): ...@@ -154,7 +154,8 @@ class NEOLogger(Logger):
def _setup(self, filename=None, reset=False): def _setup(self, filename=None, reset=False):
from . import protocol as p from . import protocol as p
global uuid_str global packb, uuid_str
packb = p.packb
uuid_str = p.uuid_str uuid_str = p.uuid_str
if self._db is not None: if self._db is not None:
self._db.close() self._db.close()
...@@ -257,7 +258,7 @@ class NEOLogger(Logger): ...@@ -257,7 +258,7 @@ class NEOLogger(Logger):
pktcls.__name__, peer, r.pkt.decode()) pktcls.__name__, peer, r.pkt.decode())
""" """
if msg is not None: if msg is not None:
msg = buffer(msg) msg = buffer(msg if type(msg) is bytes else packb(msg))
q = "INSERT INTO packet VALUES (?,?,?,?,?,?)" q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
x = [r.created, nid, r.msg_id, r.code, peer, msg] x = [r.created, nid, r.msg_id, r.code, peer, msg]
else: else:
...@@ -307,9 +308,14 @@ class NEOLogger(Logger): ...@@ -307,9 +308,14 @@ class NEOLogger(Logger):
def packet(self, connection, packet, outgoing): def packet(self, connection, packet, outgoing):
#if True or self._db is not None: #if True or self._db is not None:
if self._db is not None: if self._db is not None:
body = packet._body if self._max_packet and self._max_packet < packet.size:
if self._max_packet and self._max_packet < len(body): args = None
body = None else:
args = packet._args
try:
hash(args)
except TypeError:
args = packb(args)
self._queue(PacketRecord( self._queue(PacketRecord(
pkt=packet, pkt=packet,
created=time(), created=time(),
...@@ -318,7 +324,7 @@ class NEOLogger(Logger): ...@@ -318,7 +324,7 @@ class NEOLogger(Logger):
outgoing=outgoing, outgoing=outgoing,
uuid=connection.getUUID(), uuid=connection.getUUID(),
addr=connection.getAddress(), addr=connection.getAddress(),
msg=body)) msg=args))
def node(self, *cluster_nid): def node(self, *cluster_nid):
name = self.name and str(self.name) name = self.name and str(self.name)
......
This diff is collapsed.
...@@ -166,65 +166,6 @@ def parseMasterList(masters): ...@@ -166,65 +166,6 @@ def parseMasterList(masters):
return map(parseNodeAddress, masters.split()) return map(parseNodeAddress, masters.split())
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
TODO: For better performance, use:
- socket.recv_into (64kiB blocks)
- struct.unpack_from
- and a circular buffer of dynamic size (initial size:
twice the length passed to socket.recv_into ?)
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
self.size += len(data)
self.content.append(data)
def __len__(self):
""" Return the current buffer size """
return self.size
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
to_read = size
# select required chunks
while to_read > 0:
chunk_data = pop_chunk()
to_read -= len(chunk_data)
append_data(chunk_data)
if to_read < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:to_read], last_chunk[to_read:]
self.content.appendleft(let)
chunk_list[-1] = keep
# join all chunks (one copy)
data = ''.join(chunk_list)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
dummy_read_buffer = ReadBuffer()
dummy_read_buffer.append = lambda _: None
class cached_property(object): class cached_property(object):
""" """
A property that is only computed once per instance and then replaces itself A property that is only computed once per instance and then replaces itself
......
...@@ -585,7 +585,9 @@ class Application(BaseApplication): ...@@ -585,7 +585,9 @@ class Application(BaseApplication):
self.tm.executeQueuedEvents() self.tm.executeQueuedEvents()
def startStorage(self, node): def startStorage(self, node):
node.send(Packets.StartOperation(self.backup_tid)) # XXX: Is this boolean 'backup' field needed ?
# Maybe this can be deduced from cluster state.
node.send(Packets.StartOperation(bool(self.backup_tid)))
uuid = node.getUUID() uuid = node.getUUID()
assert uuid not in self.storage_starting_set assert uuid not in self.storage_starting_set
assert uuid not in self.storage_ready_dict assert uuid not in self.storage_ready_dict
......
...@@ -157,8 +157,30 @@ class Log(object): ...@@ -157,8 +157,30 @@ class Log(object):
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {} x = {}
try:
Unpacker = g['Unpacker']
except KeyError:
unpackb = None
else:
from msgpack import ExtraData, UnpackException
def unpackb(data):
u = Unpacker()
u.feed(data)
data = u.unpack()
if u.read_bytes(1):
raise ExtraData
return data
self.PacketMalformedError = UnpackException
self.unpackb = unpackb
if self._decode > 1: if self._decode > 1:
try:
PStruct = g['PStruct'] PStruct = g['PStruct']
except KeyError:
for p in self.Packets.itervalues():
data_path = getattr(p, 'data_path', (None,))
if p._code >> 15 == data_path[0]:
x[p._code] = data_path[1:]
else:
PBoolean = g['PBoolean'] PBoolean = g['PBoolean']
def hasData(item): def hasData(item):
items = item._items items = item._items
...@@ -215,10 +237,12 @@ class Log(object): ...@@ -215,10 +237,12 @@ class Log(object):
if body is not None: if body is not None:
log = getattr(p, '_neolog', None) log = getattr(p, '_neolog', None)
if log or self._decode: if log or self._decode:
try:
if self.unpackb:
args = self.unpackb(body)
else:
p = p() p = p()
p._id = msg_id
p._body = body p._body = body
try:
args = p.decode() args = p.decode()
except self.PacketMalformedError: except self.PacketMalformedError:
msg.append("Can't decode packet") msg.append("Can't decode packet")
......
...@@ -461,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -461,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager):
return r return r
def loadData(self, data_id): def loadData(self, data_id):
return self.query("SELECT compression, hash, value" compression, checksum, data = self.query(
" FROM data WHERE id=?", (data_id,)).fetchone() "SELECT compression, hash, value FROM data WHERE id=?",
(data_id,)).fetchone()
if checksum:
return compression, str(checksum), str(data)
return compression, checksum, data
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getReadablePartition(oid) partition = self._getReadablePartition(oid)
......
...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler):
p = Errors.TidNotFound('%s does not exist' % dump(tid)) p = Errors.TidNotFound('%s does not exist' % dump(tid))
else: else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3], p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0]) bool(t[4]), t[0])
conn.answer(p) conn.answer(p)
def getEventQueue(self): def getEventQueue(self):
......
...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler): ...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler):
# Sending such packet does not mark the connection # Sending such packet does not mark the connection
# for writing if there's too little data in the buffer. # for writing if there's too little data in the buffer.
conn.send(Packets.AddTransaction(tid, user, conn.send(Packets.AddTransaction(tid, user,
desc, ext, packed, ttid, oid_list), msg_id) desc, ext, bool(packed), ttid, oid_list), msg_id)
# To avoid delaying several connections simultaneously, # To avoid delaying several connections simultaneously,
# and also prevent the backend from scanning different # and also prevent the backend from scanning different
# parts of the DB at the same time, we ask the # parts of the DB at the same time, we ask the
...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler): ...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid in object_list: for serial, oid in object_list:
oid_set = object_dict.get(serial) oid_set = object_dict.get(serial)
if oid_set: if oid_set:
if type(oid_set) is list: if type(oid_set) is tuple:
object_dict[serial] = oid_set = set(oid_set) object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set: if oid in oid_set:
oid_set.remove(oid) oid_set.remove(oid)
......
...@@ -71,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -71,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0] ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn) self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id) self.assertEqual(self.app.packing[1], peer_id)
...@@ -83,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -83,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0] status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status) self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -72,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -72,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.service.answerPack(conn2, False) self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack) packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id # TODO: verify packet peer id
self.assertTrue(packet.decode()[0]) self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase): ...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase):
def getFakePacket(self): def getFakePacket(self):
p = Mock({ p = Mock({
'decode': (),
'__repr__': 'Fake Packet', '__repr__': 'Fake Packet',
}) })
p._args = ()
p.handler_method_name = 'fake_method' p.handler_method_name = 'fake_method'
return p return p
...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase): ...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase):
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise PacketMalformedError
conn.mockCalledMethods = {}
def fake(c):
raise PacketMalformedError('message')
self.setFakeMethod(fake)
self.handler.dispatch(conn, packet)
self.checkClosed(conn)
# raise NotReadyError # raise NotReadyError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c): def fake(c):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
import socket import socket
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, parseNodeAddress from neo.lib.util import parseNodeAddress
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase): ...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase):
self.assertIn(parseNodeAddress('localhost'), local_address(0)) self.assertIn(parseNodeAddress('localhost'), local_address(0))
self.assertIn(parseNodeAddress('localhost:10'), local_address(10)) self.assertIn(parseNodeAddress('localhost:10'), local_address(10))
def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
This diff is collapsed.
...@@ -113,7 +113,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -113,7 +113,7 @@ class ReplicationTests(NEOThreadedTest):
importZODB(3) importZODB(3)
def delaySecondary(conn, packet): def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate): if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode() tid, upstream_name, source_dict = packet._args
return not upstream_name and all(source_dict.itervalues()) return not upstream_name and all(source_dict.itervalues())
# U -> B propagation # U -> B propagation
with NEOCluster(partitions=np, replicas=nr-1, storage_count=5, with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
...@@ -513,7 +513,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -513,7 +513,7 @@ class ReplicationTests(NEOThreadedTest):
""" """
def delayAskFetch(conn, packet): def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and \ return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \ packet._args[0] == offset and \
conn in s1.getConnectionList(s0) conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, num_replicas, cell_list): def changePartitionTable(orig, ptid, num_replicas, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
...@@ -768,7 +768,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -768,7 +768,7 @@ class ReplicationTests(NEOThreadedTest):
def logReplication(conn, packet): def logReplication(conn, packet):
if isinstance(packet, (Packets.AskFetchTransactions, if isinstance(packet, (Packets.AskFetchTransactions,
Packets.AskFetchObjects)): Packets.AskFetchObjects)):
ask.append(packet.decode()[2:]) ask.append(packet._args[2:])
def getTIDList(): def getTIDList():
return [t.tid for t in c.db().storage.iterator()] return [t.tid for t in c.db().storage.iterator()]
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
...@@ -869,7 +869,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -869,7 +869,7 @@ class ReplicationTests(NEOThreadedTest):
return True return True
elif not isinstance(packet, Packets.AskFetchTransactions): elif not isinstance(packet, Packets.AskFetchTransactions):
return return
ask.append(packet.decode()) ask.append(packet._args)
conn, = upstream.master.getConnectionList(backup.master) conn, = upstream.master.getConnectionList(backup.master)
with ConnectionFilter() as f, Patch(replicator.Replicator, with ConnectionFilter() as f, Patch(replicator.Replicator,
_nextPartitionSortKey=lambda orig, self, offset: offset): _nextPartitionSortKey=lambda orig, self, offset: offset):
...@@ -930,11 +930,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -930,11 +930,11 @@ class ReplicationTests(NEOThreadedTest):
@f.add @f.add
def delayReplicate(conn, packet): def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchTransactions): if isinstance(packet, Packets.AskFetchTransactions):
trans.append(packet.decode()[2]) trans.append(packet._args[2])
elif isinstance(packet, Packets.AskFetchObjects): elif isinstance(packet, Packets.AskFetchObjects):
if obj: if obj:
return True return True
obj.append(packet.decode()[2]) obj.append(packet._args[2])
s2.start() s2.start()
self.tic() self.tic()
cluster.neoctl.enableStorageList([s2.uuid]) cluster.neoctl.enableStorageList([s2.uuid])
...@@ -1021,7 +1021,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -1021,7 +1021,7 @@ class ReplicationTests(NEOThreadedTest):
def expected(changed): def expected(changed):
s0 = 1, CellStates.UP_TO_DATE s0 = 1, CellStates.UP_TO_DATE
s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE
return changed, 3 * [[s0, (2, s)], [s0, (3, s)]] return changed, 3 * ((s0, (2, s)), (s0, (3, s)))
for dry_run in True, False: for dry_run in True, False:
self.assertEqual(expected(True), self.assertEqual(expected(True),
cluster.neoctl.tweakPartitionTable(drop_list, dry_run)) cluster.neoctl.tweakPartitionTable(drop_list, dry_run))
......
...@@ -53,7 +53,7 @@ extras_require = { ...@@ -53,7 +53,7 @@ extras_require = {
'master': [], 'master': [],
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'], 'storage-importer': zodb_require + ['setproctitle'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)] 'neoppod[%s]' % ', '.join(extras_require)]
...@@ -109,6 +109,7 @@ setup( ...@@ -109,6 +109,7 @@ setup(
], ],
}, },
install_requires = [ install_requires = [
'msgpack>=0.5.6',
'python-dateutil', # neolog --from 'python-dateutil', # neolog --from
], ],
extras_require = extras_require, extras_require = extras_require,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment