Commit ccd47c48 authored by Julien Muchembled's avatar Julien Muchembled Committed by Levin Zimmermann

protocol: switch to msgpack for packet serialization

Not only for performance reasons (at least 3% faster) but also because of
several ugly things in the way packets were defined:
- packet field names, which are only documentary; for roots fields,
  they even just duplicate the packet names
- a lot of repetitions for packet names, and even confusion between the name
  of the packet definition and the name of the actual notify/request packet
- the need to implement field types for anything, like PByte to support new
  compression formats, since PBoolean is not enough

neo/lib/protocol.py is now much smaller.
parent 1ad088c8
...@@ -13,6 +13,13 @@ ...@@ -13,6 +13,13 @@
############################################################################## ##############################################################################
def patch(): def patch():
# For msgpack & Py2/ZODB5.
try:
from zodbpickle import binary
binary._pack = bytes.__str__
except ImportError:
pass
from hashlib import md5 from hashlib import md5
from ZODB.Connection import Connection from ZODB.Connection import Connection
......
...@@ -181,7 +181,7 @@ class Application(ThreadedApplication): ...@@ -181,7 +181,7 @@ class Application(ThreadedApplication):
with self._connecting_to_master_node: with self._connecting_to_master_node:
result = self.master_conn result = self.master_conn
if result is None: if result is None:
self.new_oid_list = () self.new_oids = ()
result = self.master_conn = self._connectToPrimaryNode() result = self.master_conn = self._connectToPrimaryNode()
return result return result
...@@ -305,15 +305,19 @@ class Application(ThreadedApplication): ...@@ -305,15 +305,19 @@ class Application(ThreadedApplication):
"""Get a new OID.""" """Get a new OID."""
self._oid_lock_acquire() self._oid_lock_acquire()
try: try:
if not self.new_oid_list: for oid in self.new_oids:
break
else:
# Get new oid list from master node # Get new oid list from master node
# we manage a list of oid here to prevent # we manage a list of oid here to prevent
# from asking too many time new oid one by one # from asking too many time new oid one by one
# from master node # from master node
self._askPrimary(Packets.AskNewOIDs(100)) self._askPrimary(Packets.AskNewOIDs(100))
if not self.new_oid_list: for oid in self.new_oids:
break
else:
raise NEOStorageError('new_oid failed') raise NEOStorageError('new_oid failed')
self.last_oid = oid = self.new_oid_list.pop() self.last_oid = oid
return oid return oid
finally: finally:
self._oid_lock_release() self._oid_lock_release()
...@@ -612,7 +616,7 @@ class Application(ThreadedApplication): ...@@ -612,7 +616,7 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode. # user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB. # BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, txn_context.cache_dict) str(transaction.description), ext, list(txn_context.cache_dict))
queue = txn_context.queue queue = txn_context.queue
conn_dict = txn_context.conn_dict conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -697,7 +701,7 @@ class Application(ThreadedApplication): ...@@ -697,7 +701,7 @@ class Application(ThreadedApplication):
else: else:
try: try:
notify(Packets.AbortTransaction(txn_context.ttid, notify(Packets.AbortTransaction(txn_context.ttid,
txn_context.conn_dict)) list(txn_context.conn_dict)))
except ConnectionClosed: except ConnectionClosed:
pass pass
# No need to flush queue, as it will be destroyed on return, # No need to flush queue, as it will be destroyed on return,
...@@ -731,7 +735,8 @@ class Application(ThreadedApplication): ...@@ -731,7 +735,8 @@ class Application(ThreadedApplication):
for oid in checked_list: for oid in checked_list:
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list) p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
......
...@@ -163,8 +163,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -163,8 +163,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
self.app.setHandlerData(ttid) self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
oid_list.reverse() self.app.new_oids = iter(oid_list)
self.app.new_oid_list = oid_list
def incompleteTransaction(self, conn, message): def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be" raise NEOStorageError("storage nodes for which vote failed can not be"
......
...@@ -26,7 +26,7 @@ from .exception import NEOStorageError ...@@ -26,7 +26,7 @@ from .exception import NEOStorageError
class _WakeupPacket(object): class _WakeupPacket(object):
handler_method_name = 'pong' handler_method_name = 'pong'
decode = tuple _args = ()
getId = int getId = int
class Transaction(object): class Transaction(object):
......
...@@ -16,12 +16,19 @@ ...@@ -16,12 +16,19 @@
from functools import wraps from functools import wraps
from time import time from time import time
import msgpack
from msgpack.exceptions import UnpackValueError
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \
from .util import dummy_read_buffer, ReadBuffer Unpacker
@apply
class dummy_read_buffer(msgpack.Unpacker):
def feed(self, _):
pass
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
...@@ -292,7 +299,7 @@ class ListeningConnection(BaseConnection): ...@@ -292,7 +299,7 @@ class ListeningConnection(BaseConnection):
# message. # message.
else: else:
conn._connected() conn._connected()
self.em.addWriter(conn) # for ENCODED_VERSION self.em.addWriter(conn) # for HANDSHAKE_PACKET
def getAddress(self): def getAddress(self):
return self.connector.getAddress() return self.connector.getAddress()
...@@ -311,12 +318,12 @@ class Connection(BaseConnection): ...@@ -311,12 +318,12 @@ class Connection(BaseConnection):
client = False client = False
server = False server = False
peer_id = None peer_id = None
_parser_state = None _total_unpacked = 0
_timeout = None _timeout = None
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer() self.read_buf = Unpacker()
# NOTE cur_id will be set in Server|Client to maintain `cur_id % 2 == const` invariant # NOTE cur_id will be set in Server|Client to maintain `cur_id % 2 == const` invariant
#self.cur_id = 0 #self.cur_id = 0
self.aborted = False self.aborted = False
...@@ -429,42 +436,39 @@ class Connection(BaseConnection): ...@@ -429,42 +436,39 @@ class Connection(BaseConnection):
self._closure() self._closure()
def _parse(self): def _parse(self):
read = self.read_buf.read from .protocol import HANDSHAKE_PACKET, MAGIC_SIZE, Packets
version = read(4) read_buf = self.read_buf
if version is None: handshake = read_buf.read_bytes(len(HANDSHAKE_PACKET))
return if handshake != HANDSHAKE_PACKET:
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE, if HANDSHAKE_PACKET.startswith(handshake): # unlikely so tested last
PACKET_HEADER_FORMAT, Packets) # Not enough data and there's no API to know it in advance.
if version != ENCODED_VERSION: # Put it back.
logging.warning('Protocol version mismatch with %r', self) read_buf.feed(handshake)
return
if HANDSHAKE_PACKET.startswith(handshake[:MAGIC_SIZE]):
logging.warning('Protocol version mismatch with %r', self)
else:
logging.debug('Rejecting non-NEO %r', self)
raise ConnectorException raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size read_next = read_buf.next
unpack = PACKET_HEADER_FORMAT.unpack read_pos = read_buf.tell
def parse(): def parse():
state = self._parser_state try:
if state is None: msg_id, msg_type, args = read_next()
header = read(header_size) except StopIteration:
if header is None: return
return except UnpackValueError as e:
msg_id, msg_type, msg_len = unpack(header) raise PacketMalformedError(str(e))
try: try:
packet_klass = Packets[msg_type] packet_klass = Packets[msg_type]
except KeyError: except KeyError:
raise PacketMalformedError('Unknown packet type') raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE: pos = read_pos()
raise PacketMalformedError('message too big (%d)' % msg_len) packet = packet_klass(*args)
else: packet.setId(msg_id)
msg_id, packet_klass, msg_len = state packet.size = pos - self._total_unpacked
data = read(msg_len) self._total_unpacked = pos
if data is None: return packet
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet
self._parse = parse self._parse = parse
return parse() return parse()
...@@ -517,7 +521,7 @@ class Connection(BaseConnection): ...@@ -517,7 +521,7 @@ class Connection(BaseConnection):
def close(self): def close(self):
if self.connector is None: if self.connector is None:
assert self._on_close is None assert self._on_close is None
assert not self.read_buf assert not self.read_buf.read_bytes(1)
assert not self.isPending() assert not self.isPending()
return return
# process the network events with the last registered handler to # process the network events with the last registered handler to
...@@ -528,7 +532,7 @@ class Connection(BaseConnection): ...@@ -528,7 +532,7 @@ class Connection(BaseConnection):
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
self.read_buf.clear() self.read_buf = dummy_read_buffer
try: try:
if self.connecting: if self.connecting:
handler.connectionFailed(self) handler.connectionFailed(self)
......
...@@ -19,7 +19,7 @@ import ssl ...@@ -19,7 +19,7 @@ import ssl
import errno import errno
from time import time from time import time
from . import logging from . import logging
from .protocol import ENCODED_VERSION from .protocol import HANDSHAKE_PACKET
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
...@@ -74,15 +74,14 @@ class SocketConnector(object): ...@@ -74,15 +74,14 @@ class SocketConnector(object):
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION] self.queued = [HANDSHAKE_PACKET]
self.queue_size = len(ENCODED_VERSION) self.queue_size = len(HANDSHAKE_PACKET)
return self return self
def queue(self, data): def queue(self, data):
was_empty = not self.queued was_empty = not self.queued
self.queued += data self.queued.append(data)
for data in data: self.queue_size += len(data)
self.queue_size += len(data)
return was_empty return was_empty
def _error(self, op, exc=None): def _error(self, op, exc=None):
...@@ -172,7 +171,7 @@ class SocketConnector(object): ...@@ -172,7 +171,7 @@ class SocketConnector(object):
except socket.error, e: except socket.error, e:
self._error('recv', e) self._error('recv', e)
if data: if data:
read_buf.append(data) read_buf.feed(data)
return return
self._error('recv') self._error('recv')
...@@ -283,7 +282,7 @@ class _SSL: ...@@ -283,7 +282,7 @@ class _SSL:
# non-ragged EOF (peer properly closed its side of connection) # non-ragged EOF (peer properly closed its side of connection)
self._error('recv', None) self._error('recv', None)
return return
read_buf.append(data) read_buf.feed(data)
except ssl.SSLWantReadError: except ssl.SSLWantReadError:
pass pass
except socket.error, e: except socket.error, e:
......
...@@ -23,7 +23,7 @@ NOBODY = [] ...@@ -23,7 +23,7 @@ NOBODY = []
class _ConnectionClosed(object): class _ConnectionClosed(object):
handler_method_name = 'connectionClosed' handler_method_name = 'connectionClosed'
decode = tuple _args = ()
class getId(object): class getId(object):
def __eq__(self, other): def __eq__(self, other):
......
...@@ -71,7 +71,7 @@ class EventHandler(object): ...@@ -71,7 +71,7 @@ class EventHandler(object):
method = getattr(self, packet.handler_method_name) method = getattr(self, packet.handler_method_name)
except AttributeError: except AttributeError:
raise UnexpectedPacketError('no handler found') raise UnexpectedPacketError('no handler found')
args = packet.decode() or () args = packet._args
method(conn, *args, **kw) method(conn, *args, **kw)
except DelayEvent, e: except DelayEvent, e:
assert not kw, kw assert not kw, kw
...@@ -79,9 +79,6 @@ class EventHandler(object): ...@@ -79,9 +79,6 @@ class EventHandler(object):
except UnexpectedPacketError, e: except UnexpectedPacketError, e:
if not conn.isClosed(): if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args) self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e)
conn.close()
except NotReadyError, message: except NotReadyError, message:
if not conn.isClosed(): if not conn.isClosed():
if not message.args: if not message.args:
......
...@@ -154,7 +154,8 @@ class NEOLogger(Logger): ...@@ -154,7 +154,8 @@ class NEOLogger(Logger):
def _setup(self, filename=None, reset=False): def _setup(self, filename=None, reset=False):
from . import protocol as p from . import protocol as p
global uuid_str global packb, uuid_str
packb = p.packb
uuid_str = p.uuid_str uuid_str = p.uuid_str
if self._db is not None: if self._db is not None:
self._db.close() self._db.close()
...@@ -257,7 +258,7 @@ class NEOLogger(Logger): ...@@ -257,7 +258,7 @@ class NEOLogger(Logger):
pktcls.__name__, peer, r.pkt.decode()) pktcls.__name__, peer, r.pkt.decode())
""" """
if msg is not None: if msg is not None:
msg = buffer(msg) msg = buffer(msg if type(msg) is bytes else packb(msg))
q = "INSERT INTO packet VALUES (?,?,?,?,?,?)" q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
x = [r.created, nid, r.msg_id, r.code, peer, msg] x = [r.created, nid, r.msg_id, r.code, peer, msg]
else: else:
...@@ -307,9 +308,14 @@ class NEOLogger(Logger): ...@@ -307,9 +308,14 @@ class NEOLogger(Logger):
def packet(self, connection, packet, outgoing): def packet(self, connection, packet, outgoing):
#if True or self._db is not None: #if True or self._db is not None:
if self._db is not None: if self._db is not None:
body = packet._body if self._max_packet and self._max_packet < packet.size:
if self._max_packet and self._max_packet < len(body): args = None
body = None else:
args = packet._args
try:
hash(args)
except TypeError:
args = packb(args)
self._queue(PacketRecord( self._queue(PacketRecord(
pkt=packet, pkt=packet,
created=time(), created=time(),
...@@ -318,7 +324,7 @@ class NEOLogger(Logger): ...@@ -318,7 +324,7 @@ class NEOLogger(Logger):
outgoing=outgoing, outgoing=outgoing,
uuid=connection.getUUID(), uuid=connection.getUUID(),
addr=connection.getAddress(), addr=connection.getAddress(),
msg=body)) msg=args))
def node(self, *cluster_nid): def node(self, *cluster_nid):
name = self.name and str(self.name) name = self.name and str(self.name)
......
...@@ -14,27 +14,63 @@ ...@@ -14,27 +14,63 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys import threading
import traceback from functools import partial
from cStringIO import StringIO from msgpack import packb
from struct import Struct
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and # to upgrade other nodes.
# the high order byte 0 is different from TLS Handshake (0x16). PROTOCOL_VERSION = 0
PROTOCOL_VERSION = 6 # By encoding the handshake packet with msgpack, the whole NEO stream can be
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION) # decoded with msgpack. The first byte is 0x92, which is different from TLS
# Handshake (0x16).
HANDSHAKE_PACKET = packb(('NEO', PROTOCOL_VERSION))
# Used to distinguish non-NEO stream from version mismatch.
MAGIC_SIZE = len(HANDSHAKE_PACKET) - len(packb(PROTOCOL_VERSION))
# Avoid memory errors on corrupted data.
MAX_PACKET_SIZE = 0x4000000
PACKET_HEADER_FORMAT = Struct('!LHL')
RESPONSE_MASK = 0x8000 RESPONSE_MASK = 0x8000
# Avoid some memory errors on corrupted data.
# Before we use msgpack, we limited the size of a whole packet. That's not
# possible anymore because the size is not known in advance. Packets bigger
# than the buffer size are possible (e.g. a huge list of small items) and for
# that we could compare the stream position (Unpacker.tell); it's not worth it.
UNPACK_BUFFER_SIZE = 0x4000000
@apply
def Unpacker():
global registerExtType, packb
from msgpack import ExtType, unpackb, Packer, Unpacker
ext_type_dict = []
kw = dict(use_bin_type=True)
pack_ext = Packer(**kw).pack
def registerExtType(getstate, make):
code = len(ext_type_dict)
ext_type_dict.append(lambda data: make(unpackb(data, use_list=False)))
return lambda obj: ExtType(code, pack_ext(getstate(obj)))
iterable_types = set, tuple
def default(obj):
try:
pack = obj._pack
except AttributeError:
assert type(obj) in iterable_types, type(obj)
return list(obj)
return pack()
lock = threading.Lock()
pack = Packer(default, strict_types=True, **kw).pack
def packb(obj):
with lock: # in case that 'default' is called
return pack(obj)
return partial(Unpacker, use_list=False, max_buffer_size=UNPACK_BUFFER_SIZE,
ext_hook=lambda code, data: ext_type_dict[code](data))
class Enum(tuple): class Enum(tuple):
class Item(int): class Item(int):
__slots__ = '_name', '_enum' __slots__ = '_name', '_enum', '_pack'
def __str__(self): def __str__(self):
return self._name return self._name
def __repr__(self): def __repr__(self):
...@@ -49,31 +85,38 @@ class Enum(tuple): ...@@ -49,31 +85,38 @@ class Enum(tuple):
names = func.func_code.co_names names = func.func_code.co_names
self = tuple.__new__(cls, map(cls.Item, xrange(len(names)))) self = tuple.__new__(cls, map(cls.Item, xrange(len(names))))
self._name = func.__name__ self._name = func.__name__
pack = registerExtType(int, self.__getitem__)
for item, name in zip(self, names): for item, name in zip(self, names):
setattr(self, name, item) setattr(self, name, item)
item._name = name item._name = name
item._enum = self item._enum = self
item._pack = (lambda x: lambda: x)(pack(item))
return self return self
def __repr__(self): def __repr__(self):
return "<Enum %s>" % self._name return "<Enum %s>" % self._name
# The order of extension type is important.
# Enum types first, sorted alphabetically.
@Enum @Enum
def ErrorCodes(): def CellStates():
ACK # Write-only cell. Last transactions are missing because storage is/was down
DENIED # for a while, or because it is new for the partition. It usually becomes
NOT_READY # UP_TO_DATE when replication is done.
OID_NOT_FOUND OUT_OF_DATE
TID_NOT_FOUND # Normal state: cell is writable/readable, and it isn't planned to drop it.
OID_DOES_NOT_EXIST UP_TO_DATE
PROTOCOL_ERROR # Same as UP_TO_DATE, except that it will be discarded as soon as another
REPLICATION_ERROR # node finishes to replicate it. It means a partition is moved from 1 node
CHECKING_ERROR # to another. It is also discarded immediately if out-of-date.
BACKEND_NOT_IMPLEMENTED FEEDING
NON_READABLE_CELL # A check revealed that data differs from other replicas. Cell is neither
READ_ONLY_ACCESS # readable nor writable.
INCOMPLETE_TRANSACTION CORRUPTED
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
@Enum @Enum
def ClusterStates(): def ClusterStates():
...@@ -108,11 +151,20 @@ def ClusterStates(): ...@@ -108,11 +151,20 @@ def ClusterStates():
STOPPING_BACKUP STOPPING_BACKUP
@Enum @Enum
def NodeTypes(): def ErrorCodes():
MASTER ACK
STORAGE DENIED
CLIENT NOT_READY
ADMIN OID_NOT_FOUND
TID_NOT_FOUND
OID_DOES_NOT_EXIST
PROTOCOL_ERROR
REPLICATION_ERROR
CHECKING_ERROR
BACKEND_NOT_IMPLEMENTED
NON_READABLE_CELL
READ_ONLY_ACCESS
INCOMPLETE_TRANSACTION
@Enum @Enum
def NodeStates(): def NodeStates():
...@@ -122,23 +174,11 @@ def NodeStates(): ...@@ -122,23 +174,11 @@ def NodeStates():
PENDING PENDING
@Enum @Enum
def CellStates(): def NodeTypes():
# Write-only cell. Last transactions are missing because storage is/was down MASTER
# for a while, or because it is new for the partition. It usually becomes STORAGE
# UP_TO_DATE when replication is done. CLIENT
OUT_OF_DATE ADMIN
# Normal state: cell is writable/readable, and it isn't planned to drop it.
UP_TO_DATE
# Same as UP_TO_DATE, except that it will be discarded as soon as another
# node finishes to replicate it. It means a partition is moved from 1 node
# to another. It is also discarded immediately if out-of-date.
FEEDING
# A check revealed that data differs from other replicas. Cell is neither
# readable nor writable.
CORRUPTED
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
# used for logging # used for logging
node_state_prefix_dict = { node_state_prefix_dict = {
...@@ -214,45 +254,24 @@ class NonReadableCell(Exception): ...@@ -214,45 +254,24 @@ class NonReadableCell(Exception):
On such event, the client must retry, preferably another cell. On such event, the client must retry, preferably another cell.
""" """
class Packet(object): class Packet(object):
""" """
Base class for any packet definition. The _fmt class attribute must be Base class for any packet definition.
defined for any non-empty packet.
""" """
_ignore_when_closed = False _ignore_when_closed = False
_request = None _request = None
_answer = None _answer = None
_body = None
_code = None _code = None
_fmt = None
_id = None _id = None
allow_dict = False
nodelay = True nodelay = True
poll_thread = False poll_thread = False
def __init__(self, *args): def __init__(self, *args):
assert self._code is not None, "Packet class not registered" assert self._code is not None, "Packet class not registered"
if args: assert self.allow_dict or dict not in map(type, args), args
buf = StringIO() self._args = args
self._fmt.encode(buf.write, args)
self._body = buf.getvalue()
else:
self._body = ''
def decode(self):
assert self._body is not None
if self._fmt is None:
return ()
buf = StringIO(self._body)
try:
return self._fmt.decode(buf.read)
except ParseError, msg:
name = self.__class__.__name__
raise PacketMalformedError("%s fail (%s)" % (name, msg))
def setContent(self, msg_id, body):
""" Register the packet content for future decoding """
self._id = msg_id
self._body = body
def setId(self, value): def setId(self, value):
self._id = value self._id = value
...@@ -261,14 +280,11 @@ class Packet(object): ...@@ -261,14 +280,11 @@ class Packet(object):
assert self._id is not None, "No identifier applied on the packet" assert self._id is not None, "No identifier applied on the packet"
return self._id return self._id
def encode(self): def encode(self, packb=packb):
""" Encode a packet as a string to send it over the network """ """ Encode a packet as a string to send it over the network """
content = self._body r = packb((self._id, self._code, self._args))
return (PACKET_HEADER_FORMAT.pack(self._id, self._code, len(content)), self.size = len(r)
content) return r
def __len__(self):
return PACKET_HEADER_FORMAT.size + len(self._body)
def __repr__(self): def __repr__(self):
return '%s[%r]' % (self.__class__.__name__, self._id) return '%s[%r]' % (self.__class__.__name__, self._id)
...@@ -281,10 +297,10 @@ class Packet(object): ...@@ -281,10 +297,10 @@ class Packet(object):
return self._code == other._code return self._code == other._code
def isError(self): def isError(self):
return isinstance(self, Error) return self._code == RESPONSE_MASK
def isResponse(self): def isResponse(self):
return self._code & RESPONSE_MASK == RESPONSE_MASK return self._code & RESPONSE_MASK
def getAnswerClass(self): def getAnswerClass(self):
return self._answer return self._answer
...@@ -296,1548 +312,530 @@ class Packet(object): ...@@ -296,1548 +312,530 @@ class Packet(object):
""" """
return self._ignore_when_closed return self._ignore_when_closed
class ParseError(Exception):
"""
An exception that encapsulate another and build the 'path' of the
packet item that generate the error.
"""
def __init__(self, item, trace):
Exception.__init__(self)
self._trace = trace
self._items = [item]
def append(self, item):
self._items.append(item)
def __repr__(self):
chain = '/'.join([item.getName() for item in reversed(self._items)])
return 'at %s:\n%s' % (chain, self._trace)
__str__ = __repr__
# packet parsers
class PItem(object):
"""
Base class for any packet item, _encode and _decode must be overridden
by subclasses.
"""
def __init__(self, name):
self._name = name
def __repr__(self):
return self.__class__.__name__
def getName(self):
return self._name
def _trace(self, method, *args):
try:
return method(*args)
except ParseError, e:
# trace and forward exception
e.append(self)
raise
except Exception:
# original exception, encapsulate it
trace = ''.join(traceback.format_exception(*sys.exc_info())[2:])
raise ParseError(self, trace)
def encode(self, writer, items):
return self._trace(self._encode, writer, items)
def decode(self, reader):
return self._trace(self._decode, reader)
def _encode(self, writer, items):
raise NotImplementedError, self.__class__.__name__
def _decode(self, reader):
raise NotImplementedError, self.__class__.__name__
class PStruct(PItem):
"""
Aggregate other items
"""
def __init__(self, name, *items):
PItem.__init__(self, name)
self._items = items
def _encode(self, writer, items):
assert len(self._items) == len(items), (items, self._items)
for item, value in zip(self._items, items):
item.encode(writer, value)
def _decode(self, reader):
return tuple([item.decode(reader) for item in self._items])
class PStructItem(PItem):
"""
A single value encoded with struct
"""
def __init__(self, name):
PItem.__init__(self, name)
struct = Struct(self._fmt)
self.pack = struct.pack
self.unpack = struct.unpack
self.size = struct.size
def _encode(self, writer, value):
writer(self.pack(value))
def _decode(self, reader):
return self.unpack(reader(self.size))[0]
class PStructItemOrNone(PStructItem):
def _encode(self, writer, value):
return writer(self._None if value is None else self.pack(value))
def _decode(self, reader):
value = reader(self.size)
return None if value == self._None else self.unpack(value)[0]
class POption(PStruct):
def _encode(self, writer, value):
if value is None:
writer('\0')
else:
writer('\1')
PStruct._encode(self, writer, value)
def _decode(self, reader):
if '\0\1'.index(reader(1)):
return PStruct._decode(self, reader)
class PList(PStructItem):
"""
A list of homogeneous items
"""
_fmt = '!L'
def __init__(self, name, item):
PStructItem.__init__(self, name)
self._item = item
def _encode(self, writer, items):
writer(self.pack(len(items)))
item = self._item
for value in items:
item.encode(writer, value)
def _decode(self, reader):
length = self.unpack(reader(self.size))[0]
item = self._item
return [item.decode(reader) for _ in xrange(length)]
class PDict(PStructItem):
"""
A dictionary with custom key and value formats
"""
_fmt = '!L'
def __init__(self, name, key, value):
PStructItem.__init__(self, name)
self._key = key
self._value = value
def _encode(self, writer, item):
assert isinstance(item , dict), (type(item), item)
writer(self.pack(len(item)))
key, value = self._key, self._value
for k, v in item.iteritems():
key.encode(writer, k)
value.encode(writer, v)
def _decode(self, reader):
length = self.unpack(reader(self.size))[0]
key, value = self._key, self._value
new_dict = {}
for _ in xrange(length):
k = key.decode(reader)
v = value.decode(reader)
new_dict[k] = v
return new_dict
class PEnum(PStructItem):
"""
Encapsulate an enumeration value
"""
_fmt = 'b'
def __init__(self, name, enum): class PacketRegistryFactory(dict):
PStructItem.__init__(self, name)
self._enum = enum
def _encode(self, writer, item): def __call__(self, name, base, d):
if item is None: for k, v in d.items():
item = -1 if isinstance(v, type) and issubclass(v, Packet):
writer(self.pack(item)) v.__name__ = k
v.handler_method_name = k[0].lower() + k[1:]
def _decode(self, reader): # this builds a "singleton"
code = self.unpack(reader(self.size))[0] return type('PacketRegistry', base, d)(self)
if code == -1:
return None def register(self, doc, ignore_when_closed=None, request=False, error=False,
try: _base=(Packet,), **kw):
return self._enum[code] """ Register a packet in the packet registry """
except KeyError: code = len(self)
enum = self._enum.__class__.__name__ if doc is None:
raise ValueError, 'Invalid code for %s enum: %r' % (enum, code) self[code] = None
return # None registered only to skip a code number (for compatibility)
class PString(PStructItem): if error and not request:
""" assert not code
A variable-length string code = RESPONSE_MASK
""" kw.update(__doc__=doc, _code=code)
_fmt = '!L' packet = type('', _base, kw)
# register the request
def _encode(self, writer, value): self[code] = packet
writer(self.pack(len(value))) if request:
writer(value) if ignore_when_closed is None:
# By default, on a closed connection:
def _decode(self, reader): # - request: ignore
length = self.unpack(reader(self.size))[0] # - answer: keep
return reader(length) # - notification: keep
packet._ignore_when_closed = True
class PAddress(PString): else:
""" assert ignore_when_closed is False
An host address (IPv4/IPv6) if error:
""" packet._answer = self[RESPONSE_MASK]
else:
def __init__(self, name): # build a class for the answer
PString.__init__(self, name) code |= RESPONSE_MASK
self._port = Struct('!H') kw['_code'] = code
answer = packet._answer = self[code] = type('', _base, kw)
def _encode(self, writer, address): return packet, answer
if address:
host, port = address
PString._encode(self, writer, host)
writer(self._port.pack(port))
else: else:
PString._encode(self, writer, '') assert ignore_when_closed is None
return packet
def _decode(self, reader):
host = PString._decode(self, reader)
if host:
p = self._port
return host, p.unpack(reader(p.size))[0]
class PBoolean(PStructItem): class Packets(dict):
"""
A boolean value, encoded as a single byte
"""
_fmt = '!?'
class PNumber(PStructItem):
"""
A integer number (4-bytes length)
"""
_fmt = '!L'
class PIndex(PStructItem):
""" """
A big integer to defined indexes in a huge list. Packet registry that checks packet code uniqueness and provides an index
""" """
_fmt = '!Q' __metaclass__ = PacketRegistryFactory()
notify = __metaclass__.register
request = partial(notify, request=True)
class PPTID(PStructItemOrNone): Error = notify("""
""" Error is a special type of message, because this can be sent against
A None value means an invalid PTID any other message, even if such a message does not expect a reply
""" usually.
_fmt = '!Q'
_None = Struct(_fmt).pack(0)
class PChecksum(PItem): :nodes: * -> *
""" """, error=True)
A hash (SHA1)
"""
def _encode(self, writer, checksum):
assert len(checksum) == 20, (len(checksum), checksum)
writer(checksum)
def _decode(self, reader): RequestIdentification, AcceptIdentification = request("""
return reader(20) Request a node identification. This must be the first packet for any
connection.
class PSignedNull(PStructItemOrNone): :nodes: * -> *
_fmt = '!l' """, poll_thread=True)
_None = Struct(_fmt).pack(0)
class PUUID(PSignedNull): Ping, Pong = request("""
""" Empty request used as network barrier.
An UUID (node identifier, 4-bytes signed integer)
"""
class PTID(PItem): :nodes: * -> *
""" """)
A transaction identifier
"""
def _encode(self, writer, tid):
if tid is None:
tid = INVALID_TID
assert len(tid) == 8, (len(tid), tid)
writer(tid)
def _decode(self, reader):
tid = reader(8)
if tid == INVALID_TID:
tid = None
return tid
# same definition, for now
POID = PTID
class PFloat(PStructItemOrNone):
"""
A float number (8-bytes length)
"""
_fmt = '!d'
_None = '\xff' * 8
# common definitions
PFEmpty = PStruct('no_content')
PFNodeType = PEnum('type', NodeTypes)
PFNodeState = PEnum('state', NodeStates)
PFCellState = PEnum('state', CellStates)
PFNodeList = PList('node_list',
PStruct('node',
PFNodeType,
PAddress('address'),
PUUID('uuid'),
PFNodeState,
PFloat('id_timestamp'),
),
)
PFCellList = PList('cell_list',
PStruct('cell',
PUUID('uuid'),
PFCellState,
),
)
PFRowList = PList('row_list',
PFCellList,
)
PFHistoryList = PList('history_list',
PStruct('history_entry',
PTID('serial'),
PNumber('size'),
),
)
PFUUIDList = PList('uuid_list',
PUUID('uuid'),
)
PFTidList = PList('tid_list',
PTID('tid'),
)
PFOidList = PList('oid_list',
POID('oid'),
)
# packets definition
class Error(Packet):
"""
Error is a special type of message, because this can be sent against
any other message, even if such a message does not expect a reply
usually.
:nodes: * -> * CloseClient = notify("""
""" Tell peer that it can close the connection if it has finished with us.
_fmt = PStruct('error',
PNumber('code'),
PString('message'),
)
class Ping(Packet): :nodes: * -> *
""" """)
Empty request used as network barrier.
:nodes: * -> * AskPrimary, AnswerPrimary = request("""
""" Ask node identier of the current primary master.
_answer = PFEmpty
class CloseClient(Packet): :nodes: ctl -> A
""" """)
Tell peer that it can close the connection if it has finished with us.
:nodes: * -> * NotPrimaryMaster = notify("""
""" Notify peer that I'm not the primary master. Attach any extra
information to help the peer joining the cluster.
class RequestIdentification(Packet): :nodes: SM -> *
""" """)
Request a node identification. This must be the first packet for any
connection.
:nodes: * -> * NotifyNodeInformation = notify("""
""" Notify information about one or more nodes.
poll_thread = True
_fmt = PStruct('request_identification',
PFNodeType,
PUUID('uuid'),
PAddress('address'),
PString('name'),
PFloat('id_timestamp'),
# storage:
PList('devpath', PString('devid')),
PList('new_nid', PNumber('offset')),
)
_answer = PStruct('accept_identification',
PFNodeType,
PUUID('my_uuid'),
PUUID('your_uuid'),
)
class PrimaryMaster(Packet):
"""
Ask node identier of the current primary master.
:nodes: ctl -> A :nodes: M -> *
""" """)
_answer = PStruct('answer_primary',
PUUID('primary_uuid'),
)
class NotPrimaryMaster(Packet): AskRecovery, AnswerRecovery = request("""
""" Ask storage nodes data needed by master to recover.
Notify peer that I'm not the primary master. Attach any extra information Reused by `neoctl print ids`.
to help the peer joining the cluster.
:nodes: SM -> * :nodes: M -> S; ctl -> A -> M
""" """)
_fmt = PStruct('not_primary_master',
PSignedNull('primary'),
PList('known_master_list',
PAddress('address'),
),
)
class Recovery(Packet):
"""
Ask storage nodes data needed by master to recover.
Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M AskLastIDs, AnswerLastIDs = request("""
""" Ask the last OID/TID so that a master can initialize its
_answer = PStruct('answer_recovery', TransactionManager. Reused by `neoctl print ids`.
PPTID('ptid'),
PTID('backup_tid'),
PTID('truncate_tid'),
)
class LastIDs(Packet): :nodes: M -> S; ctl -> A -> M
""" """)
Ask the last OID/TID so that a master can initialize its TransactionManager.
Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M AskPartitionTable, AnswerPartitionTable = request("""
""" Ask storage node the remaining data needed by master to recover.
_answer = PStruct('answer_last_ids',
POID('last_oid'),
PTID('last_tid'),
)
class PartitionTable(Packet): :nodes: M -> S
""" """)
Ask storage node the remaining data needed by master to recover.
:nodes: M -> S SendPartitionTable = notify("""
""" Send the full partition table to admin/client/storage nodes on
_answer = PStruct('answer_partition_table', connection.
PPTID('ptid'),
PNumber('num_replicas'),
PFRowList,
)
class NotifyPartitionTable(Packet): :nodes: M -> A, C, S
""" """)
Send the full partition table to admin/client/storage nodes on connection.
:nodes: M -> A, C, S NotifyPartitionChanges = notify("""
""" Notify about changes in the partition table.
_fmt = PStruct('send_partition_table',
PPTID('ptid'),
PNumber('num_replicas'),
PFRowList,
)
class PartitionChanges(Packet): :nodes: M -> *
""" """)
Notify about changes in the partition table.
:nodes: M -> * StartOperation = notify("""
""" Tell a storage node to start operation. Before this message,
_fmt = PStruct('notify_partition_changes', it must only communicate with the primary master.
PPTID('ptid'),
PNumber('num_replicas'),
PList('cell_list',
PStruct('cell',
PNumber('offset'),
PUUID('uuid'),
PFCellState,
),
),
)
class StartOperation(Packet):
"""
Tell a storage node to start operation. Before this message, it must only
communicate with the primary master.
:nodes: M -> S :nodes: M -> S
""" """)
_fmt = PStruct('start_operation',
# XXX: Is this boolean needed ? Maybe this
# can be deduced from cluster state.
PBoolean('backup'),
)
class StopOperation(Packet): StopOperation = notify("""
""" Notify that the cluster is not operational anymore.
Notify that the cluster is not operational anymore. Any operation between Any operation between nodes must be aborted.
nodes must be aborted.
:nodes: M -> S, C :nodes: M -> S, C
""" """)
class UnfinishedTransactions(Packet): AskUnfinishedTransactions, AnswerUnfinishedTransactions = request("""
""" Ask unfinished transactions, which will be replicated
Ask unfinished transactions, which will be replicated when they're finished. when they're finished.
:nodes: S -> M :nodes: S -> M
""" """)
_fmt = PStruct('ask_unfinished_transactions',
PList('row_list',
PNumber('offset'),
),
)
_answer = PStruct('answer_unfinished_transactions',
PTID('max_tid'),
PList('tid_list',
PTID('unfinished_tid'),
),
)
class LockedTransactions(Packet):
"""
Ask locked transactions to replay committed transactions that haven't been
unlocked.
:nodes: M -> S AskLockedTransactions, AnswerLockedTransactions = request("""
""" Ask locked transactions to replay committed transactions
_answer = PStruct('answer_locked_transactions', that haven't been unlocked.
PDict('tid_dict',
PTID('ttid'),
PTID('tid'),
),
)
class FinalTID(Packet):
"""
Return final tid if ttid has been committed, to recover from certain
failures during tpc_finish.
:nodes: M -> S; C -> M, S :nodes: M -> S
""" """, allow_dict=True)
_fmt = PStruct('final_tid',
PTID('ttid'),
)
_answer = PStruct('final_tid', AskFinalTID, AnswerFinalTID = request("""
PTID('tid'), Return final tid if ttid has been committed, to recover from certain
) failures during tpc_finish.
class ValidateTransaction(Packet): :nodes: M -> S; C -> M, S
""" """)
Do replay a committed transaction that was not unlocked.
:nodes: M -> S ValidateTransaction = notify("""
""" Do replay a committed transaction that was not unlocked.
_fmt = PStruct('validate_transaction',
PTID('ttid'),
PTID('tid'),
)
class BeginTransaction(Packet): :nodes: M -> S
""" """)
Ask to begin a new transaction. This maps to `tpc_begin`.
:nodes: C -> M AskBeginTransaction, AnswerBeginTransaction = request("""
""" Ask to begin a new transaction. This maps to `tpc_begin`.
_fmt = PStruct('ask_begin_transaction',
PTID('tid'),
)
_answer = PStruct('answer_begin_transaction', :nodes: C -> M
PTID('tid'), """)
)
class FailedVote(Packet): FailedVote = request("""
""" Report storage nodes for which vote failed.
Report storage nodes for which vote failed. True is returned if it's still possible to finish the transaction.
True is returned if it's still possible to finish the transaction.
:nodes: C -> M :nodes: C -> M
""" """, error=True)
_fmt = PStruct('failed_vote',
PTID('tid'),
PFUUIDList,
)
_answer = Error AskFinishTransaction, AnswerTransactionFinished = request("""
Finish a transaction. Return the TID of the committed transaction.
This maps to `tpc_finish`.
class FinishTransaction(Packet): :nodes: C -> M
""" """, ignore_when_closed=False, poll_thread=True)
Finish a transaction. Return the TID of the committed transaction.
This maps to `tpc_finish`.
:nodes: C -> M AskLockInformation, AnswerInformationLocked = request("""
""" Commit a transaction. The new data is read-locked.
poll_thread = True
_fmt = PStruct('ask_finish_transaction',
PTID('tid'),
PFOidList,
PList('checked_list',
POID('oid'),
),
)
_answer = PStruct('answer_information_locked',
PTID('ttid'),
PTID('tid'),
)
class NotifyTransactionFinished(Packet):
"""
Notify that a transaction blocking a replication is now finished.
:nodes: M -> S :nodes: M -> S
""" """, ignore_when_closed=False)
_fmt = PStruct('notify_transaction_finished',
PTID('ttid'),
PTID('max_tid'),
)
class LockInformation(Packet): InvalidateObjects = notify("""
""" Notify about a new transaction modifying objects,
Commit a transaction. The new data is read-locked. invalidating client caches.
:nodes: M -> S :nodes: M -> C
""" """)
_fmt = PStruct('ask_lock_informations',
PTID('ttid'),
PTID('tid'),
)
_answer = PStruct('answer_information_locked', NotifyUnlockInformation = notify("""
PTID('ttid'), Notify about a successfully committed transaction. The new data can be
) unlocked.
class InvalidateObjects(Packet): :nodes: M -> S
""" """)
Notify about a new transaction modifying objects,
invalidating client caches.
:nodes: M -> C AskNewOIDs, AnswerNewOIDs = request("""
""" Ask new OIDs to create objects.
_fmt = PStruct('ask_finish_transaction',
PTID('tid'),
PFOidList,
)
class UnlockInformation(Packet): :nodes: C -> M
""" """)
Notify about a successfully committed transaction. The new data can be
unlocked.
:nodes: M -> S NotifyDeadlock = notify("""
""" Ask master to generate a new TTID that will be used by the client to
_fmt = PStruct('notify_unlock_information', solve a deadlock by rebasing the transaction on top of concurrent
PTID('ttid'), changes.
)
class GenerateOIDs(Packet): :nodes: S -> M -> C
""" """)
Ask new OIDs to create objects.
:nodes: C -> M AskRebaseTransaction, AnswerRebaseTransaction = request("""
""" Rebase a transaction to solve a deadlock.
_fmt = PStruct('ask_new_oids',
PNumber('num_oids'),
)
_answer = PStruct('answer_new_oids', :nodes: C -> S
PFOidList, """)
)
class Deadlock(Packet): AskRebaseObject, AnswerRebaseObject = request("""
""" Rebase an object change to solve a deadlock.
Ask master to generate a new TTID that will be used by the client to solve
a deadlock by rebasing the transaction on top of concurrent changes.
:nodes: S -> M -> C :nodes: C -> S
"""
_fmt = PStruct('notify_deadlock',
PTID('ttid'),
PTID('locking_tid'),
)
class RebaseTransaction(Packet): XXX: It is a request packet to simplify the implementation. For more
""" efficiency, this should be turned into a notification, and the
Rebase a transaction to solve a deadlock. RebaseTransaction should answered once all objects are rebased
(so that the client can still wait on something).
""", data_path=(1, 0, 2, 0))
:nodes: C -> S AskStoreObject, AnswerStoreObject = request("""
""" Ask to create/modify an object. This maps to `store`.
_fmt = PStruct('ask_rebase_transaction',
PTID('ttid'),
PTID('locking_tid'),
)
_answer = PStruct('answer_rebase_transaction', As for IStorage, 'serial' is ZERO_TID for new objects.
PFOidList,
)
class RebaseObject(Packet): :nodes: C -> S
""" """, data_path=(0, 2))
Rebase an object change to solve a deadlock.
:nodes: C -> S AbortTransaction = notify("""
Abort a transaction. This maps to `tpc_abort`.
XXX: It is a request packet to simplify the implementation. For more :nodes: C -> S; C -> M -> S
efficiency, this should be turned into a notification, and the """)
RebaseTransaction should answered once all objects are rebased
(so that the client can still wait on something).
"""
_fmt = PStruct('ask_rebase_object',
PTID('ttid'),
PTID('oid'),
)
_answer = PStruct('answer_rebase_object',
POption('conflict',
PTID('serial'),
PTID('conflict_serial'),
POption('data',
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
),
)
)
class StoreObject(Packet):
"""
Ask to create/modify an object. This maps to `store`.
As for IStorage, 'serial' is ZERO_TID for new objects. AskStoreTransaction, AnswerStoreTransaction = request("""
Ask to store a transaction. Implies vote.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_store_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
PTID('tid'),
)
_answer = PStruct('answer_store_object',
PTID('conflict'),
)
class AbortTransaction(Packet):
"""
Abort a transaction. This maps to `tpc_abort`.
:nodes: C -> S; C -> M -> S AskVoteTransaction, AnswerVoteTransaction = request("""
""" Ask to vote a transaction.
_fmt = PStruct('abort_transaction',
PTID('tid'),
PFUUIDList, # unused for * -> S
)
class StoreTransaction(Packet): :nodes: C -> S
""" """)
Ask to store a transaction. Implies vote.
:nodes: C -> S AskObject, AnswerObject = request("""
""" Ask a stored object by its OID, optionally at/before a specific tid.
_fmt = PStruct('ask_store_transaction', This maps to `load/loadBefore/loadSerial`.
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PFOidList,
)
_answer = PFEmpty
class VoteTransaction(Packet):
"""
Ask to vote a transaction.
:nodes: C -> S :nodes: C -> S
""" """, data_path=(1, 3))
_fmt = PStruct('ask_vote_transaction',
PTID('tid'),
)
_answer = PFEmpty
class GetObject(Packet): AskTIDs, AnswerTIDs = request("""
""" Ask for TIDs between a range of offsets. The order of TIDs is
Ask a stored object by its OID, optionally at/before a specific tid. descending, and the range is [first, last). This maps to `undoLog`.
This maps to `load/loadBefore/loadSerial`.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_object',
POID('oid'),
PTID('at'),
PTID('before'),
)
_answer = PStruct('answer_object',
POID('oid'),
PTID('serial_start'),
PTID('serial_end'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class TIDList(Packet):
"""
Ask for TIDs between a range of offsets. The order of TIDs is descending,
and the range is [first, last). This maps to `undoLog`.
:nodes: C -> S AskTransactionInformation, AnswerTransactionInformation = request("""
""" Ask for transaction metadata.
_fmt = PStruct('ask_tids',
PIndex('first'),
PIndex('last'),
PNumber('partition'),
)
_answer = PStruct('answer_tids', :nodes: C -> S
PFTidList, """)
)
class TIDListFrom(Packet): AskObjectHistory, AnswerObjectHistory = request("""
""" Ask history information for a given object. The order of serials is
Ask for length TIDs starting at min_tid. The order of TIDs is ascending. descending, and the range is [first, last]. This maps to `history`.
Used by `iterator`.
:nodes: C -> S :nodes: C -> S
""" """)
_fmt = PStruct('tid_list_from',
PTID('min_tid'),
PTID('max_tid'),
PNumber('length'),
PNumber('partition'),
)
_answer = PStruct('answer_tids',
PFTidList,
)
class TransactionInformation(Packet):
"""
Ask for transaction metadata.
:nodes: C -> S AskPartitionList, AnswerPartitionList = request("""
""" Ask information about partitions.
_fmt = PStruct('ask_transaction_information',
PTID('tid'),
)
_answer = PStruct('answer_transaction_information',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PFOidList,
)
class ObjectHistory(Packet):
"""
Ask history information for a given object. The order of serials is
descending, and the range is [first, last]. This maps to `history`.
:nodes: C -> S :nodes: ctl -> A
""" """)
_fmt = PStruct('ask_object_history',
POID('oid'),
PIndex('first'),
PIndex('last'),
)
_answer = PStruct('answer_object_history',
POID('oid'),
PFHistoryList,
)
class PartitionList(Packet):
"""
Ask information about partitions.
:nodes: ctl -> A AskNodeList, AnswerNodeList = request("""
""" Ask information about nodes.
_fmt = PStruct('ask_partition_list',
PNumber('min_offset'),
PNumber('max_offset'),
PUUID('uuid'),
)
_answer = PStruct('answer_partition_list',
PPTID('ptid'),
PNumber('num_replicas'),
PFRowList,
)
class NodeList(Packet):
"""
Ask information about nodes.
:nodes: ctl -> A :nodes: ctl -> A
""" """)
_fmt = PStruct('ask_node_list',
PFNodeType,
)
_answer = PStruct('answer_node_list', SetNodeState = request("""
PFNodeList, Change the state of a node.
)
class SetNodeState(Packet): :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
Change the state of a node.
:nodes: ctl -> A -> M AddPendingNodes = request("""
""" Mark given pending nodes as running, for future inclusion when tweaking
_fmt = PStruct('set_node_state', the partition table.
PUUID('uuid'),
PFNodeState,
)
_answer = Error :nodes: ctl -> A -> M
""", error=True, ignore_when_closed=False)
class AddPendingNodes(Packet): TweakPartitionTable, AnswerTweakPartitionTable = request("""
""" Ask the master to balance the partition table, optionally excluding
Mark given pending nodes as running, for future inclusion when tweaking specific nodes in anticipation of removing them.
the partition table.
:nodes: ctl -> A -> M :nodes: ctl -> A -> M
""" """)
_fmt = PStruct('add_pending_nodes',
PFUUIDList,
)
_answer = Error SetNumReplicas = request("""
Set the number of replicas.
class TweakPartitionTable(Packet): :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
Ask the master to balance the partition table, optionally excluding
specific nodes in anticipation of removing them.
:nodes: ctl -> A -> M SetClusterState = request("""
""" Set the cluster state.
_fmt = PStruct('tweak_partition_table',
PBoolean('dry_run'),
PFUUIDList,
)
_answer = PStruct('answer_tweak_partition_table', :nodes: ctl -> A -> M
PBoolean('changed'), """, error=True, ignore_when_closed=False)
PFRowList,
)
class NotifyNodeInformation(Packet): Repair = request("""
""" Ask storage nodes to repair their databases.
Notify information about one or more nodes.
:nodes: M -> * :nodes: ctl -> A -> M
""" """, error=True)
_fmt = PStruct('notify_node_informations',
PFloat('id_timestamp'),
PFNodeList,
)
class SetNumReplicas(Packet): NotifyRepair = notify("""
""" Repair is translated to this message, asking a specific storage node to
Set the number of replicas. repair its database.
:nodes: ctl -> A -> M :nodes: M -> S
""" """)
_fmt = PStruct('set_num_replicas',
PNumber('num_replicas'),
)
_answer = Error NotifyClusterInformation = notify("""
Notify about a cluster state change.
class SetClusterState(Packet): :nodes: M -> *
""" """)
Set the cluster state.
:nodes: ctl -> A -> M AskClusterState, AnswerClusterState = request("""
""" Ask the state of the cluster
_fmt = PStruct('set_cluster_state',
PEnum('state', ClusterStates),
)
_answer = Error :nodes: ctl -> A; A -> M
""")
class Repair(Packet): AskObjectUndoSerial, AnswerObjectUndoSerial = request("""
""" Ask storage the serial where object data is when undoing given
Ask storage nodes to repair their databases. transaction, for a list of OIDs.
:nodes: ctl -> A -> M Answer a dict mapping oids to 3-tuples:
""" current_serial (TID)
_flags = map(PBoolean, ('dry_run', The latest serial visible to the undoing transaction.
# 'prune_orphan' (commented because it's the only option for the moment) undo_serial (TID)
)) Where undone data is (tid at which data is before given undo).
_fmt = PStruct('repair', is_current (bool)
PFUUIDList, If current_serial's data is current on storage.
*_flags)
_answer = Error :nodes: C -> S
""", allow_dict=True)
class RepairOne(Packet): AskTIDsFrom, AnswerTIDsFrom = request("""
""" Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
Repair is translated to this message, asking a specific storage node to Used by `iterator`.
repair its database.
:nodes: M -> S :nodes: C -> S
""" """)
_fmt = PStruct('repair', *Repair._flags)
class ClusterInformation(Packet): AskPack, AnswerPack = request("""
""" Request a pack at given TID.
Notify about a cluster state change.
:nodes: M -> * :nodes: C -> M -> S
""" """, ignore_when_closed=False)
_fmt = PStruct('notify_cluster_information',
PEnum('state', ClusterStates),
)
class ClusterState(Packet): CheckReplicas = request("""
""" Ask the cluster to search for mismatches between replicas, metadata
Ask the state of the cluster only, and optionally within a specific range. Reference nodes can be
specified.
:nodes: ctl -> A; A -> M :nodes: ctl -> A -> M
""" """, error=True, allow_dict=True)
_answer = PStruct('answer_cluster_state', CheckPartition = notify("""
PEnum('state', ClusterStates), Ask a storage node to compare a partition with all other nodes.
) Like for CheckReplicas, only metadata are checked, optionally within a
specific range. A reference node can be specified.
class ObjectUndoSerial(Packet): :nodes: M -> S
""" """)
Ask storage the serial where object data is when undoing given transaction,
for a list of OIDs.
object_tid_dict has the following format: AskCheckTIDRange, AnswerCheckTIDRange = request("""
key: oid Ask some stats about a range of transactions.
value: 3-tuple Used to know if there are differences between a replicating node and
current_serial (TID) reference node.
The latest serial visible to the undoing transaction.
undo_serial (TID)
Where undone data is (tid at which data is before given undo).
is_current (bool)
If current_serial's data is current on storage.
:nodes: C -> S :nodes: S -> S
""" """)
_fmt = PStruct('ask_undo_transaction',
PTID('tid'),
PTID('ltid'),
PTID('undone_tid'),
PFOidList,
)
_answer = PStruct('answer_undo_transaction',
PDict('object_tid_dict',
POID('oid'),
PStruct('object_tid_value',
PTID('current_serial'),
PTID('undo_serial'),
PBoolean('is_current'),
),
),
)
class CheckCurrentSerial(Packet):
"""
Check if given serial is current for the given oid, and lock it so that
this state is not altered until transaction ends.
This maps to `checkCurrentSerialInTransaction`.
:nodes: C -> S AskCheckSerialRange, AnswerCheckSerialRange = request("""
""" Ask some stats about a range of object history.
_fmt = PStruct('ask_check_current_serial', Used to know if there are differences between a replicating node and
PTID('tid'), reference node.
POID('oid'),
PTID('serial'),
)
_answer = StoreObject._answer :nodes: S -> S
""")
class Pack(Packet): NotifyPartitionCorrupted = notify("""
""" Notify that mismatches were found while check replicas for a partition.
Request a pack at given TID.
:nodes: C -> M -> S :nodes: S -> M
""" """)
_fmt = PStruct('ask_pack',
PTID('tid'),
)
_answer = PStruct('answer_pack', NotifyReady = notify("""
PBoolean('status'), Notify that we're ready to serve requests.
)
class CheckReplicas(Packet): :nodes: S -> M
""" """)
Ask the cluster to search for mismatches between replicas, metadata only,
and optionally within a specific range. Reference nodes can be specified.
:nodes: ctl -> A -> M AskLastTransaction, AnswerLastTransaction = request("""
""" Ask last committed TID.
_fmt = PStruct('check_replicas',
PDict('partition_dict',
PNumber('partition'),
PUUID('source'),
),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = Error
class CheckPartition(Packet):
"""
Ask a storage node to compare a partition with all other nodes.
Like for CheckReplicas, only metadata are checked, optionally within a
specific range. A reference node can be specified.
:nodes: M -> S :nodes: C -> M; ctl -> A -> M
""" """, poll_thread=True)
_fmt = PStruct('check_partition',
PNumber('partition'),
PStruct('source',
PString('upstream_name'),
PAddress('address'),
),
PTID('min_tid'),
PTID('max_tid'),
)
class CheckTIDRange(Packet):
"""
Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
:nodes: S -> S AskCheckCurrentSerial, AnswerCheckCurrentSerial = request("""
""" Check if given serial is current for the given oid, and lock it so that
_fmt = PStruct('ask_check_tid_range', this state is not altered until transaction ends.
PNumber('partition'), This maps to `checkCurrentSerialInTransaction`.
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = PStruct('answer_check_tid_range',
PNumber('count'),
PChecksum('checksum'),
PTID('max_tid'),
)
class CheckSerialRange(Packet):
"""
Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
:nodes: S -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_check_serial_range',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
)
_answer = PStruct('answer_check_serial_range',
PNumber('count'),
PChecksum('tid_checksum'),
PTID('max_tid'),
PChecksum('oid_checksum'),
POID('max_oid'),
)
class PartitionCorrupted(Packet):
"""
Notify that mismatches were found while check replicas for a partition.
:nodes: S -> M NotifyTransactionFinished = notify("""
""" Notify that a transaction blocking a replication is now finished.
_fmt = PStruct('partition_corrupted',
PNumber('partition'),
PList('cell_list',
PUUID('uuid'),
),
)
class LastTransaction(Packet):
"""
Ask last committed TID.
:nodes: C -> M; ctl -> A -> M :nodes: M -> S
""" """)
poll_thread = True
_answer = PStruct('answer_last_transaction', Replicate = notify("""
PTID('tid'), Notify a storage node to replicate partitions up to given 'tid'
) and from given sources.
class NotifyReady(Packet): args: tid, upstream_name, {partition: address}
""" - upstream_name: replicate from an upstream cluster
Notify that we're ready to serve requests. - address: address of the source storage node, or None if there's
no new data up to 'tid' for the given partition
:nodes: S -> M :nodes: M -> S
""" """, allow_dict=True)
class FetchTransactions(Packet): NotifyReplicationDone = notify("""
""" Notify the master node that a partition has been successfully
Ask a storage node to send all transaction data we don't have, replicated from a storage to another.
and reply with the list of transactions we should not have.
:nodes: S -> S :nodes: S -> M
""" """)
_fmt = PStruct('ask_transaction_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
PFTidList, # already known transactions
)
_answer = PStruct('answer_transaction_list',
PTID('pack_tid'),
PTID('next_tid'),
PFTidList, # transactions to delete
)
class AddTransaction(Packet):
"""
Send metadata of a transaction to a node that do not have them.
:nodes: S -> S AskFetchTransactions, AnswerFetchTransactions = request("""
""" Ask a storage node to send all transaction data we don't have,
nodelay = False and reply with the list of transactions we should not have.
_fmt = PStruct('add_transaction',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PTID('ttid'),
PFOidList,
)
class FetchObjects(Packet):
"""
Ask a storage node to send object records we don't have,
and reply with the list of records we should not have.
:nodes: S -> S :nodes: S -> S
""" """)
_fmt = PStruct('ask_object_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
PDict('object_dict', # already known objects
PTID('serial'),
PFOidList,
),
)
_answer = PStruct('answer_object_list',
PTID('pack_tid'),
PTID('next_tid'),
POID('next_oid'),
PDict('object_dict', # objects to delete
PTID('serial'),
PFOidList,
),
)
class AddObject(Packet):
"""
Send an object record to a node that do not have it.
:nodes: S -> S AskFetchObjects, AnswerFetchObjects = request("""
""" Ask a storage node to send object records we don't have,
nodelay = False and reply with the list of records we should not have.
_fmt = PStruct('add_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class Replicate(Packet):
"""
Notify a storage node to replicate partitions up to given 'tid'
and from given sources.
- upstream_name: replicate from an upstream cluster :nodes: S -> S
- address: address of the source storage node, or None if there's no new """, allow_dict=True)
data up to 'tid' for the given partition
:nodes: M -> S AddTransaction = notify("""
""" Send metadata of a transaction to a node that does not have them.
_fmt = PStruct('replicate',
PTID('tid'),
PString('upstream_name'),
PDict('source_dict',
PNumber('partition'),
PAddress('address'),
)
)
class ReplicationDone(Packet):
"""
Notify the master node that a partition has been successfully replicated
from a storage to another.
:nodes: S -> M :nodes: S -> S
""" """, nodelay=False)
_fmt = PStruct('notify_replication_done',
PNumber('offset'),
PTID('tid'),
)
class Truncate(Packet): AddObject = notify("""
""" Send an object record to a node that does not have it.
Request DB to be truncated. Also used to leave backup mode.
:nodes: ctl -> A -> M; M -> S :nodes: S -> S
""" """, nodelay=False, data_path=(0, 2))
_fmt = PStruct('truncate',
PTID('tid'),
)
_answer = Error Truncate = request("""
Request DB to be truncated. Also used to leave backup mode.
class FlushLog(Packet): :nodes: ctl -> A -> M; M -> S
""" """, error=True)
Request all nodes to flush their logs.
:nodes: ctl -> A -> M -> * FlushLog = notify("""
""" Request all nodes to flush their logs.
:nodes: ctl -> A -> M -> *
""")
_next_code = 0 del notify, request
def register(request, ignore_when_closed=None):
""" Register a packet in the packet registry """
global _next_code
code = _next_code
assert code < RESPONSE_MASK
_next_code = code + 1
if request is Error:
code |= RESPONSE_MASK
# register the request
request._code = code
answer = request._answer
if ignore_when_closed is None:
# By default, on a closed connection:
# - request: ignore
# - answer: keep
# - notification: keep
ignore_when_closed = answer is not None
request._ignore_when_closed = ignore_when_closed
if answer in (Error, None):
return request
# build a class for the answer
answer = type('Answer' + request.__name__, (Packet, ), {})
answer._fmt = request._answer
answer.poll_thread = request.poll_thread
answer._request = request
assert answer._code is None, "Answer of %s is already used" % (request, )
answer._code = code | RESPONSE_MASK
request._answer = answer
return request, answer
class Packets(dict):
"""
Packet registry that checks packet code uniqueness and provides an index
"""
def __metaclass__(name, base, d):
# this builds a "singleton"
cls = type('PacketRegistry', base, d)()
for k, v in d.iteritems():
if isinstance(v, type) and issubclass(v, Packet):
v.handler_method_name = k[0].lower() + k[1:]
cls[v._code] = v
return cls
Error = register(
Error)
RequestIdentification, AcceptIdentification = register(
RequestIdentification, ignore_when_closed=True)
Ping, Pong = register(
Ping)
CloseClient = register(
CloseClient)
AskPrimary, AnswerPrimary = register(
PrimaryMaster)
NotPrimaryMaster = register(
NotPrimaryMaster)
NotifyNodeInformation = register(
NotifyNodeInformation)
AskRecovery, AnswerRecovery = register(
Recovery)
AskLastIDs, AnswerLastIDs = register(
LastIDs)
AskPartitionTable, AnswerPartitionTable = register(
PartitionTable)
SendPartitionTable = register(
NotifyPartitionTable)
NotifyPartitionChanges = register(
PartitionChanges)
StartOperation = register(
StartOperation)
StopOperation = register(
StopOperation)
AskUnfinishedTransactions, AnswerUnfinishedTransactions = register(
UnfinishedTransactions)
AskLockedTransactions, AnswerLockedTransactions = register(
LockedTransactions)
AskFinalTID, AnswerFinalTID = register(
FinalTID)
ValidateTransaction = register(
ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction)
FailedVote = register(
FailedVote)
AskFinishTransaction, AnswerTransactionFinished = register(
FinishTransaction, ignore_when_closed=False)
AskLockInformation, AnswerInformationLocked = register(
LockInformation, ignore_when_closed=False)
InvalidateObjects = register(
InvalidateObjects)
NotifyUnlockInformation = register(
UnlockInformation)
AskNewOIDs, AnswerNewOIDs = register(
GenerateOIDs)
NotifyDeadlock = register(
Deadlock)
AskRebaseTransaction, AnswerRebaseTransaction = register(
RebaseTransaction)
AskRebaseObject, AnswerRebaseObject = register(
RebaseObject)
AskStoreObject, AnswerStoreObject = register(
StoreObject)
AbortTransaction = register(
AbortTransaction)
AskStoreTransaction, AnswerStoreTransaction = register(
StoreTransaction)
AskVoteTransaction, AnswerVoteTransaction = register(
VoteTransaction)
AskObject, AnswerObject = register(
GetObject)
AskTIDs, AnswerTIDs = register(
TIDList)
AskTransactionInformation, AnswerTransactionInformation = register(
TransactionInformation)
AskObjectHistory, AnswerObjectHistory = register(
ObjectHistory)
AskPartitionList, AnswerPartitionList = register(
PartitionList)
AskNodeList, AnswerNodeList = register(
NodeList)
SetNodeState = register(
SetNodeState, ignore_when_closed=False)
AddPendingNodes = register(
AddPendingNodes, ignore_when_closed=False)
TweakPartitionTable, AnswerTweakPartitionTable = register(
TweakPartitionTable)
SetNumReplicas = register(
SetNumReplicas, ignore_when_closed=False)
SetClusterState = register(
SetClusterState, ignore_when_closed=False)
Repair = register(
Repair)
NotifyRepair = register(
RepairOne)
NotifyClusterInformation = register(
ClusterInformation)
AskClusterState, AnswerClusterState = register(
ClusterState)
AskObjectUndoSerial, AnswerObjectUndoSerial = register(
ObjectUndoSerial)
AskTIDsFrom, AnswerTIDsFrom = register(
TIDListFrom)
AskPack, AnswerPack = register(
Pack, ignore_when_closed=False)
CheckReplicas = register(
CheckReplicas)
CheckPartition = register(
CheckPartition)
AskCheckTIDRange, AnswerCheckTIDRange = register(
CheckTIDRange)
AskCheckSerialRange, AnswerCheckSerialRange = register(
CheckSerialRange)
NotifyPartitionCorrupted = register(
PartitionCorrupted)
NotifyReady = register(
NotifyReady)
AskLastTransaction, AnswerLastTransaction = register(
LastTransaction)
AskCheckCurrentSerial, AnswerCheckCurrentSerial = register(
CheckCurrentSerial)
NotifyTransactionFinished = register(
NotifyTransactionFinished)
Replicate = register(
Replicate)
NotifyReplicationDone = register(
ReplicationDone)
AskFetchTransactions, AnswerFetchTransactions = register(
FetchTransactions)
AskFetchObjects, AnswerFetchObjects = register(
FetchObjects)
AddTransaction = register(
AddTransaction)
AddObject = register(
AddObject)
Truncate = register(
Truncate)
FlushLog = register(
FlushLog)
def Errors(): def Errors():
registry_dict = {} registry_dict = {}
handler_method_name_dict = {} handler_method_name_dict = {}
Error = Packets.Error
def register_error(code): def register_error(code):
return lambda self, message='': Error(code, message) return lambda self, message='': Error(code, message)
for error in ErrorCodes: for error in ErrorCodes:
...@@ -1856,19 +854,20 @@ from operator import itemgetter ...@@ -1856,19 +854,20 @@ from operator import itemgetter
def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)): def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)):
if node_list: if node_list:
node_list.sort(key=_sort_key)
node_list = [( node_list = [(
uuid_str(uuid), str(node_type), uuid_str(uuid), str(node_type),
('[%s]:%s' if ':' in addr[0] else '%s:%s') ('[%s]:%s' if ':' in addr[0] else '%s:%s')
% addr if addr else '', str(state), % addr if addr else '', str(state),
str(id_timestamp and datetime.utcfromtimestamp(id_timestamp))) str(id_timestamp and datetime.utcfromtimestamp(id_timestamp)))
for node_type, addr, uuid, state, id_timestamp in node_list] for node_type, addr, uuid, state, id_timestamp
in sorted(node_list, key=_sort_key)]
t = ''.join('%%-%us | ' % max(len(x[i]) for x in node_list) t = ''.join('%%-%us | ' % max(len(x[i]) for x in node_list)
for i in xrange(len(node_list[0]) - 1)) for i in xrange(len(node_list[0]) - 1))
return map((prefix + t + '%s').__mod__, node_list) return map((prefix + t + '%s').__mod__, node_list)
return () return ()
NotifyNodeInformation._neolog = staticmethod(lambda timestamp, node_list: Packets.NotifyNodeInformation._neolog = staticmethod(
lambda timestamp, node_list:
((timestamp,), formatNodeList(node_list, ' ! '))) ((timestamp,), formatNodeList(node_list, ' ! ')))
Error._neolog = staticmethod(lambda *args: ((), ("%s (%s)" % args,))) Packets.Error._neolog = staticmethod(lambda *args: ((), ("%s (%s)" % args,)))
...@@ -166,65 +166,6 @@ def parseMasterList(masters): ...@@ -166,65 +166,6 @@ def parseMasterList(masters):
return map(parseNodeAddress, masters.split()) return map(parseNodeAddress, masters.split())
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
TODO: For better performance, use:
- socket.recv_into (64kiB blocks)
- struct.unpack_from
- and a circular buffer of dynamic size (initial size:
twice the length passed to socket.recv_into ?)
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
self.size += len(data)
self.content.append(data)
def __len__(self):
""" Return the current buffer size """
return self.size
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
to_read = size
# select required chunks
while to_read > 0:
chunk_data = pop_chunk()
to_read -= len(chunk_data)
append_data(chunk_data)
if to_read < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:to_read], last_chunk[to_read:]
self.content.appendleft(let)
chunk_list[-1] = keep
# join all chunks (one copy)
data = ''.join(chunk_list)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
dummy_read_buffer = ReadBuffer()
dummy_read_buffer.append = lambda _: None
class cached_property(object): class cached_property(object):
""" """
A property that is only computed once per instance and then replaces itself A property that is only computed once per instance and then replaces itself
......
...@@ -585,7 +585,9 @@ class Application(BaseApplication): ...@@ -585,7 +585,9 @@ class Application(BaseApplication):
self.tm.executeQueuedEvents() self.tm.executeQueuedEvents()
def startStorage(self, node): def startStorage(self, node):
node.send(Packets.StartOperation(self.backup_tid)) # XXX: Is this boolean 'backup' field needed ?
# Maybe this can be deduced from cluster state.
node.send(Packets.StartOperation(bool(self.backup_tid)))
uuid = node.getUUID() uuid = node.getUUID()
assert uuid not in self.storage_starting_set assert uuid not in self.storage_starting_set
assert uuid not in self.storage_ready_dict assert uuid not in self.storage_ready_dict
......
...@@ -157,27 +157,49 @@ class Log(object): ...@@ -157,27 +157,49 @@ class Log(object):
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {} x = {}
try:
Unpacker = g['Unpacker']
except KeyError:
unpackb = None
else:
from msgpack import ExtraData, UnpackException
def unpackb(data):
u = Unpacker()
u.feed(data)
data = u.unpack()
if u.read_bytes(1):
raise ExtraData
return data
self.PacketMalformedError = UnpackException
self.unpackb = unpackb
if self._decode > 1: if self._decode > 1:
PStruct = g['PStruct'] try:
PBoolean = g['PBoolean'] PStruct = g['PStruct']
def hasData(item): except KeyError:
items = item._items for p in self.Packets.itervalues():
for i, item in enumerate(items): data_path = getattr(p, 'data_path', (None,))
if isinstance(item, PStruct): if p._code >> 15 == data_path[0]:
j = hasData(item) x[p._code] = data_path[1:]
if j: else:
return (i,) + j PBoolean = g['PBoolean']
elif (isinstance(item, PBoolean) def hasData(item):
and item._name == 'compression' items = item._items
and i + 2 < len(items) for i, item in enumerate(items):
and items[i+2]._name == 'data'): if isinstance(item, PStruct):
return i, j = hasData(item)
for p in self.Packets.itervalues(): if j:
if p._fmt is not None: return (i,) + j
path = hasData(p._fmt) elif (isinstance(item, PBoolean)
if path: and item._name == 'compression'
assert not hasattr(p, '_neolog'), p and i + 2 < len(items)
x[p._code] = path and items[i+2]._name == 'data'):
return i,
for p in self.Packets.itervalues():
if p._fmt is not None:
path = hasData(p._fmt)
if path:
assert not hasattr(p, '_neolog'), p
x[p._code] = path
self._getDataPath = x.get self._getDataPath = x.get
try: try:
...@@ -215,11 +237,13 @@ class Log(object): ...@@ -215,11 +237,13 @@ class Log(object):
if body is not None: if body is not None:
log = getattr(p, '_neolog', None) log = getattr(p, '_neolog', None)
if log or self._decode: if log or self._decode:
p = p()
p._id = msg_id
p._body = body
try: try:
args = p.decode() if self.unpackb:
args = self.unpackb(body)
else:
p = p()
p._body = body
args = p.decode()
except self.PacketMalformedError: except self.PacketMalformedError:
msg.append("Can't decode packet") msg.append("Can't decode packet")
else: else:
......
...@@ -461,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -461,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager):
return r return r
def loadData(self, data_id): def loadData(self, data_id):
return self.query("SELECT compression, hash, value" compression, checksum, data = self.query(
" FROM data WHERE id=?", (data_id,)).fetchone() "SELECT compression, hash, value FROM data WHERE id=?",
(data_id,)).fetchone()
if checksum:
return compression, str(checksum), str(data)
return compression, checksum, data
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getReadablePartition(oid) partition = self._getReadablePartition(oid)
......
...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler):
p = Errors.TidNotFound('%s does not exist' % dump(tid)) p = Errors.TidNotFound('%s does not exist' % dump(tid))
else: else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3], p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0]) bool(t[4]), t[0])
conn.answer(p) conn.answer(p)
def getEventQueue(self): def getEventQueue(self):
......
...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler): ...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler):
# Sending such packet does not mark the connection # Sending such packet does not mark the connection
# for writing if there's too little data in the buffer. # for writing if there's too little data in the buffer.
conn.send(Packets.AddTransaction(tid, user, conn.send(Packets.AddTransaction(tid, user,
desc, ext, packed, ttid, oid_list), msg_id) desc, ext, bool(packed), ttid, oid_list), msg_id)
# To avoid delaying several connections simultaneously, # To avoid delaying several connections simultaneously,
# and also prevent the backend from scanning different # and also prevent the backend from scanning different
# parts of the DB at the same time, we ask the # parts of the DB at the same time, we ask the
...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler): ...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid in object_list: for serial, oid in object_list:
oid_set = object_dict.get(serial) oid_set = object_dict.get(serial)
if oid_set: if oid_set:
if type(oid_set) is list: if type(oid_set) is tuple:
object_dict[serial] = oid_set = set(oid_set) object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set: if oid in oid_set:
oid_set.remove(oid) oid_set.remove(oid)
......
...@@ -71,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -71,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0] ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn) self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id) self.assertEqual(self.app.packing[1], peer_id)
...@@ -83,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -83,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0] status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status) self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -72,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -72,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.service.answerPack(conn2, False) self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack) packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id # TODO: verify packet peer id
self.assertTrue(packet.decode()[0]) self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase): ...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase):
def getFakePacket(self): def getFakePacket(self):
p = Mock({ p = Mock({
'decode': (),
'__repr__': 'Fake Packet', '__repr__': 'Fake Packet',
}) })
p._args = ()
p.handler_method_name = 'fake_method' p.handler_method_name = 'fake_method'
return p return p
...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase): ...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase):
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise PacketMalformedError
conn.mockCalledMethods = {}
def fake(c):
raise PacketMalformedError('message')
self.setFakeMethod(fake)
self.handler.dispatch(conn, packet)
self.checkClosed(conn)
# raise NotReadyError # raise NotReadyError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c): def fake(c):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
import socket import socket
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, parseNodeAddress from neo.lib.util import parseNodeAddress
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase): ...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase):
self.assertIn(parseNodeAddress('localhost'), local_address(0)) self.assertIn(parseNodeAddress('localhost'), local_address(0))
self.assertIn(parseNodeAddress('localhost:10'), local_address(10)) self.assertIn(parseNodeAddress('localhost:10'), local_address(10))
def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -1340,7 +1340,7 @@ class Test(NEOThreadedTest): ...@@ -1340,7 +1340,7 @@ class Test(NEOThreadedTest):
# Also check that the master reset the last oid to a correct value. # Also check that the master reset the last oid to a correct value.
t.begin() t.begin()
self.assertEqual(1, u64(c.root()['x']._p_oid)) self.assertEqual(1, u64(c.root()['x']._p_oid))
self.assertFalse(cluster.client.new_oid_list) self.assertFalse(cluster.client.new_oids)
self.assertEqual(2, u64(cluster.client.new_oid())) self.assertEqual(2, u64(cluster.client.new_oid()))
@with_cluster() @with_cluster()
...@@ -2106,7 +2106,7 @@ class Test(NEOThreadedTest): ...@@ -2106,7 +2106,7 @@ class Test(NEOThreadedTest):
except threading.ThreadError: except threading.ThreadError:
l[j].acquire() l[j].acquire()
threads[j-1].start() threads[j-1].start()
if x != 'StoreTransaction': if x != 'AskStoreTransaction':
try: try:
l[i].acquire() l[i].acquire()
except IndexError: except IndexError:
...@@ -2183,15 +2183,16 @@ class Test(NEOThreadedTest): ...@@ -2183,15 +2183,16 @@ class Test(NEOThreadedTest):
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes, x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 1), (1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 2, 3, 'tpc_begin', 1, 2, 4, 3, 4, ('tpc_begin', 'tpc_begin', 1, 2, 3, 'tpc_begin', 1, 2, 4, 3, 4,
'StoreTransaction', 'RebaseTransaction', 'RebaseTransaction', 'AskStoreTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AskRebaseTransaction', 'AnswerRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'), 'AnswerRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction'),
[4, 6, 2, 6]) [4, 6, 2, 6])
try: try:
x[1].remove(1) x[1].remove(1)
except ValueError: except ValueError:
pass pass
self.assertEqual(x, {0: [2, 'StoreTransaction'], 1: ['tpc_abort']}) self.assertEqual(x, {0: [2, 'AskStoreTransaction'], 1: ['tpc_abort']})
def testCascadedDeadlockAvoidanceWithOneStorage2(self): def testCascadedDeadlockAvoidanceWithOneStorage2(self):
def changes(r1, r2, r3): def changes(r1, r2, r3):
...@@ -2214,8 +2215,8 @@ class Test(NEOThreadedTest): ...@@ -2214,8 +2215,8 @@ class Test(NEOThreadedTest):
(0, 1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, (0, 1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1,
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1), 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1),
('tpc_begin', 1, 'tpc_begin', 1, 2, 3, 'tpc_begin', ('tpc_begin', 1, 'tpc_begin', 1, 2, 3, 'tpc_begin',
2, 3, 4, 3, 4, 'StoreTransaction', 'RebaseTransaction', 2, 3, 4, 3, 4, 'AskStoreTransaction', 'AskRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'), 'AskRebaseTransaction', 'AnswerRebaseTransaction'),
[1, 7, 9, 0]) [1, 7, 9, 0])
x[0].sort(key=str) x[0].sort(key=str)
try: try:
...@@ -2224,8 +2225,8 @@ class Test(NEOThreadedTest): ...@@ -2224,8 +2225,8 @@ class Test(NEOThreadedTest):
pass pass
self.assertEqual(x, { self.assertEqual(x, {
0: [2, 3, 'AnswerRebaseTransaction', 0: [2, 3, 'AnswerRebaseTransaction',
'RebaseTransaction', 'StoreTransaction'], 'AskRebaseTransaction', 'AskStoreTransaction'],
1: ['AnswerRebaseTransaction','RebaseTransaction', 1: ['AnswerRebaseTransaction','AskRebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'], 'AnswerRebaseTransaction', 'tpc_abort'],
}) })
...@@ -2258,7 +2259,7 @@ class Test(NEOThreadedTest): ...@@ -2258,7 +2259,7 @@ class Test(NEOThreadedTest):
end = self._testComplexDeadlockAvoidanceWithOneStorage(changes, end = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(0, 1, 1, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, vote_t2, tic_t1), (0, 1, 1, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, vote_t2, tic_t1),
('tpc_begin', 1) * 2, [3, 0, 0, 0], None) ('tpc_begin', 1) * 2, [3, 0, 0, 0], None)
self.assertLessEqual(2, end[0].count('RebaseTransaction')) self.assertLessEqual(2, end[0].count('AskRebaseTransaction'))
def testFailedConflictOnBigValueDuringDeadlockAvoidance(self): def testFailedConflictOnBigValueDuringDeadlockAvoidance(self):
def changes(r1, r2, r3): def changes(r1, r2, r3):
...@@ -2274,10 +2275,10 @@ class Test(NEOThreadedTest): ...@@ -2274,10 +2275,10 @@ class Test(NEOThreadedTest):
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes, x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 1, 2, 2, 2, 1, 2, 2, 0, 0, 1, 1, 1, 0), (1, 1, 1, 2, 2, 2, 1, 2, 2, 0, 0, 1, 1, 1, 0),
('tpc_begin', 'tpc_begin', 1, 2, 'tpc_begin', 1, 3, 3, 4, ('tpc_begin', 'tpc_begin', 1, 2, 'tpc_begin', 1, 3, 3, 4,
'StoreTransaction', 2, 4, 'RebaseTransaction', 'AskStoreTransaction', 2, 4, 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'), 'AnswerRebaseTransaction', 'tpc_abort'),
[5, 1, 0, 2], POSException.ConflictError) [5, 1, 0, 2], POSException.ConflictError)
self.assertEqual(x, {0: ['StoreTransaction']}) self.assertEqual(x, {0: ['AskStoreTransaction']})
@with_cluster(replicas=1, partitions=4) @with_cluster(replicas=1, partitions=4)
def testNotifyReplicated(self, cluster): def testNotifyReplicated(self, cluster):
...@@ -2364,7 +2365,7 @@ class Test(NEOThreadedTest): ...@@ -2364,7 +2365,7 @@ class Test(NEOThreadedTest):
def delayConflict(conn, packet): def delayConflict(conn, packet):
app = self.getConnectionApp(conn) app = self.getConnectionApp(conn)
if (isinstance(packet, Packets.AnswerStoreObject) if (isinstance(packet, Packets.AnswerStoreObject)
and packet.decode()[0]): and packet._args[0]):
conn, = cluster.client.getConnectionList(app) conn, = cluster.client.getConnectionList(app)
kw = conn._handlers._pending[0][0][packet._id][1] kw = conn._handlers._pending[0][0][packet._id][1]
return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop() return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop()
...@@ -2382,8 +2383,9 @@ class Test(NEOThreadedTest): ...@@ -2382,8 +2383,9 @@ class Test(NEOThreadedTest):
self.thread_switcher(threads, self.thread_switcher(threads,
(1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0, (1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0,
t1_rebase, 2, t3_b, 3, t4_d, 0, 2, 2), t1_rebase, 2, t3_b, 3, t4_d, 0, 2, 2),
('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin', 2, 1, 1, ('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin',
3, 3, 4, 4, 3, 1, 'RebaseTransaction', 'RebaseTransaction', 2, 1, 1, 3, 3, 4, 4, 3, 1,
'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2
)) as end: )) as end:
delay = f.delayAskFetchTransactions() delay = f.delayAskFetchTransactions()
...@@ -2395,11 +2397,11 @@ class Test(NEOThreadedTest): ...@@ -2395,11 +2397,11 @@ class Test(NEOThreadedTest):
t4.begin() t4.begin()
self.assertEqual([15, 11, 13, 16], [r[x].value for x in 'abcd']) self.assertEqual([15, 11, 13, 16], [r[x].value for x in 'abcd'])
self.assertEqual([2, 2], map(end.pop(2).count, self.assertEqual([2, 2], map(end.pop(2).count,
['RebaseTransaction', 'AnswerRebaseTransaction'])) ['AskRebaseTransaction', 'AnswerRebaseTransaction']))
self.assertEqual(end, { self.assertEqual(end, {
0: [1, 'StoreTransaction'], 0: [1, 'AskStoreTransaction'],
1: ['StoreTransaction'], 1: ['AskStoreTransaction'],
3: [4, 'StoreTransaction'], 3: [4, 'AskStoreTransaction'],
}) })
self.assertFalse(s1.dm.getOrphanList()) self.assertFalse(s1.dm.getOrphanList())
...@@ -2435,7 +2437,8 @@ class Test(NEOThreadedTest): ...@@ -2435,7 +2437,8 @@ class Test(NEOThreadedTest):
self.thread_switcher((thread,), self.thread_switcher((thread,),
(1, 0, 1, 1, t2_b, 0, 0, 1, t2_vote, 0, 0), (1, 0, 1, 1, t2_b, 0, 0, 1, t2_vote, 0, 0),
('tpc_begin', 'tpc_begin', 1, 1, 2, 2, ('tpc_begin', 'tpc_begin', 1, 1, 2, 2,
'RebaseTransaction', 'RebaseTransaction', 'StoreTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AskStoreTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
)) as end: )) as end:
delay = f.delayAskFetchTransactions() delay = f.delayAskFetchTransactions()
...@@ -2498,9 +2501,10 @@ class Test(NEOThreadedTest): ...@@ -2498,9 +2501,10 @@ class Test(NEOThreadedTest):
with self.thread_switcher((commit23,), with self.thread_switcher((commit23,),
(1, 1, 0, 0, t1_rebase, 0, 0, 0, 1, 1, 1, 1, 0), (1, 1, 0, 0, t1_rebase, 0, 0, 0, 1, 1, 1, 1, 0),
('tpc_begin', 'tpc_begin', 0, 1, 0, ('tpc_begin', 'tpc_begin', 0, 1, 0,
'RebaseTransaction', 'RebaseTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction', 'tpc_begin', 1, 'tpc_abort')) as end: 'AskStoreTransaction', 'tpc_begin', 1, 'tpc_abort',
)) as end:
self.assertRaises(POSException.ConflictError, t1.commit) self.assertRaises(POSException.ConflictError, t1.commit)
commit23.join() commit23.join()
self.assertEqual(end, {0: ['tpc_abort']}) self.assertEqual(end, {0: ['tpc_abort']})
...@@ -2587,9 +2591,9 @@ class Test(NEOThreadedTest): ...@@ -2587,9 +2591,9 @@ class Test(NEOThreadedTest):
self.thread_switcher((commit2,), self.thread_switcher((commit2,),
(1, 1, 0, 0, t1_b, t1_resolve, 0, 0, 0, 0, 1, t2_vote, t1_end), (1, 1, 0, 0, t1_b, t1_resolve, 0, 0, 0, 0, 1, t2_vote, t1_end),
('tpc_begin', 'tpc_begin', 2, 1, 2, 1, 1, ('tpc_begin', 'tpc_begin', 2, 1, 2, 1, 1,
'RebaseTransaction', 'RebaseTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction')) as end: 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
commit2.join() commit2.join()
t1.begin() t1.begin()
...@@ -2597,7 +2601,7 @@ class Test(NEOThreadedTest): ...@@ -2597,7 +2601,7 @@ class Test(NEOThreadedTest):
self.assertEqual(r['a'].value, 9) self.assertEqual(r['a'].value, 9)
self.assertEqual(r['b'].value, 6) self.assertEqual(r['b'].value, 6)
t1 = end.pop(0) t1 = end.pop(0)
self.assertEqual(t1.pop(), 'StoreTransaction') self.assertEqual(t1.pop(), 'AskStoreTransaction')
self.assertEqual(sorted(t1), [1, 2]) self.assertEqual(sorted(t1), [1, 2])
self.assertFalse(end) self.assertFalse(end)
self.assertPartitionTable(cluster, 'UU|UU') self.assertPartitionTable(cluster, 'UU|UU')
...@@ -2699,9 +2703,9 @@ class Test(NEOThreadedTest): ...@@ -2699,9 +2703,9 @@ class Test(NEOThreadedTest):
with Patch(cluster.client, _loadFromStorage=load) as p, \ with Patch(cluster.client, _loadFromStorage=load) as p, \
self.thread_switcher((commit2,), self.thread_switcher((commit2,),
(1, 0, tic1, 0, t1_resolve, 1, t2_begin, 0, 1, 1, 0), (1, 0, tic1, 0, t1_resolve, 1, t2_begin, 0, 1, 1, 0),
('tpc_begin', 'tpc_begin', 1, 1, 1, 'StoreTransaction', ('tpc_begin', 'tpc_begin', 1, 1, 1, 'AskStoreTransaction',
'tpc_begin', 'RebaseTransaction', 'RebaseTransaction', 1, 'tpc_begin', 'AskRebaseTransaction', 'AskRebaseTransaction',
'StoreTransaction')) as end: 1, 'AskStoreTransaction')) as end:
self.assertRaisesRegexp(NEOStorageError, self.assertRaisesRegexp(NEOStorageError,
'^partition 0 not fully write-locked$', '^partition 0 not fully write-locked$',
t1.commit) t1.commit)
...@@ -2754,13 +2758,14 @@ class Test(NEOThreadedTest): ...@@ -2754,13 +2758,14 @@ class Test(NEOThreadedTest):
f.remove(delayFinish) f.remove(delayFinish)
with self.thread_switcher((commit2,), with self.thread_switcher((commit2,),
(1, 0, 0, 1, t2_b, 0, t1_resolve), (1, 0, 0, 1, t2_b, 0, t1_resolve),
('tpc_begin', 'tpc_begin', 0, 2, 2, 'StoreTransaction')) as end: ('tpc_begin', 'tpc_begin', 0, 2, 2, 'AskStoreTransaction')
) as end:
t1.commit() t1.commit()
commit2.join() commit2.join()
t1.begin() t1.begin()
self.assertEqual(c1.root()['b'].value, 6) self.assertEqual(c1.root()['b'].value, 6)
self.assertPartitionTable(cluster, 'UU|UU') self.assertPartitionTable(cluster, 'UU|UU')
self.assertEqual(end, {0: [2, 2, 'StoreTransaction']}) self.assertEqual(end, {0: [2, 2, 'AskStoreTransaction']})
self.assertFalse(s1.dm.getOrphanList()) self.assertFalse(s1.dm.getOrphanList())
@with_cluster(storage_count=2, partitions=2) @with_cluster(storage_count=2, partitions=2)
...@@ -2783,19 +2788,19 @@ class Test(NEOThreadedTest): ...@@ -2783,19 +2788,19 @@ class Test(NEOThreadedTest):
yield 1 yield 1
self.tic() self.tic()
with self.thread_switcher((t,), (1, 0, 1, 0, t1_b, 0, 0, 0, 1), with self.thread_switcher((t,), (1, 0, 1, 0, t1_b, 0, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 3, 3, 1, 'RebaseTransaction', ('tpc_begin', 'tpc_begin', 1, 3, 3, 1, 'AskRebaseTransaction',
2, 'AnswerRebaseTransaction')) as end: 2, 'AnswerRebaseTransaction')) as end:
t1.commit() t1.commit()
t.join() t.join()
t2.begin() t2.begin()
self.assertEqual([6, 9, 6], [r[x].value for x in 'abc']) self.assertEqual([6, 9, 6], [r[x].value for x in 'abc'])
self.assertEqual([2, 2], map(end.pop(1).count, self.assertEqual([2, 2], map(end.pop(1).count,
['RebaseTransaction', 'AnswerRebaseTransaction'])) ['AskRebaseTransaction', 'AnswerRebaseTransaction']))
# Rarely, there's an extra deadlock for t1: # Rarely, there's an extra deadlock for t1:
# 0: ['AnswerRebaseTransaction', 'RebaseTransaction', # 0: ['AnswerRebaseTransaction', 'AskRebaseTransaction',
# 'RebaseTransaction', 'AnswerRebaseTransaction', # 'AskRebaseTransaction', 'AnswerRebaseTransaction',
# 'AnswerRebaseTransaction', 2, 3, 1, # 'AnswerRebaseTransaction', 2, 3, 1,
# 'StoreTransaction', 'VoteTransaction'] # 'AskStoreTransaction', 'VoteTransaction']
self.assertEqual(end.pop(0)[0], 'AnswerRebaseTransaction') self.assertEqual(end.pop(0)[0], 'AnswerRebaseTransaction')
self.assertFalse(end) self.assertFalse(end)
...@@ -2825,13 +2830,13 @@ class Test(NEOThreadedTest): ...@@ -2825,13 +2830,13 @@ class Test(NEOThreadedTest):
threads = map(self.newPausedThread, (t2.commit, t3.commit)) threads = map(self.newPausedThread, (t2.commit, t3.commit))
with self.thread_switcher(threads, (1, 2, 0, 1, 2, 1, 0, 2, 0, 1, 2), with self.thread_switcher(threads, (1, 2, 0, 1, 2, 1, 0, 2, 0, 1, 2),
('tpc_begin', 'tpc_begin', 'tpc_begin', 1, 2, 3, 4, 4, 4, ('tpc_begin', 'tpc_begin', 'tpc_begin', 1, 2, 3, 4, 4, 4,
'RebaseTransaction', 'StoreTransaction')) as end: 'AskRebaseTransaction', 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
for t in threads: for t in threads:
t.join() t.join()
self.assertEqual(end, { self.assertEqual(end, {
0: ['AnswerRebaseTransaction', 'StoreTransaction'], 0: ['AnswerRebaseTransaction', 'AskStoreTransaction'],
2: ['StoreTransaction']}) 2: ['AskStoreTransaction']})
@with_cluster(replicas=1) @with_cluster(replicas=1)
def testConflictAfterDeadlockWithSlowReplica1(self, cluster, def testConflictAfterDeadlockWithSlowReplica1(self, cluster,
...@@ -2874,16 +2879,16 @@ class Test(NEOThreadedTest): ...@@ -2874,16 +2879,16 @@ class Test(NEOThreadedTest):
order[-1] = t1_resolve order[-1] = t1_resolve
delay = f.delayAskStoreObject() delay = f.delayAskStoreObject()
with self.thread_switcher((t,), order, with self.thread_switcher((t,), order,
('tpc_begin', 'tpc_begin', 1, 1, 2, 2, 'RebaseTransaction', ('tpc_begin', 'tpc_begin', 1, 1, 2, 2, 'AskRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction', 'AskRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction')) as end: 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
t.join() t.join()
self.assertNotIn(delay, f) self.assertNotIn(delay, f)
t2.begin() t2.begin()
end[0].sort(key=str) end[0].sort(key=str)
self.assertEqual(end, {0: [1, 'AnswerRebaseTransaction', self.assertEqual(end, {0: [1, 'AnswerRebaseTransaction',
'StoreTransaction']}) 'AskStoreTransaction']})
self.assertEqual([4, 2], [r[x].value for x in 'ab']) self.assertEqual([4, 2], [r[x].value for x in 'ab'])
def testConflictAfterDeadlockWithSlowReplica2(self): def testConflictAfterDeadlockWithSlowReplica2(self):
...@@ -2934,7 +2939,7 @@ class Test(NEOThreadedTest): ...@@ -2934,7 +2939,7 @@ class Test(NEOThreadedTest):
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.add(lambda conn, packet:
isinstance(packet, Packets.RequestIdentification) isinstance(packet, Packets.RequestIdentification)
and packet.decode()[0] == NodeTypes.STORAGE) and packet._args[0] == NodeTypes.STORAGE)
self.tic() self.tic()
m2.start() m2.start()
self.tic() self.tic()
...@@ -2974,7 +2979,7 @@ class Test(NEOThreadedTest): ...@@ -2974,7 +2979,7 @@ class Test(NEOThreadedTest):
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.add(lambda conn, packet:
isinstance(packet, Packets.RequestIdentification) isinstance(packet, Packets.RequestIdentification)
and packet.decode()[0] == NodeTypes.MASTER) and packet._args[0] == NodeTypes.MASTER)
cluster.start(recovering=True) cluster.start(recovering=True)
neoctl = cluster.neoctl neoctl = cluster.neoctl
getClusterState = neoctl.getClusterState getClusterState = neoctl.getClusterState
......
...@@ -113,7 +113,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -113,7 +113,7 @@ class ReplicationTests(NEOThreadedTest):
importZODB(3) importZODB(3)
def delaySecondary(conn, packet): def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate): if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode() tid, upstream_name, source_dict = packet._args
return not upstream_name and all(source_dict.itervalues()) return not upstream_name and all(source_dict.itervalues())
# U -> B propagation # U -> B propagation
with NEOCluster(partitions=np, replicas=nr-1, storage_count=5, with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
...@@ -513,7 +513,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -513,7 +513,7 @@ class ReplicationTests(NEOThreadedTest):
""" """
def delayAskFetch(conn, packet): def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and \ return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \ packet._args[0] == offset and \
conn in s1.getConnectionList(s0) conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, num_replicas, cell_list): def changePartitionTable(orig, ptid, num_replicas, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
...@@ -768,7 +768,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -768,7 +768,7 @@ class ReplicationTests(NEOThreadedTest):
def logReplication(conn, packet): def logReplication(conn, packet):
if isinstance(packet, (Packets.AskFetchTransactions, if isinstance(packet, (Packets.AskFetchTransactions,
Packets.AskFetchObjects)): Packets.AskFetchObjects)):
ask.append(packet.decode()[2:]) ask.append(packet._args[2:])
def getTIDList(): def getTIDList():
return [t.tid for t in c.db().storage.iterator()] return [t.tid for t in c.db().storage.iterator()]
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
...@@ -869,7 +869,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -869,7 +869,7 @@ class ReplicationTests(NEOThreadedTest):
return True return True
elif not isinstance(packet, Packets.AskFetchTransactions): elif not isinstance(packet, Packets.AskFetchTransactions):
return return
ask.append(packet.decode()) ask.append(packet._args)
conn, = upstream.master.getConnectionList(backup.master) conn, = upstream.master.getConnectionList(backup.master)
with ConnectionFilter() as f, Patch(replicator.Replicator, with ConnectionFilter() as f, Patch(replicator.Replicator,
_nextPartitionSortKey=lambda orig, self, offset: offset): _nextPartitionSortKey=lambda orig, self, offset: offset):
...@@ -930,11 +930,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -930,11 +930,11 @@ class ReplicationTests(NEOThreadedTest):
@f.add @f.add
def delayReplicate(conn, packet): def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchTransactions): if isinstance(packet, Packets.AskFetchTransactions):
trans.append(packet.decode()[2]) trans.append(packet._args[2])
elif isinstance(packet, Packets.AskFetchObjects): elif isinstance(packet, Packets.AskFetchObjects):
if obj: if obj:
return True return True
obj.append(packet.decode()[2]) obj.append(packet._args[2])
s2.start() s2.start()
self.tic() self.tic()
cluster.neoctl.enableStorageList([s2.uuid]) cluster.neoctl.enableStorageList([s2.uuid])
...@@ -1021,7 +1021,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -1021,7 +1021,7 @@ class ReplicationTests(NEOThreadedTest):
def expected(changed): def expected(changed):
s0 = 1, CellStates.UP_TO_DATE s0 = 1, CellStates.UP_TO_DATE
s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE
return changed, 3 * [[s0, (2, s)], [s0, (3, s)]] return changed, 3 * ((s0, (2, s)), (s0, (3, s)))
for dry_run in True, False: for dry_run in True, False:
self.assertEqual(expected(True), self.assertEqual(expected(True),
cluster.neoctl.tweakPartitionTable(drop_list, dry_run)) cluster.neoctl.tweakPartitionTable(drop_list, dry_run))
......
...@@ -53,7 +53,7 @@ extras_require = { ...@@ -53,7 +53,7 @@ extras_require = {
'master': [], 'master': [],
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'], 'storage-importer': zodb_require + ['setproctitle'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)] 'neoppod[%s]' % ', '.join(extras_require)]
...@@ -109,6 +109,7 @@ setup( ...@@ -109,6 +109,7 @@ setup(
], ],
}, },
install_requires = [ install_requires = [
'msgpack>=0.5.6',
'python-dateutil', # neolog --from 'python-dateutil', # neolog --from
], ],
extras_require = extras_require, extras_require = extras_require,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment