Commit 03077c10 authored by Kirill Smelkov's avatar Kirill Smelkov

Merge remote-tracking branch 'origin/master' into t

* origin/master: (23 commits)
  mysql: more index hints
  Release version 1.8
  README: update URLs
  README: update wrt added support for RocksDB and recent ZODB
  storage: update DatabaseManager.getLastTID docstring
  neolog: new --decompress option
  doc: update TODO about missing invalidations in read-only mode
  mysql: remove obsolete comment about broken PARTITIONing support
  qa: make ClusterPdb compatible with the simple pdb of neo.tests
  client: fix NameError when a secondary master reports that it's not the primary
  storage: new --disable-drop-partitions option
  qa: add testDropPartitions
  Better use of __import__
  qa: update list of excluded tests in testSSL
  master: improve algorithm to tweak the partition table
  storage: ignore unassigned partitions when looking for last oids/tids
  neolog: new option to hide the node column
  Remove packet timeouts
  Use TCP keepalives instead of applicative pings
  Remove unused 'on_timeout' feature on connections
  ...
parents c4d3957f 0868de70
Change History Change History
============== ==============
1.8 (2017-07-04)
----------------
This release mainly stabilizes NEO when it is used with several storage nodes,
fixing many race conditions involving events like transactional operations
(read/write, conflict resolution...), replication, partition table tweaking,
and all kinds of failures (node crashes, network cuts...). This includes a
rework of conflict resolution, to implement the long-awaited deadlock avoidance
(it was a limitation caused by object-level locking).
Similarly, having spare master nodes is not an experimental feature anymore:
the `election` (of the primary master) has been reimplemented, and it now
happens during the RECOVERING phase. This comes with a change about node
states: BROKEN/HIDDEN/UNKNOWN are removed, DOWN is renamed into UNKNOWN,
and TEMPORARILY_DOWN into DOWN.
And still for more resiliency, the new algorithm to tweak the partition table
is better at minimizing the amount of replication, and it does not discard
readable cells too quickly anymore: a partition can now have multiple FEEDING
cells, to avoid going below the wanted level of replication.
Other changes:
- General:
- Packet timeouts have been removed.
TCP keepalives are used instead of applicative pings.
- Connection handshake between nodes is reviewed to make sure that they
speak the same protocol before doing anything else, and report clearer
error messages otherwise. A dangerous bug was that there was no protocol
version check between neoctl and the admin node.
- Proper handling of incoming packets for closed/aborted connections.
- An exception while processing an answer could leave the handler switcher
in the bad state.
- In STOPPING cluster state, really wait for all transaction to be finished.
- Several issues when undoing transactions with conflict resolutions
have been fixed.
- Delayed connection acceptation when the storage node is ready.
- Client:
- Added support for `zodburi`_.
- Fix load error during conflict resolution in case of late invalidation.
- Do not wait tpc_vote to start resolving conflicts.
- Fix harmless 'unexpected ... AnswerRequestIdentification' exceptions.
- Storage:
- New --disable-drop-partitions option, which is useful for big databases
because the current code to delete data of discarded cells is inefficient
(this option should disappear in the future).
- Prevent 2 nodes from working with the same database.
- Discard answers from aborted replications.
In some cases, this led to data corruption or crashes.
- MySQL backend:
- Added support for RocksDB.
- Do not flood logs when retrying to connect non-stop.
- Do not retry a failing query forever.
- By default, do not retry to connect to the server automatically.
- Tools:
- neolog: new --decompress option.
- neolog: new option to hide the node column.
- neoctl: make the identification of the primary master easier with
'print node'.
- A lot of improvements for developers and debugging.
.. _zodburi: https://docs.pylonsproject.org/projects/zodburi
1.7.1 (2017-01-18) 1.7.1 (2017-01-18)
------------------ ------------------
......
...@@ -16,7 +16,7 @@ A NEO cluster is composed of the following types of nodes: ...@@ -16,7 +16,7 @@ A NEO cluster is composed of the following types of nodes:
Stores data, preserving history. All available storage nodes are in use Stores data, preserving history. All available storage nodes are in use
simultaneously. This offers redundancy and data distribution. simultaneously. This offers redundancy and data distribution.
Available backends: MySQL (InnoDB or TokuDB), SQLite Available backends: MySQL (InnoDB, RocksDB or TokuDB), SQLite
- "admin" nodes (mandatory for startup, optional after) - "admin" nodes (mandatory for startup, optional after)
...@@ -38,8 +38,8 @@ Any ZODB like FileStorage can be converted to NEO instantaneously, ...@@ -38,8 +38,8 @@ Any ZODB like FileStorage can be converted to NEO instantaneously,
which means the database is operational before all data are imported. which means the database is operational before all data are imported.
There's also a tool to convert back to FileStorage. There's also a tool to convert back to FileStorage.
See also http://www.neoppod.org/links for more detailed information about For more detailed information about features related to scalability,
features related to scalability. see the `Architecture and Characteristics` section of https://neo.nexedi.com/.
Requirements Requirements
============ ============
...@@ -52,7 +52,7 @@ Requirements ...@@ -52,7 +52,7 @@ Requirements
- MySQLdb: https://github.com/PyMySQL/mysqlclient-python - MySQLdb: https://github.com/PyMySQL/mysqlclient-python
- For client nodes: ZODB 3.10.x - For client nodes: ZODB 3.10.x or later
Installation Installation
============ ============
...@@ -199,7 +199,7 @@ Developers ...@@ -199,7 +199,7 @@ Developers
========== ==========
Developers interested in NEO may refer to Developers interested in NEO may refer to
`NEO Web site <http://www.neoppod.org/>`_ and subscribe to following mailing `NEO Web site <https://neo.nexedi.com/>`_ and subscribe to following mailing
lists: lists:
- `neo-users <http://mail.tiolive.com/mailman/listinfo/neo-users>`_: - `neo-users <http://mail.tiolive.com/mailman/listinfo/neo-users>`_:
...@@ -213,4 +213,4 @@ https://www.erp5.com/quality/integration/P-ERP5.Com.Unit%20Tests/Base_viewListMo ...@@ -213,4 +213,4 @@ https://www.erp5.com/quality/integration/P-ERP5.Com.Unit%20Tests/Base_viewListMo
Commercial Support Commercial Support
================== ==================
Nexedi provides commercial support for NEO: http://www.nexedi.com/ Nexedi provides commercial support for NEO: https://www.nexedi.com/
...@@ -84,6 +84,8 @@ ...@@ -84,6 +84,8 @@
keys (trans.tid & obj.{tid,oid}). keys (trans.tid & obj.{tid,oid}).
Master Master
- Implement back-channel for invalidations in read-only mode,
so that clients of backup clusters are notified of new data.
- Master node data redundancy (HIGH AVAILABILITY) - Master node data redundancy (HIGH AVAILABILITY)
Secondary master nodes should replicate primary master data (ie, primary Secondary master nodes should replicate primary master data (ie, primary
master should inform them of such changes). master should inform them of such changes).
......
...@@ -40,7 +40,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -40,7 +40,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
try: try:
super(PrimaryNotificationsHandler, self).notPrimaryMaster(*args) super(PrimaryNotificationsHandler, self).notPrimaryMaster(*args)
except PrimaryElected, e: except PrimaryElected, e:
app.primary_master_node, = e.args self.app.primary_master_node, = e.args
def _acceptIdentification(self, node, num_partitions, num_replicas): def _acceptIdentification(self, node, num_partitions, num_replicas):
self.app.pt = PartitionTable(num_partitions, num_replicas) self.app.pt = PartitionTable(num_partitions, num_replicas)
......
...@@ -44,7 +44,6 @@ class ConnectionPool(object): ...@@ -44,7 +44,6 @@ class ConnectionPool(object):
app = self.app app = self.app
if app.master_conn is None: if app.master_conn is None:
raise NEOPrimaryMasterLost raise NEOPrimaryMasterLost
logging.debug('trying to connect to %s - %s', node, node.getState())
conn = MTClientConnection(app, app.storage_event_handler, node, conn = MTClientConnection(app, app.storage_event_handler, node,
dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
......
...@@ -101,7 +101,7 @@ if IF == 'pdb': ...@@ -101,7 +101,7 @@ if IF == 'pdb':
def __init__(self, bp_list): def __init__(self, bp_list):
self._lock = threading.Lock() self._lock = threading.Lock()
for o, name in bp_list: for o, name in bp_list:
o = __import__(o, fromlist=1) o = __import__(o, fromlist=('*',), level=0)
x = name.split('.') x = name.split('.')
name = x.pop() name = x.pop()
for x in x: for x in x:
......
...@@ -97,6 +97,9 @@ class ConfigurationManager(object): ...@@ -97,6 +97,9 @@ class ConfigurationManager(object):
bind = self.__get('bind') bind = self.__get('bind')
return parseNodeAddress(bind, 0) return parseNodeAddress(bind, 0)
def getDisableDropPartitions(self):
return self.__get('disable_drop_partitions', True)
def getDatabase(self): def getDatabase(self):
return self.__get('database') return self.__get('database')
......
This diff is collapsed.
...@@ -57,6 +57,18 @@ class SocketConnector(object): ...@@ -57,6 +57,18 @@ class SocketConnector(object):
self.socket_fd = s.fileno() self.socket_fd = s.fileno()
# always use non-blocking sockets # always use non-blocking sockets
s.setblocking(0) s.setblocking(0)
# TCP keepalive, enabled on both sides to detect:
# - remote host crash
# - network failure
# They're more efficient than applicative pings and we don't want
# to consider the connection dead if the remote node is busy.
# The following 3 lines are specific to Linux. It seems that OSX
# has similar options (TCP_KEEPALIVE/TCP_KEEPINTVL/TCP_KEEPCNT),
# and Windows has SIO_KEEPALIVE_VALS (fixed count of 10).
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION] self.queued = [ENCODED_VERSION]
......
...@@ -194,8 +194,6 @@ class EventHandler(object): ...@@ -194,8 +194,6 @@ class EventHandler(object):
conn.answer(Packets.Pong()) conn.answer(Packets.Pong())
def pong(self, conn): def pong(self, conn):
# Ignore PONG packets. The only purpose of ping/pong packets is
# to test/maintain underlying connection.
pass pass
def closeClient(self, conn): def closeClient(self, conn):
......
...@@ -174,8 +174,9 @@ class AdministrationHandler(MasterHandler): ...@@ -174,8 +174,9 @@ class AdministrationHandler(MasterHandler):
ClusterStates.BACKINGUP): ClusterStates.BACKINGUP):
raise ProtocolError('Can not tweak partition table in %s state' raise ProtocolError('Can not tweak partition table in %s state'
% state) % state)
app.broadcastPartitionChanges(app.pt.tweak( app.broadcastPartitionChanges(app.pt.tweak([node
map(app.nm.getByUUID, uuid_list))) for node in app.nm.getStorageList()
if node.getUUID() in uuid_list or not node.isRunning()]))
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
def truncate(self, conn, tid): def truncate(self, conn, tid):
......
...@@ -69,7 +69,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -69,7 +69,7 @@ class ClientServiceHandler(MasterHandler):
if tid: if tid:
p = Packets.AskLockInformation(ttid, tid) p = Packets.AskLockInformation(ttid, tid)
for node in node_list: for node in node_list:
node.ask(p, timeout=60) # NOTE node.ask(p)
# NOTE continues in onTransactionCommitted # NOTE continues in onTransactionCommitted
......
This diff is collapsed.
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
import bz2, gzip, errno, optparse, os, signal, sqlite3, sys, time import bz2, gzip, errno, optparse, os, signal, sqlite3, sys, time
from bisect import insort from bisect import insort
from logging import getLevelName from logging import getLevelName
from zlib import decompress
comp_dict = dict(bz2=bz2.BZ2File, gz=gzip.GzipFile) comp_dict = dict(bz2=bz2.BZ2File, gz=gzip.GzipFile)
...@@ -28,11 +29,12 @@ class Log(object): ...@@ -28,11 +29,12 @@ class Log(object):
_log_id = _packet_id = -1 _log_id = _packet_id = -1
_protocol_date = None _protocol_date = None
def __init__(self, db_path, decode_all=False, date_format=None, def __init__(self, db_path, decode=0, date_format=None,
filter_from=None, node_list=None): filter_from=None, node_column=True, node_list=None):
self._date_format = '%F %T' if date_format is None else date_format self._date_format = '%F %T' if date_format is None else date_format
self._decode_all = decode_all self._decode = decode
self._filter_from = filter_from self._filter_from = filter_from
self._node_column = node_column
self._node_list = node_list self._node_list = node_list
name = os.path.basename(db_path) name = os.path.basename(db_path)
try: try:
...@@ -93,6 +95,30 @@ class Log(object): ...@@ -93,6 +95,30 @@ class Log(object):
exec bz2.decompress(text) in g exec bz2.decompress(text) in g
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {}
if self._decode > 1:
PStruct = g['PStruct']
PBoolean = g['PBoolean']
def hasData(item):
items = item._items
for i, item in enumerate(items):
if isinstance(item, PStruct):
j = hasData(item)
if j:
return (i,) + j
elif (isinstance(item, PBoolean)
and item._name == 'compression'
and i + 2 < len(items)
and items[i+2]._name == 'data'):
return i,
for p in self.Packets.itervalues():
if p._fmt is not None:
path = hasData(p._fmt)
if path:
assert not hasattr(p, '_neolog'), p
x[p._code] = path
self._getDataPath = x.get
try: try:
self._next_protocol, = q("SELECT date FROM protocol WHERE date>?", self._next_protocol, = q("SELECT date FROM protocol WHERE date>?",
(date,)).next() (date,)).next()
...@@ -109,7 +135,8 @@ class Log(object): ...@@ -109,7 +135,8 @@ class Log(object):
d = int(date) d = int(date)
prefix = '%s.%04u ' % (time.strftime(prefix, time.localtime(d)), prefix = '%s.%04u ' % (time.strftime(prefix, time.localtime(d)),
int((date - d) * 10000)) int((date - d) * 10000))
prefix += '%-9s %-10s ' % (levelname, name) prefix += ('%-9s %-10s ' % (levelname, name) if self._node_column else
'%-9s ' % levelname)
for msg in msg_list: for msg in msg_list:
print prefix + msg print prefix + msg
...@@ -126,7 +153,7 @@ class Log(object): ...@@ -126,7 +153,7 @@ class Log(object):
msg = ['#0x%04x %-30s %s' % (msg_id, msg, peer)] msg = ['#0x%04x %-30s %s' % (msg_id, msg, peer)]
if body is not None: if body is not None:
log = getattr(p, '_neolog', None) log = getattr(p, '_neolog', None)
if log or self._decode_all: if log or self._decode:
p = p() p = p()
p._id = msg_id p._id = msg_id
p._body = body p._body = body
...@@ -138,10 +165,28 @@ class Log(object): ...@@ -138,10 +165,28 @@ class Log(object):
if log: if log:
args, extra = log(*args) args, extra = log(*args)
msg += extra msg += extra
if args and self._decode_all: else:
path = self._getDataPath(code)
if path:
args = self._decompress(args, path)
if args and self._decode:
msg[0] += ' \t| ' + repr(args) msg[0] += ' \t| ' + repr(args)
return date, name, 'PACKET', msg return date, name, 'PACKET', msg
def _decompress(self, args, path):
if args:
args = list(args)
i = path[0]
path = path[1:]
if path:
args[i] = self._decompress(args[i], path)
else:
data = args[i+2]
if args[i]:
data = decompress(data)
args[i:i+3] = (len(data), data),
return tuple(args)
def emit_many(log_list): def emit_many(log_list):
log_list = [(log, iter(log).next) for log in log_list] log_list = [(log, iter(log).next) for log in log_list]
...@@ -179,7 +224,9 @@ def emit_many(log_list): ...@@ -179,7 +224,9 @@ def emit_many(log_list):
def main(): def main():
parser = optparse.OptionParser() parser = optparse.OptionParser()
parser.add_option('-a', '--all', action="store_true", parser.add_option('-a', '--all', action="store_true",
help='decode all packets') help='decode body of packets')
parser.add_option('-A', '--decompress', action="store_true",
help='decompress data when decode body of packets (implies --all)')
parser.add_option('-d', '--date', metavar='FORMAT', parser.add_option('-d', '--date', metavar='FORMAT',
help='custom date format, according to strftime(3)') help='custom date format, according to strftime(3)')
parser.add_option('-f', '--follow', action="store_true", parser.add_option('-f', '--follow', action="store_true",
...@@ -189,7 +236,8 @@ def main(): ...@@ -189,7 +236,8 @@ def main():
' seconds (see -s)', metavar='PID') ' seconds (see -s)', metavar='PID')
parser.add_option('-n', '--node', action="append", parser.add_option('-n', '--node', action="append",
help='only show log entries from the given node' help='only show log entries from the given node'
' (only useful for logs produced by threaded tests)') ' (only useful for logs produced by threaded tests),'
" special value '-' hides the column")
parser.add_option('-s', '--sleep-interval', type="float", default=1, parser.add_option('-s', '--sleep-interval', type="float", default=1,
help='with -f, sleep for approximately N seconds (default 1.0)' help='with -f, sleep for approximately N seconds (default 1.0)'
' between iterations', metavar='N') ' between iterations', metavar='N')
...@@ -204,8 +252,15 @@ def main(): ...@@ -204,8 +252,15 @@ def main():
filter_from = options.filter_from filter_from = options.filter_from
if filter_from and filter_from < 0: if filter_from and filter_from < 0:
filter_from += time.time() filter_from += time.time()
log_list = [Log(db_path, options.all, options.date, filter_from, node_list = options.node or []
options.node) try:
node_list.remove('-')
node_column = False
except ValueError:
node_column = True
log_list = [Log(db_path,
2 if options.decompress else 1 if options.all else 0,
options.date, filter_from, node_column, node_list)
for db_path in args] for db_path in args]
if options.follow: if options.follow:
try: try:
......
...@@ -30,6 +30,11 @@ parser.add_option('-d', '--database', help = 'database connections string') ...@@ -30,6 +30,11 @@ parser.add_option('-d', '--database', help = 'database connections string')
parser.add_option('-e', '--engine', help = 'database engine') parser.add_option('-e', '--engine', help = 'database engine')
parser.add_option('-w', '--wait', help='seconds to wait for backend to be ' parser.add_option('-w', '--wait', help='seconds to wait for backend to be '
'available, before erroring-out (-1 = infinite)', type='float', default=0) 'available, before erroring-out (-1 = infinite)', type='float', default=0)
parser.add_option('--disable-drop-partitions', action='store_true',
help = 'do not delete data of discarded cells, which is'
' useful for big databases because the current'
' implementation is inefficient (this option should'
' disappear in the future)')
parser.add_option('--reset', action='store_true', parser.add_option('--reset', action='store_true',
help='remove an existing database if any, and exit') help='remove an existing database if any, and exit')
......
...@@ -42,7 +42,6 @@ from neo.tests.benchmark import BenchmarkRunner ...@@ -42,7 +42,6 @@ from neo.tests.benchmark import BenchmarkRunner
# each of them have to import its TestCase classes # each of them have to import its TestCase classes
UNIT_TEST_MODULES = [ UNIT_TEST_MODULES = [
# generic parts # generic parts
'neo.tests.testConnection',
'neo.tests.testHandler', 'neo.tests.testHandler',
'neo.tests.testNodes', 'neo.tests.testNodes',
'neo.tests.testUtil', 'neo.tests.testUtil',
...@@ -174,7 +173,7 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -174,7 +173,7 @@ class NeoTestRunner(unittest.TextTestResult):
exclude != fnmatchcase(test_module, only)): exclude != fnmatchcase(test_module, only)):
continue continue
try: try:
test_module = __import__(test_module, globals(), locals(), ['*']) test_module = __import__(test_module, fromlist=('*',), level=0)
except ImportError, err: except ImportError, err:
self.failedImports[test_module] = err self.failedImports[test_module] = err
print "Import of %s failed : %s" % (test_module, err) print "Import of %s failed : %s" % (test_module, err)
......
...@@ -48,6 +48,7 @@ class Application(BaseApplication): ...@@ -48,6 +48,7 @@ class Application(BaseApplication):
self.dm = buildDatabaseManager(config.getAdapter(), self.dm = buildDatabaseManager(config.getAdapter(),
(config.getDatabase(), config.getEngine(), config.getWait()), (config.getDatabase(), config.getEngine(), config.getWait()),
) )
self.disable_drop_partitions = config.getDisableDropPartitions()
# load master nodes # load master nodes
for master_address in config.getMasters(): for master_address in config.getMasters():
......
...@@ -29,8 +29,7 @@ def getAdapterKlass(name): ...@@ -29,8 +29,7 @@ def getAdapterKlass(name):
module, name = DATABASE_MANAGER_DICT[name or 'MySQL'].split('.') module, name = DATABASE_MANAGER_DICT[name or 'MySQL'].split('.')
except KeyError: except KeyError:
raise DatabaseFailure('Cannot find a database adapter <%s>' % name) raise DatabaseFailure('Cannot find a database adapter <%s>' % name)
module = getattr(__import__(__name__, fromlist=[module], level=1), module) return getattr(__import__(module, globals(), level=1), name)
return getattr(module, name)
def buildDatabaseManager(name, args=(), kw={}): def buildDatabaseManager(name, args=(), kw={}):
return getAdapterKlass(name)(*args, **kw) return getAdapterKlass(name)(*args, **kw)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import struct, threading import os, errno, socket, struct, sys, threading
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
...@@ -57,6 +57,10 @@ class DatabaseManager(object): ...@@ -57,6 +57,10 @@ class DatabaseManager(object):
ENGINES = () ENGINES = ()
UNSAFE = False UNSAFE = False
__lock = None
LOCK = "neostorage"
LOCKED = "error: database is locked"
_deferred = 0 _deferred = 0
_duplicating = _repairing = None _duplicating = _repairing = None
...@@ -86,6 +90,7 @@ class DatabaseManager(object): ...@@ -86,6 +90,7 @@ class DatabaseManager(object):
def _duplicate(self): def _duplicate(self):
cls = self.__class__ cls = self.__class__
db = cls.__new__(cls) db = cls.__new__(cls)
db.LOCK = None
db._duplicating = self db._duplicating = self
try: try:
db._connect() db._connect()
...@@ -104,6 +109,26 @@ class DatabaseManager(object): ...@@ -104,6 +109,26 @@ class DatabaseManager(object):
def _connect(self): def _connect(self):
"""Connect to the database""" """Connect to the database"""
def lock(self, db_path):
if self.LOCK:
assert self.__lock is None, self.__lock
# For platforms that don't support anonymous sockets,
# we can either use zc.lockfile or an empty SQLite db
# (with BEGIN EXCLUSIVE).
try:
stat = os.stat(db_path)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return # in-memory or temporary database
s = self.__lock = socket.socket(socket.AF_UNIX)
try:
s.bind('\0%s:%s:%s' % (self.LOCK, stat.st_dev, stat.st_ino))
except socket.error as e:
if e.errno != errno.EADDRINUSE:
raise
sys.exit(self.LOCKED)
@abstract @abstract
def erase(self): def erase(self):
"""""" """"""
...@@ -154,6 +179,9 @@ class DatabaseManager(object): ...@@ -154,6 +179,9 @@ class DatabaseManager(object):
def close(self): def close(self):
self._deferredCommit() self._deferredCommit()
self._close() self._close()
if self.__lock:
self.__lock.close()
del self.__lock
def _commit(self): def _commit(self):
"""Backend-specific code to commit the pending changes""" """Backend-specific code to commit the pending changes"""
...@@ -301,10 +329,23 @@ class DatabaseManager(object): ...@@ -301,10 +329,23 @@ class DatabaseManager(object):
Required only to import a DB using Importer backend. Required only to import a DB using Importer backend.
max_tid must be in unpacked format. max_tid must be in unpacked format.
Data from unassigned partitions must be ignored.
This is important because there may remain data from cells that have
been discarded, either due to --disable-drop-partitions option,
or in the future when dropping partitions is done in background
(because this is an expensive operation).
XXX: Given the TODO comment in getLastIDs, getting ids
from readable partitions should be enough.
""" """
def _getLastIDs(self): def _getLastIDs(self):
"""""" """Return (trans, obj, max(oid)) where
both 'trans' and 'obj' are {partition: max(tid)}
Same as in getLastTID: data from unassigned partitions must be ignored.
"""
@requires(_getLastIDs) @requires(_getLastIDs)
def getLastIDs(self): def getLastIDs(self):
......
...@@ -29,6 +29,7 @@ import os ...@@ -29,6 +29,7 @@ import os
import re import re
import string import string
import struct import struct
import sys
import time import time
from . import LOG_QUERIES from . import LOG_QUERIES
...@@ -52,9 +53,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -52,9 +53,6 @@ class MySQLDatabaseManager(DatabaseManager):
ENGINES = "InnoDB", "RocksDB", "TokuDB" ENGINES = "InnoDB", "RocksDB", "TokuDB"
_engine = ENGINES[0] # default engine _engine = ENGINES[0] # default engine
# Disabled even on MySQL 5.1-5.5 and MariaDB 5.2-5.3 because
# 'select count(*) from obj' sometimes returns incorrect values
# (tested with testOudatedCellsOnDownStorage).
_use_partition = False _use_partition = False
_max_allowed_packet = 32769 * 1024 _max_allowed_packet = 32769 * 1024
...@@ -102,9 +100,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -102,9 +100,17 @@ class MySQLDatabaseManager(DatabaseManager):
conn.autocommit(False) conn.autocommit(False)
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1)) conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
conn.query("SHOW VARIABLES WHERE variable_name='max_allowed_packet'") def query(sql):
r = conn.store_result() conn.query(sql)
(name, value), = r.fetch_row(r.num_rows()) r = conn.store_result()
return r.fetch_row(r.num_rows())
if self.LOCK:
(locked,), = query("SELECT GET_LOCK('%s.%s', 0)"
% (self.db, self.LOCK))
if not locked:
sys.exit(self.LOCKED)
(name, value), = query(
"SHOW VARIABLES WHERE variable_name='max_allowed_packet'")
if int(value) < self._max_allowed_packet: if int(value) < self._max_allowed_packet:
raise DatabaseFailure("Global variable %r is too small." raise DatabaseFailure("Global variable %r is too small."
" Minimal value must be %uk." " Minimal value must be %uk."
...@@ -304,21 +310,37 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -304,21 +310,37 @@ class MySQLDatabaseManager(DatabaseManager):
return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid) return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid)
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getAssignedPartitionList(self):
nid = self.getUUID()
if nid is None:
return ()
return [p for p, in self.query("SELECT rid FROM pt WHERE nid=%s" % nid)]
def _sqlmax(self, sql, arg_list):
q = self.query
x = [x for x in arg_list for x, in q(sql % x) if x is not None]
if x: return max(x)
def getLastTID(self, max_tid): def getLastTID(self, max_tid):
return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans" return self._sqlmax(
" WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0] "SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)"
" WHERE tid<=%s and `partition`=%%s" % max_tid,
self._getAssignedPartitionList())
def _getLastIDs(self): def _getLastIDs(self):
offset_list = self._getAssignedPartitionList()
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
trans = {partition: p64(tid) sql = ("SELECT MAX(tid) FROM %s FORCE INDEX (PRIMARY)"
for partition, tid in q("SELECT `partition`, MAX(tid)" " WHERE `partition`=%s")
" FROM trans GROUP BY `partition`")} trans, obj = ({partition: p64(tid)
obj = {partition: p64(tid) for partition in offset_list
for partition, tid in q("SELECT `partition`, MAX(tid)" for tid, in q(sql % (t, partition))
" FROM obj GROUP BY `partition`")} if tid is not None}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" for t in ('trans', 'obj'))
" GROUP BY `partition`) as t")[0][0] oid = self._sqlmax(
"SELECT MAX(oid) FROM obj FORCE INDEX (`partition`)"
" WHERE `partition`=%s", offset_list)
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def _getUnfinishedTIDDict(self): def _getUnfinishedTIDDict(self):
...@@ -337,7 +359,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -337,7 +359,7 @@ class MySQLDatabaseManager(DatabaseManager):
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
oid = util.u64(oid) oid = util.u64(oid)
r = self.query("SELECT tid FROM obj" r = self.query("SELECT tid FROM obj FORCE INDEX(`partition`)"
" WHERE `partition`=%d AND oid=%d" " WHERE `partition`=%d AND oid=%d"
" ORDER BY tid DESC LIMIT 1" " ORDER BY tid DESC LIMIT 1"
% (self._getReadablePartition(oid), oid)) % (self._getReadablePartition(oid), oid))
...@@ -358,7 +380,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -358,7 +380,8 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
partition = self._getReadablePartition(oid) partition = self._getReadablePartition(oid)
sql = ('SELECT tid, compression, data.hash, value, value_tid' sql = ('SELECT tid, compression, data.hash, value, value_tid'
' FROM obj LEFT JOIN data ON (obj.data_id = data.id)' ' FROM obj FORCE INDEX(`partition`)'
' LEFT JOIN data ON (obj.data_id = data.id)'
' WHERE `partition` = %d AND oid = %d') % (partition, oid) ' WHERE `partition` = %d AND oid = %d') % (partition, oid)
if before_tid is not None: if before_tid is not None:
sql += ' AND tid < %d ORDER BY tid DESC LIMIT 1' % before_tid sql += ' AND tid < %d ORDER BY tid DESC LIMIT 1' % before_tid
...@@ -414,7 +437,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -414,7 +437,8 @@ class MySQLDatabaseManager(DatabaseManager):
for partition in offset_list: for partition in offset_list:
where = " WHERE `partition`=%d" % partition where = " WHERE `partition`=%d" % partition
data_id_list = [x for x, in data_id_list = [x for x, in
q("SELECT DISTINCT data_id FROM obj USE INDEX(PRIMARY)" + where) q("SELECT DISTINCT data_id FROM obj FORCE INDEX(PRIMARY)"
+ where)
if x] if x]
if not self._use_partition: if not self._use_partition:
q("DELETE FROM obj" + where) q("DELETE FROM obj" + where)
...@@ -578,7 +602,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -578,7 +602,7 @@ class MySQLDatabaseManager(DatabaseManager):
del _structLL del _structLL
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
sql = ('SELECT tid, value_tid FROM obj' sql = ('SELECT tid, value_tid FROM obj FORCE INDEX(`partition`)'
' WHERE `partition` = %d AND oid = %d' ' WHERE `partition` = %d AND oid = %d'
) % (self._getReadablePartition(oid), oid) ) % (self._getReadablePartition(oid), oid)
if tid is not None: if tid is not None:
...@@ -669,7 +693,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -669,7 +693,8 @@ class MySQLDatabaseManager(DatabaseManager):
p64 = util.p64 p64 = util.p64
r = self.query("SELECT tid, IF(compression < 128, LENGTH(value)," r = self.query("SELECT tid, IF(compression < 128, LENGTH(value),"
" CAST(CONV(HEX(SUBSTR(value, 5, 4)), 16, 10) AS INT))" " CAST(CONV(HEX(SUBSTR(value, 5, 4)), 16, 10) AS INT))"
" FROM obj LEFT JOIN data ON (obj.data_id = data.id)" " FROM obj FORCE INDEX(`partition`)"
" LEFT JOIN data ON (obj.data_id = data.id)"
" WHERE `partition` = %d AND oid = %d AND tid >= %d" " WHERE `partition` = %d AND oid = %d AND tid >= %d"
" ORDER BY tid DESC LIMIT %d, %d" % " ORDER BY tid DESC LIMIT %d, %d" %
(self._getReadablePartition(oid), oid, (self._getReadablePartition(oid), oid,
...@@ -682,7 +707,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -682,7 +707,7 @@ class MySQLDatabaseManager(DatabaseManager):
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
min_tid = u64(min_tid) min_tid = u64(min_tid)
r = self.query('SELECT tid, oid FROM obj' r = self.query('SELECT tid, oid FROM obj FORCE INDEX(PRIMARY)'
' WHERE `partition` = %d AND tid <= %d' ' WHERE `partition` = %d AND tid <= %d'
' AND (tid = %d AND %d <= oid OR %d < tid)' ' AND (tid = %d AND %d <= oid OR %d < tid)'
' ORDER BY tid ASC, oid ASC LIMIT %d' % ( ' ORDER BY tid ASC, oid ASC LIMIT %d' % (
...@@ -751,7 +776,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -751,7 +776,8 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)" for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
" FROM obj WHERE tid <= %d GROUP BY oid" " FROM obj FORCE INDEX(`partition`)"
" WHERE tid <= %d GROUP BY oid"
% tid): % tid):
partition = getPartition(oid) partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE `partition` = %d" if q("SELECT 1 FROM obj WHERE `partition` = %d"
...@@ -801,7 +827,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -801,7 +827,7 @@ class MySQLDatabaseManager(DatabaseManager):
# last grouped value, instead of the greatest one. # last grouped value, instead of the greatest one.
r = self.query( r = self.query(
"""SELECT tid, oid """SELECT tid, oid
FROM obj FROM obj FORCE INDEX(PRIMARY)
WHERE `partition` = %(partition)s WHERE `partition` = %(partition)s
AND tid <= %(max_tid)d AND tid <= %(max_tid)d
AND (tid > %(min_tid)d OR AND (tid > %(min_tid)d OR
......
...@@ -78,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -78,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False) self.conn = sqlite3.connect(self.db, check_same_thread=False)
self.lock(self.db)
if self.UNSAFE: if self.UNSAFE:
q = self.query q = self.query
q("PRAGMA synchronous = OFF") q("PRAGMA synchronous = OFF")
...@@ -243,20 +244,25 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -243,20 +244,25 @@ class SQLiteDatabaseManager(DatabaseManager):
# each partition (and finish in Python with max() for getLastTID). # each partition (and finish in Python with max() for getLastTID).
def getLastTID(self, max_tid): def getLastTID(self, max_tid):
return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?", return self.query(
(max_tid,)).next()[0] "SELECT MAX(tid) FROM pt, trans"
" WHERE nid=? AND rid=partition AND tid<=?",
(self.getUUID(), max_tid,)).next()[0]
def _getLastIDs(self): def _getLastIDs(self):
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
args = self.getUUID(),
trans = {partition: p64(tid) trans = {partition: p64(tid)
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q(
" FROM trans GROUP BY partition")} "SELECT partition, MAX(tid) FROM pt, trans"
" WHERE nid=? AND rid=partition GROUP BY partition", args)}
obj = {partition: p64(tid) obj = {partition: p64(tid)
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q(
" FROM obj GROUP BY partition")} "SELECT partition, MAX(tid) FROM pt, obj"
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" " WHERE nid=? AND rid=partition GROUP BY partition", args)}
" GROUP BY partition) as t").next()[0] oid = q("SELECT MAX(oid) oid FROM pt, obj"
" WHERE nid=? AND rid=partition", args).next()[0]
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def _getUnfinishedTIDDict(self): def _getUnfinishedTIDDict(self):
......
...@@ -38,6 +38,9 @@ class InitializationHandler(BaseMasterHandler): ...@@ -38,6 +38,9 @@ class InitializationHandler(BaseMasterHandler):
# delete objects database # delete objects database
dm = app.dm dm = app.dm
if unassigned_set: if unassigned_set:
if app.disable_drop_partitions:
logging.info("don't drop data for partitions %r", unassigned_set)
else:
logging.debug('drop data for partitions %r', unassigned_set) logging.debug('drop data for partitions %r', unassigned_set)
dm.dropPartitions(unassigned_set) dm.dropPartitions(unassigned_set)
......
...@@ -46,7 +46,6 @@ class StorageOperationHandler(EventHandler): ...@@ -46,7 +46,6 @@ class StorageOperationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
if app.operational and conn.isClient(): if app.operational and conn.isClient():
# XXX: Connection and Node should merged.
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid: if uuid:
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
......
...@@ -356,6 +356,7 @@ class Replicator(object): ...@@ -356,6 +356,7 @@ class Replicator(object):
self.fetchTransactions() self.fetchTransactions()
def fetchTransactions(self, min_tid=None): def fetchTransactions(self, min_tid=None):
assert self.current_node.getConnection().isClient(), self.current_node
offset = self.current_partition offset = self.current_partition
p = self.partition_dict[offset] p = self.partition_dict[offset]
if min_tid: if min_tid:
......
...@@ -190,6 +190,11 @@ class NeoTestBase(unittest.TestCase): ...@@ -190,6 +190,11 @@ class NeoTestBase(unittest.TestCase):
"Mock objects can't be compared with '==' or '!='" "Mock objects can't be compared with '==' or '!='"
return super(NeoTestBase, self).assertEqual(first, second, msg=msg) return super(NeoTestBase, self).assertEqual(first, second, msg=msg)
def assertPartitionTable(self, pt, expected, key=None):
self.assertEqual(
expected if isinstance(expected, str) else '|'.join(expected),
'|'.join(pt._formatRows(sorted(pt.count_dict, key=key))))
class NeoUnitTestBase(NeoTestBase): class NeoUnitTestBase(NeoTestBase):
""" Base class for neo tests, implements common checks """ """ Base class for neo tests, implements common checks """
...@@ -217,7 +222,8 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -217,7 +222,8 @@ class NeoUnitTestBase(NeoTestBase):
temp_dir = getTempDirectory() temp_dir = getTempDirectory()
for i in xrange(number): for i in xrange(number):
try: try:
os.remove(os.path.join(temp_dir, 'test_neo%s.sqlite' % i)) os.remove(os.path.join(temp_dir,
'%s%s.sqlite' % (prefix, i)))
except OSError, e: except OSError, e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
......
...@@ -104,7 +104,7 @@ class ClusterPdb(object): ...@@ -104,7 +104,7 @@ class ClusterPdb(object):
def broken_peer(self): def broken_peer(self):
return self._getLastPdb(os.getpid()) is None return self._getLastPdb(os.getpid()) is None
def __call__(self, max_count=None, depth=0, text=None): def __call__(self, depth=0, max_count=None, gui=False):
depth += 1 depth += 1
if max_count: if max_count:
frame = sys._getframe(depth) frame = sys._getframe(depth)
...@@ -113,13 +113,8 @@ class ClusterPdb(object): ...@@ -113,13 +113,8 @@ class ClusterPdb(object):
self._count_dict[key] = count = 1 + self._count_dict.get(key, 0) self._count_dict[key] = count = 1 + self._count_dict.get(key, 0)
if max_count < count: if max_count < count:
return return
if not text: if gui:
try:
import rpdb2 import rpdb2
except ImportError:
if text is not None:
raise
else:
if rpdb2.g_debugger is None: if rpdb2.g_debugger is None:
rpdb2_CStateManager = rpdb2.CStateManager rpdb2_CStateManager = rpdb2.CStateManager
def CStateManager(*args, **kw): def CStateManager(*args, **kw):
......
...@@ -37,10 +37,11 @@ from neo.lib import logging ...@@ -37,10 +37,11 @@ from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES UUID_NAMESPACES
from neo.lib.util import dump from neo.lib.util import dump
from .. import ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL, \ from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, NeoTestBase, setupMySQLdb buildUrlFromString, cluster, getTempDirectory, NeoTestBase, Patch,
setupMySQLdb)
from neo.client.Storage import Storage from neo.client.Storage import Storage
from neo.storage.database import buildDatabaseManager from neo.storage.database import manager, buildDatabaseManager
try: try:
coverage = sys.modules['neo.scripts.runner'].coverage coverage = sys.modules['neo.scripts.runner'].coverage
...@@ -124,7 +125,7 @@ class NEOProcess(object): ...@@ -124,7 +125,7 @@ class NEOProcess(object):
def __init__(self, command, uuid, arg_dict): def __init__(self, command, uuid, arg_dict):
try: try:
__import__('neo.scripts.' + command) __import__('neo.scripts.' + command, level=0)
except ImportError: except ImportError:
raise NotFound, '%s not found' % (command) raise NotFound, '%s not found' % (command)
self.command = command self.command = command
...@@ -491,7 +492,8 @@ class NEOCluster(object): ...@@ -491,7 +492,8 @@ class NEOCluster(object):
def getSQLConnection(self, db): def getSQLConnection(self, db):
assert db is not None and db in self.db_list assert db is not None and db in self.db_list
return buildDatabaseManager(self.adapter, (self.db_template(db),)) with Patch(manager.DatabaseManager, LOCK=None):
return buildDatabaseManager(self.adapter, (self.db_template(db),))
def getMasterProcessList(self): def getMasterProcessList(self):
return self.process_dict.get(NodeTypes.MASTER) return self.process_dict.get(NodeTypes.MASTER)
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import random, time, unittest
from collections import defaultdict from collections import defaultdict
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib import logging
from neo.lib.protocol import NodeStates, CellStates from neo.lib.protocol import NodeStates, CellStates
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.master.pt import PartitionTable from neo.master.pt import PartitionTable
...@@ -45,7 +46,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -45,7 +46,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
self.assertEqual(len(pt.getRow(x)), 0) self.assertEqual(len(pt.getRow(x)), 0)
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertRaises(RuntimeError, pt.make, []) self.assertRaises(AssertionError, pt.make, [])
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
...@@ -132,77 +133,35 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -132,77 +133,35 @@ class MasterPartitionTableTests(NeoUnitTestBase):
(1, 2, CellStates.DISCARDED), (1, 2, CellStates.DISCARDED),
(2, 2, CellStates.DISCARDED)]) (2, 2, CellStates.DISCARDED)])
pt._setCell(0, sn[0], CellStates.UP_TO_DATE)
self.assertEqual(self.tweak(pt), [(2, 3, CellStates.FEEDING)]) self.assertEqual(self.tweak(pt), [(2, 3, CellStates.FEEDING)])
def test_16_make(self): def test_16_make(self):
num_partitions = 5 node_list = [self.createStorage(
num_replicas = 1 ("127.0.0.1", 19000 + i), self.getStorageUUID(),
pt = PartitionTable(num_partitions, num_replicas) NodeStates.RUNNING)
# add nodes for i in xrange(4)]
uuid1 = self.getStorageUUID() for np, nr, expected in (
server1 = ("127.0.0.1", 19001) (3, 0, 'U..|.U.|..U'),
sn1 = self.createStorage(server1, uuid1, NodeStates.RUNNING) (5, 1, 'UU..|..UU|UU..|..UU|UU..'),
# add not running node (9, 2, 'UUU.|UU.U|U.UU|.UUU|UUU.|UU.U|U.UU|.UUU|UUU.'),
uuid2 = self.getStorageUUID() ):
server2 = ("127.0.0.2", 19001) pt = PartitionTable(np, nr)
sn2 = self.createStorage(server2, uuid2) pt.make(node_list)
sn2.setState(NodeStates.DOWN) self.assertPartitionTable(pt, expected)
# add node without uuid self.assertTrue(pt.filled())
server3 = ("127.0.0.3", 19001) self.assertTrue(pt.operational())
sn3 = self.createStorage(server3, None, NodeStates.RUNNING) # create a pt with less nodes
# add clear node pt.clear()
uuid4 = self.getStorageUUID() self.assertFalse(pt.filled())
server4 = ("127.0.0.4", 19001) self.assertFalse(pt.operational())
sn4 = self.createStorage(server4, uuid4, NodeStates.RUNNING) pt.make(node_list[:1])
uuid5 = self.getStorageUUID() self.assertPartitionTable(pt, '|'.join('U' * np))
server5 = ("127.0.0.5", 1900) self.assertTrue(pt.filled())
sn5 = self.createStorage(server5, uuid5, NodeStates.RUNNING) self.assertTrue(pt.operational())
# make the table
pt.make([sn1, sn2, sn3, sn4, sn5])
# check it's ok, only running nodes and node with uuid
# must be present
for x in xrange(num_partitions):
cells = pt.getCellList(x)
self.assertEqual(len(cells), 2)
nodes = [x.getNode() for x in cells]
for node in nodes:
self.assertTrue(node in (sn1, sn4, sn5))
self.assertTrue(node not in (sn2, sn3))
self.assertTrue(pt.filled())
self.assertTrue(pt.operational())
# create a pt with less nodes
pt.clear()
self.assertFalse(pt.filled())
self.assertFalse(pt.operational())
pt.make([sn1])
# check it's ok
for x in xrange(num_partitions):
cells = pt.getCellList(x)
self.assertEqual(len(cells), 1)
nodes = [x.getNode() for x in cells]
for node in nodes:
self.assertEqual(node, sn1)
self.assertTrue(pt.filled())
self.assertTrue(pt.operational())
def _pt_states(self, pt):
node_dict = defaultdict(list)
for offset, row in enumerate(pt.partition_list):
for cell in row:
state_list = node_dict[cell.getNode()]
if state_list:
self.assertTrue(state_list[-1][0] < offset)
state_list.append((offset, str(cell.getState())[0]))
return map(dict, sorted(node_dict.itervalues()))
def checkPT(self, pt, exclude_empty=False):
new_pt = PartitionTable(pt.np, pt.nr)
new_pt.make(node for node, count in pt.count_dict.iteritems()
if count or not exclude_empty)
self.assertEqual(self._pt_states(pt), self._pt_states(new_pt))
def update(self, pt, change_list=None): def update(self, pt, change_list=None):
offset_list = range(pt.np) offset_list = xrange(pt.np)
for node in pt.count_dict: for node in pt.count_dict:
pt.updatable(node.getUUID(), offset_list) pt.updatable(node.getUUID(), offset_list)
if change_list is None: if change_list is None:
...@@ -215,9 +174,11 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -215,9 +174,11 @@ class MasterPartitionTableTests(NeoUnitTestBase):
for offset, uuid, state in change_list: for offset, uuid, state in change_list:
if state is CellStates.OUT_OF_DATE: if state is CellStates.OUT_OF_DATE:
pt.setUpToDate(node_dict[uuid], offset) pt.setUpToDate(node_dict[uuid], offset)
pt.log()
def tweak(self, pt, drop_list=()): def tweak(self, pt, drop_list=()):
change_list = pt.tweak(drop_list) change_list = pt.tweak(drop_list)
pt.log()
self.assertFalse(pt.tweak(drop_list)) self.assertFalse(pt.tweak(drop_list))
return change_list return change_list
...@@ -225,6 +186,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -225,6 +186,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
sn = [self.createStorage(None, i + 1, NodeStates.RUNNING) sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(5)] for i in xrange(5)]
pt = PartitionTable(5, 2) pt = PartitionTable(5, 2)
pt.setID(1)
# part 0 # part 0
pt._setCell(0, sn[0], CellStates.DISCARDED) pt._setCell(0, sn[0], CellStates.DISCARDED)
pt._setCell(0, sn[1], CellStates.UP_TO_DATE) pt._setCell(0, sn[1], CellStates.UP_TO_DATE)
...@@ -246,45 +208,108 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -246,45 +208,108 @@ class MasterPartitionTableTests(NeoUnitTestBase):
pt._setCell(4, sn[4], CellStates.UP_TO_DATE) pt._setCell(4, sn[4], CellStates.UP_TO_DATE)
count_dict = defaultdict(int) count_dict = defaultdict(int)
self.assertPartitionTable(pt, (
'.U...',
'FFO..',
'FUU..',
'UUUU.',
'U...U'))
change_list = self.tweak(pt) change_list = self.tweak(pt)
self.assertPartitionTable(pt, (
'.UO.O',
'UU.O.',
'UFU.O',
'.UUU.',
'U..OU'))
for offset, uuid, state in change_list: for offset, uuid, state in change_list:
count_dict[state] += 1 count_dict[state] += 1
self.assertEqual(count_dict, {CellStates.DISCARDED: 3, self.assertEqual(count_dict, {CellStates.DISCARDED: 2,
CellStates.FEEDING: 1,
CellStates.OUT_OF_DATE: 5, CellStates.OUT_OF_DATE: 5,
CellStates.UP_TO_DATE: 3}) CellStates.UP_TO_DATE: 3})
self.update(pt, change_list) self.update(pt)
self.checkPT(pt) self.assertPartitionTable(pt, (
'.UU.U',
'UU.U.',
'U.U.U',
'.UUU.',
'U..UU'))
self.assertRaises(PartitionTableException, pt.dropNodeList, sn[1:4]) self.assertRaises(PartitionTableException, pt.dropNodeList, sn[1:4])
self.assertEqual(6, len(pt.dropNodeList(sn[1:3], True))) self.assertEqual(6, len(pt.dropNodeList(sn[1:3], True)))
self.assertEqual(3, len(pt.dropNodeList([sn[1]]))) self.assertEqual(3, len(pt.dropNodeList([sn[1]])))
pt.addNodeList([sn[1]]) pt.addNodeList([sn[1]])
self.assertPartitionTable(pt, (
'..U.U',
'U..U.',
'U.U.U',
'..UU.',
'U..UU'))
change_list = self.tweak(pt) change_list = self.tweak(pt)
self.assertPartitionTable(pt, (
'.OU.U',
'UO.U.',
'U.U.U',
'.OUU.',
'U..UU'))
self.assertEqual(3, len(change_list)) self.assertEqual(3, len(change_list))
self.update(pt, change_list) self.update(pt, change_list)
self.checkPT(pt)
for np, i in (12, 0), (12, 1), (13, 2): for np, i, expected in (
(12, 0, ('U...|.U..|..U.|...U|'
'U...|.U..|..U.|...U|'
'U...|.U..|..U.|...U',)),
(12, 1, ('UU...|..UU.|U...U|.UU..|...UU|'
'UU...|..UU.|U...U|.UU..|...UU|'
'UU...|..UU.',)),
(13, 2, ('U.UU.|.U.UU|UUU..|..UUU|UU..U|'
'U.UU.|.U.UU|UUU..|..UUU|UU..U|'
'U.UU.|.U.UU|UUU..',
'UUU..|U..UU|.UUU.|UU..U|..UUU|'
'UUU..|U..UU|.UUU.|UU..U|..UUU|'
'UUU..|U..UU|.UUU.')),
):
pt = PartitionTable(np, i) pt = PartitionTable(np, i)
i += 1 i += 1
pt.make(sn[:i]) pt.make(sn[:i])
pt.log()
for n in sn[i:i+3]: for n in sn[i:i+3]:
self.assertEqual([n], pt.addNodeList([n])) self.assertEqual([n], pt.addNodeList([n]))
self.update(pt, self.tweak(pt)) self.update(pt, self.tweak(pt))
self.checkPT(pt) self.assertPartitionTable(pt, expected[0])
pt.clear() pt.clear()
pt.make(sn[:i]) pt.make(sn[:i])
for n in sn[i:i+3]: for n in sn[i:i+3]:
self.assertEqual([n], pt.addNodeList([n])) self.assertEqual([n], pt.addNodeList([n]))
self.tweak(pt) self.tweak(pt)
self.update(pt) self.update(pt)
self.checkPT(pt) self.assertPartitionTable(pt, expected[-1])
pt = PartitionTable(7, 0) pt = PartitionTable(7, 0)
pt.make(sn[:1]) pt.make(sn[:1])
pt.addNodeList(sn[1:3]) pt.addNodeList(sn[1:3])
self.assertPartitionTable(pt, 'U..|U..|U..|U..|U..|U..|U..')
self.update(pt, self.tweak(pt, sn[:1])) self.update(pt, self.tweak(pt, sn[:1]))
self.checkPT(pt, True) self.assertPartitionTable(pt, '.U.|..U|.U.|..U|.U.|..U|.U.')
def test_18_tweak(self):
s = repr(time.time())
logging.info("using seed %r", s)
r = random.Random(s)
sn_count = 11
sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(sn_count)]
pt = PartitionTable(1000, 2)
pt.setID(1)
for offset in xrange(pt.np):
state = CellStates.UP_TO_DATE
k = r.randrange(1, sn_count)
for s in r.sample(sn, k):
pt._setCell(offset, s, state)
if k * r.random() < 1:
state = CellStates.OUT_OF_DATE
pt.log()
self.tweak(pt)
self.update(pt)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -131,6 +131,15 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -131,6 +131,15 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2)) self.assertEqual(set(list1), set(list2))
def _test_lockDatabase_open(self):
raise NotImplementedError
def test_lockDatabase(self):
db = self._test_lockDatabase_open()
self.assertRaises(SystemExit, self._test_lockDatabase_open)
db.close()
self._test_lockDatabase_open().close()
def test_getUnfinishedTIDDict(self): def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
......
...@@ -29,11 +29,13 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -29,11 +29,13 @@ class StorageMySQLdbTests(StorageDBTests):
engine = None engine = None
def getDB(self, reset=0): def _test_lockDatabase_open(self):
self.prepareDatabase(number=1, prefix=DB_PREFIX) self.prepareDatabase(number=1, prefix=DB_PREFIX)
# db manager
database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET) database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET)
db = MySQLDatabaseManager(database, self.engine) return MySQLDatabaseManager(database, self.engine)
def getDB(self, reset=0):
db = self._test_lockDatabase_open()
self.assertEqual(db.db, DB_PREFIX + '0') self.assertEqual(db.db, DB_PREFIX + '0')
self.assertEqual(db.user, DB_USER) self.assertEqual(db.user, DB_USER)
try: try:
...@@ -129,11 +131,13 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -129,11 +131,13 @@ class StorageMySQLdbTests(StorageDBTests):
class StorageMySQLdbRocksDBTests(StorageMySQLdbTests): class StorageMySQLdbRocksDBTests(StorageMySQLdbTests):
engine = "RocksDB" engine = "RocksDB"
test_lockDatabase = None
class StorageMySQLdbTokuDBTests(StorageMySQLdbTests): class StorageMySQLdbTokuDBTests(StorageMySQLdbTests):
engine = "TokuDB" engine = "TokuDB"
test_lockDatabase = None
del StorageDBTests del StorageDBTests
......
...@@ -14,17 +14,29 @@ ...@@ -14,17 +14,29 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import os, unittest
from .. import getTempDirectory, DB_PREFIX
from .testStorageDBTests import StorageDBTests from .testStorageDBTests import StorageDBTests
from neo.storage.database.sqlite import SQLiteDatabaseManager from neo.storage.database.sqlite import SQLiteDatabaseManager
class StorageSQLiteTests(StorageDBTests): class StorageSQLiteTests(StorageDBTests):
def _test_lockDatabase_open(self):
db = os.path.join(getTempDirectory(), DB_PREFIX + '0.sqlite')
return SQLiteDatabaseManager(db)
def getDB(self, reset=0): def getDB(self, reset=0):
db = SQLiteDatabaseManager(':memory:') db = SQLiteDatabaseManager(':memory:')
db.setup(reset) db.setup(reset)
return db return db
def test_lockDatabase(self):
super(StorageSQLiteTests, self).test_lockDatabase()
# No lock on temporary databases.
db = self.getDB()
self.getDB().close()
db.close()
del StorageDBTests del StorageDBTests
if __name__ == "__main__": if __name__ == "__main__":
......
# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from time import time
from .mock import Mock
from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ClientConnection, \
MTClientConnection, CRITICAL_TIMEOUT
from neo.lib.handler import EventHandler
from neo.lib.protocol import ENCODED_VERSION, Packets
from . import NeoUnitTestBase, Patch
connector_cpt = 0
class DummyConnector(Mock):
def __init__(self, addr, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
self.addr = addr
Mock.__init__(self)
def getAddress(self):
return self.addr
def getDescriptor(self):
return self.desc
accept = getError = makeClientConnection = makeListeningConnection = \
receive = send = lambda *args, **kw: None
dummy_connector = Patch(BaseConnection,
ConnectorClass=lambda orig, self, *args, **kw: DummyConnector(*args, **kw))
class ConnectionTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.app = Mock({'__repr__': 'Fake App'})
self.app.ssl = None
self.em = self.app.em = Mock({'__repr__': 'Fake Em'})
self.handler = Mock({'__repr__': 'Fake Handler'})
self.address = ("127.0.0.7", 93413)
self.node = Mock({'getAddress': self.address})
def _makeClientConnection(self):
with dummy_connector:
conn = ClientConnection(self.app, self.handler, self.node)
self.connector = conn.connector
return conn
def testTimeout(self):
# NOTE: This method uses ping/pong packets only because MT connections
# don't accept any other packet without specifying a queue.
self.handler = EventHandler(self.app)
conn = self._makeClientConnection()
conn.read_buf.append(ENCODED_VERSION)
use_case_list = (
# (a) For a single packet sent at T,
# the limit time for the answer is T + (1 * CRITICAL_TIMEOUT)
((), (1., 1)),
# (b) Same as (a), even if send another packet at (T + CT/2).
# But receiving a packet (at T + CT - ε) resets the timeout
# (which means the limit for the 2nd one is T + 2*CT)
((.5, None), (1., 1, 2., 3)),
# (c) Same as (b) with a first answer at well before the limit
# (T' = T + CT/2). The limit for the second one is T' + CT.
((.1, None, .5, 3), (1.5, 1)),
)
def set_time(t):
connection.time = lambda: int(CRITICAL_TIMEOUT * (1000 + t))
closed = []
conn.close = lambda: closed.append(connection.time())
def answer(packet_id):
p = Packets.Pong()
p.setId(packet_id)
conn.connector.receive = lambda read_buf: \
read_buf.append(''.join(p.encode()))
conn.readable()
checkTimeout()
conn.process()
def checkTimeout():
timeout = conn.getTimeout()
if timeout and timeout <= connection.time():
conn.onTimeout()
try:
for use_case, expected in use_case_list:
i = iter(use_case)
conn.cur_id = 1 # XXX -> conn._reset() ?
set_time(0)
# No timeout when no pending request
self.assertEqual(conn._handlers.getNextTimeout(), None)
conn.ask(Packets.Ping())
for t in i:
set_time(t)
checkTimeout()
packet_id = i.next()
if packet_id is None:
conn.ask(Packets.Ping())
else:
answer(packet_id)
i = iter(expected)
for t in i:
set_time(t - .1)
checkTimeout()
set_time(t)
# this test method relies on the fact that only
# conn.close is called in case of a timeout
checkTimeout()
self.assertEqual(closed.pop(), connection.time())
answer(i.next())
self.assertFalse(conn.isPending())
self.assertFalse(closed)
finally:
connection.time = time
class MTConnectionTests(ConnectionTests):
# XXX: here we test non-client-connection-related things too, which
# duplicates test suite work... Should be fragmented into finer-grained
# test classes.
def setUp(self):
super(MTConnectionTests, self).setUp()
self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'})
def _makeClientConnection(self):
with dummy_connector:
conn = MTClientConnection(self.app, self.handler, self.node,
dispatcher=self.dispatcher)
self.connector = conn.connector
return conn
def test_MTClientConnectionQueueParameter(self):
ask = self._makeClientConnection().ask
packet = Packets.AskPrimary() # Any non-Ping simple "ask" packet
# One cannot "ask" anything without a queue
self.assertRaises(TypeError, ask, packet)
ask(packet, queue=object())
# ... except Ping
ask(Packets.Ping())
if __name__ == '__main__':
unittest.main()
...@@ -1062,11 +1062,11 @@ class NEOThreadedTest(NeoTestBase): ...@@ -1062,11 +1062,11 @@ class NEOThreadedTest(NeoTestBase):
with Patch(client, _getFinalTID=lambda *_: None): with Patch(client, _getFinalTID=lambda *_: None):
self.assertRaises(ConnectionClosed, txn.commit) self.assertRaises(ConnectionClosed, txn.commit)
def assertPartitionTable(self, cluster, stats, pt_node=None): def assertPartitionTable(self, cluster, expected, pt_node=None):
pt = (pt_node or cluster.admin).pt
index = [x.uuid for x in cluster.storage_list].index index = [x.uuid for x in cluster.storage_list].index
self.assertEqual(stats, '|'.join(pt._formatRows(sorted( super(NEOThreadedTest, self).assertPartitionTable(
pt.count_dict, key=lambda x: index(x.getUUID()))))) (pt_node or cluster.admin).pt, expected,
lambda x: index(x.getUUID()))
@staticmethod @staticmethod
def noConnection(jar, storage): def noConnection(jar, storage):
......
...@@ -35,7 +35,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation ...@@ -35,7 +35,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.handler import DelayEvent from neo.lib.handler import DelayEvent
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes, from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
Packets, Packet, uuid_str, ZERO_OID, ZERO_TID) Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID)
from .. import expectedFailure, unpickle_state, Patch, TransactionalResource from .. import expectedFailure, unpickle_state, Patch, TransactionalResource
from . import ClientApplication, ConnectionFilter, LockLock, NEOThreadedTest, \ from . import ClientApplication, ConnectionFilter, LockLock, NEOThreadedTest, \
RandomConflictDict, ThreadId, with_cluster RandomConflictDict, ThreadId, with_cluster
...@@ -1350,19 +1350,6 @@ class Test(NEOThreadedTest): ...@@ -1350,19 +1350,6 @@ class Test(NEOThreadedTest):
poll(0) poll(0)
self.assertIs(client.connector, None) self.assertIs(client.connector, None)
def testConnectionTimeout(self):
with self.getLoopbackConnection() as conn:
conn.KEEP_ALIVE
def onTimeout(orig):
conn.idle()
orig()
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
with Patch(conn, onTimeout=onTimeout):
conn.em.poll(1)
self.assertFalse(conn.isClosed())
@with_cluster() @with_cluster()
def testClientDisconnectedFromMaster(self, cluster): def testClientDisconnectedFromMaster(self, cluster):
def disconnect(conn, packet): def disconnect(conn, packet):
...@@ -2061,7 +2048,7 @@ class Test(NEOThreadedTest): ...@@ -2061,7 +2048,7 @@ class Test(NEOThreadedTest):
if (isinstance(packet, Packets.AnswerStoreObject) if (isinstance(packet, Packets.AnswerStoreObject)
and packet.decode()[0]): and packet.decode()[0]):
conn, = cluster.client.getConnectionList(app) conn, = cluster.client.getConnectionList(app)
kw = conn._handlers._pending[0][0][packet._id][3] kw = conn._handlers._pending[0][0][packet._id][1]
return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop() return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop()
def writeA(orig, txn_context, oid, serial, data): def writeA(orig, txn_context, oid, serial, data):
if u64(oid) == 1: if u64(oid) == 1:
...@@ -2335,6 +2322,34 @@ class Test(NEOThreadedTest): ...@@ -2335,6 +2322,34 @@ class Test(NEOThreadedTest):
self.assertFalse(m1.primary) self.assertFalse(m1.primary)
self.assertTrue(m1.is_alive()) self.assertTrue(m1.is_alive())
@with_cluster(partitions=2, storage_count=2)
def testStorageBackendLastIDs(self, cluster):
"""
Check that getLastIDs/getLastTID ignore data from unassigned partitions.
XXX: this kind of test should not be reexecuted with SSL
"""
cluster.sortStorageList()
t, c = cluster.getTransaction()
c.root()[''] = PCounter()
t.commit()
big_id_list = ('\x7c' * 8, '\x7e' * 8), ('\x7b' * 8, '\x7d' * 8)
for i in 0, 1:
dm = cluster.storage_list[i].dm
expected = dm.getLastTID(u64(MAX_TID)), dm.getLastIDs()
oid, tid = big_id_list[i]
for j, expected in (
(1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())),
(i, (u64(tid), (tid, {}, {}, oid)))):
oid, tid = big_id_list[j]
# Somehow we abuse 'storeTransaction' because we ask it to
# write data for unassigned partitions. This is not checked
# so for the moment, the test works.
dm.storeTransaction(tid, ((oid, None, None),),
((oid,), '', '', '', 0, tid), False)
self.assertEqual(expected,
(dm.getLastTID(u64(MAX_TID)), dm.getLastIDs()))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -29,7 +29,6 @@ from neo.storage.checker import CHECK_COUNT ...@@ -29,7 +29,6 @@ from neo.storage.checker import CHECK_COUNT
from neo.storage.replicator import Replicator from neo.storage.replicator import Replicator
from neo.lib.connector import SocketConnector from neo.lib.connector import SocketConnector
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64, u64 from neo.lib.util import p64, u64
...@@ -283,35 +282,6 @@ class ReplicationTests(NEOThreadedTest): ...@@ -283,35 +282,6 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.last_tid, upstream.last_tid) self.assertEqual(backup.last_tid, upstream.last_tid)
self.assertEqual(np*3, self.checkBackup(backup)) self.assertEqual(np*3, self.checkBackup(backup))
@backup_test()
def testBackupUpstreamMasterDead(self, backup):
"""Check proper behaviour when upstream master is unreachable
More generally, this checks that when a handler raises when a connection
is closed voluntarily, the connection is in a consistent state and can
be, for example, closed again after the exception is caught, without
assertion failure.
"""
conn, = backup.master.getConnectionList(backup.upstream.master)
# trigger ping
self.assertFalse(conn.isPending())
conn.onTimeout()
self.assertTrue(conn.isPending())
# force ping to have expired
# connection will be closed before upstream master has time
# to answer
def _poll(orig, self, blocking):
if backup.master.em is self:
p.revert()
conn._next_timeout = 0
conn.onTimeout()
else:
orig(self, blocking)
with Patch(EventManager, _poll=_poll) as p:
self.tic()
new_conn, = backup.master.getConnectionList(backup.upstream.master)
self.assertIsNot(new_conn, conn)
@backup_test() @backup_test()
def testBackupUpstreamStorageDead(self, backup): def testBackupUpstreamStorageDead(self, backup):
upstream = backup.upstream upstream = backup.upstream
...@@ -334,7 +304,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -334,7 +304,7 @@ class ReplicationTests(NEOThreadedTest):
self.tic(check_timeout=(backup.storage,)) self.tic(check_timeout=(backup.storage,))
# 2nd failed, 3rd deferred # 2nd failed, 3rd deferred
self.assertEqual(count[0], 4) self.assertEqual(count[0], 4)
self.assertTrue(t <= time.time()) self.assertLessEqual(t, time.time())
@backup_test() @backup_test()
def testBackupDelayedUnlockTransaction(self, backup): def testBackupDelayedUnlockTransaction(self, backup):
...@@ -406,13 +376,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -406,13 +376,13 @@ class ReplicationTests(NEOThreadedTest):
s2.start() s2.start()
self.tic() self.tic()
cluster.enableStorageList([s2]) cluster.enableStorageList([s2])
# 2 UP_TO_DATE cells should become FEEDING, # 2 UP_TO_DATE cells become FEEDING:
# and be dropped only when the replication is done, # they are dropped only when the replication is done,
# so that 1 storage can still die without data loss. # so that 1 storage can still die without data loss.
with Patch(s0.dm, changePartitionTable=changePartitionTable): with Patch(s0.dm, changePartitionTable=changePartitionTable):
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
self.tic() self.tic()
expectedFailure(self.assertEqual)(cluster.neoctl.getClusterState(), self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
@with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3) @with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3)
...@@ -625,6 +595,31 @@ class ReplicationTests(NEOThreadedTest): ...@@ -625,6 +595,31 @@ class ReplicationTests(NEOThreadedTest):
with s0.dm.replicated(1): with s0.dm.replicated(1):
self.assertFalse(s0.dm.getObject(ob._p_oid, tid2)) self.assertFalse(s0.dm.getObject(ob._p_oid, tid2))
@with_cluster(start_cluster=0, storage_count=2, partitions=2)
def testDropPartitions(self, cluster, disable=False):
s0, s1 = cluster.storage_list
cluster.start(storage_list=(s0,))
t, c = cluster.getTransaction()
c.root()[''] = PCounter()
t.commit()
s1.start()
self.tic()
self.assertEqual(3, s0.sqlCount('obj'))
cluster.enableStorageList((s1,))
cluster.neoctl.tweakPartitionTable()
self.tic()
self.assertEqual(1, s1.sqlCount('obj'))
# Deletion should start as soon as the cell is discarded, as a
# background task, instead of doing it during initialization.
count = s0.sqlCount('obj')
s0.stop()
cluster.join((s0,))
s0.resetNode()
s0.start()
self.tic()
self.assertEqual(2, s0.sqlCount('obj'))
expectedFailure(self.assertEqual)(2, count)
@with_cluster(start_cluster=0, replicas=1) @with_cluster(start_cluster=0, replicas=1)
def testResumingReplication(self, cluster): def testResumingReplication(self, cluster):
if 1: if 1:
......
...@@ -34,8 +34,8 @@ class SSLMixin: ...@@ -34,8 +34,8 @@ class SSLMixin:
class SSLTests(SSLMixin, test.Test): class SSLTests(SSLMixin, test.Test):
# exclude expected failures # exclude expected failures
testDeadlockAvoidance = None # XXX why this fails? testStorageDataLock2 = None # XXX why this fails?
testUndoConflict = testUndoConflictDuringStore = None # XXX why this fails? testUndoConflictDuringStore = None # XXX why this fails?
def testAbortConnection(self, after_handshake=1): def testAbortConnection(self, after_handshake=1):
with self.getLoopbackConnection() as conn: with self.getLoopbackConnection() as conn:
......
...@@ -16,7 +16,7 @@ Topic :: Software Development :: Libraries :: Python Modules ...@@ -16,7 +16,7 @@ Topic :: Software Development :: Libraries :: Python Modules
mock = 'neo/tests/mock.py' mock = 'neo/tests/mock.py'
if not os.path.exists(mock): if not os.path.exists(mock):
import cStringIO, hashlib,subprocess, urllib, zipfile import cStringIO, hashlib, subprocess, urllib, zipfile
x = 'pythonmock-0.1.0.zip' x = 'pythonmock-0.1.0.zip'
try: try:
x = subprocess.check_output(('git', 'cat-file', 'blob', x)) x = subprocess.check_output(('git', 'cat-file', 'blob', x))
...@@ -24,8 +24,9 @@ if not os.path.exists(mock): ...@@ -24,8 +24,9 @@ if not os.path.exists(mock):
x = urllib.urlopen( x = urllib.urlopen(
'http://downloads.sf.net/sourceforge/python-mock/' + x).read() 'http://downloads.sf.net/sourceforge/python-mock/' + x).read()
mock_py = zipfile.ZipFile(cStringIO.StringIO(x)).read('mock.py') mock_py = zipfile.ZipFile(cStringIO.StringIO(x)).read('mock.py')
if hashlib.md5(mock_py).hexdigest() != '79f42f390678e5195d9ce4ae43bd18ec': if (hashlib.sha256(mock_py).hexdigest() !=
raise EnvironmentError("MD5 checksum mismatch downloading 'mock.py'") 'c6ed26e4312ed82160016637a9b6f8baa71cf31a67c555d44045a1ef1d60d1bc'):
raise EnvironmentError("SHA checksum mismatch downloading 'mock.py'")
open(mock, 'w').write(mock_py) open(mock, 'w').write(mock_py)
zodb_require = ['ZODB3>=3.10dev'] zodb_require = ['ZODB3>=3.10dev']
...@@ -59,11 +60,11 @@ else: ...@@ -59,11 +60,11 @@ else:
setup( setup(
name = 'neoppod', name = 'neoppod',
version = '1.7.1', version = '1.8',
description = __doc__.strip(), description = __doc__.strip(),
author = 'Nexedi SA', author = 'Nexedi SA',
author_email = 'neo-dev@erp5.org', author_email = 'neo-dev@erp5.org',
url = 'http://www.neoppod.org/', url = 'https://neo.nexedi.com/',
license = 'GPL 2+', license = 'GPL 2+',
platforms = ["any"], platforms = ["any"],
classifiers=classifiers.splitlines(), classifiers=classifiers.splitlines(),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment