Commit 03077c10 authored by Kirill Smelkov's avatar Kirill Smelkov

Merge remote-tracking branch 'origin/master' into t

* origin/master: (23 commits)
  mysql: more index hints
  Release version 1.8
  README: update URLs
  README: update wrt added support for RocksDB and recent ZODB
  storage: update DatabaseManager.getLastTID docstring
  neolog: new --decompress option
  doc: update TODO about missing invalidations in read-only mode
  mysql: remove obsolete comment about broken PARTITIONing support
  qa: make ClusterPdb compatible with the simple pdb of neo.tests
  client: fix NameError when a secondary master reports that it's not the primary
  storage: new --disable-drop-partitions option
  qa: add testDropPartitions
  Better use of __import__
  qa: update list of excluded tests in testSSL
  master: improve algorithm to tweak the partition table
  storage: ignore unassigned partitions when looking for last oids/tids
  neolog: new option to hide the node column
  Remove packet timeouts
  Use TCP keepalives instead of applicative pings
  Remove unused 'on_timeout' feature on connections
  ...
parents c4d3957f 0868de70
Change History Change History
============== ==============
1.8 (2017-07-04)
----------------
This release mainly stabilizes NEO when it is used with several storage nodes,
fixing many race conditions involving events like transactional operations
(read/write, conflict resolution...), replication, partition table tweaking,
and all kinds of failures (node crashes, network cuts...). This includes a
rework of conflict resolution, to implement the long-awaited deadlock avoidance
(it was a limitation caused by object-level locking).
Similarly, having spare master nodes is not an experimental feature anymore:
the `election` (of the primary master) has been reimplemented, and it now
happens during the RECOVERING phase. This comes with a change about node
states: BROKEN/HIDDEN/UNKNOWN are removed, DOWN is renamed into UNKNOWN,
and TEMPORARILY_DOWN into DOWN.
And still for more resiliency, the new algorithm to tweak the partition table
is better at minimizing the amount of replication, and it does not discard
readable cells too quickly anymore: a partition can now have multiple FEEDING
cells, to avoid going below the wanted level of replication.
Other changes:
- General:
- Packet timeouts have been removed.
TCP keepalives are used instead of applicative pings.
- Connection handshake between nodes is reviewed to make sure that they
speak the same protocol before doing anything else, and report clearer
error messages otherwise. A dangerous bug was that there was no protocol
version check between neoctl and the admin node.
- Proper handling of incoming packets for closed/aborted connections.
- An exception while processing an answer could leave the handler switcher
in the bad state.
- In STOPPING cluster state, really wait for all transaction to be finished.
- Several issues when undoing transactions with conflict resolutions
have been fixed.
- Delayed connection acceptation when the storage node is ready.
- Client:
- Added support for `zodburi`_.
- Fix load error during conflict resolution in case of late invalidation.
- Do not wait tpc_vote to start resolving conflicts.
- Fix harmless 'unexpected ... AnswerRequestIdentification' exceptions.
- Storage:
- New --disable-drop-partitions option, which is useful for big databases
because the current code to delete data of discarded cells is inefficient
(this option should disappear in the future).
- Prevent 2 nodes from working with the same database.
- Discard answers from aborted replications.
In some cases, this led to data corruption or crashes.
- MySQL backend:
- Added support for RocksDB.
- Do not flood logs when retrying to connect non-stop.
- Do not retry a failing query forever.
- By default, do not retry to connect to the server automatically.
- Tools:
- neolog: new --decompress option.
- neolog: new option to hide the node column.
- neoctl: make the identification of the primary master easier with
'print node'.
- A lot of improvements for developers and debugging.
.. _zodburi: https://docs.pylonsproject.org/projects/zodburi
1.7.1 (2017-01-18) 1.7.1 (2017-01-18)
------------------ ------------------
......
...@@ -16,7 +16,7 @@ A NEO cluster is composed of the following types of nodes: ...@@ -16,7 +16,7 @@ A NEO cluster is composed of the following types of nodes:
Stores data, preserving history. All available storage nodes are in use Stores data, preserving history. All available storage nodes are in use
simultaneously. This offers redundancy and data distribution. simultaneously. This offers redundancy and data distribution.
Available backends: MySQL (InnoDB or TokuDB), SQLite Available backends: MySQL (InnoDB, RocksDB or TokuDB), SQLite
- "admin" nodes (mandatory for startup, optional after) - "admin" nodes (mandatory for startup, optional after)
...@@ -38,8 +38,8 @@ Any ZODB like FileStorage can be converted to NEO instantaneously, ...@@ -38,8 +38,8 @@ Any ZODB like FileStorage can be converted to NEO instantaneously,
which means the database is operational before all data are imported. which means the database is operational before all data are imported.
There's also a tool to convert back to FileStorage. There's also a tool to convert back to FileStorage.
See also http://www.neoppod.org/links for more detailed information about For more detailed information about features related to scalability,
features related to scalability. see the `Architecture and Characteristics` section of https://neo.nexedi.com/.
Requirements Requirements
============ ============
...@@ -52,7 +52,7 @@ Requirements ...@@ -52,7 +52,7 @@ Requirements
- MySQLdb: https://github.com/PyMySQL/mysqlclient-python - MySQLdb: https://github.com/PyMySQL/mysqlclient-python
- For client nodes: ZODB 3.10.x - For client nodes: ZODB 3.10.x or later
Installation Installation
============ ============
...@@ -199,7 +199,7 @@ Developers ...@@ -199,7 +199,7 @@ Developers
========== ==========
Developers interested in NEO may refer to Developers interested in NEO may refer to
`NEO Web site <http://www.neoppod.org/>`_ and subscribe to following mailing `NEO Web site <https://neo.nexedi.com/>`_ and subscribe to following mailing
lists: lists:
- `neo-users <http://mail.tiolive.com/mailman/listinfo/neo-users>`_: - `neo-users <http://mail.tiolive.com/mailman/listinfo/neo-users>`_:
...@@ -213,4 +213,4 @@ https://www.erp5.com/quality/integration/P-ERP5.Com.Unit%20Tests/Base_viewListMo ...@@ -213,4 +213,4 @@ https://www.erp5.com/quality/integration/P-ERP5.Com.Unit%20Tests/Base_viewListMo
Commercial Support Commercial Support
================== ==================
Nexedi provides commercial support for NEO: http://www.nexedi.com/ Nexedi provides commercial support for NEO: https://www.nexedi.com/
...@@ -84,6 +84,8 @@ ...@@ -84,6 +84,8 @@
keys (trans.tid & obj.{tid,oid}). keys (trans.tid & obj.{tid,oid}).
Master Master
- Implement back-channel for invalidations in read-only mode,
so that clients of backup clusters are notified of new data.
- Master node data redundancy (HIGH AVAILABILITY) - Master node data redundancy (HIGH AVAILABILITY)
Secondary master nodes should replicate primary master data (ie, primary Secondary master nodes should replicate primary master data (ie, primary
master should inform them of such changes). master should inform them of such changes).
......
...@@ -40,7 +40,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -40,7 +40,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
try: try:
super(PrimaryNotificationsHandler, self).notPrimaryMaster(*args) super(PrimaryNotificationsHandler, self).notPrimaryMaster(*args)
except PrimaryElected, e: except PrimaryElected, e:
app.primary_master_node, = e.args self.app.primary_master_node, = e.args
def _acceptIdentification(self, node, num_partitions, num_replicas): def _acceptIdentification(self, node, num_partitions, num_replicas):
self.app.pt = PartitionTable(num_partitions, num_replicas) self.app.pt = PartitionTable(num_partitions, num_replicas)
......
...@@ -44,7 +44,6 @@ class ConnectionPool(object): ...@@ -44,7 +44,6 @@ class ConnectionPool(object):
app = self.app app = self.app
if app.master_conn is None: if app.master_conn is None:
raise NEOPrimaryMasterLost raise NEOPrimaryMasterLost
logging.debug('trying to connect to %s - %s', node, node.getState())
conn = MTClientConnection(app, app.storage_event_handler, node, conn = MTClientConnection(app, app.storage_event_handler, node,
dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
......
...@@ -101,7 +101,7 @@ if IF == 'pdb': ...@@ -101,7 +101,7 @@ if IF == 'pdb':
def __init__(self, bp_list): def __init__(self, bp_list):
self._lock = threading.Lock() self._lock = threading.Lock()
for o, name in bp_list: for o, name in bp_list:
o = __import__(o, fromlist=1) o = __import__(o, fromlist=('*',), level=0)
x = name.split('.') x = name.split('.')
name = x.pop() name = x.pop()
for x in x: for x in x:
......
...@@ -97,6 +97,9 @@ class ConfigurationManager(object): ...@@ -97,6 +97,9 @@ class ConfigurationManager(object):
bind = self.__get('bind') bind = self.__get('bind')
return parseNodeAddress(bind, 0) return parseNodeAddress(bind, 0)
def getDisableDropPartitions(self):
return self.__get('disable_drop_partitions', True)
def getDatabase(self): def getDatabase(self):
return self.__get('database') return self.__get('database')
......
...@@ -23,16 +23,11 @@ from .locking import RLock ...@@ -23,16 +23,11 @@ from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets from .protocol import uuid_str, Errors, PacketMalformedError, Packets
from .util import dummy_read_buffer, ReadBuffer from .util import dummy_read_buffer, ReadBuffer
CRITICAL_TIMEOUT = 30
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
class HandlerSwitcher(object): class HandlerSwitcher(object):
_is_handling = False _is_handling = False
_next_timeout = None
_next_timeout_msg_id = None
_next_on_timeout = None
_pending = ({}, None), # ( {msgid -> (answer_klass, timeout, on_timeout, kw)}, _pending = ({}, None), # ( {msgid -> (answer_klass, timeout, on_timeout, kw)},
# handler ) # handler )
...@@ -55,7 +50,7 @@ class HandlerSwitcher(object): ...@@ -55,7 +50,7 @@ class HandlerSwitcher(object):
while request_dict: while request_dict:
msg_id, request = request_dict.popitem() msg_id, request = request_dict.popitem()
p.setId(msg_id) p.setId(msg_id)
handler.packetReceived(conn, p, request[3]) handler.packetReceived(conn, p, request[1])
if len(self._pending) == 1: if len(self._pending) == 1:
break break
del self._pending[0] del self._pending[0]
...@@ -67,7 +62,7 @@ class HandlerSwitcher(object): ...@@ -67,7 +62,7 @@ class HandlerSwitcher(object):
""" Return the last (may be unapplied) handler registered """ """ Return the last (may be unapplied) handler registered """
return self._pending[-1][1] return self._pending[-1][1]
def emit(self, request, timeout, on_timeout, kw={}): def emit(self, request, kw={}):
# register the request in the current handler # register the request in the current handler
_pending = self._pending _pending = self._pending
if self._is_handling: if self._is_handling:
...@@ -82,26 +77,7 @@ class HandlerSwitcher(object): ...@@ -82,26 +77,7 @@ class HandlerSwitcher(object):
answer_class = request.getAnswerClass() answer_class = request.getAnswerClass()
assert answer_class is not None, "Not a request" assert answer_class is not None, "Not a request"
assert msg_id not in request_dict, "Packet id already expected" assert msg_id not in request_dict, "Packet id already expected"
next_timeout = self._next_timeout request_dict[msg_id] = answer_class, kw
if next_timeout is None or timeout < next_timeout:
self._next_timeout = timeout
self._next_timeout_msg_id = msg_id
self._next_on_timeout = on_timeout
request_dict[msg_id] = answer_class, timeout, on_timeout, kw
def getNextTimeout(self):
return self._next_timeout
def timeout(self, connection):
msg_id = self._next_timeout_msg_id
if self._next_on_timeout is not None:
self._next_on_timeout(connection, msg_id)
if self._next_timeout_msg_id != msg_id:
# on_timeout sent a packet with a smaller timeout
# so keep the connection open
return
# Notify that a timeout occurred
return msg_id
def handle(self, connection, packet): def handle(self, connection, packet):
assert not self._is_handling assert not self._is_handling
...@@ -128,7 +104,7 @@ class HandlerSwitcher(object): ...@@ -128,7 +104,7 @@ class HandlerSwitcher(object):
request_dict, handler = pending[0] request_dict, handler = pending[0]
# checkout the expected answer class # checkout the expected answer class
try: try:
klass, _, _, kw = request_dict.pop(msg_id) klass, kw = request_dict.pop(msg_id)
except KeyError: except KeyError:
klass = None klass = None
kw = {} kw = {}
...@@ -147,18 +123,6 @@ class HandlerSwitcher(object): ...@@ -147,18 +123,6 @@ class HandlerSwitcher(object):
del pending[0] del pending[0]
logging.debug('Apply handler %r on %r', pending[0][1], logging.debug('Apply handler %r on %r', pending[0][1],
connection) connection)
if msg_id == self._next_timeout_msg_id:
self._updateNextTimeout()
def _updateNextTimeout(self):
# Find next timeout and its msg_id
next_timeout = None
for pending in self._pending:
for msg_id, (_, timeout, on_timeout, _) in pending[0].iteritems():
if not next_timeout or timeout < next_timeout[0]:
next_timeout = timeout, msg_id, on_timeout
self._next_timeout, self._next_timeout_msg_id, self._next_on_timeout = \
next_timeout or (None, None, None)
def setHandler(self, handler): def setHandler(self, handler):
can_apply = len(self._pending) == 1 and not self._pending[0][0] can_apply = len(self._pending) == 1 and not self._pending[0][0]
...@@ -176,24 +140,33 @@ class BaseConnection(object): ...@@ -176,24 +140,33 @@ class BaseConnection(object):
About timeouts: About timeouts:
Timeout are mainly per-connection instead of per-packet. In the past, ask() took a timeout parameter as a way to close the
The idea is that most of time, packets are received and processed connection if the remote node was too long to reply, with the idea
sequentially, so if it takes a long for a peer to process a packet, that something went wrong. There was no known bug but this feature was
following packets would just be enqueued. actually a bad idea.
What really matters is that the peer makes progress in its work.
As long as we receive an answer, we consider it's still alive and It is impossible to test whether the remote node is in good state or
it may just have started to process the following request. So we reset not. The experience shows that timeouts were always triggered because
timeouts. the remote nodes were simply too slow. Waiting remains the best option
There is anyway nothing more we could do, because processing of a packet and anything else would only make things worse.
may be delayed in a very unpredictable way depending of previously
received packets on peer side. The only case where it could make sense to react on a slow request is
Even ourself may be slow to receive a packet. We must not timeout for when there is redundancy, more exactly for read requests to storage
an answer that is already in our incoming buffer (read_buf or _queue). nodes when there are replicas. A client node could resend its request
Timeouts in HandlerSwitcher are only there to prioritize some packets. to another node, _without_ breaking the first connection (then wait for
the first reply and ignore the other).
The previous timeout implementation (before May 2017) was not well
suited to support the above use case so most of the code has been
removed, but it may contain some interesting parts.
Currently, since applicative pings have been replaced by TCP
keepalives, timeouts are only used for 2 things:
- to avoid reconnecting too fast
- to close idle client connections
""" """
from .connector import SocketConnector as ConnectorClass from .connector import SocketConnector as ConnectorClass
KEEP_ALIVE = 60
def __init__(self, event_manager, handler, connector, addr=None): def __init__(self, event_manager, handler, connector, addr=None):
assert connector is not None, "Need a low-level connector" assert connector is not None, "Need a low-level connector"
...@@ -294,9 +267,6 @@ class BaseConnection(object): ...@@ -294,9 +267,6 @@ class BaseConnection(object):
""" """
return attributeTracker.whoSet(self, 'connector') return attributeTracker.whoSet(self, 'connector')
def idle(self):
pass
attributeTracker.track(BaseConnection) attributeTracker.track(BaseConnection)
...@@ -340,9 +310,8 @@ class Connection(BaseConnection): ...@@ -340,9 +310,8 @@ class Connection(BaseConnection):
client = False client = False
server = False server = False
peer_id = None peer_id = None
_next_timeout = None
_parser_state = None _parser_state = None
_timeout = 0 _timeout = None
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
...@@ -376,10 +345,11 @@ class Connection(BaseConnection): ...@@ -376,10 +345,11 @@ class Connection(BaseConnection):
def asClient(self): def asClient(self):
# TODO adjust .cur_id % 2 to be as client # TODO adjust .cur_id % 2 to be as client
try: try:
del self.idle del self._timeout
assert self.client
except AttributeError: except AttributeError:
self.client = True self.client = True
else:
assert self.client
def asServer(self): def asServer(self):
# TODO adjust .cur_id % 2 to be as server # TODO adjust .cur_id % 2 to be as server
...@@ -387,15 +357,21 @@ class Connection(BaseConnection): ...@@ -387,15 +357,21 @@ class Connection(BaseConnection):
def _closeClient(self): def _closeClient(self):
if self.server: if self.server:
del self.idle del self._timeout
self.client = False self.client = False
self.send(Packets.CloseClient()) self.send(Packets.CloseClient())
else: else:
self.close() self.close()
def closeClient(self): def closeClient(self):
# Currently, the only usage that is really useful is between a backup
# storage node and an upstream one, to avoid:
# - maintaining many connections for nothing when there's no write
# activity for a long time (and waste resources with keepalives)
# - reconnecting too often (i.e. be reactive) when there's moderate
# activity (think of a timer with a period of 1 minute)
if self.connector is not None and self.client: if self.connector is not None and self.client:
self.idle = self._closeClient self._timeout = time() + 100
def isAborted(self): def isAborted(self):
return self.aborted return self.aborted
...@@ -418,34 +394,13 @@ class Connection(BaseConnection): ...@@ -418,34 +394,13 @@ class Connection(BaseConnection):
self.cur_id = (next_id + 2) & 0xffffffff self.cur_id = (next_id + 2) & 0xffffffff
return next_id return next_id
def updateTimeout(self, t=None):
if not self._queue:
if not t:
t = self._next_timeout - self._timeout
self._timeout = self._handlers.getNextTimeout() or self.KEEP_ALIVE
self._next_timeout = t + self._timeout
def getTimeout(self): def getTimeout(self):
if not self._queue: if not self._queue:
return self._next_timeout return self._timeout
def onTimeout(self): def onTimeout(self):
handlers = self._handlers assert self._timeout
if handlers.isPending(): self._closeClient()
# It is possible that another thread used ask() while getting a
# timeout from epoll, so we must check again the value of
# _next_timeout (we know that _queue is still empty).
# Although this test is only useful for MTClientConnection,
# it's not worth complicating the code more.
if self._next_timeout <= time():
msg_id = handlers.timeout(self)
if msg_id is None:
self._next_timeout = time() + self._timeout
else:
logging.info('timeout for #0x%08x with %r', msg_id, self)
self.close()
else:
self.idle()
def abort(self): def abort(self):
"""Abort dealing with this connection.""" """Abort dealing with this connection."""
...@@ -514,7 +469,6 @@ class Connection(BaseConnection): ...@@ -514,7 +469,6 @@ class Connection(BaseConnection):
def readable(self): def readable(self):
"""Called when self is readable.""" """Called when self is readable."""
# last known remote activity # last known remote activity
self._next_timeout = time() + self._timeout
try: try:
try: try:
if self.connector.receive(self.read_buf): if self.connector.receive(self.read_buf):
...@@ -545,10 +499,7 @@ class Connection(BaseConnection): ...@@ -545,10 +499,7 @@ class Connection(BaseConnection):
Process a pending packet. Process a pending packet.
""" """
# check out packet and process it with current handler # check out packet and process it with current handler
try: self._handlers.handle(self, self._queue.pop(0))
self._handlers.handle(self, self._queue.pop(0))
finally:
self.updateTimeout()
def pending(self): def pending(self):
connector = self.connector connector = self.connector
...@@ -605,7 +556,7 @@ class Connection(BaseConnection): ...@@ -605,7 +556,7 @@ class Connection(BaseConnection):
packet.setId(self._getNextId() if msg_id is None else msg_id) packet.setId(self._getNextId() if msg_id is None else msg_id)
self._addPacket(packet) self._addPacket(packet)
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, **kw): def ask(self, packet, **kw):
""" """
Send a packet with a new ID and register the expectation of an answer Send a packet with a new ID and register the expectation of an answer
""" """
...@@ -614,14 +565,7 @@ class Connection(BaseConnection): ...@@ -614,14 +565,7 @@ class Connection(BaseConnection):
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
self._addPacket(packet) self._addPacket(packet)
handlers = self._handlers self._handlers.emit(packet, kw)
t = None if handlers.isPending() else time()
handlers.emit(packet, timeout, on_timeout, kw)
if not self._queue:
next_timeout = self._next_timeout
self.updateTimeout(t)
if self._next_timeout < next_timeout:
self.em.wakeup()
return msg_id return msg_id
def answer(self, packet): def answer(self, packet):
...@@ -634,9 +578,6 @@ class Connection(BaseConnection): ...@@ -634,9 +578,6 @@ class Connection(BaseConnection):
packet.setId(self.peer_id) packet.setId(self.peer_id)
self._addPacket(packet) self._addPacket(packet)
def idle(self):
self.ask(Packets.Ping())
def _connected(self): def _connected(self):
self.connecting = False self.connecting = False
self.getHandler().connectionCompleted(self) self.getHandler().connectionCompleted(self)
...@@ -688,7 +629,6 @@ class ClientConnection(Connection): ...@@ -688,7 +629,6 @@ class ClientConnection(Connection):
def _maybeConnected(self): def _maybeConnected(self):
self.writable = self.lockWrapper(super(ClientConnection, self).writable) self.writable = self.lockWrapper(super(ClientConnection, self).writable)
self.updateTimeout(time())
if self._ssl: if self._ssl:
self.connector.ssl(self._ssl, self._connected) self.connector.ssl(self._ssl, self._connected)
else: else:
...@@ -698,20 +638,12 @@ class ClientConnection(Connection): ...@@ -698,20 +638,12 @@ class ClientConnection(Connection):
class ServerConnection(Connection): class ServerConnection(Connection):
"""A connection from a remote node to this node.""" """A connection from a remote node to this node."""
# Both server and client must check the connection, in case:
# - the remote crashed brutally (i.e. without closing TCP connections)
# - or packets sent by the remote are dropped (network failure)
# Use different timeout so that in normal condition, server never has to
# ping the client. Otherwise, it would do it about half of the time.
KEEP_ALIVE = Connection.KEEP_ALIVE + 5
server = True server = True
cur_id = 0 # cur_id % 2 is 0 for server initated "streams" cur_id = 0 # cur_id % 2 is 0 for server initated "streams"
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
Connection.__init__(self, *args, **kw) Connection.__init__(self, *args, **kw)
self.em.register(self) self.em.register(self)
self.updateTimeout(time())
class MTConnectionType(type): class MTConnectionType(type):
...@@ -770,14 +702,36 @@ class MTClientConnection(ClientConnection): ...@@ -770,14 +702,36 @@ class MTClientConnection(ClientConnection):
# Alias without lock (cheaper than super()) # Alias without lock (cheaper than super())
_ask = ClientConnection.ask.__func__ _ask = ClientConnection.ask.__func__
def ask(self, packet, timeout=CRITICAL_TIMEOUT, on_timeout=None, def ask(self, packet, queue=None, **kw):
queue=None, **kw):
with self.lock: with self.lock:
if queue is None: if queue is None:
if type(packet) is Packets.Ping: if type(packet) is Packets.Ping:
return self._ask(packet, timeout, on_timeout, **kw) return self._ask(packet, **kw)
raise TypeError('Only Ping packet can be asked' raise TypeError('Only Ping packet can be asked'
' without a queue, got a %r.' % packet) ' without a queue, got a %r.' % packet)
msg_id = self._ask(packet, timeout, on_timeout, **kw) msg_id = self._ask(packet, **kw)
self.dispatcher.register(self, msg_id, queue) self.dispatcher.register(self, msg_id, queue)
return msg_id return msg_id
# Currently, on connected connections, we only use timeouts for
# closeClient, which is never used for MTClientConnection.
# So we disable the logic completely as a precaution, and for performance.
# What is specific to MTClientConnection is that the poll thread must be
# woken up whenever the timeout is changed to a smaller value.
def closeClient(self):
# For example here, in addition to what the super method does,
# we may have to call `self.em.wakeup()`
raise NotImplementedError
def getTimeout(self):
pass
def onTimeout(self):
# It is possible that another thread manipulated the connection while
# getting a timeout from epoll. Only the poll thread fills _queue
# so we know that it is empty, but we may have to check timeout values
# again (i.e. compare time() with the result of getTimeout()).
raise NotImplementedError
###
...@@ -57,6 +57,18 @@ class SocketConnector(object): ...@@ -57,6 +57,18 @@ class SocketConnector(object):
self.socket_fd = s.fileno() self.socket_fd = s.fileno()
# always use non-blocking sockets # always use non-blocking sockets
s.setblocking(0) s.setblocking(0)
# TCP keepalive, enabled on both sides to detect:
# - remote host crash
# - network failure
# They're more efficient than applicative pings and we don't want
# to consider the connection dead if the remote node is busy.
# The following 3 lines are specific to Linux. It seems that OSX
# has similar options (TCP_KEEPALIVE/TCP_KEEPINTVL/TCP_KEEPCNT),
# and Windows has SIO_KEEPALIVE_VALS (fixed count of 10).
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION] self.queued = [ENCODED_VERSION]
......
...@@ -194,8 +194,6 @@ class EventHandler(object): ...@@ -194,8 +194,6 @@ class EventHandler(object):
conn.answer(Packets.Pong()) conn.answer(Packets.Pong())
def pong(self, conn): def pong(self, conn):
# Ignore PONG packets. The only purpose of ping/pong packets is
# to test/maintain underlying connection.
pass pass
def closeClient(self, conn): def closeClient(self, conn):
......
...@@ -174,8 +174,9 @@ class AdministrationHandler(MasterHandler): ...@@ -174,8 +174,9 @@ class AdministrationHandler(MasterHandler):
ClusterStates.BACKINGUP): ClusterStates.BACKINGUP):
raise ProtocolError('Can not tweak partition table in %s state' raise ProtocolError('Can not tweak partition table in %s state'
% state) % state)
app.broadcastPartitionChanges(app.pt.tweak( app.broadcastPartitionChanges(app.pt.tweak([node
map(app.nm.getByUUID, uuid_list))) for node in app.nm.getStorageList()
if node.getUUID() in uuid_list or not node.isRunning()]))
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
def truncate(self, conn, tid): def truncate(self, conn, tid):
......
...@@ -69,7 +69,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -69,7 +69,7 @@ class ClientServiceHandler(MasterHandler):
if tid: if tid:
p = Packets.AskLockInformation(ttid, tid) p = Packets.AskLockInformation(ttid, tid)
for node in node_list: for node in node_list:
node.ask(p, timeout=60) # NOTE node.ask(p)
# NOTE continues in onTransactionCommitted # NOTE continues in onTransactionCommitted
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import defaultdict from collections import Counter, defaultdict
import neo.lib.pt import neo.lib.pt
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import CellStates, ZERO_TID from neo.lib.protocol import CellStates, ZERO_TID
...@@ -43,16 +43,6 @@ class Cell(neo.lib.pt.Cell): ...@@ -43,16 +43,6 @@ class Cell(neo.lib.pt.Cell):
neo.lib.pt.Cell = Cell neo.lib.pt.Cell = Cell
class MappedNode(object):
def __init__(self, node):
self.node = node
self.assigned = set()
def __getattr__(self, attr):
return getattr(self.node, attr)
class PartitionTable(neo.lib.pt.PartitionTable): class PartitionTable(neo.lib.pt.PartitionTable):
"""This class manages a partition table for the primary master node""" """This class manages a partition table for the primary master node"""
...@@ -68,32 +58,14 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -68,32 +58,14 @@ class PartitionTable(neo.lib.pt.PartitionTable):
def make(self, node_list): def make(self, node_list):
"""Make a new partition table from scratch.""" """Make a new partition table from scratch."""
# start with the first PTID assert self._id is None and node_list, (self._id, node_list)
self._id = 1 for node in node_list:
# First, filter the list of nodes. assert node.isRunning() and node.getUUID() is not None, node
node_list = [n for n in node_list if n.isRunning() \ self.addNodeList(node_list)
and n.getUUID() is not None] self.tweak()
if len(node_list) == 0: for node, count in self.count_dict.items():
# Impossible. if not count:
raise RuntimeError, 'cannot make a partition table with an ' \ del self.count_dict[node]
'empty storage node list'
# Take it into account that the number of storage nodes may be less
# than the number of replicas.
repeats = min(self.nr + 1, len(node_list))
index = 0
for offset in xrange(self.np):
row = []
for _ in xrange(repeats):
node = node_list[index]
row.append(Cell(node))
self.count_dict[node] = self.count_dict.get(node, 0) + 1
index += 1
if index == len(node_list):
index = 0
self.partition_list[offset] = row
self.num_filled_rows = self.np
def dropNodeList(self, node_list, simulate=False): def dropNodeList(self, node_list, simulate=False):
partition_list = [] partition_list = []
...@@ -161,8 +133,9 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -161,8 +133,9 @@ class PartitionTable(neo.lib.pt.PartitionTable):
def setUpToDate(self, node, offset): def setUpToDate(self, node, offset):
"""Set a cell as up-to-date""" """Set a cell as up-to-date"""
uuid = node.getUUID() uuid = node.getUUID()
# check the partition is assigned and known as outdated # Check the partition is assigned and known as outdated.
for cell in self.getCellList(offset): row = self.partition_list[offset]
for cell in row:
if cell.getUUID() == uuid: if cell.getUUID() == uuid:
if cell.isOutOfDate() and cell.updatable: if cell.isOutOfDate() and cell.updatable:
break break
...@@ -170,17 +143,26 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -170,17 +143,26 @@ class PartitionTable(neo.lib.pt.PartitionTable):
else: else:
raise neo.lib.pt.PartitionTableException('Non-assigned partition') raise neo.lib.pt.PartitionTableException('Non-assigned partition')
# update the partition table # Update the partition table.
self._setCell(offset, node, CellStates.UP_TO_DATE) self._setCell(offset, node, CellStates.UP_TO_DATE)
cell_list = [(offset, uuid, CellStates.UP_TO_DATE)] cell_list = [(offset, uuid, CellStates.UP_TO_DATE)]
# If the partition contains a feeding cell, drop it now. # Do no keep too many feeding cells.
for feeding_cell in self.getCellList(offset): readable_list = filter(Cell.isReadable, row)
if feeding_cell.isFeeding(): iter_feeding = (cell.getNode() for cell in readable_list
node = feeding_cell.getNode() if cell.isFeeding())
self.removeCell(offset, node) # If all cells are readable, we can now drop all feeding cells.
cell_list.append((offset, node.getUUID(), CellStates.DISCARDED)) if len(readable_list) != len(row):
break # Else we normally discard at most 1 cell. In the case that cells
# became non-readable since the last tweak, we want to avoid going
# below the wanted number of replicas. Also first try to discard
# feeding cells from nodes that it was decided to drop.
iter_feeding = sorted(iter_feeding, key=lambda node: not all(
cell.isFeeding() for _, cell in self.iterNodeCell(node)
))[:max(0, len(readable_list) - self.nr)]
for node in iter_feeding:
self.removeCell(offset, node)
cell_list.append((offset, node.getUUID(), CellStates.DISCARDED))
return cell_list return cell_list
...@@ -196,87 +178,193 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -196,87 +178,193 @@ class PartitionTable(neo.lib.pt.PartitionTable):
def tweak(self, drop_list=()): def tweak(self, drop_list=()):
"""Optimize partition table """Optimize partition table
This is done by computing a minimal diff between current partition table This reassigns cells in 3 ways:
and what make() would do. - Discard cells of nodes listed in 'drop_list'. For partitions with too
few readable cells, some cells are instead marked as FEEDING. This is
a preliminary step to drop these nodes, otherwise the partition table
could become non-operational.
- Other nodes must have the same number of cells, off by 1.
- When a transaction creates new objects (oids are roughly allocated
sequentially), we expect better performance by maximizing the number
of involved nodes (i.e. parallelizing writes).
Examples of optimal partition tables with np=10, nr=1 and 5 nodes:
UU... ..UU.
..UU. U...U
U...U .UU..
.UU.. ...UU
...UU UU...
UU... ..UU.
..UU. U...U
U...U .UU..
.UU.. ...UU
...UU UU...
The above 2 PT only differ by permutation of nodes, and this method
plays on it to minimize the resulting amount of replication.
For performance reasons, this algorithm uses a heuristic.
When (np * nr) is not a multiple of the number of nodes, some nodes
have 1 extra cell compared to other. In such case, other optimal PT
could be considered by rotation of the partitions. Actually np times
more, but it's not worth it since they don't differ enough (if np is
big enough) and we don't already do an exhaustive search.
Example with np=3, nr=1 and 2 nodes:
U. .U U.
.U U. U.
U. U. .U
""" """
assigned_dict = {x: {} for x in self.count_dict} # Collect some data in a usable form for the rest of the method.
readable_list = [set() for x in xrange(self.np)] node_list = {node: {} for node in self.count_dict
if node not in drop_list}
drop_list = defaultdict(list)
for offset, row in enumerate(self.partition_list): for offset, row in enumerate(self.partition_list):
for cell in row: for cell in row:
if cell.isReadable(): cell_dict = node_list.get(cell.getNode())
readable_list[offset].add(cell) if cell_dict is None:
assigned_dict[cell.getNode()][offset] = cell drop_list[offset].append(cell)
pt = PartitionTable(self.np, self.nr)
drop_list = set(drop_list).intersection(assigned_dict)
node_set = {MappedNode(x) for x in assigned_dict
if x not in drop_list}
pt.make(node_set)
for offset, row in enumerate(pt.partition_list):
for cell in row:
if cell.isReadable():
cell.getNode().assigned.add(offset)
def map_nodes():
node_list = []
for node, assigned in assigned_dict.iteritems():
if node in drop_list:
yield node, frozenset()
continue
readable = {offset for offset, cell in assigned.iteritems()
if cell.isReadable()}
# the criterion on UUID is purely cosmetic
node_list.append((len(readable), len(assigned),
-node.getUUID(), readable, node))
node_list.sort(reverse=1)
for _, _, _, readable, node in node_list:
assigned = assigned_dict[node]
mapped = min(node_set, key=lambda m: (
len(m.assigned.symmetric_difference(assigned)),
len(m.assigned ^ readable)))
node_set.remove(mapped)
yield node, mapped.assigned
assert not node_set
changed_list = []
uptodate_set = set()
remove_dict = defaultdict(list)
for node, mapped in map_nodes():
uuid = node.getUUID()
assigned = assigned_dict[node]
for offset, cell in assigned.iteritems():
if offset in mapped:
if cell.isReadable():
uptodate_set.add(offset)
readable_list[offset].remove(cell)
if cell.isFeeding():
self.count_dict[node] += 1
state = CellStates.UP_TO_DATE
cell.setState(state)
changed_list.append((offset, uuid, state))
else: else:
if not cell.isFeeding(): cell_dict[offset] = cell
self.count_dict[node] -= 1 # The sort by node id is cosmetic, to prefer result like the first one
remove_dict[offset].append(cell) # in __doc__.
for offset in mapped.difference(assigned): node_list = sorted(node_list.iteritems(), key=lambda x: x[0].getUUID())
self.count_dict[node] += 1
state = CellStates.OUT_OF_DATE # Generate an optimal PT.
self.partition_list[offset].append(Cell(node, state)) node_count = len(node_list)
changed_list.append((offset, uuid, state)) repeats = min(self.nr + 1, node_count)
count_dict = self.count_dict.copy() x = [[] for _ in xrange(node_count)]
for offset, cell_list in remove_dict.iteritems(): i = 0
for offset in xrange(self.np):
for _ in xrange(repeats):
x[i % node_count].append(offset)
i += 1
option_dict = Counter(map(tuple, x))
# Strategies to find the "best" permutation of nodes.
def node_options():
# The second part of the key goes with the above cosmetic sort.
option_list = sorted(option_dict, key=lambda x: (-len(x), x))
# 1. Search for solution that does not cause extra replication.
# This is important because tweak() must does nothing if it's
# called a second time whereas the list of nodes hasn't changed.
result = []
for i, (_, cell_dict) in enumerate(node_list):
option = {offset for offset, cell in cell_dict.iteritems()
if not cell.isFeeding()}
x = filter(option.issubset, option_list)
if not x:
break
result.append((i, x))
else:
yield result
# 2. We have to move cells. Evaluating all options would have
# a complexity of O(node_count!), which is clearly too slow,
# so we use a heuristic.
# For each node, we compare the resulting amount of replication
# in the best (min_cost) and worst (max_cost) case, and we first
# iterate over nodes with the biggest difference. This minimizes
# the impact of bad allocation patterns for the last nodes.
result = []
np_complement = frozenset(xrange(self.np)).difference
for i, (_, cell_dict) in enumerate(node_list):
cost_list = []
for x, option in enumerate(option_list):
discard = [0, 0]
for offset in np_complement(option):
cell = cell_dict.get(offset)
if cell:
discard[cell.isReadable()] += 1
cost_list.append(((discard[1], discard[0]), x))
cost_list.sort()
min_cost = cost_list[0][0]
max_cost = cost_list[-1][0]
result.append((
min_cost[0] - max_cost[0],
min_cost[1] - max_cost[1],
i, [option_list[x[1]] for x in cost_list]))
result.sort()
yield result
# The main loop, which is where we evaluate options.
new = [] # the solution
stack = [] # data recursion
def options():
return iter(node_options[len(new)][-1])
for node_options in node_options(): # for each strategy
iter_option = options()
while 1:
try:
option = next(iter_option)
except StopIteration: # 1st strategy only
if new:
iter_option = stack.pop()
option_dict[new.pop()] += 1
continue
break
if option_dict[option]:
new.append(option)
if len(new) == len(node_list):
break
stack.append(iter_option)
iter_option = options()
option_dict[option] -= 1
if new:
break
else:
raise AssertionError
# Apply the solution.
if self._id is None:
self._id = 1
self.num_filled_rows = self.np
new_state = CellStates.UP_TO_DATE
else:
new_state = CellStates.OUT_OF_DATE
changed_list = []
outdated_list = [repeats] * self.np
discard_list = defaultdict(list)
for i, offset_list in enumerate(new):
node, cell_dict = node_list[node_options[i][-2]]
for offset in offset_list:
cell = cell_dict.pop(offset, None)
if cell is None:
self.count_dict[node] += 1
self.partition_list[offset].append(Cell(node, new_state))
changed_list.append((offset, node.getUUID(), new_state))
elif cell.isReadable():
if cell.isFeeding():
cell.setState(CellStates.UP_TO_DATE)
changed_list.append((offset, node.getUUID(),
CellStates.UP_TO_DATE))
outdated_list[offset] -= 1
for offset, cell in cell_dict.iteritems():
discard_list[offset].append(cell)
for offset, drop_list in drop_list.iteritems():
discard_list[offset] += drop_list
# We have sorted cells to discard in order to first deallocate nodes
# in drop_list, and have feeding cells in other nodes.
# The following loop also makes sure not to discard cells too quickly,
# by keeping a minimum of 'repeats' readable cells.
for offset, outdated in enumerate(outdated_list):
row = self.partition_list[offset] row = self.partition_list[offset]
feeding = None if offset in uptodate_set else min( for cell in discard_list[offset]:
readable_list[offset], key=lambda x: count_dict[x.getNode()]) if outdated and cell.isReadable():
for cell in cell_list: outdated -= 1
if cell is feeding:
count_dict[cell.getNode()] += 1
if cell.isFeeding(): if cell.isFeeding():
continue continue
state = CellStates.FEEDING state = CellStates.FEEDING
cell.setState(state) cell.setState(state)
else: else:
self.count_dict[cell.getNode()] -= 1
state = CellStates.DISCARDED state = CellStates.DISCARDED
row.remove(cell) row.remove(cell)
changed_list.append((offset, cell.getUUID(), state)) changed_list.append((offset, cell.getUUID(), state))
assert self.num_filled_rows == len(filter(None, self.partition_list))
assert self.operational(), changed_list
return changed_list return changed_list
def outdate(self, lost_node=None): def outdate(self, lost_node=None):
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
import bz2, gzip, errno, optparse, os, signal, sqlite3, sys, time import bz2, gzip, errno, optparse, os, signal, sqlite3, sys, time
from bisect import insort from bisect import insort
from logging import getLevelName from logging import getLevelName
from zlib import decompress
comp_dict = dict(bz2=bz2.BZ2File, gz=gzip.GzipFile) comp_dict = dict(bz2=bz2.BZ2File, gz=gzip.GzipFile)
...@@ -28,11 +29,12 @@ class Log(object): ...@@ -28,11 +29,12 @@ class Log(object):
_log_id = _packet_id = -1 _log_id = _packet_id = -1
_protocol_date = None _protocol_date = None
def __init__(self, db_path, decode_all=False, date_format=None, def __init__(self, db_path, decode=0, date_format=None,
filter_from=None, node_list=None): filter_from=None, node_column=True, node_list=None):
self._date_format = '%F %T' if date_format is None else date_format self._date_format = '%F %T' if date_format is None else date_format
self._decode_all = decode_all self._decode = decode
self._filter_from = filter_from self._filter_from = filter_from
self._node_column = node_column
self._node_list = node_list self._node_list = node_list
name = os.path.basename(db_path) name = os.path.basename(db_path)
try: try:
...@@ -93,6 +95,30 @@ class Log(object): ...@@ -93,6 +95,30 @@ class Log(object):
exec bz2.decompress(text) in g exec bz2.decompress(text) in g
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {}
if self._decode > 1:
PStruct = g['PStruct']
PBoolean = g['PBoolean']
def hasData(item):
items = item._items
for i, item in enumerate(items):
if isinstance(item, PStruct):
j = hasData(item)
if j:
return (i,) + j
elif (isinstance(item, PBoolean)
and item._name == 'compression'
and i + 2 < len(items)
and items[i+2]._name == 'data'):
return i,
for p in self.Packets.itervalues():
if p._fmt is not None:
path = hasData(p._fmt)
if path:
assert not hasattr(p, '_neolog'), p
x[p._code] = path
self._getDataPath = x.get
try: try:
self._next_protocol, = q("SELECT date FROM protocol WHERE date>?", self._next_protocol, = q("SELECT date FROM protocol WHERE date>?",
(date,)).next() (date,)).next()
...@@ -109,7 +135,8 @@ class Log(object): ...@@ -109,7 +135,8 @@ class Log(object):
d = int(date) d = int(date)
prefix = '%s.%04u ' % (time.strftime(prefix, time.localtime(d)), prefix = '%s.%04u ' % (time.strftime(prefix, time.localtime(d)),
int((date - d) * 10000)) int((date - d) * 10000))
prefix += '%-9s %-10s ' % (levelname, name) prefix += ('%-9s %-10s ' % (levelname, name) if self._node_column else
'%-9s ' % levelname)
for msg in msg_list: for msg in msg_list:
print prefix + msg print prefix + msg
...@@ -126,7 +153,7 @@ class Log(object): ...@@ -126,7 +153,7 @@ class Log(object):
msg = ['#0x%04x %-30s %s' % (msg_id, msg, peer)] msg = ['#0x%04x %-30s %s' % (msg_id, msg, peer)]
if body is not None: if body is not None:
log = getattr(p, '_neolog', None) log = getattr(p, '_neolog', None)
if log or self._decode_all: if log or self._decode:
p = p() p = p()
p._id = msg_id p._id = msg_id
p._body = body p._body = body
...@@ -138,10 +165,28 @@ class Log(object): ...@@ -138,10 +165,28 @@ class Log(object):
if log: if log:
args, extra = log(*args) args, extra = log(*args)
msg += extra msg += extra
if args and self._decode_all: else:
path = self._getDataPath(code)
if path:
args = self._decompress(args, path)
if args and self._decode:
msg[0] += ' \t| ' + repr(args) msg[0] += ' \t| ' + repr(args)
return date, name, 'PACKET', msg return date, name, 'PACKET', msg
def _decompress(self, args, path):
if args:
args = list(args)
i = path[0]
path = path[1:]
if path:
args[i] = self._decompress(args[i], path)
else:
data = args[i+2]
if args[i]:
data = decompress(data)
args[i:i+3] = (len(data), data),
return tuple(args)
def emit_many(log_list): def emit_many(log_list):
log_list = [(log, iter(log).next) for log in log_list] log_list = [(log, iter(log).next) for log in log_list]
...@@ -179,7 +224,9 @@ def emit_many(log_list): ...@@ -179,7 +224,9 @@ def emit_many(log_list):
def main(): def main():
parser = optparse.OptionParser() parser = optparse.OptionParser()
parser.add_option('-a', '--all', action="store_true", parser.add_option('-a', '--all', action="store_true",
help='decode all packets') help='decode body of packets')
parser.add_option('-A', '--decompress', action="store_true",
help='decompress data when decode body of packets (implies --all)')
parser.add_option('-d', '--date', metavar='FORMAT', parser.add_option('-d', '--date', metavar='FORMAT',
help='custom date format, according to strftime(3)') help='custom date format, according to strftime(3)')
parser.add_option('-f', '--follow', action="store_true", parser.add_option('-f', '--follow', action="store_true",
...@@ -189,7 +236,8 @@ def main(): ...@@ -189,7 +236,8 @@ def main():
' seconds (see -s)', metavar='PID') ' seconds (see -s)', metavar='PID')
parser.add_option('-n', '--node', action="append", parser.add_option('-n', '--node', action="append",
help='only show log entries from the given node' help='only show log entries from the given node'
' (only useful for logs produced by threaded tests)') ' (only useful for logs produced by threaded tests),'
" special value '-' hides the column")
parser.add_option('-s', '--sleep-interval', type="float", default=1, parser.add_option('-s', '--sleep-interval', type="float", default=1,
help='with -f, sleep for approximately N seconds (default 1.0)' help='with -f, sleep for approximately N seconds (default 1.0)'
' between iterations', metavar='N') ' between iterations', metavar='N')
...@@ -204,8 +252,15 @@ def main(): ...@@ -204,8 +252,15 @@ def main():
filter_from = options.filter_from filter_from = options.filter_from
if filter_from and filter_from < 0: if filter_from and filter_from < 0:
filter_from += time.time() filter_from += time.time()
log_list = [Log(db_path, options.all, options.date, filter_from, node_list = options.node or []
options.node) try:
node_list.remove('-')
node_column = False
except ValueError:
node_column = True
log_list = [Log(db_path,
2 if options.decompress else 1 if options.all else 0,
options.date, filter_from, node_column, node_list)
for db_path in args] for db_path in args]
if options.follow: if options.follow:
try: try:
......
...@@ -30,6 +30,11 @@ parser.add_option('-d', '--database', help = 'database connections string') ...@@ -30,6 +30,11 @@ parser.add_option('-d', '--database', help = 'database connections string')
parser.add_option('-e', '--engine', help = 'database engine') parser.add_option('-e', '--engine', help = 'database engine')
parser.add_option('-w', '--wait', help='seconds to wait for backend to be ' parser.add_option('-w', '--wait', help='seconds to wait for backend to be '
'available, before erroring-out (-1 = infinite)', type='float', default=0) 'available, before erroring-out (-1 = infinite)', type='float', default=0)
parser.add_option('--disable-drop-partitions', action='store_true',
help = 'do not delete data of discarded cells, which is'
' useful for big databases because the current'
' implementation is inefficient (this option should'
' disappear in the future)')
parser.add_option('--reset', action='store_true', parser.add_option('--reset', action='store_true',
help='remove an existing database if any, and exit') help='remove an existing database if any, and exit')
......
...@@ -42,7 +42,6 @@ from neo.tests.benchmark import BenchmarkRunner ...@@ -42,7 +42,6 @@ from neo.tests.benchmark import BenchmarkRunner
# each of them have to import its TestCase classes # each of them have to import its TestCase classes
UNIT_TEST_MODULES = [ UNIT_TEST_MODULES = [
# generic parts # generic parts
'neo.tests.testConnection',
'neo.tests.testHandler', 'neo.tests.testHandler',
'neo.tests.testNodes', 'neo.tests.testNodes',
'neo.tests.testUtil', 'neo.tests.testUtil',
...@@ -174,7 +173,7 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -174,7 +173,7 @@ class NeoTestRunner(unittest.TextTestResult):
exclude != fnmatchcase(test_module, only)): exclude != fnmatchcase(test_module, only)):
continue continue
try: try:
test_module = __import__(test_module, globals(), locals(), ['*']) test_module = __import__(test_module, fromlist=('*',), level=0)
except ImportError, err: except ImportError, err:
self.failedImports[test_module] = err self.failedImports[test_module] = err
print "Import of %s failed : %s" % (test_module, err) print "Import of %s failed : %s" % (test_module, err)
......
...@@ -48,6 +48,7 @@ class Application(BaseApplication): ...@@ -48,6 +48,7 @@ class Application(BaseApplication):
self.dm = buildDatabaseManager(config.getAdapter(), self.dm = buildDatabaseManager(config.getAdapter(),
(config.getDatabase(), config.getEngine(), config.getWait()), (config.getDatabase(), config.getEngine(), config.getWait()),
) )
self.disable_drop_partitions = config.getDisableDropPartitions()
# load master nodes # load master nodes
for master_address in config.getMasters(): for master_address in config.getMasters():
......
...@@ -29,8 +29,7 @@ def getAdapterKlass(name): ...@@ -29,8 +29,7 @@ def getAdapterKlass(name):
module, name = DATABASE_MANAGER_DICT[name or 'MySQL'].split('.') module, name = DATABASE_MANAGER_DICT[name or 'MySQL'].split('.')
except KeyError: except KeyError:
raise DatabaseFailure('Cannot find a database adapter <%s>' % name) raise DatabaseFailure('Cannot find a database adapter <%s>' % name)
module = getattr(__import__(__name__, fromlist=[module], level=1), module) return getattr(__import__(module, globals(), level=1), name)
return getattr(module, name)
def buildDatabaseManager(name, args=(), kw={}): def buildDatabaseManager(name, args=(), kw={}):
return getAdapterKlass(name)(*args, **kw) return getAdapterKlass(name)(*args, **kw)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import struct, threading import os, errno, socket, struct, sys, threading
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
...@@ -57,6 +57,10 @@ class DatabaseManager(object): ...@@ -57,6 +57,10 @@ class DatabaseManager(object):
ENGINES = () ENGINES = ()
UNSAFE = False UNSAFE = False
__lock = None
LOCK = "neostorage"
LOCKED = "error: database is locked"
_deferred = 0 _deferred = 0
_duplicating = _repairing = None _duplicating = _repairing = None
...@@ -86,6 +90,7 @@ class DatabaseManager(object): ...@@ -86,6 +90,7 @@ class DatabaseManager(object):
def _duplicate(self): def _duplicate(self):
cls = self.__class__ cls = self.__class__
db = cls.__new__(cls) db = cls.__new__(cls)
db.LOCK = None
db._duplicating = self db._duplicating = self
try: try:
db._connect() db._connect()
...@@ -104,6 +109,26 @@ class DatabaseManager(object): ...@@ -104,6 +109,26 @@ class DatabaseManager(object):
def _connect(self): def _connect(self):
"""Connect to the database""" """Connect to the database"""
def lock(self, db_path):
if self.LOCK:
assert self.__lock is None, self.__lock
# For platforms that don't support anonymous sockets,
# we can either use zc.lockfile or an empty SQLite db
# (with BEGIN EXCLUSIVE).
try:
stat = os.stat(db_path)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return # in-memory or temporary database
s = self.__lock = socket.socket(socket.AF_UNIX)
try:
s.bind('\0%s:%s:%s' % (self.LOCK, stat.st_dev, stat.st_ino))
except socket.error as e:
if e.errno != errno.EADDRINUSE:
raise
sys.exit(self.LOCKED)
@abstract @abstract
def erase(self): def erase(self):
"""""" """"""
...@@ -154,6 +179,9 @@ class DatabaseManager(object): ...@@ -154,6 +179,9 @@ class DatabaseManager(object):
def close(self): def close(self):
self._deferredCommit() self._deferredCommit()
self._close() self._close()
if self.__lock:
self.__lock.close()
del self.__lock
def _commit(self): def _commit(self):
"""Backend-specific code to commit the pending changes""" """Backend-specific code to commit the pending changes"""
...@@ -301,10 +329,23 @@ class DatabaseManager(object): ...@@ -301,10 +329,23 @@ class DatabaseManager(object):
Required only to import a DB using Importer backend. Required only to import a DB using Importer backend.
max_tid must be in unpacked format. max_tid must be in unpacked format.
Data from unassigned partitions must be ignored.
This is important because there may remain data from cells that have
been discarded, either due to --disable-drop-partitions option,
or in the future when dropping partitions is done in background
(because this is an expensive operation).
XXX: Given the TODO comment in getLastIDs, getting ids
from readable partitions should be enough.
""" """
def _getLastIDs(self): def _getLastIDs(self):
"""""" """Return (trans, obj, max(oid)) where
both 'trans' and 'obj' are {partition: max(tid)}
Same as in getLastTID: data from unassigned partitions must be ignored.
"""
@requires(_getLastIDs) @requires(_getLastIDs)
def getLastIDs(self): def getLastIDs(self):
......
...@@ -29,6 +29,7 @@ import os ...@@ -29,6 +29,7 @@ import os
import re import re
import string import string
import struct import struct
import sys
import time import time
from . import LOG_QUERIES from . import LOG_QUERIES
...@@ -52,9 +53,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -52,9 +53,6 @@ class MySQLDatabaseManager(DatabaseManager):
ENGINES = "InnoDB", "RocksDB", "TokuDB" ENGINES = "InnoDB", "RocksDB", "TokuDB"
_engine = ENGINES[0] # default engine _engine = ENGINES[0] # default engine
# Disabled even on MySQL 5.1-5.5 and MariaDB 5.2-5.3 because
# 'select count(*) from obj' sometimes returns incorrect values
# (tested with testOudatedCellsOnDownStorage).
_use_partition = False _use_partition = False
_max_allowed_packet = 32769 * 1024 _max_allowed_packet = 32769 * 1024
...@@ -102,9 +100,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -102,9 +100,17 @@ class MySQLDatabaseManager(DatabaseManager):
conn.autocommit(False) conn.autocommit(False)
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1)) conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
conn.query("SHOW VARIABLES WHERE variable_name='max_allowed_packet'") def query(sql):
r = conn.store_result() conn.query(sql)
(name, value), = r.fetch_row(r.num_rows()) r = conn.store_result()
return r.fetch_row(r.num_rows())
if self.LOCK:
(locked,), = query("SELECT GET_LOCK('%s.%s', 0)"
% (self.db, self.LOCK))
if not locked:
sys.exit(self.LOCKED)
(name, value), = query(
"SHOW VARIABLES WHERE variable_name='max_allowed_packet'")
if int(value) < self._max_allowed_packet: if int(value) < self._max_allowed_packet:
raise DatabaseFailure("Global variable %r is too small." raise DatabaseFailure("Global variable %r is too small."
" Minimal value must be %uk." " Minimal value must be %uk."
...@@ -304,21 +310,37 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -304,21 +310,37 @@ class MySQLDatabaseManager(DatabaseManager):
return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid) return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid)
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getAssignedPartitionList(self):
nid = self.getUUID()
if nid is None:
return ()
return [p for p, in self.query("SELECT rid FROM pt WHERE nid=%s" % nid)]
def _sqlmax(self, sql, arg_list):
q = self.query
x = [x for x in arg_list for x, in q(sql % x) if x is not None]
if x: return max(x)
def getLastTID(self, max_tid): def getLastTID(self, max_tid):
return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans" return self._sqlmax(
" WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0] "SELECT MAX(tid) as t FROM trans FORCE INDEX (PRIMARY)"
" WHERE tid<=%s and `partition`=%%s" % max_tid,
self._getAssignedPartitionList())
def _getLastIDs(self): def _getLastIDs(self):
offset_list = self._getAssignedPartitionList()
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
trans = {partition: p64(tid) sql = ("SELECT MAX(tid) FROM %s FORCE INDEX (PRIMARY)"
for partition, tid in q("SELECT `partition`, MAX(tid)" " WHERE `partition`=%s")
" FROM trans GROUP BY `partition`")} trans, obj = ({partition: p64(tid)
obj = {partition: p64(tid) for partition in offset_list
for partition, tid in q("SELECT `partition`, MAX(tid)" for tid, in q(sql % (t, partition))
" FROM obj GROUP BY `partition`")} if tid is not None}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" for t in ('trans', 'obj'))
" GROUP BY `partition`) as t")[0][0] oid = self._sqlmax(
"SELECT MAX(oid) FROM obj FORCE INDEX (`partition`)"
" WHERE `partition`=%s", offset_list)
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def _getUnfinishedTIDDict(self): def _getUnfinishedTIDDict(self):
...@@ -337,7 +359,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -337,7 +359,7 @@ class MySQLDatabaseManager(DatabaseManager):
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
oid = util.u64(oid) oid = util.u64(oid)
r = self.query("SELECT tid FROM obj" r = self.query("SELECT tid FROM obj FORCE INDEX(`partition`)"
" WHERE `partition`=%d AND oid=%d" " WHERE `partition`=%d AND oid=%d"
" ORDER BY tid DESC LIMIT 1" " ORDER BY tid DESC LIMIT 1"
% (self._getReadablePartition(oid), oid)) % (self._getReadablePartition(oid), oid))
...@@ -358,7 +380,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -358,7 +380,8 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
partition = self._getReadablePartition(oid) partition = self._getReadablePartition(oid)
sql = ('SELECT tid, compression, data.hash, value, value_tid' sql = ('SELECT tid, compression, data.hash, value, value_tid'
' FROM obj LEFT JOIN data ON (obj.data_id = data.id)' ' FROM obj FORCE INDEX(`partition`)'
' LEFT JOIN data ON (obj.data_id = data.id)'
' WHERE `partition` = %d AND oid = %d') % (partition, oid) ' WHERE `partition` = %d AND oid = %d') % (partition, oid)
if before_tid is not None: if before_tid is not None:
sql += ' AND tid < %d ORDER BY tid DESC LIMIT 1' % before_tid sql += ' AND tid < %d ORDER BY tid DESC LIMIT 1' % before_tid
...@@ -414,7 +437,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -414,7 +437,8 @@ class MySQLDatabaseManager(DatabaseManager):
for partition in offset_list: for partition in offset_list:
where = " WHERE `partition`=%d" % partition where = " WHERE `partition`=%d" % partition
data_id_list = [x for x, in data_id_list = [x for x, in
q("SELECT DISTINCT data_id FROM obj USE INDEX(PRIMARY)" + where) q("SELECT DISTINCT data_id FROM obj FORCE INDEX(PRIMARY)"
+ where)
if x] if x]
if not self._use_partition: if not self._use_partition:
q("DELETE FROM obj" + where) q("DELETE FROM obj" + where)
...@@ -578,7 +602,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -578,7 +602,7 @@ class MySQLDatabaseManager(DatabaseManager):
del _structLL del _structLL
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
sql = ('SELECT tid, value_tid FROM obj' sql = ('SELECT tid, value_tid FROM obj FORCE INDEX(`partition`)'
' WHERE `partition` = %d AND oid = %d' ' WHERE `partition` = %d AND oid = %d'
) % (self._getReadablePartition(oid), oid) ) % (self._getReadablePartition(oid), oid)
if tid is not None: if tid is not None:
...@@ -669,7 +693,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -669,7 +693,8 @@ class MySQLDatabaseManager(DatabaseManager):
p64 = util.p64 p64 = util.p64
r = self.query("SELECT tid, IF(compression < 128, LENGTH(value)," r = self.query("SELECT tid, IF(compression < 128, LENGTH(value),"
" CAST(CONV(HEX(SUBSTR(value, 5, 4)), 16, 10) AS INT))" " CAST(CONV(HEX(SUBSTR(value, 5, 4)), 16, 10) AS INT))"
" FROM obj LEFT JOIN data ON (obj.data_id = data.id)" " FROM obj FORCE INDEX(`partition`)"
" LEFT JOIN data ON (obj.data_id = data.id)"
" WHERE `partition` = %d AND oid = %d AND tid >= %d" " WHERE `partition` = %d AND oid = %d AND tid >= %d"
" ORDER BY tid DESC LIMIT %d, %d" % " ORDER BY tid DESC LIMIT %d, %d" %
(self._getReadablePartition(oid), oid, (self._getReadablePartition(oid), oid,
...@@ -682,7 +707,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -682,7 +707,7 @@ class MySQLDatabaseManager(DatabaseManager):
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
min_tid = u64(min_tid) min_tid = u64(min_tid)
r = self.query('SELECT tid, oid FROM obj' r = self.query('SELECT tid, oid FROM obj FORCE INDEX(PRIMARY)'
' WHERE `partition` = %d AND tid <= %d' ' WHERE `partition` = %d AND tid <= %d'
' AND (tid = %d AND %d <= oid OR %d < tid)' ' AND (tid = %d AND %d <= oid OR %d < tid)'
' ORDER BY tid ASC, oid ASC LIMIT %d' % ( ' ORDER BY tid ASC, oid ASC LIMIT %d' % (
...@@ -751,7 +776,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -751,7 +776,8 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)" for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid, MAX(tid)"
" FROM obj WHERE tid <= %d GROUP BY oid" " FROM obj FORCE INDEX(`partition`)"
" WHERE tid <= %d GROUP BY oid"
% tid): % tid):
partition = getPartition(oid) partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE `partition` = %d" if q("SELECT 1 FROM obj WHERE `partition` = %d"
...@@ -801,7 +827,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -801,7 +827,7 @@ class MySQLDatabaseManager(DatabaseManager):
# last grouped value, instead of the greatest one. # last grouped value, instead of the greatest one.
r = self.query( r = self.query(
"""SELECT tid, oid """SELECT tid, oid
FROM obj FROM obj FORCE INDEX(PRIMARY)
WHERE `partition` = %(partition)s WHERE `partition` = %(partition)s
AND tid <= %(max_tid)d AND tid <= %(max_tid)d
AND (tid > %(min_tid)d OR AND (tid > %(min_tid)d OR
......
...@@ -78,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -78,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False) self.conn = sqlite3.connect(self.db, check_same_thread=False)
self.lock(self.db)
if self.UNSAFE: if self.UNSAFE:
q = self.query q = self.query
q("PRAGMA synchronous = OFF") q("PRAGMA synchronous = OFF")
...@@ -243,20 +244,25 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -243,20 +244,25 @@ class SQLiteDatabaseManager(DatabaseManager):
# each partition (and finish in Python with max() for getLastTID). # each partition (and finish in Python with max() for getLastTID).
def getLastTID(self, max_tid): def getLastTID(self, max_tid):
return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?", return self.query(
(max_tid,)).next()[0] "SELECT MAX(tid) FROM pt, trans"
" WHERE nid=? AND rid=partition AND tid<=?",
(self.getUUID(), max_tid,)).next()[0]
def _getLastIDs(self): def _getLastIDs(self):
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
args = self.getUUID(),
trans = {partition: p64(tid) trans = {partition: p64(tid)
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q(
" FROM trans GROUP BY partition")} "SELECT partition, MAX(tid) FROM pt, trans"
" WHERE nid=? AND rid=partition GROUP BY partition", args)}
obj = {partition: p64(tid) obj = {partition: p64(tid)
for partition, tid in q("SELECT partition, MAX(tid)" for partition, tid in q(
" FROM obj GROUP BY partition")} "SELECT partition, MAX(tid) FROM pt, obj"
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" " WHERE nid=? AND rid=partition GROUP BY partition", args)}
" GROUP BY partition) as t").next()[0] oid = q("SELECT MAX(oid) oid FROM pt, obj"
" WHERE nid=? AND rid=partition", args).next()[0]
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def _getUnfinishedTIDDict(self): def _getUnfinishedTIDDict(self):
......
...@@ -38,6 +38,9 @@ class InitializationHandler(BaseMasterHandler): ...@@ -38,6 +38,9 @@ class InitializationHandler(BaseMasterHandler):
# delete objects database # delete objects database
dm = app.dm dm = app.dm
if unassigned_set: if unassigned_set:
if app.disable_drop_partitions:
logging.info("don't drop data for partitions %r", unassigned_set)
else:
logging.debug('drop data for partitions %r', unassigned_set) logging.debug('drop data for partitions %r', unassigned_set)
dm.dropPartitions(unassigned_set) dm.dropPartitions(unassigned_set)
......
...@@ -46,7 +46,6 @@ class StorageOperationHandler(EventHandler): ...@@ -46,7 +46,6 @@ class StorageOperationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
if app.operational and conn.isClient(): if app.operational and conn.isClient():
# XXX: Connection and Node should merged.
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid: if uuid:
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
......
...@@ -356,6 +356,7 @@ class Replicator(object): ...@@ -356,6 +356,7 @@ class Replicator(object):
self.fetchTransactions() self.fetchTransactions()
def fetchTransactions(self, min_tid=None): def fetchTransactions(self, min_tid=None):
assert self.current_node.getConnection().isClient(), self.current_node
offset = self.current_partition offset = self.current_partition
p = self.partition_dict[offset] p = self.partition_dict[offset]
if min_tid: if min_tid:
......
...@@ -190,6 +190,11 @@ class NeoTestBase(unittest.TestCase): ...@@ -190,6 +190,11 @@ class NeoTestBase(unittest.TestCase):
"Mock objects can't be compared with '==' or '!='" "Mock objects can't be compared with '==' or '!='"
return super(NeoTestBase, self).assertEqual(first, second, msg=msg) return super(NeoTestBase, self).assertEqual(first, second, msg=msg)
def assertPartitionTable(self, pt, expected, key=None):
self.assertEqual(
expected if isinstance(expected, str) else '|'.join(expected),
'|'.join(pt._formatRows(sorted(pt.count_dict, key=key))))
class NeoUnitTestBase(NeoTestBase): class NeoUnitTestBase(NeoTestBase):
""" Base class for neo tests, implements common checks """ """ Base class for neo tests, implements common checks """
...@@ -217,7 +222,8 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -217,7 +222,8 @@ class NeoUnitTestBase(NeoTestBase):
temp_dir = getTempDirectory() temp_dir = getTempDirectory()
for i in xrange(number): for i in xrange(number):
try: try:
os.remove(os.path.join(temp_dir, 'test_neo%s.sqlite' % i)) os.remove(os.path.join(temp_dir,
'%s%s.sqlite' % (prefix, i)))
except OSError, e: except OSError, e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
......
...@@ -104,7 +104,7 @@ class ClusterPdb(object): ...@@ -104,7 +104,7 @@ class ClusterPdb(object):
def broken_peer(self): def broken_peer(self):
return self._getLastPdb(os.getpid()) is None return self._getLastPdb(os.getpid()) is None
def __call__(self, max_count=None, depth=0, text=None): def __call__(self, depth=0, max_count=None, gui=False):
depth += 1 depth += 1
if max_count: if max_count:
frame = sys._getframe(depth) frame = sys._getframe(depth)
...@@ -113,13 +113,8 @@ class ClusterPdb(object): ...@@ -113,13 +113,8 @@ class ClusterPdb(object):
self._count_dict[key] = count = 1 + self._count_dict.get(key, 0) self._count_dict[key] = count = 1 + self._count_dict.get(key, 0)
if max_count < count: if max_count < count:
return return
if not text: if gui:
try:
import rpdb2 import rpdb2
except ImportError:
if text is not None:
raise
else:
if rpdb2.g_debugger is None: if rpdb2.g_debugger is None:
rpdb2_CStateManager = rpdb2.CStateManager rpdb2_CStateManager = rpdb2.CStateManager
def CStateManager(*args, **kw): def CStateManager(*args, **kw):
......
...@@ -37,10 +37,11 @@ from neo.lib import logging ...@@ -37,10 +37,11 @@ from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES UUID_NAMESPACES
from neo.lib.util import dump from neo.lib.util import dump
from .. import ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL, \ from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, NeoTestBase, setupMySQLdb buildUrlFromString, cluster, getTempDirectory, NeoTestBase, Patch,
setupMySQLdb)
from neo.client.Storage import Storage from neo.client.Storage import Storage
from neo.storage.database import buildDatabaseManager from neo.storage.database import manager, buildDatabaseManager
try: try:
coverage = sys.modules['neo.scripts.runner'].coverage coverage = sys.modules['neo.scripts.runner'].coverage
...@@ -124,7 +125,7 @@ class NEOProcess(object): ...@@ -124,7 +125,7 @@ class NEOProcess(object):
def __init__(self, command, uuid, arg_dict): def __init__(self, command, uuid, arg_dict):
try: try:
__import__('neo.scripts.' + command) __import__('neo.scripts.' + command, level=0)
except ImportError: except ImportError:
raise NotFound, '%s not found' % (command) raise NotFound, '%s not found' % (command)
self.command = command self.command = command
...@@ -491,7 +492,8 @@ class NEOCluster(object): ...@@ -491,7 +492,8 @@ class NEOCluster(object):
def getSQLConnection(self, db): def getSQLConnection(self, db):
assert db is not None and db in self.db_list assert db is not None and db in self.db_list
return buildDatabaseManager(self.adapter, (self.db_template(db),)) with Patch(manager.DatabaseManager, LOCK=None):
return buildDatabaseManager(self.adapter, (self.db_template(db),))
def getMasterProcessList(self): def getMasterProcessList(self):
return self.process_dict.get(NodeTypes.MASTER) return self.process_dict.get(NodeTypes.MASTER)
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import random, time, unittest
from collections import defaultdict from collections import defaultdict
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib import logging
from neo.lib.protocol import NodeStates, CellStates from neo.lib.protocol import NodeStates, CellStates
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.master.pt import PartitionTable from neo.master.pt import PartitionTable
...@@ -45,7 +46,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -45,7 +46,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
self.assertEqual(len(pt.getRow(x)), 0) self.assertEqual(len(pt.getRow(x)), 0)
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertRaises(RuntimeError, pt.make, []) self.assertRaises(AssertionError, pt.make, [])
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
...@@ -132,77 +133,35 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -132,77 +133,35 @@ class MasterPartitionTableTests(NeoUnitTestBase):
(1, 2, CellStates.DISCARDED), (1, 2, CellStates.DISCARDED),
(2, 2, CellStates.DISCARDED)]) (2, 2, CellStates.DISCARDED)])
pt._setCell(0, sn[0], CellStates.UP_TO_DATE)
self.assertEqual(self.tweak(pt), [(2, 3, CellStates.FEEDING)]) self.assertEqual(self.tweak(pt), [(2, 3, CellStates.FEEDING)])
def test_16_make(self): def test_16_make(self):
num_partitions = 5 node_list = [self.createStorage(
num_replicas = 1 ("127.0.0.1", 19000 + i), self.getStorageUUID(),
pt = PartitionTable(num_partitions, num_replicas) NodeStates.RUNNING)
# add nodes for i in xrange(4)]
uuid1 = self.getStorageUUID() for np, nr, expected in (
server1 = ("127.0.0.1", 19001) (3, 0, 'U..|.U.|..U'),
sn1 = self.createStorage(server1, uuid1, NodeStates.RUNNING) (5, 1, 'UU..|..UU|UU..|..UU|UU..'),
# add not running node (9, 2, 'UUU.|UU.U|U.UU|.UUU|UUU.|UU.U|U.UU|.UUU|UUU.'),
uuid2 = self.getStorageUUID() ):
server2 = ("127.0.0.2", 19001) pt = PartitionTable(np, nr)
sn2 = self.createStorage(server2, uuid2) pt.make(node_list)
sn2.setState(NodeStates.DOWN) self.assertPartitionTable(pt, expected)
# add node without uuid self.assertTrue(pt.filled())
server3 = ("127.0.0.3", 19001) self.assertTrue(pt.operational())
sn3 = self.createStorage(server3, None, NodeStates.RUNNING) # create a pt with less nodes
# add clear node pt.clear()
uuid4 = self.getStorageUUID() self.assertFalse(pt.filled())
server4 = ("127.0.0.4", 19001) self.assertFalse(pt.operational())
sn4 = self.createStorage(server4, uuid4, NodeStates.RUNNING) pt.make(node_list[:1])
uuid5 = self.getStorageUUID() self.assertPartitionTable(pt, '|'.join('U' * np))
server5 = ("127.0.0.5", 1900) self.assertTrue(pt.filled())
sn5 = self.createStorage(server5, uuid5, NodeStates.RUNNING) self.assertTrue(pt.operational())
# make the table
pt.make([sn1, sn2, sn3, sn4, sn5])
# check it's ok, only running nodes and node with uuid
# must be present
for x in xrange(num_partitions):
cells = pt.getCellList(x)
self.assertEqual(len(cells), 2)
nodes = [x.getNode() for x in cells]
for node in nodes:
self.assertTrue(node in (sn1, sn4, sn5))
self.assertTrue(node not in (sn2, sn3))
self.assertTrue(pt.filled())
self.assertTrue(pt.operational())
# create a pt with less nodes
pt.clear()
self.assertFalse(pt.filled())
self.assertFalse(pt.operational())
pt.make([sn1])
# check it's ok
for x in xrange(num_partitions):
cells = pt.getCellList(x)
self.assertEqual(len(cells), 1)
nodes = [x.getNode() for x in cells]
for node in nodes:
self.assertEqual(node, sn1)
self.assertTrue(pt.filled())
self.assertTrue(pt.operational())
def _pt_states(self, pt):
node_dict = defaultdict(list)
for offset, row in enumerate(pt.partition_list):
for cell in row:
state_list = node_dict[cell.getNode()]
if state_list:
self.assertTrue(state_list[-1][0] < offset)
state_list.append((offset, str(cell.getState())[0]))
return map(dict, sorted(node_dict.itervalues()))
def checkPT(self, pt, exclude_empty=False):
new_pt = PartitionTable(pt.np, pt.nr)
new_pt.make(node for node, count in pt.count_dict.iteritems()
if count or not exclude_empty)
self.assertEqual(self._pt_states(pt), self._pt_states(new_pt))
def update(self, pt, change_list=None): def update(self, pt, change_list=None):
offset_list = range(pt.np) offset_list = xrange(pt.np)
for node in pt.count_dict: for node in pt.count_dict:
pt.updatable(node.getUUID(), offset_list) pt.updatable(node.getUUID(), offset_list)
if change_list is None: if change_list is None:
...@@ -215,9 +174,11 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -215,9 +174,11 @@ class MasterPartitionTableTests(NeoUnitTestBase):
for offset, uuid, state in change_list: for offset, uuid, state in change_list:
if state is CellStates.OUT_OF_DATE: if state is CellStates.OUT_OF_DATE:
pt.setUpToDate(node_dict[uuid], offset) pt.setUpToDate(node_dict[uuid], offset)
pt.log()
def tweak(self, pt, drop_list=()): def tweak(self, pt, drop_list=()):
change_list = pt.tweak(drop_list) change_list = pt.tweak(drop_list)
pt.log()
self.assertFalse(pt.tweak(drop_list)) self.assertFalse(pt.tweak(drop_list))
return change_list return change_list
...@@ -225,6 +186,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -225,6 +186,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
sn = [self.createStorage(None, i + 1, NodeStates.RUNNING) sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(5)] for i in xrange(5)]
pt = PartitionTable(5, 2) pt = PartitionTable(5, 2)
pt.setID(1)
# part 0 # part 0
pt._setCell(0, sn[0], CellStates.DISCARDED) pt._setCell(0, sn[0], CellStates.DISCARDED)
pt._setCell(0, sn[1], CellStates.UP_TO_DATE) pt._setCell(0, sn[1], CellStates.UP_TO_DATE)
...@@ -246,45 +208,108 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -246,45 +208,108 @@ class MasterPartitionTableTests(NeoUnitTestBase):
pt._setCell(4, sn[4], CellStates.UP_TO_DATE) pt._setCell(4, sn[4], CellStates.UP_TO_DATE)
count_dict = defaultdict(int) count_dict = defaultdict(int)
self.assertPartitionTable(pt, (
'.U...',
'FFO..',
'FUU..',
'UUUU.',
'U...U'))
change_list = self.tweak(pt) change_list = self.tweak(pt)
self.assertPartitionTable(pt, (
'.UO.O',
'UU.O.',
'UFU.O',
'.UUU.',
'U..OU'))
for offset, uuid, state in change_list: for offset, uuid, state in change_list:
count_dict[state] += 1 count_dict[state] += 1
self.assertEqual(count_dict, {CellStates.DISCARDED: 3, self.assertEqual(count_dict, {CellStates.DISCARDED: 2,
CellStates.FEEDING: 1,
CellStates.OUT_OF_DATE: 5, CellStates.OUT_OF_DATE: 5,
CellStates.UP_TO_DATE: 3}) CellStates.UP_TO_DATE: 3})
self.update(pt, change_list) self.update(pt)
self.checkPT(pt) self.assertPartitionTable(pt, (
'.UU.U',
'UU.U.',
'U.U.U',
'.UUU.',
'U..UU'))
self.assertRaises(PartitionTableException, pt.dropNodeList, sn[1:4]) self.assertRaises(PartitionTableException, pt.dropNodeList, sn[1:4])
self.assertEqual(6, len(pt.dropNodeList(sn[1:3], True))) self.assertEqual(6, len(pt.dropNodeList(sn[1:3], True)))
self.assertEqual(3, len(pt.dropNodeList([sn[1]]))) self.assertEqual(3, len(pt.dropNodeList([sn[1]])))
pt.addNodeList([sn[1]]) pt.addNodeList([sn[1]])
self.assertPartitionTable(pt, (
'..U.U',
'U..U.',
'U.U.U',
'..UU.',
'U..UU'))
change_list = self.tweak(pt) change_list = self.tweak(pt)
self.assertPartitionTable(pt, (
'.OU.U',
'UO.U.',
'U.U.U',
'.OUU.',
'U..UU'))
self.assertEqual(3, len(change_list)) self.assertEqual(3, len(change_list))
self.update(pt, change_list) self.update(pt, change_list)
self.checkPT(pt)
for np, i in (12, 0), (12, 1), (13, 2): for np, i, expected in (
(12, 0, ('U...|.U..|..U.|...U|'
'U...|.U..|..U.|...U|'
'U...|.U..|..U.|...U',)),
(12, 1, ('UU...|..UU.|U...U|.UU..|...UU|'
'UU...|..UU.|U...U|.UU..|...UU|'
'UU...|..UU.',)),
(13, 2, ('U.UU.|.U.UU|UUU..|..UUU|UU..U|'
'U.UU.|.U.UU|UUU..|..UUU|UU..U|'
'U.UU.|.U.UU|UUU..',
'UUU..|U..UU|.UUU.|UU..U|..UUU|'
'UUU..|U..UU|.UUU.|UU..U|..UUU|'
'UUU..|U..UU|.UUU.')),
):
pt = PartitionTable(np, i) pt = PartitionTable(np, i)
i += 1 i += 1
pt.make(sn[:i]) pt.make(sn[:i])
pt.log()
for n in sn[i:i+3]: for n in sn[i:i+3]:
self.assertEqual([n], pt.addNodeList([n])) self.assertEqual([n], pt.addNodeList([n]))
self.update(pt, self.tweak(pt)) self.update(pt, self.tweak(pt))
self.checkPT(pt) self.assertPartitionTable(pt, expected[0])
pt.clear() pt.clear()
pt.make(sn[:i]) pt.make(sn[:i])
for n in sn[i:i+3]: for n in sn[i:i+3]:
self.assertEqual([n], pt.addNodeList([n])) self.assertEqual([n], pt.addNodeList([n]))
self.tweak(pt) self.tweak(pt)
self.update(pt) self.update(pt)
self.checkPT(pt) self.assertPartitionTable(pt, expected[-1])
pt = PartitionTable(7, 0) pt = PartitionTable(7, 0)
pt.make(sn[:1]) pt.make(sn[:1])
pt.addNodeList(sn[1:3]) pt.addNodeList(sn[1:3])
self.assertPartitionTable(pt, 'U..|U..|U..|U..|U..|U..|U..')
self.update(pt, self.tweak(pt, sn[:1])) self.update(pt, self.tweak(pt, sn[:1]))
self.checkPT(pt, True) self.assertPartitionTable(pt, '.U.|..U|.U.|..U|.U.|..U|.U.')
def test_18_tweak(self):
s = repr(time.time())
logging.info("using seed %r", s)
r = random.Random(s)
sn_count = 11
sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(sn_count)]
pt = PartitionTable(1000, 2)
pt.setID(1)
for offset in xrange(pt.np):
state = CellStates.UP_TO_DATE
k = r.randrange(1, sn_count)
for s in r.sample(sn, k):
pt._setCell(offset, s, state)
if k * r.random() < 1:
state = CellStates.OUT_OF_DATE
pt.log()
self.tweak(pt)
self.update(pt)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -131,6 +131,15 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -131,6 +131,15 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2)) self.assertEqual(set(list1), set(list2))
def _test_lockDatabase_open(self):
raise NotImplementedError
def test_lockDatabase(self):
db = self._test_lockDatabase_open()
self.assertRaises(SystemExit, self._test_lockDatabase_open)
db.close()
self._test_lockDatabase_open().close()
def test_getUnfinishedTIDDict(self): def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
......
...@@ -29,11 +29,13 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -29,11 +29,13 @@ class StorageMySQLdbTests(StorageDBTests):
engine = None engine = None
def getDB(self, reset=0): def _test_lockDatabase_open(self):
self.prepareDatabase(number=1, prefix=DB_PREFIX) self.prepareDatabase(number=1, prefix=DB_PREFIX)
# db manager
database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET) database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET)
db = MySQLDatabaseManager(database, self.engine) return MySQLDatabaseManager(database, self.engine)
def getDB(self, reset=0):
db = self._test_lockDatabase_open()
self.assertEqual(db.db, DB_PREFIX + '0') self.assertEqual(db.db, DB_PREFIX + '0')
self.assertEqual(db.user, DB_USER) self.assertEqual(db.user, DB_USER)
try: try:
...@@ -129,11 +131,13 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -129,11 +131,13 @@ class StorageMySQLdbTests(StorageDBTests):
class StorageMySQLdbRocksDBTests(StorageMySQLdbTests): class StorageMySQLdbRocksDBTests(StorageMySQLdbTests):
engine = "RocksDB" engine = "RocksDB"
test_lockDatabase = None
class StorageMySQLdbTokuDBTests(StorageMySQLdbTests): class StorageMySQLdbTokuDBTests(StorageMySQLdbTests):
engine = "TokuDB" engine = "TokuDB"
test_lockDatabase = None
del StorageDBTests del StorageDBTests
......
...@@ -14,17 +14,29 @@ ...@@ -14,17 +14,29 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import os, unittest
from .. import getTempDirectory, DB_PREFIX
from .testStorageDBTests import StorageDBTests from .testStorageDBTests import StorageDBTests
from neo.storage.database.sqlite import SQLiteDatabaseManager from neo.storage.database.sqlite import SQLiteDatabaseManager
class StorageSQLiteTests(StorageDBTests): class StorageSQLiteTests(StorageDBTests):
def _test_lockDatabase_open(self):
db = os.path.join(getTempDirectory(), DB_PREFIX + '0.sqlite')
return SQLiteDatabaseManager(db)
def getDB(self, reset=0): def getDB(self, reset=0):
db = SQLiteDatabaseManager(':memory:') db = SQLiteDatabaseManager(':memory:')
db.setup(reset) db.setup(reset)
return db return db
def test_lockDatabase(self):
super(StorageSQLiteTests, self).test_lockDatabase()
# No lock on temporary databases.
db = self.getDB()
self.getDB().close()
db.close()
del StorageDBTests del StorageDBTests
if __name__ == "__main__": if __name__ == "__main__":
......
# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2017 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from time import time
from .mock import Mock
from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ClientConnection, \
MTClientConnection, CRITICAL_TIMEOUT
from neo.lib.handler import EventHandler
from neo.lib.protocol import ENCODED_VERSION, Packets
from . import NeoUnitTestBase, Patch
connector_cpt = 0
class DummyConnector(Mock):
def __init__(self, addr, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
self.addr = addr
Mock.__init__(self)
def getAddress(self):
return self.addr
def getDescriptor(self):
return self.desc
accept = getError = makeClientConnection = makeListeningConnection = \
receive = send = lambda *args, **kw: None
dummy_connector = Patch(BaseConnection,
ConnectorClass=lambda orig, self, *args, **kw: DummyConnector(*args, **kw))
class ConnectionTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.app = Mock({'__repr__': 'Fake App'})
self.app.ssl = None
self.em = self.app.em = Mock({'__repr__': 'Fake Em'})
self.handler = Mock({'__repr__': 'Fake Handler'})
self.address = ("127.0.0.7", 93413)
self.node = Mock({'getAddress': self.address})
def _makeClientConnection(self):
with dummy_connector:
conn = ClientConnection(self.app, self.handler, self.node)
self.connector = conn.connector
return conn
def testTimeout(self):
# NOTE: This method uses ping/pong packets only because MT connections
# don't accept any other packet without specifying a queue.
self.handler = EventHandler(self.app)
conn = self._makeClientConnection()
conn.read_buf.append(ENCODED_VERSION)
use_case_list = (
# (a) For a single packet sent at T,
# the limit time for the answer is T + (1 * CRITICAL_TIMEOUT)
((), (1., 1)),
# (b) Same as (a), even if send another packet at (T + CT/2).
# But receiving a packet (at T + CT - ε) resets the timeout
# (which means the limit for the 2nd one is T + 2*CT)
((.5, None), (1., 1, 2., 3)),
# (c) Same as (b) with a first answer at well before the limit
# (T' = T + CT/2). The limit for the second one is T' + CT.
((.1, None, .5, 3), (1.5, 1)),
)
def set_time(t):
connection.time = lambda: int(CRITICAL_TIMEOUT * (1000 + t))
closed = []
conn.close = lambda: closed.append(connection.time())
def answer(packet_id):
p = Packets.Pong()
p.setId(packet_id)
conn.connector.receive = lambda read_buf: \
read_buf.append(''.join(p.encode()))
conn.readable()
checkTimeout()
conn.process()
def checkTimeout():
timeout = conn.getTimeout()
if timeout and timeout <= connection.time():
conn.onTimeout()
try:
for use_case, expected in use_case_list:
i = iter(use_case)
conn.cur_id = 1 # XXX -> conn._reset() ?
set_time(0)
# No timeout when no pending request
self.assertEqual(conn._handlers.getNextTimeout(), None)
conn.ask(Packets.Ping())
for t in i:
set_time(t)
checkTimeout()
packet_id = i.next()
if packet_id is None:
conn.ask(Packets.Ping())
else:
answer(packet_id)
i = iter(expected)
for t in i:
set_time(t - .1)
checkTimeout()
set_time(t)
# this test method relies on the fact that only
# conn.close is called in case of a timeout
checkTimeout()
self.assertEqual(closed.pop(), connection.time())
answer(i.next())
self.assertFalse(conn.isPending())
self.assertFalse(closed)
finally:
connection.time = time
class MTConnectionTests(ConnectionTests):
# XXX: here we test non-client-connection-related things too, which
# duplicates test suite work... Should be fragmented into finer-grained
# test classes.
def setUp(self):
super(MTConnectionTests, self).setUp()
self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'})
def _makeClientConnection(self):
with dummy_connector:
conn = MTClientConnection(self.app, self.handler, self.node,
dispatcher=self.dispatcher)
self.connector = conn.connector
return conn
def test_MTClientConnectionQueueParameter(self):
ask = self._makeClientConnection().ask
packet = Packets.AskPrimary() # Any non-Ping simple "ask" packet
# One cannot "ask" anything without a queue
self.assertRaises(TypeError, ask, packet)
ask(packet, queue=object())
# ... except Ping
ask(Packets.Ping())
if __name__ == '__main__':
unittest.main()
...@@ -1062,11 +1062,11 @@ class NEOThreadedTest(NeoTestBase): ...@@ -1062,11 +1062,11 @@ class NEOThreadedTest(NeoTestBase):
with Patch(client, _getFinalTID=lambda *_: None): with Patch(client, _getFinalTID=lambda *_: None):
self.assertRaises(ConnectionClosed, txn.commit) self.assertRaises(ConnectionClosed, txn.commit)
def assertPartitionTable(self, cluster, stats, pt_node=None): def assertPartitionTable(self, cluster, expected, pt_node=None):
pt = (pt_node or cluster.admin).pt
index = [x.uuid for x in cluster.storage_list].index index = [x.uuid for x in cluster.storage_list].index
self.assertEqual(stats, '|'.join(pt._formatRows(sorted( super(NEOThreadedTest, self).assertPartitionTable(
pt.count_dict, key=lambda x: index(x.getUUID()))))) (pt_node or cluster.admin).pt, expected,
lambda x: index(x.getUUID()))
@staticmethod @staticmethod
def noConnection(jar, storage): def noConnection(jar, storage):
......
...@@ -35,7 +35,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation ...@@ -35,7 +35,7 @@ from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.handler import DelayEvent from neo.lib.handler import DelayEvent
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes, from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
Packets, Packet, uuid_str, ZERO_OID, ZERO_TID) Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID)
from .. import expectedFailure, unpickle_state, Patch, TransactionalResource from .. import expectedFailure, unpickle_state, Patch, TransactionalResource
from . import ClientApplication, ConnectionFilter, LockLock, NEOThreadedTest, \ from . import ClientApplication, ConnectionFilter, LockLock, NEOThreadedTest, \
RandomConflictDict, ThreadId, with_cluster RandomConflictDict, ThreadId, with_cluster
...@@ -1350,19 +1350,6 @@ class Test(NEOThreadedTest): ...@@ -1350,19 +1350,6 @@ class Test(NEOThreadedTest):
poll(0) poll(0)
self.assertIs(client.connector, None) self.assertIs(client.connector, None)
def testConnectionTimeout(self):
with self.getLoopbackConnection() as conn:
conn.KEEP_ALIVE
def onTimeout(orig):
conn.idle()
orig()
with Patch(conn, KEEP_ALIVE=0):
while conn.connecting:
conn.em.poll(1)
with Patch(conn, onTimeout=onTimeout):
conn.em.poll(1)
self.assertFalse(conn.isClosed())
@with_cluster() @with_cluster()
def testClientDisconnectedFromMaster(self, cluster): def testClientDisconnectedFromMaster(self, cluster):
def disconnect(conn, packet): def disconnect(conn, packet):
...@@ -2061,7 +2048,7 @@ class Test(NEOThreadedTest): ...@@ -2061,7 +2048,7 @@ class Test(NEOThreadedTest):
if (isinstance(packet, Packets.AnswerStoreObject) if (isinstance(packet, Packets.AnswerStoreObject)
and packet.decode()[0]): and packet.decode()[0]):
conn, = cluster.client.getConnectionList(app) conn, = cluster.client.getConnectionList(app)
kw = conn._handlers._pending[0][0][packet._id][3] kw = conn._handlers._pending[0][0][packet._id][1]
return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop() return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop()
def writeA(orig, txn_context, oid, serial, data): def writeA(orig, txn_context, oid, serial, data):
if u64(oid) == 1: if u64(oid) == 1:
...@@ -2335,6 +2322,34 @@ class Test(NEOThreadedTest): ...@@ -2335,6 +2322,34 @@ class Test(NEOThreadedTest):
self.assertFalse(m1.primary) self.assertFalse(m1.primary)
self.assertTrue(m1.is_alive()) self.assertTrue(m1.is_alive())
@with_cluster(partitions=2, storage_count=2)
def testStorageBackendLastIDs(self, cluster):
"""
Check that getLastIDs/getLastTID ignore data from unassigned partitions.
XXX: this kind of test should not be reexecuted with SSL
"""
cluster.sortStorageList()
t, c = cluster.getTransaction()
c.root()[''] = PCounter()
t.commit()
big_id_list = ('\x7c' * 8, '\x7e' * 8), ('\x7b' * 8, '\x7d' * 8)
for i in 0, 1:
dm = cluster.storage_list[i].dm
expected = dm.getLastTID(u64(MAX_TID)), dm.getLastIDs()
oid, tid = big_id_list[i]
for j, expected in (
(1 - i, (dm.getLastTID(u64(MAX_TID)), dm.getLastIDs())),
(i, (u64(tid), (tid, {}, {}, oid)))):
oid, tid = big_id_list[j]
# Somehow we abuse 'storeTransaction' because we ask it to
# write data for unassigned partitions. This is not checked
# so for the moment, the test works.
dm.storeTransaction(tid, ((oid, None, None),),
((oid,), '', '', '', 0, tid), False)
self.assertEqual(expected,
(dm.getLastTID(u64(MAX_TID)), dm.getLastIDs()))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -29,7 +29,6 @@ from neo.storage.checker import CHECK_COUNT ...@@ -29,7 +29,6 @@ from neo.storage.checker import CHECK_COUNT
from neo.storage.replicator import Replicator from neo.storage.replicator import Replicator
from neo.lib.connector import SocketConnector from neo.lib.connector import SocketConnector
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import p64, u64 from neo.lib.util import p64, u64
...@@ -283,35 +282,6 @@ class ReplicationTests(NEOThreadedTest): ...@@ -283,35 +282,6 @@ class ReplicationTests(NEOThreadedTest):
self.assertEqual(backup.last_tid, upstream.last_tid) self.assertEqual(backup.last_tid, upstream.last_tid)
self.assertEqual(np*3, self.checkBackup(backup)) self.assertEqual(np*3, self.checkBackup(backup))
@backup_test()
def testBackupUpstreamMasterDead(self, backup):
"""Check proper behaviour when upstream master is unreachable
More generally, this checks that when a handler raises when a connection
is closed voluntarily, the connection is in a consistent state and can
be, for example, closed again after the exception is caught, without
assertion failure.
"""
conn, = backup.master.getConnectionList(backup.upstream.master)
# trigger ping
self.assertFalse(conn.isPending())
conn.onTimeout()
self.assertTrue(conn.isPending())
# force ping to have expired
# connection will be closed before upstream master has time
# to answer
def _poll(orig, self, blocking):
if backup.master.em is self:
p.revert()
conn._next_timeout = 0
conn.onTimeout()
else:
orig(self, blocking)
with Patch(EventManager, _poll=_poll) as p:
self.tic()
new_conn, = backup.master.getConnectionList(backup.upstream.master)
self.assertIsNot(new_conn, conn)
@backup_test() @backup_test()
def testBackupUpstreamStorageDead(self, backup): def testBackupUpstreamStorageDead(self, backup):
upstream = backup.upstream upstream = backup.upstream
...@@ -334,7 +304,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -334,7 +304,7 @@ class ReplicationTests(NEOThreadedTest):
self.tic(check_timeout=(backup.storage,)) self.tic(check_timeout=(backup.storage,))
# 2nd failed, 3rd deferred # 2nd failed, 3rd deferred
self.assertEqual(count[0], 4) self.assertEqual(count[0], 4)
self.assertTrue(t <= time.time()) self.assertLessEqual(t, time.time())
@backup_test() @backup_test()
def testBackupDelayedUnlockTransaction(self, backup): def testBackupDelayedUnlockTransaction(self, backup):
...@@ -406,13 +376,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -406,13 +376,13 @@ class ReplicationTests(NEOThreadedTest):
s2.start() s2.start()
self.tic() self.tic()
cluster.enableStorageList([s2]) cluster.enableStorageList([s2])
# 2 UP_TO_DATE cells should become FEEDING, # 2 UP_TO_DATE cells become FEEDING:
# and be dropped only when the replication is done, # they are dropped only when the replication is done,
# so that 1 storage can still die without data loss. # so that 1 storage can still die without data loss.
with Patch(s0.dm, changePartitionTable=changePartitionTable): with Patch(s0.dm, changePartitionTable=changePartitionTable):
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
self.tic() self.tic()
expectedFailure(self.assertEqual)(cluster.neoctl.getClusterState(), self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
@with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3) @with_cluster(start_cluster=0, partitions=3, replicas=1, storage_count=3)
...@@ -625,6 +595,31 @@ class ReplicationTests(NEOThreadedTest): ...@@ -625,6 +595,31 @@ class ReplicationTests(NEOThreadedTest):
with s0.dm.replicated(1): with s0.dm.replicated(1):
self.assertFalse(s0.dm.getObject(ob._p_oid, tid2)) self.assertFalse(s0.dm.getObject(ob._p_oid, tid2))
@with_cluster(start_cluster=0, storage_count=2, partitions=2)
def testDropPartitions(self, cluster, disable=False):
s0, s1 = cluster.storage_list
cluster.start(storage_list=(s0,))
t, c = cluster.getTransaction()
c.root()[''] = PCounter()
t.commit()
s1.start()
self.tic()
self.assertEqual(3, s0.sqlCount('obj'))
cluster.enableStorageList((s1,))
cluster.neoctl.tweakPartitionTable()
self.tic()
self.assertEqual(1, s1.sqlCount('obj'))
# Deletion should start as soon as the cell is discarded, as a
# background task, instead of doing it during initialization.
count = s0.sqlCount('obj')
s0.stop()
cluster.join((s0,))
s0.resetNode()
s0.start()
self.tic()
self.assertEqual(2, s0.sqlCount('obj'))
expectedFailure(self.assertEqual)(2, count)
@with_cluster(start_cluster=0, replicas=1) @with_cluster(start_cluster=0, replicas=1)
def testResumingReplication(self, cluster): def testResumingReplication(self, cluster):
if 1: if 1:
......
...@@ -34,8 +34,8 @@ class SSLMixin: ...@@ -34,8 +34,8 @@ class SSLMixin:
class SSLTests(SSLMixin, test.Test): class SSLTests(SSLMixin, test.Test):
# exclude expected failures # exclude expected failures
testDeadlockAvoidance = None # XXX why this fails? testStorageDataLock2 = None # XXX why this fails?
testUndoConflict = testUndoConflictDuringStore = None # XXX why this fails? testUndoConflictDuringStore = None # XXX why this fails?
def testAbortConnection(self, after_handshake=1): def testAbortConnection(self, after_handshake=1):
with self.getLoopbackConnection() as conn: with self.getLoopbackConnection() as conn:
......
...@@ -16,7 +16,7 @@ Topic :: Software Development :: Libraries :: Python Modules ...@@ -16,7 +16,7 @@ Topic :: Software Development :: Libraries :: Python Modules
mock = 'neo/tests/mock.py' mock = 'neo/tests/mock.py'
if not os.path.exists(mock): if not os.path.exists(mock):
import cStringIO, hashlib,subprocess, urllib, zipfile import cStringIO, hashlib, subprocess, urllib, zipfile
x = 'pythonmock-0.1.0.zip' x = 'pythonmock-0.1.0.zip'
try: try:
x = subprocess.check_output(('git', 'cat-file', 'blob', x)) x = subprocess.check_output(('git', 'cat-file', 'blob', x))
...@@ -24,8 +24,9 @@ if not os.path.exists(mock): ...@@ -24,8 +24,9 @@ if not os.path.exists(mock):
x = urllib.urlopen( x = urllib.urlopen(
'http://downloads.sf.net/sourceforge/python-mock/' + x).read() 'http://downloads.sf.net/sourceforge/python-mock/' + x).read()
mock_py = zipfile.ZipFile(cStringIO.StringIO(x)).read('mock.py') mock_py = zipfile.ZipFile(cStringIO.StringIO(x)).read('mock.py')
if hashlib.md5(mock_py).hexdigest() != '79f42f390678e5195d9ce4ae43bd18ec': if (hashlib.sha256(mock_py).hexdigest() !=
raise EnvironmentError("MD5 checksum mismatch downloading 'mock.py'") 'c6ed26e4312ed82160016637a9b6f8baa71cf31a67c555d44045a1ef1d60d1bc'):
raise EnvironmentError("SHA checksum mismatch downloading 'mock.py'")
open(mock, 'w').write(mock_py) open(mock, 'w').write(mock_py)
zodb_require = ['ZODB3>=3.10dev'] zodb_require = ['ZODB3>=3.10dev']
...@@ -59,11 +60,11 @@ else: ...@@ -59,11 +60,11 @@ else:
setup( setup(
name = 'neoppod', name = 'neoppod',
version = '1.7.1', version = '1.8',
description = __doc__.strip(), description = __doc__.strip(),
author = 'Nexedi SA', author = 'Nexedi SA',
author_email = 'neo-dev@erp5.org', author_email = 'neo-dev@erp5.org',
url = 'http://www.neoppod.org/', url = 'https://neo.nexedi.com/',
license = 'GPL 2+', license = 'GPL 2+',
platforms = ["any"], platforms = ["any"],
classifiers=classifiers.splitlines(), classifiers=classifiers.splitlines(),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment