...
 
Commits (64)
This diff is collapsed.
......@@ -14,34 +14,72 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging, protocol
from neo.lib import logging
from neo.lib.handler import EventHandler
from neo.lib.protocol import uuid_str, Packets
from neo.lib.protocol import uuid_str, \
NodeTypes, NotReadyError, Packets, ProtocolError
from neo.lib.pt import PartitionTable
from neo.lib.exception import PrimaryFailure
def check_primary_master(func):
def wrapper(self, *args, **kw):
if self.app.master_conn is not None:
return func(self, *args, **kw)
raise protocol.NotReadyError('Not connected to a primary master.')
return wrapper
def forward_ask(klass):
return check_primary_master(lambda self, conn, *args, **kw:
self.app.master_conn.ask(klass(*args, **kw),
conn=conn, msg_id=conn.getPeerId()))
NOT_CONNECTED_MESSAGE = 'Not connected to a primary master.'
def AdminEventHandlerType(name, bases, d):
def check_connection(func):
return lambda self, conn, *args, **kw: \
self._checkConnection(conn) and func(self, conn, *args, **kw)
def forward_ask(klass):
return lambda self, conn, *args: self.app.master_conn.ask(
klass(*args), conn=conn, msg_id=conn.getPeerId())
del d['__metaclass__']
for x in (
Packets.AddPendingNodes,
Packets.AskLastIDs,
Packets.AskLastTransaction,
Packets.AskRecovery,
Packets.CheckReplicas,
Packets.Repair,
Packets.SetClusterState,
Packets.SetNodeState,
Packets.SetNumReplicas,
Packets.Truncate,
Packets.TweakPartitionTable,
):
d[x.handler_method_name] = forward_ask(x)
return type(name, bases, {k: v if k[0] == '_' else check_connection(v)
for k, v in d.iteritems()})
class AdminEventHandler(EventHandler):
"""This class deals with events for administrating cluster."""
@check_primary_master
__metaclass__ = AdminEventHandlerType
def _checkConnection(self, conn):
if self.app.master_conn is None:
raise NotReadyError(NOT_CONNECTED_MESSAGE)
return True
def requestIdentification(self, conn, node_type, uuid, address, name, *_):
if node_type != NodeTypes.ADMIN:
raise ProtocolError("reject non-admin node")
app = self.app
try:
backup = app.backup_dict[name]
except KeyError:
raise ProtocolError("unknown backup cluster %r" % name)
if backup.conn is not None:
raise ProtocolError("already connected")
backup.conn = conn
conn.setHandler(app.backup_handler)
conn.answer(Packets.AcceptIdentification(
NodeTypes.ADMIN, None, None))
def askPartitionList(self, conn, min_offset, max_offset, uuid):
logging.info("ask partition list from %s to %s for %s",
min_offset, max_offset, uuid_str(uuid))
self.app.sendPartitionTable(conn, min_offset, max_offset, uuid)
@check_primary_master
def askNodeList(self, conn, node_type):
if node_type is None:
node_type = 'all'
......@@ -54,37 +92,25 @@ class AdminEventHandler(EventHandler):
p = Packets.AnswerNodeList(node_information_list)
conn.answer(p)
@check_primary_master
def askClusterState(self, conn):
conn.answer(Packets.AnswerClusterState(self.app.cluster_state))
@check_primary_master
def askPrimary(self, conn):
master_node = self.app.master_node
conn.answer(Packets.AnswerPrimary(master_node.getUUID()))
@check_primary_master
def flushLog(self, conn):
self.app.master_conn.send(Packets.FlushLog())
super(AdminEventHandler, self).flushLog(conn)
askLastIDs = forward_ask(Packets.AskLastIDs)
askLastTransaction = forward_ask(Packets.AskLastTransaction)
addPendingNodes = forward_ask(Packets.AddPendingNodes)
askRecovery = forward_ask(Packets.AskRecovery)
tweakPartitionTable = forward_ask(Packets.TweakPartitionTable)
setClusterState = forward_ask(Packets.SetClusterState)
setNodeState = forward_ask(Packets.SetNodeState)
setNumReplicas = forward_ask(Packets.SetNumReplicas)
checkReplicas = forward_ask(Packets.CheckReplicas)
truncate = forward_ask(Packets.Truncate)
repair = forward_ask(Packets.Repair)
def askMonitorInformation(self, conn):
self.app.askMonitorInformation(conn)
class MasterEventHandler(EventHandler):
""" This class is just used to dispatch message to right handler"""
def _connectionLost(self, conn):
def connectionClosed(self, conn):
app = self.app
if app.listening_conn: # if running
assert app.master_conn in (conn, None)
......@@ -93,38 +119,108 @@ class MasterEventHandler(EventHandler):
app.uuid = None
raise PrimaryFailure
def connectionFailed(self, conn):
self._connectionLost(conn)
def connectionClosed(self, conn):
self._connectionLost(conn)
def dispatch(self, conn, packet, kw={}):
if 'conn' in kw:
# expected answer
if packet.isResponse():
packet.setId(kw['msg_id'])
kw['conn'].answer(packet)
else:
self.app.request_handler.dispatch(conn, packet, kw)
else:
# unexpected answers and notifications
forward = kw.get('conn')
if forward is None:
super(MasterEventHandler, self).dispatch(conn, packet, kw)
else:
forward.send(packet, kw['msg_id'])
def answerClusterState(self, conn, state):
self.app.cluster_state = state
self.app.updateMonitorInformation(None, cluster_state=state)
notifyClusterInformation = answerClusterState
def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
pt = self.app.pt = object.__new__(PartitionTable)
pt.load(ptid, num_replicas, row_list, self.app.nm)
app = self.app
app.pt = object.__new__(PartitionTable)
app.pt.load(ptid, num_replicas, row_list, app.nm)
app.partitionTableUpdated()
def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
self.app.pt.update(ptid, num_replicas, cell_list, self.app.nm)
app = self.app
app.pt.update(ptid, num_replicas, cell_list, app.nm)
app.partitionTableUpdated()
def notifyNodeInformation(self, *args):
super(MasterEventHandler, self).notifyNodeInformation(*args)
self.app.partitionTableUpdated()
def notifyClusterInformation(self, conn, cluster_state):
self.app.cluster_state = cluster_state
def notifyUpstreamAdmin(self, conn, addr):
app = self.app
node = app.upstream_admin
if node is None:
node = app.upstream_admin = app.nm.createAdmin()
elif node.getAddress() == addr:
return
node.setAddress(addr)
if app.upstream_admin_conn:
app.upstream_admin_conn.close()
else:
app.connectToUpstreamAdmin()
def answerLastTransaction(self, conn, ltid):
app = self.app
app.ltid = ltid
app.maybeNotify(None)
def answerRecovery(self, name, ptid, backup_tid, truncate_tid):
self.app.backup_tid = backup_tid
def monitor(func):
def wrapper(self, conn, *args, **kw):
for name, backup in self.app.backup_dict.iteritems():
if backup.conn is conn:
return func(self, name, *args, **kw)
raise AssertionError
return wrapper
class BackupHandler(EventHandler):
@monitor
def connectionClosed(self, name):
app = self.app
old = app.backup_dict[name]
new = app.backup_dict[name] = old.__class__()
new.max_lag = old.max_lag
app.maybeNotify(name)
@monitor
def notifyMonitorInformation(self, name, info):
self.app.updateMonitorInformation(name, **info)
class MasterRequestEventHandler(EventHandler):
""" This class handle all answer from primary master node"""
# XXX: to be deleted ?
@monitor
def answerRecovery(self, name, ptid, backup_tid, truncate_tid):
self.app.backup_dict[name].backup_tid = backup_tid
@monitor
def answerLastTransaction(self, name, ltid):
app = self.app
app.backup_dict[name].ltid = ltid
app.maybeNotify(name)
class UpstreamAdminHandler(AdminEventHandler):
def _checkConnection(self, conn):
assert conn is self.app.upstream_admin_conn
return super(UpstreamAdminHandler, self)._checkConnection(conn)
def connectionClosed(self, conn):
app = self.app
if conn is app.upstream_admin_conn:
app.connectToUpstreamAdmin()
connectionFailed = connectionClosed
def connectionCompleted(self, conn):
super(UpstreamAdminHandler, self).connectionCompleted(conn)
conn.ask(Packets.RequestIdentification(NodeTypes.ADMIN,
None, None, self.app.name, None, {}))
def _acceptIdentification(self, node):
node.send(Packets.NotifyMonitorInformation({
'cluster_state': self.app.cluster_state,
'down': self.app.down,
'pt_summary': self.app.pt_summary,
}))
......@@ -13,6 +13,13 @@
##############################################################################
def patch():
# For msgpack & Py2/ZODB5.
try:
from zodbpickle import binary
binary._pack = bytes.__str__
except ImportError:
pass
from hashlib import md5
from ZODB.Connection import Connection
......@@ -87,4 +94,4 @@ def patch():
patch()
import app # set up signal handlers early enough to do it in the main thread
from . import app # set up signal handlers early enough to do it in the main thread
This diff is collapsed.
# -*- coding: utf-8 -*-
#
# Copyright (C) 2006-2019 Nexedi SA
#
......@@ -45,8 +46,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
# Either we're connecting or we already know the last tid
# via invalidations.
assert app.master_conn is None, app.master_conn
app._cache_lock_acquire()
try:
with app._cache_lock:
if app_last_tid < ltid:
app._cache.clear_current()
# In the past, we tried not to invalidate the
......@@ -60,9 +60,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
app._cache.clear()
# Make sure a parallel load won't refill the cache
# with garbage.
app._loading_oid = app._loading_invalidated = None
finally:
app._cache_lock_release()
app._loading.clear()
db = app.getDB()
db is None or db.invalidateCache()
app.last_tid = ltid
......@@ -70,21 +68,22 @@ class PrimaryNotificationsHandler(MTEventHandler):
def answerTransactionFinished(self, conn, _, tid, callback, cache_dict):
app = self.app
app.last_tid = tid
# Update cache
cache = app._cache
app._cache_lock_acquire()
try:
invalidate = cache.invalidate
loading_get = app._loading.get
with app._cache_lock:
for oid, data in cache_dict.iteritems():
# Update ex-latest value in cache
cache.invalidate(oid, tid)
invalidate(oid, tid)
loading = loading_get(oid)
if loading:
loading[1].append(tid)
if data is not None:
# Store in cache with no next_tid
cache.store(oid, data, tid, None)
if callback is not None:
callback(tid)
finally:
app._cache_lock_release()
app.last_tid = tid # see comment in invalidateObjects
def connectionClosed(self, conn):
app = self.app
......@@ -112,20 +111,24 @@ class PrimaryNotificationsHandler(MTEventHandler):
app = self.app
if app.ignore_invalidations:
return
app.last_tid = tid
app._cache_lock_acquire()
try:
with app._cache_lock:
invalidate = app._cache.invalidate
loading = app._loading_oid
loading_get = app._loading.get
for oid in oid_list:
invalidate(oid, tid)
if oid == loading:
app._loading_invalidated.append(tid)
loading = loading_get(oid)
if loading:
loading[1].append(tid)
db = app.getDB()
if db is not None:
db.invalidate(tid, oid_list)
finally:
app._cache_lock_release()
# ZODB<5: Update before releasing the lock so that app.load
# asks the last serial (with respect to already processed
# invalidations by Connection._setstate).
# ZODB≥5: Update after db.invalidate because the MVCC
# adapter starts at the greatest TID between
# IStorage.lastTransaction and processed invalidations.
app.last_tid = tid
def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
pt = self.app.pt = object.__new__(PartitionTable)
......@@ -146,13 +149,6 @@ class PrimaryNotificationsHandler(MTEventHandler):
if node and node.isConnected():
node.getConnection().close()
def notifyDeadlock(self, conn, ttid, locking_tid):
for txn_context in self.app.txn_contexts():
if txn_context.ttid == ttid:
txn_context.conflict_dict[None] = locking_tid
txn_context.wakeup(conn)
break
class PrimaryAnswersHandler(AnswerBaseHandler):
""" Handle that process expected packets from the primary master """
......@@ -160,8 +156,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list):
oid_list.reverse()
self.app.new_oid_list = oid_list
self.app.new_oids = iter(oid_list)
def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be"
......
......@@ -28,11 +28,24 @@ from ..transactions import Transaction
from ..exception import NEOStorageError, NEOStorageNotFoundError
from ..exception import NEOStorageReadRetry, NEOStorageDoesNotExistError
@apply
class _DeadlockPacket(object):
handler_method_name = 'notifyDeadlock'
_args = ()
getId = int
class StorageEventHandler(MTEventHandler):
def _acceptIdentification(*args):
pass
def notifyDeadlock(self, conn, ttid, oid):
for txn_context in self.app.txn_contexts():
if txn_context.ttid == ttid:
txn_context.queue.put((conn, _DeadlockPacket, {'oid': oid}))
break
class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """
......@@ -72,22 +85,28 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject
def answerRebaseTransaction(self, conn, oid_list):
def notifyDeadlock(self, conn, oid):
# To avoid a possible deadlock, storage nodes are waiting for our
# lock to be cancelled, so that a transaction that started earlier
# can complete. This is done by acquiring the lock again.
txn_context = self.app.getHandlerData()
if txn_context.stored:
return
ttid = txn_context.ttid
queue = txn_context.queue
logging.info('Deadlock avoidance triggered for %s:%s',
dump(oid), dump(ttid))
assert conn.getUUID() not in txn_context.data_dict.get(oid, ((),))[-1]
try:
for oid in oid_list:
# We could have an extra parameter to tell the storage if we
# still have the data, and in this case revert what was done
# in Transaction.written. This would save bandwidth in case of
# conflict.
conn.ask(Packets.AskRebaseObject(ttid, oid),
queue=queue, oid=oid)
# We could have an extra parameter to tell the storage if we
# still have the data, and in this case revert what was done
# in Transaction.written. This would save bandwidth in case of
# conflict.
conn.ask(Packets.AskRelockObject(ttid, oid),
queue=txn_context.queue, oid=oid)
except ConnectionClosed:
txn_context.conn_dict[conn.getUUID()] = None
def answerRebaseObject(self, conn, conflict, oid):
def answerRelockObject(self, conn, conflict, oid):
if conflict:
txn_context = self.app.getHandlerData()
serial, conflict, data = conflict
......@@ -105,7 +124,7 @@ class StorageAnswersHandler(AnswerBaseHandler):
assert oid in txn_context.data_dict
if serial <= txn_context.conflict_dict.get(oid, ''):
# Another node already reported this conflict or a newer,
# by answering to this rebase or to the previous store.
# by answering to this relock or to the previous store.
return
# A node has not answered yet to a previous store. Do not wait
# it to report the conflict because it may fail before.
......@@ -119,8 +138,8 @@ class StorageAnswersHandler(AnswerBaseHandler):
compression, checksum, data = data
if checksum != makeChecksum(data):
raise NEOStorageError(
'wrong checksum while getting back data for'
' object %s during rebase of transaction %s'
'wrong checksum while getting back data for %s:%s'
' (deadlock avoidance)'
% (dump(oid), dump(txn_context.ttid)))
data = decompress_list[compression](data)
size = len(data)
......
......@@ -22,19 +22,12 @@ from neo.lib.protocol import Packets
from neo.lib.util import dump
from .exception import NEOStorageError
@apply
class _WakeupPacket(object):
handler_method_name = 'pong'
decode = tuple
getId = int
class Transaction(object):
cache_size = 0 # size of data in cache_dict
data_size = 0 # size of data in data_dict
error = None
locking_tid = None
stored = False
voted = False
ttid = None # XXX: useless, except for testBackupReadOnlyAccess
lockless_dict = None # {partition: {uuid}}
......@@ -50,21 +43,18 @@ class Transaction(object):
self.conflict_dict = {} # {oid: serial}
# resolved conflicts
self.resolved_dict = {} # {oid: serial}
# involved storage nodes; connection is None is connection was lost
# involved storage nodes; connection is None if connection was lost
self.conn_dict = {} # {node_id: connection}
def __repr__(self):
error = self.error
return ("<%s ttid=%s locking_tid=%s voted=%u"
return ("<%s ttid=%s voted=%u"
" #queue=%s #writing=%s #written=%s%s>") % (
self.__class__.__name__,
dump(self.ttid), dump(self.locking_tid), self.voted,
dump(self.ttid), self.voted,
len(self.queue._queue), len(self.data_dict), len(self.cache_dict),
' error=%r' % error if error else '')
def wakeup(self, conn):
self.queue.put((conn, _WakeupPacket, {}))
def write(self, app, packet, object_id, **kw):
uuid_list = []
pt = app.pt
......@@ -78,14 +68,6 @@ class Transaction(object):
conn = conn_dict[uuid]
except KeyError:
conn = conn_dict[uuid] = app.getStorageConnection(node)
if self.locking_tid and 'oid' in kw:
# A deadlock happened but this node is not aware of it.
# Tell it to write-lock with the same locking tid as
# for the other nodes. The condition on kw is to
# distinguish whether we're writing an oid or
# transaction metadata.
conn.ask(Packets.AskRebaseTransaction(
self.ttid, self.locking_tid), queue=self.queue)
conn.ask(packet, queue=self.queue, **kw)
uuid_list.append(uuid)
except AttributeError:
......@@ -117,18 +99,19 @@ class Transaction(object):
return
if lockless:
if lockless != serial: # late lockless write
# Oops! We shouldn't have executed the above 'remove'. Readd.
assert lockless < serial, (lockless, serial)
uuid_list.append(uuid)
return
# It's safe to do this after the above excepts: either the cell is
# It's safe to do this after the above except: either the cell is
# already marked as lockless or the node will be reported as failed.
lockless = self.lockless_dict
if not lockless:
lockless = self.lockless_dict = defaultdict(set)
lockless[app.pt.getPartition(oid)].add(uuid)
if oid in self.conflict_dict:
# In the case of a rebase, uuid_list may not contain the id
# of the node reporting a conflict.
# In case of deadlock avoidance, uuid_list may not contain the
# id of the node reporting a conflict.
return
if uuid_list:
return
......
#
# Copyright (C) 2017-2019 Nexedi SA
# Copyright (C) 2017-2021 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
......@@ -17,7 +17,7 @@
URI format:
neo://name@master1,master2,...,masterN?options
neo(s)://[credentials@]master1,master2,...,masterN/name?options
"""
import ZODB.config
......@@ -27,6 +27,9 @@ from cStringIO import StringIO
from collections import OrderedDict
from urlparse import urlsplit, parse_qsl
# _credopts defines which options correspond to credentials
_credopts = {'ca', 'cert', 'key'}
# neo_zconf_options returns set of zconfig options supported by NEO storage
def neo_zconf_options():
neo_schema = """<schema>
......@@ -41,6 +44,8 @@ def neo_zconf_options():
options = {k for k, _ in neo_storage_zconf}
assert 'master_nodes' in options
assert 'name' in options
for opt in _credopts:
assert opt in options, opt
return options
......@@ -52,27 +57,46 @@ def canonical_opt_name(name):
def _resolve_uri(uri):
scheme, netloc, path, query, frag = urlsplit(uri)
if scheme != "neo":
raise ValueError("invalid uri: %s : expected neo:// scheme" % uri)
if path != "":
raise ValueError("invalid uri: %s : non-empty path" % uri)
if scheme not in ("neo", "neos"):
raise ValueError("invalid uri: %s : expected neo:// or neos:// scheme" % uri)
if frag != "":
raise ValueError("invalid uri: %s : non-empty fragment" % uri)
# extract master list and name from netloc
name, masterloc = netloc.split('@', 1)
# name is given as path
if path.startswith("/"):
path = path[1:]
name = path
if name == '':
raise ValueError("invalid uri: %s : cluster name not specified" % uri)
# extract master list and credentials from netloc
cred, masterloc = '', netloc
if '@' in netloc:
cred, masterloc = netloc.split('@', 1)
master_list = masterloc.split(',')
neokw = OrderedDict()
neokw['master_nodes'] = ' '.join(master_list)
neokw['name'] = name
# parse credentials
if cred:
if scheme != "neos":
raise ValueError("invalid uri: %s : credentials can be specified only with neos:// scheme" % uri)
# ca=ca.crt;cert=my.crt;key=my.key
for k, v in OrderedDict(parse_qsl(cred)).items():
if k not in _credopts:
raise ValueError("invalid uri: %s : unexpected credential %s" % (uri, k))
neokw[k] = v
# get options from query: only those that are defined by NEO schema go to
# storage - rest are returned as database options
dbkw = {}
neo_options = neo_zconf_options()
for k, v in OrderedDict(parse_qsl(query)).items():
if k in neo_options:
if k in _credopts:
raise ValueError("invalid uri: %s : option %s must be in credentials" % (uri, k))
elif k in neo_options:
neokw[k] = v
else:
# it might be option for storage, but not in canonical form e.g.
......
......@@ -13,16 +13,20 @@ import sys
def app_set():
try:
return sys.modules['neo.lib.threaded_app'].app_set
app_set = sys.modules['neo.lib.threaded_app'].app_set
except KeyError:
f = sys._getframe(4)
try:
while f.f_code.co_name != 'run' or \
f.f_locals.get('self').__class__.__name__ != 'Application':
f = f.f_back
return f.f_locals['self'],
except AttributeError:
return ()
pass
else:
if app_set:
return app_set
f = sys._getframe(4)
try:
while f.f_code.co_name != 'run' or \
f.f_locals.get('self').__class__.__name__ != 'Application':
f = f.f_back
return f.f_locals['self'],
except AttributeError:
return ()
def defer(task):
def wrapper():
......@@ -197,8 +201,7 @@ elif IF == 'trace-cache':
@defer
def profile(app):
app._cache_lock_acquire()
try:
with app._cache_lock:
cache = app._cache
if type(cache) is ClientCache:
app._cache = CacheTracer(cache, '%s-%s.neo-cache-trace' %
......@@ -206,5 +209,3 @@ elif IF == 'trace-cache':
app._cache.clear()
else:
app._cache = cache.close()
finally:
app._cache_lock_release()
......@@ -57,8 +57,9 @@ class BaseApplication(object):
_('s', 'section', default=section,
help='specify a configuration section')
_('c', 'cluster', required=True, help='the cluster name')
_('m', 'masters', default=masters, parse=util.parseMasterList,
help='master node list')
_('m', 'masters', parse=util.parseMasterList,
help='space-separated list of master node addresses',
**{'default': masters} if masters else {'type': lambda x: x or ''})
_('b', 'bind', default=bind,
parse=lambda x: util.parseNodeAddress(x, 0),
help='the local address to bind to')
......
......@@ -26,15 +26,14 @@ class BootstrapManager(EventHandler):
Manage the bootstrap stage, lookup for the primary master then connect to it
"""
def __init__(self, app, node_type, server=None, devpath=(), new_nid=()):
def __init__(self, app, node_type, server=None, **extra):
"""
Manage the bootstrap stage of a non-master node, it lookup for the
primary master node, connect to it then returns when the master node
is ready.
"""
self.server = server
self.devpath = devpath
self.new_nid = new_nid
self.extra = extra
self.node_type = node_type
app.nm.reset()
......@@ -43,7 +42,7 @@ class BootstrapManager(EventHandler):
def connectionCompleted(self, conn):
EventHandler.connectionCompleted(self, conn)
conn.ask(Packets.RequestIdentification(self.node_type, self.uuid,
self.server, self.app.name, None, self.devpath, self.new_nid))
self.server, self.app.name, None, self.extra))
def connectionFailed(self, conn):
EventHandler.connectionFailed(self, conn)
......
......@@ -18,6 +18,15 @@ import argparse, os, sys
from functools import wraps
from ConfigParser import SafeConfigParser
class _DefaultList(list):
"""
Special list type for default values of 'append' argparse actions,
so that the parser restarts from an empty list when the option is
used on the command-line.
"""
def __copy__(self):
return []
class _Required(object):
......@@ -30,6 +39,8 @@ class _Required(object):
class _Option(object):
multiple = False
def __init__(self, *args, **kw):
if len(args) > 1:
self.short, self.name = args
......@@ -51,7 +62,12 @@ class _Option(object):
action.required = _Required(option_list, self.name)
def fromConfigFile(self, cfg, section):
return self(cfg.get(section, self.name.replace('-', '_')))
value = cfg.get(section, self.name.replace('-', '_'))
if self.multiple:
return [self(value)
for value in value.splitlines()
if value]
return self(value)
@staticmethod
def parse(value):
......@@ -66,7 +82,7 @@ class BoolOption(_Option):
return value
def fromConfigFile(self, cfg, section):
return cfg.getboolean(section, self.name)
return cfg.getboolean(section, self.name.replace('-', '_'))
class Option(_Option):
......@@ -81,6 +97,11 @@ class Option(_Option):
kw[x] = getattr(self, x)
except AttributeError:
pass
if self.multiple:
kw['action'] = 'append'
default = kw.get('default')
if default:
kw['default'] = _DefaultList(default)
return kw
@staticmethod
......@@ -132,9 +153,6 @@ class OptionGroup(object):
class Argument(Option):
def __init__(self, name, **kw):
super(Argument, self).__init__(name, **kw)
def _asArgparse(self, parser, option_list):
kw = {'help': self.help, 'type': self}
for x in 'default', 'metavar', 'nargs', 'choices':
......
......@@ -16,12 +16,30 @@
from functools import wraps
from time import time
import msgpack
from msgpack.exceptions import OutOfData, UnpackValueError
from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets
from .util import dummy_read_buffer, ReadBuffer
from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \
Unpacker
try:
msgpack.Unpacker().read_bytes(1)
except OutOfData: # fallback implementation
# monkey-patch for upstream issues 259 & 352
def read_bytes(self, n):
ret = self._read(min(n, len(self._buffer) - self._buff_i))
self._consume()
return ret
msgpack.Unpacker.read_bytes = read_bytes
del read_bytes
@apply
class dummy_read_buffer(msgpack.Unpacker):
def feed(self, _):
pass
class ConnectionClosed(Exception):
pass
......@@ -291,7 +309,7 @@ class ListeningConnection(BaseConnection):
# message.
else:
conn._connected()
self.em.addWriter(conn) # for ENCODED_VERSION
self.em.addWriter(conn) # for HANDSHAKE_PACKET
def getAddress(self):
return self.connector.getAddress()
......@@ -310,12 +328,12 @@ class Connection(BaseConnection):
client = False
server = False
peer_id = None
_parser_state = None
_total_unpacked = 0
_timeout = None
def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer()
self.read_buf = Unpacker()
self.cur_id = 0
self.aborted = False
self.uuid = None
......@@ -425,42 +443,39 @@ class Connection(BaseConnection):
self._closure()
def _parse(self):
read = self.read_buf.read
version = read(4)
if version is None:
return
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE,
PACKET_HEADER_FORMAT, Packets)
if version != ENCODED_VERSION:
logging.warning('Protocol version mismatch with %r', self)
from .protocol import HANDSHAKE_PACKET, MAGIC_SIZE, Packets
read_buf = self.read_buf
handshake = read_buf.read_bytes(len(HANDSHAKE_PACKET))
if handshake != HANDSHAKE_PACKET:
if HANDSHAKE_PACKET.startswith(handshake): # unlikely so tested last
# Not enough data and there's no API to know it in advance.
# Put it back.
read_buf.feed(handshake)
return
if HANDSHAKE_PACKET.startswith(handshake[:MAGIC_SIZE]):
logging.warning('Protocol version mismatch with %r', self)
else:
logging.debug('Rejecting non-NEO %r', self)
raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size
unpack = PACKET_HEADER_FORMAT.unpack
read_next = read_buf.next
read_pos = read_buf.tell
def parse():
state = self._parser_state
if state is None:
header = read(header_size)
if header is None:
return
msg_id, msg_type, msg_len = unpack(header)
try:
packet_klass = Packets[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
else:
msg_id, packet_klass, msg_len = state
data = read(msg_len)
if data is None:
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet
try:
msg_id, msg_type, args = read_next()
except StopIteration:
return
except UnpackValueError as e:
raise PacketMalformedError(str(e))
try:
packet_klass = Packets[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
pos = read_pos()
packet = packet_klass(*args)
packet.setId(msg_id)
packet.size = pos - self._total_unpacked
self._total_unpacked = pos
return packet
self._parse = parse
return parse()
......@@ -513,7 +528,7 @@ class Connection(BaseConnection):
def close(self):
if self.connector is None:
assert self._on_close is None
assert not self.read_buf
assert not self.read_buf.read_bytes(1)
assert not self.isPending()
return
# process the network events with the last registered handler to
......@@ -524,7 +539,7 @@ class Connection(BaseConnection):
if self._on_close is not None:
self._on_close()
self._on_close = None
self.read_buf.clear()
self.read_buf = dummy_read_buffer
try:
if self.connecting:
handler.connectionFailed(self)
......
......@@ -19,7 +19,7 @@ import ssl
import errno
from time import time
from . import logging
from .protocol import ENCODED_VERSION
from .protocol import HANDSHAKE_PACKET
# Global connector registry.
# Fill by calling registerConnectorHandler.
......@@ -74,15 +74,14 @@ class SocketConnector(object):
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION]
self.queue_size = len(ENCODED_VERSION)
self.queued = [HANDSHAKE_PACKET]
self.queue_size = len(HANDSHAKE_PACKET)
return self
def queue(self, data):
was_empty = not self.queued
self.queued += data
for data in data:
self.queue_size += len(data)
self.queued.append(data)
self.queue_size += len(data)
return was_empty
def _error(self, op, exc=None):
......@@ -136,8 +135,7 @@ class SocketConnector(object):
def ssl(self, ssl, on_handshake_done=None):
self.socket = ssl.wrap_socket(self.socket,
server_side=self.is_server,
do_handshake_on_connect=False,
suppress_ragged_eofs=False)
do_handshake_on_connect=False)
self.__class__ = self.SSLHandshakeConnectorClass
self.on_handshake_done = on_handshake_done
self.queued or self.queued.append('')
......@@ -172,7 +170,7 @@ class SocketConnector(object):
except socket.error, e:
self._error('recv', e)
if data:
read_buf.append(data)
read_buf.feed(data)
return
self._error('recv')
......@@ -278,7 +276,11 @@ class _SSL:
def receive(self, read_buf):
try:
while 1:
read_buf.append(self.socket.recv(4096))
data = self.socket.recv(4096)
if not data:
self._error('recv', None)
return
read_buf.feed(data)
except ssl.SSLWantReadError:
pass
except socket.error, e:
......
......@@ -23,7 +23,7 @@ NOBODY = []
class _ConnectionClosed(object):
handler_method_name = 'connectionClosed'
decode = tuple
_args = ()
class getId(object):
def __eq__(self, other):
......
......@@ -71,7 +71,7 @@ class EventHandler(object):
method = getattr(self, packet.handler_method_name)
except AttributeError:
raise UnexpectedPacketError('no handler found')
args = packet.decode() or ()
args = packet._args
method(conn, *args, **kw)
except DelayEvent, e:
assert not kw, kw
......@@ -79,9 +79,6 @@ class EventHandler(object):
except UnexpectedPacketError, e:
if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e)
conn.close()
except NotReadyError, message:
if not conn.isClosed():
if not message.args:
......@@ -325,8 +322,8 @@ class EventQueue(object):
# Stable sort when 2 keys are equal.
# XXX: Is it really useful to keep events with same key ordered
# chronologically ? The caller could use more specific keys. For
# write-locks (by the storage node), the locking tid seems enough.
# chronologically ? The caller could use more specific keys.
# For write-locks (by the storage node), the ttid seems enough.
sortQueuedEvents = (lambda key=itemgetter(0): lambda self:
self._event_queue.sort(key=key))()
......@@ -337,14 +334,6 @@ class EventQueue(object):
if key is not None:
self.sortQueuedEvents()
def sortAndExecuteQueuedEvents(self):
if self._executing_event < 0:
self.sortQueuedEvents()
self.executeQueuedEvents()
else:
# We can't sort events when they're being processed.
self._executing_event = 1
def executeQueuedEvents(self):
# Not reentrant. When processing a queued event, calling this method
# only tells the caller to retry all events from the beginning, because
......@@ -372,10 +361,6 @@ class EventQueue(object):
while done:
del queue[done.pop()]
self._executing_event = 0
# What sortAndExecuteQueuedEvents could not do immediately
# is done here:
if event[0] is not None:
self.sortQueuedEvents()
self._executing_event = -1
def logQueuedEvents(self):
......
......@@ -74,7 +74,7 @@ def implements(obj, ignore=()):
assert not wrong_signature, wrong_signature
return obj
def _set_code(func):
def _stub(func):
args, varargs, varkw, _ = inspect.getargspec(func)
if varargs:
args.append("*" + varargs)
......@@ -82,16 +82,25 @@ def _set_code(func):
args.append("**" + varkw)
exec "def %s(%s): raise NotImplementedError\nf = %s" % (
func.__name__, ",".join(args), func.__name__)
func.func_code = f.func_code
return f
def abstract(func):
_set_code(func)
func.__abstract__ = 1
return func
f = _stub(func)
f.__abstract__ = 1
f.__defaults__ = func.__defaults__
f.__doc__ = func.__doc__
return f
def requires(*args):
for func in args:
_set_code(func)
# Tolerate useless abstract decoration on required method (e.g. it
# simplifies the implementation of a fallback decorator), but remove
# marker since it does not need to be implemented if it's required
# by a method that is overridden.
try:
del func.__abstract__
except AttributeError:
func.__code__ = _stub(func).__code__
def decorator(func):
func.__requires__ = args
return func
......
......@@ -152,7 +152,8 @@ class NEOLogger(Logger):
def _setup(self, filename=None, reset=False):
from . import protocol as p
global uuid_str
global packb, uuid_str
packb = p.packb
uuid_str = p.uuid_str
if self._db is not None:
self._db.close()
......@@ -166,7 +167,11 @@ class NEOLogger(Logger):
q("PRAGMA synchronous = OFF")
if 1: # Not only when logging everything,
# but also for interoperability with logrotate.
q("PRAGMA journal_mode = MEMORY")
# Close the returned cursor to work around the following
# OperationalError with PyPy:
# cannot commit transaction - SQL statements in progress
# XXX: Is it enough?
q("PRAGMA journal_mode = MEMORY").close()
for t, columns in (('log', (
"level INTEGER NOT NULL",
"pathname TEXT",
......@@ -250,7 +255,7 @@ class NEOLogger(Logger):
'>' if r.outgoing else '<', uuid_str(r.uuid), ip, port)
msg = r.msg
if msg is not None:
msg = buffer(msg)
msg = buffer(msg if type(msg) is bytes else packb(msg))
q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
x = [r.created, nid, r.msg_id, r.code, peer, msg]
else:
......@@ -299,9 +304,14 @@ class NEOLogger(Logger):
def packet(self, connection, packet, outgoing):
if self._db is not None:
body = packet._body
if self._max_packet and self._max_packet < len(body):
body = None
if self._max_packet and self._max_packet < packet.size:
args = None
else:
args = packet._args
try:
hash(args)
except TypeError:
args = packb(args)
self._queue(PacketRecord(
created=time(),
msg_id=packet._id,
......@@ -309,7 +319,7 @@ class NEOLogger(Logger):
outgoing=outgoing,
uuid=connection.getUUID(),
addr=connection.getAddress(),
msg=body))
msg=args))
def node(self, *cluster_nid):
name = self.name and str(self.name)
......
......@@ -28,7 +28,7 @@ class Node(object):
_connection = None
_identified = False
devpath = ()
extra = {}
id_timestamp = None
def __init__(self, manager, address=None, uuid=None, state=NodeStates.DOWN):
......@@ -393,6 +393,12 @@ class NodeManager(EventQueue):
raise NotReadyError('unknown by master')
return node
def createMasters(self, master_nodes):
for address in master_nodes:
self.createMaster(address=address)
if not self.getMasterList():
raise ValueError("At least one master must be defined")
def _createNode(self, klass, address=None, uuid=None, **kw):
by_address = self.getByAddress(address)
by_uuid = self.getByUUID(uuid)
......
This diff is collapsed.
......@@ -53,10 +53,7 @@ class ThreadedApplication(BaseApplication):
self.name = name
self.dispatcher = Dispatcher()
self.master_conn = None
# load master node list
for address in master_nodes:
self.nm.createMaster(address=address)
self.nm.createMasters(master_nodes)
# Internal attribute distinct between thread
self._thread_container = ThreadContainer()
......
......@@ -39,7 +39,8 @@ nextafter()
TID_LOW_OVERFLOW = 2**32
TID_LOW_MAX = TID_LOW_OVERFLOW - 1
SECOND_PER_TID_LOW = 60.0 / TID_LOW_OVERFLOW
SECOND_FROM_UINT32 = 60. / TID_LOW_OVERFLOW
MICRO_FROM_UINT32 = 1e6 / TID_LOW_OVERFLOW
TID_CHUNK_RULES = (
(-1900, 0),
(-1, 12),
......@@ -52,7 +53,7 @@ def tidFromTime(tm):
gmt = gmtime(tm)
return packTID(
(gmt.tm_year, gmt.tm_mon, gmt.tm_mday, gmt.tm_hour, gmt.tm_min),
int((gmt.tm_sec + (tm - int(tm))) / SECOND_PER_TID_LOW))
int((gmt.tm_sec + (tm - int(tm))) / SECOND_FROM_UINT32))
def packTID(higher, lower):
"""
......@@ -95,15 +96,10 @@ def unpackTID(ptid):
higher.reverse()
return (tuple(higher), lower)
def timeStringFromTID(ptid):
"""
Return a string in the format "yyyy-mm-dd hh:mm:ss.ssssss" from a TID
"""
higher, lower = unpackTID(ptid)
seconds = lower * SECOND_PER_TID_LOW
return '%04d-%02d-%02d %02d:%02d:%09.6f' % (higher[0], higher[1], higher[2],
higher[3], higher[4], seconds)
def datetimeFromTID(tid):
higher, lower = unpackTID(tid)
seconds, lower = divmod(lower * 60, TID_LOW_OVERFLOW)
return datetime(*(higher + (seconds, int(lower * MICRO_FROM_UINT32))))
def addTID(ptid, offset):
"""
......@@ -162,69 +158,9 @@ def parseNodeAddress(address, port_opt=None):
return socket.