Commit 04f72a4c authored by Julien Muchembled's avatar Julien Muchembled

New feature to check that partitions are replicated properly

This includes an API change of Node.isIdentified, which now tells whether
identification packets have been exchanged or not.
All handlers must be updated to implement '_acceptIdentification' instead of
overriding EventHandler.acceptIdentification: this patch only does it for
StorageOperationHandler
parent 2241c3a1
......@@ -133,6 +133,11 @@ RC - Review output of pylint (CODE)
be done ? hope to find a storage with valid checksum ? assume that data
is correct in storage but was altered when it travelled through network
as we loaded it ?).
- Check replicas: (HIGH AVAILABILITY)
- Automatically tell corrupted cells to fix their data when a good source
is known.
- Add an option to also check all rows of trans/obj/data, instead of only
keys (trans.tid & obj.{tid,oid}).
Master
- Master node data redundancy (HIGH AVAILABILITY)
......
......@@ -83,6 +83,7 @@ class AdminEventHandler(EventHandler):
addPendingNodes = forward_ask(Packets.AddPendingNodes)
setClusterState = forward_ask(Packets.SetClusterState)
checkReplicas = forward_ask(Packets.CheckReplicas)
class MasterEventHandler(EventHandler):
......
......@@ -15,6 +15,7 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from functools import wraps
import neo.lib
from .protocol import (
NodeStates, Packets, ErrorCodes, Errors, BrokenNodeDisallowedError,
......@@ -121,6 +122,19 @@ class EventHandler(object):
# Packet handlers.
def acceptIdentification(self, conn, node_type, *args):
try:
acceptIdentification = self._acceptIdentification
except AttributeError:
raise UnexpectedPacketError('no handler found')
node = self.app.nm.getByAddress(conn.getAddress())
assert node.getConnection() is conn, (node.getConnection(), conn)
if node.getType() == node_type:
node.setIdentified()
acceptIdentification(node, *args)
return
conn.close()
def ping(self, conn):
conn.answer(Packets.Pong())
......
......@@ -37,6 +37,7 @@ class Node(object):
self._uuid = uuid
self._manager = manager
self._last_state_change = time()
self._identified = False
manager.add(self)
def notify(self, packet):
......@@ -98,6 +99,7 @@ class Node(object):
"""
assert self._connection is not None
del self._connection
self._identified = False
self._manager._updateIdentified(self)
def setConnection(self, connection, force=None):
......@@ -113,6 +115,8 @@ class Node(object):
conn = self._connection
if conn is None:
self._connection = connection
if connection.isServer():
self.setIdentified()
else:
assert force is not None, \
attributeTracker.whoSet(self, '_connection')
......@@ -127,7 +131,11 @@ class Node(object):
if not force or conn.getPeerId() is not None or \
type(conn.getHandler()) is not type(connection.getHandler()):
raise ProtocolError("already connected")
conn.setOnClose(lambda: setattr(self, '_connection', connection))
def on_closed():
self._connection = connection
assert connection.isServer()
self.setIdentified()
conn.setOnClose(on_closed)
conn.close()
assert not connection.isClosed(), connection
connection.setOnClose(self.onConnectionClosed)
......@@ -147,11 +155,15 @@ class Node(object):
return self._connection is not None and (connecting or
not self._connection.connecting)
def setIdentified(self):
assert self._connection is not None
self._identified = True
def isIdentified(self):
"""
Returns True is the node is connected and identified
Returns True if identification packets have been exchanged
"""
return self._connection is not None and self._uuid is not None
return self._identified
def __repr__(self):
return '<%s(uuid=%s, address=%s, state=%s, connection=%r) at %x>' % (
......@@ -396,7 +408,10 @@ class NodeManager(object):
def _updateIdentified(self, node):
uuid = node.getUUID()
if node.isIdentified():
if uuid:
# XXX: It's probably a bug to include connecting nodes but there's
# no API yet to update manager when connection is established.
if node.isConnected(connecting=True):
self._identified_dict[uuid] = node
else:
self._identified_dict.pop(uuid, None)
......
......@@ -25,7 +25,7 @@ from struct import Struct
from .util import Enum, getAddressType
# The protocol version (major, minor).
PROTOCOL_VERSION = (6, 1)
PROTOCOL_VERSION = (7, 1)
# Size restrictions.
MIN_PACKET_SIZE = 10
......@@ -49,6 +49,7 @@ class ErrorCodes(Enum):
BROKEN_NODE = Enum.Item(5)
ALREADY_PENDING = Enum.Item(7)
REPLICATION_ERROR = Enum.Item(8)
CHECKING_ERROR = Enum.Item(9)
ErrorCodes = ErrorCodes()
class ClusterStates(Enum):
......@@ -83,6 +84,7 @@ class CellStates(Enum):
OUT_OF_DATE = Enum.Item(2)
FEEDING = Enum.Item(3)
DISCARDED = Enum.Item(4)
CORRUPTED = Enum.Item(5)
CellStates = CellStates()
class LockState(Enum):
......@@ -108,6 +110,7 @@ cell_state_prefix_dict = {
CellStates.OUT_OF_DATE: 'O',
CellStates.FEEDING: 'F',
CellStates.DISCARDED: 'D',
CellStates.CORRUPTED: 'C',
}
# Other constants.
......@@ -1239,6 +1242,35 @@ class Pack(Packet):
PBoolean('status'),
)
class CheckReplicas(Packet):
"""
ctl -> A
A -> M
"""
_fmt = PStruct('check_replicas',
PDict('partition_dict',
PNumber('partition'),
PUUID('source'),
),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = Error
class CheckPartition(Packet):
"""
M -> S
"""
_fmt = PStruct('check_partition',
PNumber('partition'),
PStruct('source',
PString('upstream_name'),
PAddress('address'),
),
PTID('min_tid'),
PTID('max_tid'),
)
class CheckTIDRange(Packet):
"""
Ask some stats about a range of transactions.
......@@ -1251,15 +1283,13 @@ class CheckTIDRange(Packet):
S -> S
"""
_fmt = PStruct('ask_check_tid_range',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
PNumber('length'),
PNumber('partition'),
)
_answer = PStruct('answer_check_tid_range',
PTID('min_tid'),
PNumber('length'),
PNumber('count'),
PChecksum('checksum'),
PTID('max_tid'),
......@@ -1277,22 +1307,30 @@ class CheckSerialRange(Packet):
S -> S
"""
_fmt = PStruct('ask_check_serial_range',
POID('min_oid'),
PTID('min_serial'),
PTID('max_tid'),
PNumber('length'),
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
)
_answer = PStruct('answer_check_serial_range',
POID('min_oid'),
PTID('min_serial'),
PNumber('length'),
PNumber('count'),
PChecksum('tid_checksum'),
PTID('max_tid'),
PChecksum('oid_checksum'),
POID('max_oid'),
PChecksum('serial_checksum'),
PTID('max_serial'),
)
class PartitionCorrupted(Packet):
"""
S -> M
"""
_fmt = PStruct('partition_corrupted',
PNumber('partition'),
PList('cell_list',
PUUID('uuid'),
),
)
class LastTransaction(Packet):
......@@ -1601,10 +1639,16 @@ class Packets(dict):
TIDListFrom)
AskPack, AnswerPack = register(
Pack, ignore_when_closed=False)
CheckReplicas = register(
CheckReplicas)
CheckPartition = register(
CheckPartition)
AskCheckTIDRange, AnswerCheckTIDRange = register(
CheckTIDRange)
AskCheckSerialRange, AnswerCheckSerialRange = register(
CheckSerialRange)
NotifyPartitionCorrupted = register(
PartitionCorrupted)
NotifyReady = register(
NotifyReady)
AskLastTransaction, AnswerLastTransaction = register(
......
......@@ -34,7 +34,7 @@ class Cell(object):
def __init__(self, node, state = CellStates.UP_TO_DATE):
self.node = node
self.setState(state)
self.state = state
def __repr__(self):
return "<Cell(uuid=%s, address=%s, state=%s)>" % (
......@@ -59,6 +59,13 @@ class Cell(object):
def isFeeding(self):
return self.state == CellStates.FEEDING
def isCorrupted(self):
return self.state == CellStates.CORRUPTED
def isReadable(self):
return self.state == CellStates.UP_TO_DATE or \
self.state == CellStates.FEEDING
def getNode(self):
return self.node
......@@ -122,6 +129,12 @@ class PartitionTable(object):
except IndexError:
return False
def getNodeSet(self):
return set(x.getNode() for row in self.partition_list for x in row)
def getConnectedNodeList(self):
return [node for node in self.getNodeSet() if node.isConnected()]
def getNodeList(self):
"""Return all used nodes."""
return [node for node, count in self.count_dict.iteritems() \
......@@ -129,8 +142,7 @@ class PartitionTable(object):
def getCellList(self, offset, readable=False):
if readable:
return [cell for cell in self.partition_list[offset]
if not cell.isOutOfDate()]
return filter(Cell.isReadable, self.partition_list[offset])
return list(self.partition_list[offset])
def getPartition(self, oid_or_tid):
......@@ -280,7 +292,7 @@ class PartitionTable(object):
return False
for row in self.partition_list:
for cell in row:
if not cell.isOutOfDate() and cell.getNode().isRunning():
if cell.isReadable() and cell.getNode().isRunning():
break
else:
return False
......
......@@ -279,7 +279,7 @@ class BackupApplication(object):
primary = primary_node is node
result = None if primary else app.pt.setUpToDate(node, offset)
if app.getClusterState() == ClusterStates.BACKINGUP:
assert not cell.isOutOfDate()
assert cell.isReadable()
if result: # was out-of-date
max_tid, = [x.backup_tid for x in cell_list
if x.getNode() is primary_node]
......
......@@ -15,6 +15,7 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import random
import neo
from . import MasterHandler
......@@ -162,3 +163,48 @@ class AdministrationHandler(MasterHandler):
# broadcast the new partition table
app.broadcastPartitionChanges(cell_list)
conn.answer(Errors.Ack('Nodes added: %s' % (uuids, )))
def checkReplicas(self, conn, partition_dict, min_tid, max_tid):
app = self.app
pt = app.pt
backingup = app.cluster_state == ClusterStates.BACKINGUP
if not max_tid:
max_tid = pt.getCheckTid(partition_dict) if backingup else \
app.getLastTransaction()
if min_tid > max_tid:
neo.lib.logging.warning("nothing to check: min_tid=%s > max_tid=%s",
dump(min_tid), dump(max_tid))
else:
getByUUID = app.nm.getByUUID
node_set = set()
for offset, source in partition_dict.iteritems():
# XXX: For the moment, code checking replicas is unable to fix
# corrupted partitions (when a good cell is known)
# so only check readable ones.
# (see also Checker._nextPartition of storage)
cell_list = pt.getCellList(offset, True)
#cell_list = [cell for cell in pt.getCellList(offset)
# if not cell.isOutOfDate()]
if len(cell_list) + (backingup and not source) <= 1:
continue
for cell in cell_list:
node = cell.getNode()
if node in node_set:
break
else:
node_set.add(node)
if source:
source = '', getByUUID(source).getAddress()
else:
readable = [cell for cell in cell_list if cell.isReadable()]
if 1 == len(readable) < len(cell_list):
source = '', readable[0].getAddress()
elif backingup:
source = app.backup_app.name, random.choice(
app.backup_app.pt.getCellList(offset, readable=True)
).getAddress()
else:
source = '', None
node.getConnection().notify(Packets.CheckPartition(
offset, source, min_tid, max_tid))
conn.answer(Errors.Ack(''))
......@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib
from neo.lib.protocol import ClusterStates, Packets, ProtocolError
from neo.lib.protocol import CellStates, ClusterStates, Packets, ProtocolError
from neo.lib.exception import OperationFailure
from neo.lib.util import dump
from neo.lib.connector import ConnectorConnectionClosedException
......@@ -72,6 +72,17 @@ class StorageServiceHandler(BaseServiceHandler):
# transaction locked on this storage node
self.app.tm.lock(ttid, conn.getUUID())
def notifyPartitionCorrupted(self, conn, partition, cell_list):
change_list = []
for cell in self.app.pt.getCellList(partition):
if cell.getUUID() in cell_list:
cell.setState(CellStates.CORRUPTED)
change_list.append((partition, cell.getUUID(),
CellStates.CORRUPTED))
self.app.broadcastPartitionChanges(change_list)
if not self.app.pt.operational():
raise OperationFailure('cannot continue operation')
def notifyReplicationDone(self, conn, offset, tid):
app = self.app
node = app.nm.getByUUID(conn.getUUID())
......
......@@ -25,12 +25,13 @@ class Cell(neo.lib.pt.Cell):
replicating = ZERO_TID
def setState(self, state):
readable = self.isReadable()
super(Cell, self).setState(state)
if readable and not self.isReadable():
try:
if CellStates.OUT_OF_DATE == state != self.state:
del self.backup_tid, self.replicating
except AttributeError:
pass
return super(Cell, self).setState(state)
neo.lib.pt.Cell = Cell
......@@ -196,7 +197,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
CellStates.OUT_OF_DATE))
node_count += 1
elif node_count + 1 < max_count:
if feeding_cell is not None or max_cell.isOutOfDate():
if feeding_cell is not None or not max_cell.isReadable():
# If there is a feeding cell already or it is
# out-of-date, just drop the node.
row.remove(max_cell)
......@@ -239,10 +240,10 @@ class PartitionTable(neo.lib.pt.PartitionTable):
else:
# Remove an excessive feeding cell.
removed_cell_list.append(cell)
elif cell.isOutOfDate():
out_of_date_cell_list.append(cell)
else:
elif cell.isUpToDate():
up_to_date_cell_list.append(cell)
else:
out_of_date_cell_list.append(cell)
# If all cells are up-to-date, a feeding cell is not required.
if len(out_of_date_cell_list) == 0 and feeding_cell is not None:
......@@ -311,7 +312,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
lost = lost_node
cell_list = []
for cell in row:
if not cell.isOutOfDate():
if cell.isReadable():
if cell.getNode().isRunning():
lost = None
else :
......@@ -330,7 +331,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
yield offset, cell
break
def getUpToDateCellNodeSet(self):
def getReadableCellNodeSet(self):
"""
Return a set of all nodes which are part of at least one UP TO DATE
partition.
......@@ -338,17 +339,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
return set(cell.getNode()
for row in self.partition_list
for cell in row
if not cell.isOutOfDate())
def getOutOfDateCellNodeSet(self):
"""
Return a set of all nodes which are part of at least one OUT OF DATE
partition.
"""
return set(cell.getNode()
for row in self.partition_list
for cell in row
if cell.isOutOfDate())
if cell.isReadable())
def setBackupTidDict(self, backup_tid_dict):
for row in self.partition_list:
......@@ -358,8 +349,16 @@ class PartitionTable(neo.lib.pt.PartitionTable):
def getBackupTid(self):
try:
return min(max(cell.backup_tid for cell in row
if not cell.isOutOfDate())
return min(max(cell.backup_tid for cell in row if cell.isReadable())
for row in self.partition_list)
except ValueError:
return ZERO_TID
def getCheckTid(self, partition_list):
try:
return min(min(cell.backup_tid
for cell in self.partition_list[offset]
if cell.isReadable())
for offset in partition_list)
except ValueError:
return ZERO_TID
......@@ -65,39 +65,39 @@ class RecoveryManager(MasterHandler):
if pt.filled():
# A partition table exists, we are starting an existing
# cluster.
partition_node_set = pt.getUpToDateCellNodeSet()
partition_node_set = pt.getReadableCellNodeSet()
pending_node_set = set(x for x in partition_node_set
if x.isPending())
if app._startup_allowed or \
partition_node_set == pending_node_set:
allowed_node_set = pending_node_set
extra_node_set = pt.getOutOfDateCellNodeSet()
node_list = pt.getConnectedNodeList
elif app._startup_allowed:
# No partition table and admin allowed startup, we are
# creating a new cluster out of all pending nodes.
allowed_node_set = set(app.nm.getStorageList(
only_identified=True))
extra_node_set = set()
node_list = lambda: allowed_node_set
if allowed_node_set:
for node in allowed_node_set:
assert node.isPending(), node
if node.getConnection().isPending():
break
else:
allowed_node_set |= extra_node_set
node_list = node_list()
break
neo.lib.logging.info('startup allowed')
for node in allowed_node_set:
for node in node_list:
node.setRunning()
app.broadcastNodesInformation(allowed_node_set)
app.broadcastNodesInformation(node_list)
if pt.getID() is None:
neo.lib.logging.info('creating a new partition table')
# reset IDs generators & build new partition with running nodes
app.tm.setLastOID(ZERO_OID)
pt.make(allowed_node_set)
pt.make(node_list)
self._broadcastPartitionTable(pt.getID(), pt.getRowList())
elif app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict)
......
......@@ -17,7 +17,7 @@
from .neoctl import NeoCTL, NotReadyException
from neo.lib.util import bin, dump
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, ZERO_TID
action_dict = {
'print': {
......@@ -30,6 +30,7 @@ action_dict = {
'node': 'setNodeState',
'cluster': 'setClusterState',
},
'check': 'checkReplicas',
'start': 'startCluster',
'add': 'enableStorageList',
'drop': 'dropNode',
......@@ -187,6 +188,33 @@ class TerminalNeoCTL(object):
"""
return self.formatUUID(self.neoctl.getPrimary())
def checkReplicas(self, params):
"""
Parameters: [partition]:[reference] ... [min_tid [max_tid]]
"""
partition_dict = {}
params = iter(params)
min_tid = ZERO_TID
max_tid = None
for p in params:
try:
partition, source = p.split(':')
except ValueError:
min_tid = p64(p)
try:
max_tid = p64(params.next())
except StopIteration:
pass
break
source = bin(source) if source else None
if partition:
partition_dict[int(partition)] = source
else:
assert not partition_dict
np = len(self.neoctl.getPartitionRowList()[1])
partition_dict = dict.fromkeys(xrange(np), source)
self.neoctl.checkReplicas(partition_dict, min_tid, max_tid)
class Application(object):
"""The storage node application."""
......
......@@ -163,3 +163,8 @@ class NeoCTL(object):
raise RuntimeError(response)
return response[1]
def checkReplicas(self, *args):
response = self.__ask(Packets.CheckReplicas(*args))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response)
return response[2]
......@@ -28,6 +28,7 @@ from neo.lib.connector import getConnectorHandler
from neo.lib.pt import PartitionTable
from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager
from .checker import Checker
from .database import buildDatabaseManager
from .exception import AlreadyPendingError
from .handlers import identification, verification, initialization
......@@ -66,6 +67,7 @@ class Application(object):
# partitions.
self.pt = None
self.checker = Checker(self)
self.replicator = Replicator(self)
self.listening_conn = None
self.master_conn = None
......@@ -207,6 +209,8 @@ class Application(object):
neo.lib.logging.error('operation stopped: %s', msg)
except PrimaryFailure, msg:
neo.lib.logging.error('primary master is down: %s', msg)
finally:
self.checker = Checker(self)
def connectToPrimary(self):
"""Find a primary master node, and connect to it.
......@@ -369,6 +373,11 @@ class Application(object):
return
self.task_queue.appendleft(iterator)
def closeClient(self, connection):
if connection is not self.replicator.getCurrentConnection() and \
connection not in self.checker.conn_dict:
connection.closeClient()
def shutdown(self, erase=False):
"""Close all connections and exit"""
for c in self.em.getConnectionList():
......
##############################################################################
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsibility of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# guarantees and support are strongly advised to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
##############################################################################
from collections import deque
from functools import wraps
import neo.lib
from neo.lib.connection import ClientConnection
from neo.lib.connector import ConnectorConnectionClosedException
from neo.lib.protocol import NodeTypes, Packets, ZERO_OID
from neo.lib.util import add64, dump
from .handlers.storage import StorageOperationHandler
CHECK_COUNT = 4000
class Checker(object):
def __init__(self, app):
self.app = app
self.queue = deque()
self.conn_dict = {}
def __call__(self, partition, source, min_tid, max_tid):
self.queue.append((partition, source, min_tid, max_tid))
if not self.conn_dict:
self._nextPartition()
def _nextPartition(self):
app = self.app
def connect(node, uuid=app.uuid, name=app.name):
if node.getUUID() == app.uuid:
return
if node.isConnected(connecting=True):
conn = node.getConnection()
conn.asClient()
else:
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler())
conn.ask(Packets.RequestIdentification(
NodeTypes.STORAGE, uuid, app.server, name))
self.conn_dict[conn] = node.isIdentified()
conn_set = set(self.conn_dict)
conn_set.discard(None)
try:
self.conn_dict.clear()
while True:
try:
partition, (name, source), min_tid, max_tid = \
self.queue.popleft()
except IndexError:
return
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
msg = "discarded or out-of-date"
else:
try:
for cell in app.pt.getCellList(partition):
# XXX: Ignore corrupted cells for the moment
# because we're still unable to fix them
# (see also AdministrationHandler of master)
if cell.isReadable(): #if not cell.isOutOfDate():
connect(cell.getNode())
if source:
node = app.nm.getByAddress(source)
if name:
source = app.nm.createStorage(address=source) \
if node is None else node
connect(source, None, name)
elif (node.getUUID() == app.uuid or
node.isConnected(connecting=True) and
node.getConnection() in self.conn_dict):
source = node
else:
msg = "unavailable source"
if self.conn_dict:
break
msg = "no replica"
except ConnectorConnectionClosedException:
msg = "connection closed"
finally:
conn_set.update(self.conn_dict)
self.conn_dict.clear()
neo.lib.logging.error(
"Failed to start checking partition %u (%s)",
partition, msg)
conn_set.difference_update(self.conn_dict)
finally:
for conn in conn_set:
app.closeClient(conn)
neo.lib.logging.debug("start checking partition %u from %s to %s",
partition, dump(min_tid), dump(max_tid))
self.min_tid = self.next_tid = min_tid
self.max_tid = max_tid
self.next_oid = None
self.partition = partition
self.source = source
args = partition, CHECK_COUNT, min_tid, max_tid
p = Packets.AskCheckTIDRange(*args)
for conn, identified in self.conn_dict.items():
self.conn_dict[conn] = conn.ask(p) if identified else None
self.conn_dict[None] = app.dm.checkTIDRange(*args)
def connected(self, node):
conn = node.getConnection()
if self.conn_dict.get(conn, self) is None:
self.conn_dict[conn] = conn.ask(Packets.AskCheckTIDRange(
self.partition, CHECK_COUNT, self.next_tid, self.max_tid))
def connectionLost(self, conn):
try:
del self.conn_dict[conn]
except KeyError:
return
if self.source is not None and self.source.getConnection() is conn:
del self.source
elif len(self.conn_dict) > 1:
neo.lib.logging.warning("node lost but keep up checking partition"
" %u", self.partition)
return
neo.lib.logging.warning("check of partition %u aborted", self.partition)
self._nextPartition()
def _nextRange(self):
if self.next_oid:
args = self.partition, CHECK_COUNT, self.next_tid, self.max_tid, \
self.next_oid
p = Packets.AskCheckSerialRange(*args)
check = self.app.dm.checkSerialRange
else:
args = self.partition, CHECK_COUNT, self.next_tid, self.max_tid
p = Packets.AskCheckTIDRange(*args)
check = self.app.dm.checkTIDRange
for conn in self.conn_dict.keys():
self.conn_dict[conn] = check(*args) if conn is None else conn.ask(p)
def checkRange(self, conn, *args):
if self.conn_dict.get(conn, self) != conn.getPeerId():
# Ignore answers to old requests,
# because we did nothing to cancel them.
neo.lib.logging.info("ignored AnswerCheck*Range%r", args)
return
self.conn_dict[conn] = args
answer_set = set(self.conn_dict.itervalues())
if len(answer_set) > 1:
for answer in answer_set:
if type(answer) is not tuple:
return
# TODO: Automatically tell corrupted cells to fix their data
# if we know a good source.
# For the moment, tell master to put them in CORRUPTED state
# and keep up checking if useful.
uuid = self.app.uuid
args = None if self.source is None else self.conn_dict[
None if self.source.getUUID() == uuid
else self.source.getConnection()]
uuid_list = []
for conn, answer in self.conn_dict.items():
if answer != args:
del self.conn_dict[conn]
if conn is None:
uuid_list.append(uuid)
else:
uuid_list.append(conn.getUUID())
self.app.closeClient(conn)
p = Packets.NotifyPartitionCorrupted(self.partition, uuid_list)
self.app.master_conn.notify(p)
if len(self.conn_dict) <= 1:
neo.lib.logging.warning("check of partition %u aborted",
self.partition)
self.queue.clear()
self._nextPartition()
return
try:
count, _, max_tid = args
except ValueError:
count, _, self.next_tid, _, max_oid = args
if count < CHECK_COUNT:
neo.lib.logging.debug("partition %u checked from %s to %s",
self.partition, dump(self.min_tid), dump(self.max_tid))
self._nextPartition()
return
self.next_oid = add64(max_oid, 1)
else:
(count, _, max_tid), = answer_set
if count < CHECK_COUNT:
self.next_tid = self.min_tid
self.next_oid = ZERO_OID
else:
self.next_tid = add64(max_tid, 1)
self._nextRange()
......@@ -532,7 +532,7 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def checkTIDRange(self, min_tid, max_tid, length, partition):
def checkTIDRange(self, partition, length, min_tid, max_tid):
"""
Generate a diggest from transaction list.
min_tid (packed)
......@@ -549,12 +549,12 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
"""
Generate a diggest from object list.
min_oid (packed)
OID at which verification starts.
min_serial (packed)
min_tid (packed)
Serial of min_oid object at which search should start.
length
Maximum number of records to include in result.
......
......@@ -702,7 +702,7 @@ class MySQLDatabaseManager(DatabaseManager):
data_id_set.discard(None)
self._pruneData(data_id_set)
def checkTIDRange(self, min_tid, max_tid, length, partition):
def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tid_checksum, max_tid = self.query(
"""SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
FROM (SELECT tid FROM trans
......@@ -713,30 +713,30 @@ class MySQLDatabaseManager(DatabaseManager):
'partition': partition,
'min_tid': util.u64(min_tid),
'max_tid': util.u64(max_tid),
'limit': '' if length is None else 'LIMIT %(length)d' % length,
'limit': '' if length is None else 'LIMIT %u' % length,
})[0]
if count:
return count, a2b_hex(tid_checksum), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_.
# We would need a function (that could be named 'LAST') that returns the
# last grouped value, instead of the greatest one.
r = self.query(
"""SELECT oid, tid
"""SELECT tid, oid
FROM obj
WHERE partition = %(partition)s
AND tid <= %(max_tid)d
AND (oid > %(min_oid)d OR
oid = %(min_oid)d AND tid >= %(min_tid)d)
ORDER BY oid ASC, tid ASC %(limit)s""" % {
AND (tid > %(min_tid)d OR
tid = %(min_tid)d AND oid >= %(min_oid)d)
ORDER BY tid, oid %(limit)s""" % {
'min_oid': u64(min_oid),
'min_tid': u64(min_serial),
'min_tid': u64(min_tid),
'max_tid': u64(max_tid),
'limit': '' if length is None else 'LIMIT %(length)d' % length,
'limit': '' if length is None else 'LIMIT %u' % length,
'partition': partition,
})
if r:
......@@ -746,4 +746,4 @@ class MySQLDatabaseManager(DatabaseManager):
p64(r[-1][0]),
sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
......@@ -595,7 +595,7 @@ class SQLiteDatabaseManager(DatabaseManager):
data_id_set.discard(None)
self._pruneData(data_id_set)
def checkTIDRange(self, min_tid, max_tid, length, partition):
def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tids, max_tid = self.query("""\
SELECT COUNT(*), GROUP_CONCAT(tid), MAX(tid)
FROM (SELECT tid FROM trans
......@@ -607,20 +607,19 @@ class SQLiteDatabaseManager(DatabaseManager):
return count, sha1(tids).digest(), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_.
# We would need a function (that could be named 'LAST') that returns the
# last grouped value, instead of the greatest one.
min_oid = u64(min_oid)
min_tid = u64(min_tid)
r = self.query("""\
SELECT oid, tid
SELECT tid, oid
FROM obj
WHERE partition=? AND tid<=?
AND (oid>? OR oid=? AND tid>=?)
ORDER BY oid ASC, tid ASC LIMIT ?""",
(partition, u64(max_tid), min_oid, min_oid, u64(min_serial),
WHERE partition=? AND tid<=? AND (tid>? OR tid=? AND oid>=?)
ORDER BY tid, oid LIMIT ?""",
(partition, u64(max_tid), min_tid, min_tid, u64(min_oid),
-1 if length is None else length)).fetchall()
if r:
p64 = util.p64
......@@ -629,4 +628,4 @@ class SQLiteDatabaseManager(DatabaseManager):
p64(r[-1][0]),
sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
......@@ -72,3 +72,6 @@ class MasterOperationHandler(BaseMasterHandler):
def askTruncate(self, conn, tid):
self.app.dm.truncate(tid)
conn.answer(Packets.AnswerTruncate())
def checkPartition(self, conn, *args):
self.app.checker(*args)
......@@ -25,26 +25,42 @@ from neo.lib.protocol import Errors, NodeStates, Packets, \
from neo.lib.util import add64
def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
def wrapper(self, conn, *args, **kw):
assert self.app.replicator.getCurrentConnection() is conn
return func(self, conn, *args, **kw)
return wraps(func)(decorator)
return wraps(func)(wrapper)
def checkFeedingConnection(check):
def decorator(func):
def wrapper(self, conn, partition, *args, **kw):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or (cell.isOutOfDate() if check else
not cell.isReadable()):
p = Errors.CheckingError if check else Errors.ReplicationError
return conn.answer(p("partition %u not readable" % partition))
conn.asServer()
return func(self, conn, partition, *args, **kw)
return wraps(func)(wrapper)
return decorator
class StorageOperationHandler(EventHandler):
"""This class handles events for replications."""
def connectionLost(self, conn, new_state):
if self.app.listening_conn and conn.isClient():
app = self.app
if app.listening_conn and conn.isClient():
# XXX: Connection and Node should merged.
uuid = conn.getUUID()
if uuid:
node = self.app.nm.getByUUID(uuid)
node = app.nm.getByUUID(uuid)
else:
node = self.app.nm.getByAddress(conn.getAddress())
node = app.nm.getByAddress(conn.getAddress())
node.setState(NodeStates.DOWN)
replicator = self.app.replicator
replicator = app.replicator
if replicator.current_node is node:
replicator.abort()
app.checker.connectionLost(conn)
# Client
......@@ -52,10 +68,9 @@ class StorageOperationHandler(EventHandler):
if self.app.listening_conn:
self.app.replicator.abort()
@checkConnectionIsReplicatorConnection
def acceptIdentification(self, conn, node_type,
uuid, num_partitions, num_replicas, your_uuid):
self.app.replicator.fetchTransactions()
def _acceptIdentification(self, node, *args):
self.app.replicator.connected(node)
self.app.checker.connected(node)
@checkConnectionIsReplicatorConnection
def answerFetchTransactions(self, conn, pack_tid, next_tid, tid_list):
......@@ -105,33 +120,53 @@ class StorageOperationHandler(EventHandler):
def replicationError(self, conn, message):
self.app.replicator.abort('source message: ' + message)
def checkingError(self, conn, message):
try:
self.app.checker.connectionLost(conn)
finally:
self.app.closeClient(conn)
@property
def answerCheckTIDRange(self):
return self.app.checker.checkRange
@property
def answerCheckSerialRange(self):
return self.app.checker.checkRange
# Server (all methods must set connection as server so that it isn't closed
# if client tasks are finished)
def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
conn.asServer()
count, tid_checksum, max_tid = self.app.dm.checkTIDRange(min_tid,
max_tid, length, partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length,
count, tid_checksum, max_tid))
@checkFeedingConnection(check=True)
def askCheckTIDRange(self, conn, *args):
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
def check():
r = self.app.dm.checkTIDRange(*args)
try:
conn.answer(Packets.AnswerCheckTIDRange(*r), msg_id)
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
yield
self.app.newTask(check())
def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition):
conn.asServer()
count, oid_checksum, max_oid, serial_checksum, max_serial = \
self.app.dm.checkSerialRange(min_oid, min_serial, max_tid, length,
partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial))
@checkFeedingConnection(check=True)
def askCheckSerialRange(self, conn, *args):
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
def check():
r = self.app.dm.checkSerialRange(*args)
try:
conn.answer(Packets.AnswerCheckSerialRange(*r), msg_id)
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
yield
self.app.newTask(check())
@checkFeedingConnection(check=False)
def askFetchTransactions(self, conn, partition, length, min_tid, max_tid,
tid_list):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
peer_tid_set = set(tid_list)
......@@ -162,14 +197,10 @@ class StorageOperationHandler(EventHandler):
pass
app.newTask(push())
@checkFeedingConnection(check=False)
def askFetchObjects(self, conn, partition, length, min_tid, max_tid,
min_oid, object_dict):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
dm = app.dm
......
......@@ -132,7 +132,7 @@ class Replicator(object):
outdated_list = []
for offset in xrange(pt.getPartitions()):
for cell in pt.getCellList(offset):
if cell.getUUID() == uuid:
if cell.getUUID() == uuid and not cell.isCorrupted():
self.partition_dict[offset] = p = Partition()
if cell.isOutOfDate():
outdated_list.append(offset)
......@@ -154,17 +154,25 @@ class Replicator(object):
abort = False
added_list = []
app = self.app
last_tid, last_trans_dict, last_obj_dict = app.dm.getLastTIDs()
for offset, uuid, state in cell_list:
if uuid == app.uuid:
if state == CellStates.DISCARDED:
if state in (CellStates.DISCARDED, CellStates.CORRUPTED):
try:
del self.partition_dict[offset]
except KeyError:
continue
self.replicate_dict.pop(offset, None)
self.source_dict.pop(offset, None)
abort = abort or self.current_partition == offset
elif state == CellStates.OUT_OF_DATE:
assert offset not in self.partition_dict
self.partition_dict[offset] = p = Partition()
p.next_trans = p.next_obj = ZERO_TID
try:
p.next_trans = add64(last_trans_dict[offset], 1)
except KeyError:
p.next_trans = ZERO_TID
p.next_obj = last_obj_dict.get(offset, ZERO_TID)
p.max_ttid = INVALID_TID
added_list.append(offset)
if added_list:
......@@ -212,11 +220,10 @@ class Replicator(object):
self.current_partition = offset
previous_node = self.current_node
self.current_node = node
if node.isConnected():
if node.isConnected(connecting=True):
if node.isIdentified():
node.getConnection().asClient()
self.fetchTransactions()
if node is previous_node:
return
else:
assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app),
......@@ -224,7 +231,11 @@ class Replicator(object):
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name))
if previous_node is not None and previous_node.isConnected():
previous_node.getConnection().closeClient()
app.closeClient(previous_node.getConnection())
def connected(self, node):
if self.current_node is node and self.current_partition is not None:
self.fetchTransactions()
def fetchTransactions(self, min_tid=None):
offset = self.current_partition
......
......@@ -15,10 +15,11 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from binascii import a2b_hex
import unittest
from mock import Mock
from neo.lib.util import dump, p64, u64
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, MAX_TID
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
from .. import NeoUnitTestBase
from neo.lib.exception import DatabaseFailure
......@@ -499,10 +500,6 @@ class StorageDBTests(NeoUnitTestBase):
def test_getReplicationTIDList(self):
self.setNumPartitions(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
# - all
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3])
# - one partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3])
......@@ -519,6 +516,37 @@ class StorageDBTests(NeoUnitTestBase):
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 0)
self.checkSet(result, [tid1])
def test_checkRange(self):
def check(trans, obj, *args):
self.assertEqual(trans, self.db.checkTIDRange(*args))
self.assertEqual(obj, self.db.checkSerialRange(*(args+(ZERO_OID,))))
self.setNumPartitions(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
z = 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
# - one partition
check((2, a2b_hex('84320eb8dbbe583f67055c15155ab6794f11654d'), tid3),
z,
0, 10, ZERO_TID, MAX_TID)
# - another partition
check((2, a2b_hex('1f02f98cf775a9e0ce9252ff5972dce728c4ddb0'), tid4),
(4, a2b_hex('e5b47bddeae2096220298df686737d939a27d736'), tid4,
a2b_hex('1e9093698424b5370e19acd2d5fc20dcd56a32cd'), p64(1)),
1, 10, ZERO_TID, MAX_TID)
self.assertEqual(
(3, a2b_hex('b85e2d4914e22b5ad3b82b312b3dc405dc17dcb8'), tid4,
a2b_hex('1b6d73ecdc064595fe915a5c26da06b195caccaa'), p64(1)),
self.db.checkSerialRange(1, 10, ZERO_TID, MAX_TID, p64(2)))
# - min_tid is inclusive
check((1, a2b_hex('da4b9237bacccdf19c0760cab7aec4a8359010b0'), tid3),
z,
0, 10, tid3, MAX_TID)
# - max tid is inclusive
x = 1, a2b_hex('b6589fc6ab0dc82cf12099d1c2d40ab994e8410c'), tid1
check(x, z, 0, 10, ZERO_TID, tid2)
# - limit
y = 1, a2b_hex('356a192b7913b04c54574d18c28d46e6395428ab'), tid2
check(y, x + y[1:], 1, 1, ZERO_TID, MAX_TID)
def test_findUndoTID(self):
self.setNumPartitions(4, True)
db = self.db
......
......@@ -638,71 +638,6 @@ class ProtocolTests(NeoUnitTestBase):
def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
def test_AskCheckTIDRange(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckTIDRange(min_tid, max_tid, length, partition)
p_min_tid, p_max_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_max_tid, max_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
count = 1
tid_checksum = "3" * 20
max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid)
p_min_tid, p_length, p_count, p_tid_checksum, p_max_tid = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_tid, max_tid)
def test_AskCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
max_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckSerialRange(min_oid, min_serial, max_tid, length,
partition)
p_min_oid, p_min_serial, p_max_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_max_tid, max_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
count = 1
oid_checksum = "4" * 20
max_oid = self.getOID(5)
tid_checksum = "5" * 20
max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial)
p_min_oid, p_min_serial, p_length, p_count, p_oid_checksum, \
p_max_oid, p_tid_checksum, p_max_serial = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_oid_checksum, oid_checksum)
self.assertEqual(p_max_oid, max_oid)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_serial, max_serial)
def test_AskPack(self):
tid = self.getNextTID()
p = Packets.AskPack(tid)
......
......@@ -715,7 +715,7 @@ class NEOCluster(object):
txn = transaction.Transaction()
storage.tpc_begin(txn, tid(i))
for o in oid_list:
storage.store(p64(o), tid_dict.get(o), repr((i, o)), '', txn)
storage.store(oid(o), tid_dict.get(o), repr((i, o)), '', txn)
storage.tpc_vote(txn)
i = storage.tpc_finish(txn)
for o in oid_list:
......@@ -788,6 +788,7 @@ def predictable_random(seed=None):
# Because we have 2 running threads when client works, we can't
# patch neo.client.pool (and cluster should have 1 storage).
from neo.master import backup_app
from neo.master.handlers import administration
from neo.storage import replicator
def decorator(wrapped):
def wrapper(*args, **kw):
......@@ -800,10 +801,12 @@ def predictable_random(seed=None):
if node_type == NodeTypes.CLIENT else
UUID_NAMESPACES[node_type] + ''.join(
chr(r.randrange(256)) for _ in xrange(15)))
backup_app.random = replicator.random = r
administration.random = backup_app.random = replicator.random \
= r
return wrapped(*args, **kw)
finally:
del MasterApplication.getNewUUID
backup_app.random = replicator.random = random
administration.random = backup_app.random = replicator.random \
= random
return wraps(wrapped)(wrapper)
return decorator
......@@ -23,11 +23,13 @@ import threading
import transaction
import unittest
import neo.lib
from neo.storage.checker import CHECK_COUNT
from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError
from neo.lib.connection import MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID
from neo.lib.util import p64
from . import NEOCluster, NEOThreadedTest, Patch, predictable_random
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
......@@ -36,8 +38,9 @@ class ReplicationTests(NEOThreadedTest):
def checksumPartition(self, storage, partition):
dm = storage.dm
args = ZERO_TID, MAX_TID, None, partition
return dm.checkTIDRange(*args), dm.checkSerialRange(ZERO_TID, *args)
args = partition, None, ZERO_TID, MAX_TID
return dm.checkTIDRange(*args), \
dm.checkSerialRange(min_oid=ZERO_OID, *args)
def checkPartitionReplicated(self, source, destination, partition):
self.assertEqual(self.checksumPartition(source, partition),
......@@ -60,25 +63,28 @@ class ReplicationTests(NEOThreadedTest):
return checked
def testBackupNormalCase(self):
upstream = NEOCluster(partitions=7, replicas=1, storage_count=3)
np = 7
nr = 2
check_dict = dict.fromkeys(xrange(np))
upstream = NEOCluster(partitions=np, replicas=nr-1, storage_count=3)
try:
upstream.start()
importZODB = upstream.importZODB()
importZODB(3)
upstream.client.setPoll(0)
backup = NEOCluster(partitions=7, replicas=1, storage_count=5,
backup = NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream)
try:
backup.start()
# Initialize & catch up.
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
self.assertEqual(np*nr, self.checkBackup(backup))
# Normal case, following upstream cluster closely.
importZODB(17)
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
self.assertEqual(np*nr, self.checkBackup(backup))
# Check that a backup cluster can be restarted.
finally:
backup.stop()
......@@ -90,11 +96,13 @@ class ReplicationTests(NEOThreadedTest):
importZODB(17)
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
self.assertEqual(np*nr, self.checkBackup(backup))
backup.neoctl.checkReplicas(check_dict, ZERO_TID, None)
backup.tic()
# Stop backing up, nothing truncated.
backup.neoctl.setClusterState(ClusterStates.STOPPING_BACKUP)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
self.assertEqual(np*nr, self.checkBackup(backup))
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally:
......@@ -110,6 +118,8 @@ class ReplicationTests(NEOThreadedTest):
- primary storage disconnected from backup master
- non-primary storage disconnected from backup master
"""
np = 4
check_dict = dict.fromkeys(xrange(np))
from neo.master.backup_app import random
def fetchObjects(orig, min_tid=None, min_oid=ZERO_OID):
if min_tid is None:
......@@ -124,11 +134,11 @@ class ReplicationTests(NEOThreadedTest):
node_list.remove(txn.getNode())
node_list[0].getConnection().close()
return orig(txn)
upstream = NEOCluster(partitions=4, replicas=0, storage_count=1)
upstream = NEOCluster(partitions=np, replicas=0, storage_count=1)
try:
upstream.start()
importZODB = upstream.importZODB(random=random)
backup = NEOCluster(partitions=4, replicas=2, storage_count=4,
backup = NEOCluster(partitions=np, replicas=2, storage_count=4,
upstream=upstream)
try:
backup.start()
......@@ -160,8 +170,10 @@ class ReplicationTests(NEOThreadedTest):
finally:
del p
upstream.client.setPoll(0)
if event > 5:
backup.neoctl.checkReplicas(check_dict, ZERO_TID, None)
backup.tic()
self.assertEqual(12, self.checkBackup(backup))
self.assertEqual(np*3, self.checkBackup(backup))
finally:
backup.stop()
finally:
......@@ -196,12 +208,13 @@ class ReplicationTests(NEOThreadedTest):
# default for performance reason
orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list)
cluster = NEOCluster(partitions=3, replicas=1, storage_count=3)
np = 3
cluster = NEOCluster(partitions=np, replicas=1, storage_count=3)
s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
try:
cluster.start([s0])
cluster.populate([range(6)] * 3)
cluster.populate([range(np*2)] * np)
cluster.client.setPoll(0)
s1.start()
s2.start()
......@@ -223,6 +236,50 @@ class ReplicationTests(NEOThreadedTest):
cluster.stop()
cluster.reset(True)
def testCheckReplicas(self):
from neo.storage import checker
def corrupt(offset):
s0, s1, s2 = (storage_dict[cell.getUUID()]
for cell in cluster.master.pt.getCellList(offset, True))
s1.dm.deleteObject(p64(np+offset), p64(corrupt_tid))
return s0.uuid
def check(expected_state, expected_count):
self.assertEqual(expected_count, len([None
for row in cluster.neoctl.getPartitionRowList()[1]
for cell in row[1]
if cell[1] == CellStates.CORRUPTED]))
self.assertEqual(expected_state, cluster.neoctl.getClusterState())
np = 5
tid_count = np * 3
corrupt_tid = tid_count // 2
check_dict = dict.fromkeys(xrange(np))
cluster = NEOCluster(partitions=np, replicas=2, storage_count=3)
try:
checker.CHECK_COUNT = 2
cluster.start()
cluster.populate([range(np*2)] * tid_count)
cluster.client.setPoll(0)
storage_dict = dict((x.uuid, x) for x in cluster.storage_list)
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
cluster.tic()
check(ClusterStates.RUNNING, 0)
source = corrupt(0)
cluster.neoctl.checkReplicas(check_dict, p64(corrupt_tid+1), None)
cluster.tic()
check(ClusterStates.RUNNING, 0)
cluster.neoctl.checkReplicas({0: source}, ZERO_TID, None)
cluster.tic()
check(ClusterStates.RUNNING, 1)
corrupt(1)
cluster.neoctl.checkReplicas(check_dict, p64(corrupt_tid+1), None)
cluster.tic()
check(ClusterStates.RUNNING, 1)
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
cluster.tic()
check(ClusterStates.VERIFYING, 4)
finally:
checker.CHECK_COUNT = CHECK_COUNT
cluster.stop()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment