Commit 04f72a4c authored by Julien Muchembled's avatar Julien Muchembled

New feature to check that partitions are replicated properly

This includes an API change of Node.isIdentified, which now tells whether
identification packets have been exchanged or not.
All handlers must be updated to implement '_acceptIdentification' instead of
overriding EventHandler.acceptIdentification: this patch only does it for
StorageOperationHandler
parent 2241c3a1
...@@ -133,6 +133,11 @@ RC - Review output of pylint (CODE) ...@@ -133,6 +133,11 @@ RC - Review output of pylint (CODE)
be done ? hope to find a storage with valid checksum ? assume that data be done ? hope to find a storage with valid checksum ? assume that data
is correct in storage but was altered when it travelled through network is correct in storage but was altered when it travelled through network
as we loaded it ?). as we loaded it ?).
- Check replicas: (HIGH AVAILABILITY)
- Automatically tell corrupted cells to fix their data when a good source
is known.
- Add an option to also check all rows of trans/obj/data, instead of only
keys (trans.tid & obj.{tid,oid}).
Master Master
- Master node data redundancy (HIGH AVAILABILITY) - Master node data redundancy (HIGH AVAILABILITY)
......
...@@ -83,6 +83,7 @@ class AdminEventHandler(EventHandler): ...@@ -83,6 +83,7 @@ class AdminEventHandler(EventHandler):
addPendingNodes = forward_ask(Packets.AddPendingNodes) addPendingNodes = forward_ask(Packets.AddPendingNodes)
setClusterState = forward_ask(Packets.SetClusterState) setClusterState = forward_ask(Packets.SetClusterState)
checkReplicas = forward_ask(Packets.CheckReplicas)
class MasterEventHandler(EventHandler): class MasterEventHandler(EventHandler):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from functools import wraps
import neo.lib import neo.lib
from .protocol import ( from .protocol import (
NodeStates, Packets, ErrorCodes, Errors, BrokenNodeDisallowedError, NodeStates, Packets, ErrorCodes, Errors, BrokenNodeDisallowedError,
...@@ -121,6 +122,19 @@ class EventHandler(object): ...@@ -121,6 +122,19 @@ class EventHandler(object):
# Packet handlers. # Packet handlers.
def acceptIdentification(self, conn, node_type, *args):
try:
acceptIdentification = self._acceptIdentification
except AttributeError:
raise UnexpectedPacketError('no handler found')
node = self.app.nm.getByAddress(conn.getAddress())
assert node.getConnection() is conn, (node.getConnection(), conn)
if node.getType() == node_type:
node.setIdentified()
acceptIdentification(node, *args)
return
conn.close()
def ping(self, conn): def ping(self, conn):
conn.answer(Packets.Pong()) conn.answer(Packets.Pong())
......
...@@ -37,6 +37,7 @@ class Node(object): ...@@ -37,6 +37,7 @@ class Node(object):
self._uuid = uuid self._uuid = uuid
self._manager = manager self._manager = manager
self._last_state_change = time() self._last_state_change = time()
self._identified = False
manager.add(self) manager.add(self)
def notify(self, packet): def notify(self, packet):
...@@ -98,6 +99,7 @@ class Node(object): ...@@ -98,6 +99,7 @@ class Node(object):
""" """
assert self._connection is not None assert self._connection is not None
del self._connection del self._connection
self._identified = False
self._manager._updateIdentified(self) self._manager._updateIdentified(self)
def setConnection(self, connection, force=None): def setConnection(self, connection, force=None):
...@@ -113,6 +115,8 @@ class Node(object): ...@@ -113,6 +115,8 @@ class Node(object):
conn = self._connection conn = self._connection
if conn is None: if conn is None:
self._connection = connection self._connection = connection
if connection.isServer():
self.setIdentified()
else: else:
assert force is not None, \ assert force is not None, \
attributeTracker.whoSet(self, '_connection') attributeTracker.whoSet(self, '_connection')
...@@ -127,7 +131,11 @@ class Node(object): ...@@ -127,7 +131,11 @@ class Node(object):
if not force or conn.getPeerId() is not None or \ if not force or conn.getPeerId() is not None or \
type(conn.getHandler()) is not type(connection.getHandler()): type(conn.getHandler()) is not type(connection.getHandler()):
raise ProtocolError("already connected") raise ProtocolError("already connected")
conn.setOnClose(lambda: setattr(self, '_connection', connection)) def on_closed():
self._connection = connection
assert connection.isServer()
self.setIdentified()
conn.setOnClose(on_closed)
conn.close() conn.close()
assert not connection.isClosed(), connection assert not connection.isClosed(), connection
connection.setOnClose(self.onConnectionClosed) connection.setOnClose(self.onConnectionClosed)
...@@ -147,11 +155,15 @@ class Node(object): ...@@ -147,11 +155,15 @@ class Node(object):
return self._connection is not None and (connecting or return self._connection is not None and (connecting or
not self._connection.connecting) not self._connection.connecting)
def setIdentified(self):
assert self._connection is not None
self._identified = True
def isIdentified(self): def isIdentified(self):
""" """
Returns True is the node is connected and identified Returns True if identification packets have been exchanged
""" """
return self._connection is not None and self._uuid is not None return self._identified
def __repr__(self): def __repr__(self):
return '<%s(uuid=%s, address=%s, state=%s, connection=%r) at %x>' % ( return '<%s(uuid=%s, address=%s, state=%s, connection=%r) at %x>' % (
...@@ -396,7 +408,10 @@ class NodeManager(object): ...@@ -396,7 +408,10 @@ class NodeManager(object):
def _updateIdentified(self, node): def _updateIdentified(self, node):
uuid = node.getUUID() uuid = node.getUUID()
if node.isIdentified(): if uuid:
# XXX: It's probably a bug to include connecting nodes but there's
# no API yet to update manager when connection is established.
if node.isConnected(connecting=True):
self._identified_dict[uuid] = node self._identified_dict[uuid] = node
else: else:
self._identified_dict.pop(uuid, None) self._identified_dict.pop(uuid, None)
......
...@@ -25,7 +25,7 @@ from struct import Struct ...@@ -25,7 +25,7 @@ from struct import Struct
from .util import Enum, getAddressType from .util import Enum, getAddressType
# The protocol version (major, minor). # The protocol version (major, minor).
PROTOCOL_VERSION = (6, 1) PROTOCOL_VERSION = (7, 1)
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -49,6 +49,7 @@ class ErrorCodes(Enum): ...@@ -49,6 +49,7 @@ class ErrorCodes(Enum):
BROKEN_NODE = Enum.Item(5) BROKEN_NODE = Enum.Item(5)
ALREADY_PENDING = Enum.Item(7) ALREADY_PENDING = Enum.Item(7)
REPLICATION_ERROR = Enum.Item(8) REPLICATION_ERROR = Enum.Item(8)
CHECKING_ERROR = Enum.Item(9)
ErrorCodes = ErrorCodes() ErrorCodes = ErrorCodes()
class ClusterStates(Enum): class ClusterStates(Enum):
...@@ -83,6 +84,7 @@ class CellStates(Enum): ...@@ -83,6 +84,7 @@ class CellStates(Enum):
OUT_OF_DATE = Enum.Item(2) OUT_OF_DATE = Enum.Item(2)
FEEDING = Enum.Item(3) FEEDING = Enum.Item(3)
DISCARDED = Enum.Item(4) DISCARDED = Enum.Item(4)
CORRUPTED = Enum.Item(5)
CellStates = CellStates() CellStates = CellStates()
class LockState(Enum): class LockState(Enum):
...@@ -108,6 +110,7 @@ cell_state_prefix_dict = { ...@@ -108,6 +110,7 @@ cell_state_prefix_dict = {
CellStates.OUT_OF_DATE: 'O', CellStates.OUT_OF_DATE: 'O',
CellStates.FEEDING: 'F', CellStates.FEEDING: 'F',
CellStates.DISCARDED: 'D', CellStates.DISCARDED: 'D',
CellStates.CORRUPTED: 'C',
} }
# Other constants. # Other constants.
...@@ -1239,6 +1242,35 @@ class Pack(Packet): ...@@ -1239,6 +1242,35 @@ class Pack(Packet):
PBoolean('status'), PBoolean('status'),
) )
class CheckReplicas(Packet):
"""
ctl -> A
A -> M
"""
_fmt = PStruct('check_replicas',
PDict('partition_dict',
PNumber('partition'),
PUUID('source'),
),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = Error
class CheckPartition(Packet):
"""
M -> S
"""
_fmt = PStruct('check_partition',
PNumber('partition'),
PStruct('source',
PString('upstream_name'),
PAddress('address'),
),
PTID('min_tid'),
PTID('max_tid'),
)
class CheckTIDRange(Packet): class CheckTIDRange(Packet):
""" """
Ask some stats about a range of transactions. Ask some stats about a range of transactions.
...@@ -1251,15 +1283,13 @@ class CheckTIDRange(Packet): ...@@ -1251,15 +1283,13 @@ class CheckTIDRange(Packet):
S -> S S -> S
""" """
_fmt = PStruct('ask_check_tid_range', _fmt = PStruct('ask_check_tid_range',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'), PTID('min_tid'),
PTID('max_tid'), PTID('max_tid'),
PNumber('length'),
PNumber('partition'),
) )
_answer = PStruct('answer_check_tid_range', _answer = PStruct('answer_check_tid_range',
PTID('min_tid'),
PNumber('length'),
PNumber('count'), PNumber('count'),
PChecksum('checksum'), PChecksum('checksum'),
PTID('max_tid'), PTID('max_tid'),
...@@ -1277,22 +1307,30 @@ class CheckSerialRange(Packet): ...@@ -1277,22 +1307,30 @@ class CheckSerialRange(Packet):
S -> S S -> S
""" """
_fmt = PStruct('ask_check_serial_range', _fmt = PStruct('ask_check_serial_range',
POID('min_oid'),
PTID('min_serial'),
PTID('max_tid'),
PNumber('length'),
PNumber('partition'), PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
) )
_answer = PStruct('answer_check_serial_range', _answer = PStruct('answer_check_serial_range',
POID('min_oid'),
PTID('min_serial'),
PNumber('length'),
PNumber('count'), PNumber('count'),
PChecksum('tid_checksum'),
PTID('max_tid'),
PChecksum('oid_checksum'), PChecksum('oid_checksum'),
POID('max_oid'), POID('max_oid'),
PChecksum('serial_checksum'), )
PTID('max_serial'),
class PartitionCorrupted(Packet):
"""
S -> M
"""
_fmt = PStruct('partition_corrupted',
PNumber('partition'),
PList('cell_list',
PUUID('uuid'),
),
) )
class LastTransaction(Packet): class LastTransaction(Packet):
...@@ -1601,10 +1639,16 @@ class Packets(dict): ...@@ -1601,10 +1639,16 @@ class Packets(dict):
TIDListFrom) TIDListFrom)
AskPack, AnswerPack = register( AskPack, AnswerPack = register(
Pack, ignore_when_closed=False) Pack, ignore_when_closed=False)
CheckReplicas = register(
CheckReplicas)
CheckPartition = register(
CheckPartition)
AskCheckTIDRange, AnswerCheckTIDRange = register( AskCheckTIDRange, AnswerCheckTIDRange = register(
CheckTIDRange) CheckTIDRange)
AskCheckSerialRange, AnswerCheckSerialRange = register( AskCheckSerialRange, AnswerCheckSerialRange = register(
CheckSerialRange) CheckSerialRange)
NotifyPartitionCorrupted = register(
PartitionCorrupted)
NotifyReady = register( NotifyReady = register(
NotifyReady) NotifyReady)
AskLastTransaction, AnswerLastTransaction = register( AskLastTransaction, AnswerLastTransaction = register(
......
...@@ -34,7 +34,7 @@ class Cell(object): ...@@ -34,7 +34,7 @@ class Cell(object):
def __init__(self, node, state = CellStates.UP_TO_DATE): def __init__(self, node, state = CellStates.UP_TO_DATE):
self.node = node self.node = node
self.setState(state) self.state = state
def __repr__(self): def __repr__(self):
return "<Cell(uuid=%s, address=%s, state=%s)>" % ( return "<Cell(uuid=%s, address=%s, state=%s)>" % (
...@@ -59,6 +59,13 @@ class Cell(object): ...@@ -59,6 +59,13 @@ class Cell(object):
def isFeeding(self): def isFeeding(self):
return self.state == CellStates.FEEDING return self.state == CellStates.FEEDING
def isCorrupted(self):
return self.state == CellStates.CORRUPTED
def isReadable(self):
return self.state == CellStates.UP_TO_DATE or \
self.state == CellStates.FEEDING
def getNode(self): def getNode(self):
return self.node return self.node
...@@ -122,6 +129,12 @@ class PartitionTable(object): ...@@ -122,6 +129,12 @@ class PartitionTable(object):
except IndexError: except IndexError:
return False return False
def getNodeSet(self):
return set(x.getNode() for row in self.partition_list for x in row)
def getConnectedNodeList(self):
return [node for node in self.getNodeSet() if node.isConnected()]
def getNodeList(self): def getNodeList(self):
"""Return all used nodes.""" """Return all used nodes."""
return [node for node, count in self.count_dict.iteritems() \ return [node for node, count in self.count_dict.iteritems() \
...@@ -129,8 +142,7 @@ class PartitionTable(object): ...@@ -129,8 +142,7 @@ class PartitionTable(object):
def getCellList(self, offset, readable=False): def getCellList(self, offset, readable=False):
if readable: if readable:
return [cell for cell in self.partition_list[offset] return filter(Cell.isReadable, self.partition_list[offset])
if not cell.isOutOfDate()]
return list(self.partition_list[offset]) return list(self.partition_list[offset])
def getPartition(self, oid_or_tid): def getPartition(self, oid_or_tid):
...@@ -280,7 +292,7 @@ class PartitionTable(object): ...@@ -280,7 +292,7 @@ class PartitionTable(object):
return False return False
for row in self.partition_list: for row in self.partition_list:
for cell in row: for cell in row:
if not cell.isOutOfDate() and cell.getNode().isRunning(): if cell.isReadable() and cell.getNode().isRunning():
break break
else: else:
return False return False
......
...@@ -279,7 +279,7 @@ class BackupApplication(object): ...@@ -279,7 +279,7 @@ class BackupApplication(object):
primary = primary_node is node primary = primary_node is node
result = None if primary else app.pt.setUpToDate(node, offset) result = None if primary else app.pt.setUpToDate(node, offset)
if app.getClusterState() == ClusterStates.BACKINGUP: if app.getClusterState() == ClusterStates.BACKINGUP:
assert not cell.isOutOfDate() assert cell.isReadable()
if result: # was out-of-date if result: # was out-of-date
max_tid, = [x.backup_tid for x in cell_list max_tid, = [x.backup_tid for x in cell_list
if x.getNode() is primary_node] if x.getNode() is primary_node]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import random
import neo import neo
from . import MasterHandler from . import MasterHandler
...@@ -162,3 +163,48 @@ class AdministrationHandler(MasterHandler): ...@@ -162,3 +163,48 @@ class AdministrationHandler(MasterHandler):
# broadcast the new partition table # broadcast the new partition table
app.broadcastPartitionChanges(cell_list) app.broadcastPartitionChanges(cell_list)
conn.answer(Errors.Ack('Nodes added: %s' % (uuids, ))) conn.answer(Errors.Ack('Nodes added: %s' % (uuids, )))
def checkReplicas(self, conn, partition_dict, min_tid, max_tid):
app = self.app
pt = app.pt
backingup = app.cluster_state == ClusterStates.BACKINGUP
if not max_tid:
max_tid = pt.getCheckTid(partition_dict) if backingup else \
app.getLastTransaction()
if min_tid > max_tid:
neo.lib.logging.warning("nothing to check: min_tid=%s > max_tid=%s",
dump(min_tid), dump(max_tid))
else:
getByUUID = app.nm.getByUUID
node_set = set()
for offset, source in partition_dict.iteritems():
# XXX: For the moment, code checking replicas is unable to fix
# corrupted partitions (when a good cell is known)
# so only check readable ones.
# (see also Checker._nextPartition of storage)
cell_list = pt.getCellList(offset, True)
#cell_list = [cell for cell in pt.getCellList(offset)
# if not cell.isOutOfDate()]
if len(cell_list) + (backingup and not source) <= 1:
continue
for cell in cell_list:
node = cell.getNode()
if node in node_set:
break
else:
node_set.add(node)
if source:
source = '', getByUUID(source).getAddress()
else:
readable = [cell for cell in cell_list if cell.isReadable()]
if 1 == len(readable) < len(cell_list):
source = '', readable[0].getAddress()
elif backingup:
source = app.backup_app.name, random.choice(
app.backup_app.pt.getCellList(offset, readable=True)
).getAddress()
else:
source = '', None
node.getConnection().notify(Packets.CheckPartition(
offset, source, min_tid, max_tid))
conn.answer(Errors.Ack(''))
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib import neo.lib
from neo.lib.protocol import ClusterStates, Packets, ProtocolError from neo.lib.protocol import CellStates, ClusterStates, Packets, ProtocolError
from neo.lib.exception import OperationFailure from neo.lib.exception import OperationFailure
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.connector import ConnectorConnectionClosedException from neo.lib.connector import ConnectorConnectionClosedException
...@@ -72,6 +72,17 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -72,6 +72,17 @@ class StorageServiceHandler(BaseServiceHandler):
# transaction locked on this storage node # transaction locked on this storage node
self.app.tm.lock(ttid, conn.getUUID()) self.app.tm.lock(ttid, conn.getUUID())
def notifyPartitionCorrupted(self, conn, partition, cell_list):
change_list = []
for cell in self.app.pt.getCellList(partition):
if cell.getUUID() in cell_list:
cell.setState(CellStates.CORRUPTED)
change_list.append((partition, cell.getUUID(),
CellStates.CORRUPTED))
self.app.broadcastPartitionChanges(change_list)
if not self.app.pt.operational():
raise OperationFailure('cannot continue operation')
def notifyReplicationDone(self, conn, offset, tid): def notifyReplicationDone(self, conn, offset, tid):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
......
...@@ -25,12 +25,13 @@ class Cell(neo.lib.pt.Cell): ...@@ -25,12 +25,13 @@ class Cell(neo.lib.pt.Cell):
replicating = ZERO_TID replicating = ZERO_TID
def setState(self, state): def setState(self, state):
readable = self.isReadable()
super(Cell, self).setState(state)
if readable and not self.isReadable():
try: try:
if CellStates.OUT_OF_DATE == state != self.state:
del self.backup_tid, self.replicating del self.backup_tid, self.replicating
except AttributeError: except AttributeError:
pass pass
return super(Cell, self).setState(state)
neo.lib.pt.Cell = Cell neo.lib.pt.Cell = Cell
...@@ -196,7 +197,7 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -196,7 +197,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
CellStates.OUT_OF_DATE)) CellStates.OUT_OF_DATE))
node_count += 1 node_count += 1
elif node_count + 1 < max_count: elif node_count + 1 < max_count:
if feeding_cell is not None or max_cell.isOutOfDate(): if feeding_cell is not None or not max_cell.isReadable():
# If there is a feeding cell already or it is # If there is a feeding cell already or it is
# out-of-date, just drop the node. # out-of-date, just drop the node.
row.remove(max_cell) row.remove(max_cell)
...@@ -239,10 +240,10 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -239,10 +240,10 @@ class PartitionTable(neo.lib.pt.PartitionTable):
else: else:
# Remove an excessive feeding cell. # Remove an excessive feeding cell.
removed_cell_list.append(cell) removed_cell_list.append(cell)
elif cell.isOutOfDate(): elif cell.isUpToDate():
out_of_date_cell_list.append(cell)
else:
up_to_date_cell_list.append(cell) up_to_date_cell_list.append(cell)
else:
out_of_date_cell_list.append(cell)
# If all cells are up-to-date, a feeding cell is not required. # If all cells are up-to-date, a feeding cell is not required.
if len(out_of_date_cell_list) == 0 and feeding_cell is not None: if len(out_of_date_cell_list) == 0 and feeding_cell is not None:
...@@ -311,7 +312,7 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -311,7 +312,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
lost = lost_node lost = lost_node
cell_list = [] cell_list = []
for cell in row: for cell in row:
if not cell.isOutOfDate(): if cell.isReadable():
if cell.getNode().isRunning(): if cell.getNode().isRunning():
lost = None lost = None
else : else :
...@@ -330,7 +331,7 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -330,7 +331,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
yield offset, cell yield offset, cell
break break
def getUpToDateCellNodeSet(self): def getReadableCellNodeSet(self):
""" """
Return a set of all nodes which are part of at least one UP TO DATE Return a set of all nodes which are part of at least one UP TO DATE
partition. partition.
...@@ -338,17 +339,7 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -338,17 +339,7 @@ class PartitionTable(neo.lib.pt.PartitionTable):
return set(cell.getNode() return set(cell.getNode()
for row in self.partition_list for row in self.partition_list
for cell in row for cell in row
if not cell.isOutOfDate()) if cell.isReadable())
def getOutOfDateCellNodeSet(self):
"""
Return a set of all nodes which are part of at least one OUT OF DATE
partition.
"""
return set(cell.getNode()
for row in self.partition_list
for cell in row
if cell.isOutOfDate())
def setBackupTidDict(self, backup_tid_dict): def setBackupTidDict(self, backup_tid_dict):
for row in self.partition_list: for row in self.partition_list:
...@@ -358,8 +349,16 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -358,8 +349,16 @@ class PartitionTable(neo.lib.pt.PartitionTable):
def getBackupTid(self): def getBackupTid(self):
try: try:
return min(max(cell.backup_tid for cell in row return min(max(cell.backup_tid for cell in row if cell.isReadable())
if not cell.isOutOfDate())
for row in self.partition_list) for row in self.partition_list)
except ValueError: except ValueError:
return ZERO_TID return ZERO_TID
def getCheckTid(self, partition_list):
try:
return min(min(cell.backup_tid
for cell in self.partition_list[offset]
if cell.isReadable())
for offset in partition_list)
except ValueError:
return ZERO_TID
...@@ -65,39 +65,39 @@ class RecoveryManager(MasterHandler): ...@@ -65,39 +65,39 @@ class RecoveryManager(MasterHandler):
if pt.filled(): if pt.filled():
# A partition table exists, we are starting an existing # A partition table exists, we are starting an existing
# cluster. # cluster.
partition_node_set = pt.getUpToDateCellNodeSet() partition_node_set = pt.getReadableCellNodeSet()
pending_node_set = set(x for x in partition_node_set pending_node_set = set(x for x in partition_node_set
if x.isPending()) if x.isPending())
if app._startup_allowed or \ if app._startup_allowed or \
partition_node_set == pending_node_set: partition_node_set == pending_node_set:
allowed_node_set = pending_node_set allowed_node_set = pending_node_set
extra_node_set = pt.getOutOfDateCellNodeSet() node_list = pt.getConnectedNodeList
elif app._startup_allowed: elif app._startup_allowed:
# No partition table and admin allowed startup, we are # No partition table and admin allowed startup, we are
# creating a new cluster out of all pending nodes. # creating a new cluster out of all pending nodes.
allowed_node_set = set(app.nm.getStorageList( allowed_node_set = set(app.nm.getStorageList(
only_identified=True)) only_identified=True))
extra_node_set = set() node_list = lambda: allowed_node_set
if allowed_node_set: if allowed_node_set:
for node in allowed_node_set: for node in allowed_node_set:
assert node.isPending(), node assert node.isPending(), node
if node.getConnection().isPending(): if node.getConnection().isPending():
break break
else: else:
allowed_node_set |= extra_node_set node_list = node_list()
break break
neo.lib.logging.info('startup allowed') neo.lib.logging.info('startup allowed')
for node in allowed_node_set: for node in node_list:
node.setRunning() node.setRunning()
app.broadcastNodesInformation(allowed_node_set) app.broadcastNodesInformation(node_list)
if pt.getID() is None: if pt.getID() is None:
neo.lib.logging.info('creating a new partition table') neo.lib.logging.info('creating a new partition table')
# reset IDs generators & build new partition with running nodes # reset IDs generators & build new partition with running nodes
app.tm.setLastOID(ZERO_OID) app.tm.setLastOID(ZERO_OID)
pt.make(allowed_node_set) pt.make(node_list)
self._broadcastPartitionTable(pt.getID(), pt.getRowList()) self._broadcastPartitionTable(pt.getID(), pt.getRowList())
elif app.backup_tid: elif app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict) pt.setBackupTidDict(self.backup_tid_dict)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from .neoctl import NeoCTL, NotReadyException from .neoctl import NeoCTL, NotReadyException
from neo.lib.util import bin, dump from neo.lib.util import bin, dump
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, ZERO_TID
action_dict = { action_dict = {
'print': { 'print': {
...@@ -30,6 +30,7 @@ action_dict = { ...@@ -30,6 +30,7 @@ action_dict = {
'node': 'setNodeState', 'node': 'setNodeState',
'cluster': 'setClusterState', 'cluster': 'setClusterState',
}, },
'check': 'checkReplicas',
'start': 'startCluster', 'start': 'startCluster',
'add': 'enableStorageList', 'add': 'enableStorageList',
'drop': 'dropNode', 'drop': 'dropNode',
...@@ -187,6 +188,33 @@ class TerminalNeoCTL(object): ...@@ -187,6 +188,33 @@ class TerminalNeoCTL(object):
""" """
return self.formatUUID(self.neoctl.getPrimary()) return self.formatUUID(self.neoctl.getPrimary())
def checkReplicas(self, params):
"""
Parameters: [partition]:[reference] ... [min_tid [max_tid]]
"""
partition_dict = {}
params = iter(params)
min_tid = ZERO_TID
max_tid = None
for p in params:
try:
partition, source = p.split(':')
except ValueError:
min_tid = p64(p)
try:
max_tid = p64(params.next())
except StopIteration:
pass
break
source = bin(source) if source else None
if partition:
partition_dict[int(partition)] = source
else:
assert not partition_dict
np = len(self.neoctl.getPartitionRowList()[1])
partition_dict = dict.fromkeys(xrange(np), source)
self.neoctl.checkReplicas(partition_dict, min_tid, max_tid)
class Application(object): class Application(object):
"""The storage node application.""" """The storage node application."""
......
...@@ -163,3 +163,8 @@ class NeoCTL(object): ...@@ -163,3 +163,8 @@ class NeoCTL(object):
raise RuntimeError(response) raise RuntimeError(response)
return response[1] return response[1]
def checkReplicas(self, *args):
response = self.__ask(Packets.CheckReplicas(*args))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response)
return response[2]
...@@ -28,6 +28,7 @@ from neo.lib.connector import getConnectorHandler ...@@ -28,6 +28,7 @@ from neo.lib.connector import getConnectorHandler
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from .checker import Checker
from .database import buildDatabaseManager from .database import buildDatabaseManager
from .exception import AlreadyPendingError from .exception import AlreadyPendingError
from .handlers import identification, verification, initialization from .handlers import identification, verification, initialization
...@@ -66,6 +67,7 @@ class Application(object): ...@@ -66,6 +67,7 @@ class Application(object):
# partitions. # partitions.
self.pt = None self.pt = None
self.checker = Checker(self)
self.replicator = Replicator(self) self.replicator = Replicator(self)
self.listening_conn = None self.listening_conn = None
self.master_conn = None self.master_conn = None
...@@ -207,6 +209,8 @@ class Application(object): ...@@ -207,6 +209,8 @@ class Application(object):
neo.lib.logging.error('operation stopped: %s', msg) neo.lib.logging.error('operation stopped: %s', msg)
except PrimaryFailure, msg: except PrimaryFailure, msg:
neo.lib.logging.error('primary master is down: %s', msg) neo.lib.logging.error('primary master is down: %s', msg)
finally:
self.checker = Checker(self)
def connectToPrimary(self): def connectToPrimary(self):
"""Find a primary master node, and connect to it. """Find a primary master node, and connect to it.
...@@ -369,6 +373,11 @@ class Application(object): ...@@ -369,6 +373,11 @@ class Application(object):
return return
self.task_queue.appendleft(iterator) self.task_queue.appendleft(iterator)
def closeClient(self, connection):
if connection is not self.replicator.getCurrentConnection() and \
connection not in self.checker.conn_dict:
connection.closeClient()
def shutdown(self, erase=False): def shutdown(self, erase=False):
"""Close all connections and exit""" """Close all connections and exit"""
for c in self.em.getConnectionList(): for c in self.em.getConnectionList():
......
##############################################################################
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsibility of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# guarantees and support are strongly advised to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
##############################################################################
from collections import deque
from functools import wraps
import neo.lib
from neo.lib.connection import ClientConnection
from neo.lib.connector import ConnectorConnectionClosedException
from neo.lib.protocol import NodeTypes, Packets, ZERO_OID
from neo.lib.util import add64, dump
from .handlers.storage import StorageOperationHandler
CHECK_COUNT = 4000
class Checker(object):
def __init__(self, app):
self.app = app
self.queue = deque()
self.conn_dict = {}
def __call__(self, partition, source, min_tid, max_tid):
self.queue.append((partition, source, min_tid, max_tid))
if not self.conn_dict:
self._nextPartition()
def _nextPartition(self):
app = self.app
def connect(node, uuid=app.uuid, name=app.name):
if node.getUUID() == app.uuid:
return
if node.isConnected(connecting=True):
conn = node.getConnection()
conn.asClient()
else:
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler())
conn.ask(Packets.RequestIdentification(
NodeTypes.STORAGE, uuid, app.server, name))
self.conn_dict[conn] = node.isIdentified()
conn_set = set(self.conn_dict)
conn_set.discard(None)
try:
self.conn_dict.clear()
while True:
try:
partition, (name, source), min_tid, max_tid = \
self.queue.popleft()
except IndexError:
return
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
msg = "discarded or out-of-date"
else:
try:
for cell in app.pt.getCellList(partition):
# XXX: Ignore corrupted cells for the moment
# because we're still unable to fix them
# (see also AdministrationHandler of master)
if cell.isReadable(): #if not cell.isOutOfDate():
connect(cell.getNode())
if source:
node = app.nm.getByAddress(source)
if name:
source = app.nm.createStorage(address=source) \
if node is None else node
connect(source, None, name)
elif (node.getUUID() == app.uuid or
node.isConnected(connecting=True) and
node.getConnection() in self.conn_dict):
source = node
else:
msg = "unavailable source"
if self.conn_dict:
break
msg = "no replica"
except ConnectorConnectionClosedException:
msg = "connection closed"
finally:
conn_set.update(self.conn_dict)
self.conn_dict.clear()
neo.lib.logging.error(
"Failed to start checking partition %u (%s)",
partition, msg)
conn_set.difference_update(self.conn_dict)
finally:
for conn in conn_set:
app.closeClient(conn)
neo.lib.logging.debug("start checking partition %u from %s to %s",
partition, dump(min_tid), dump(max_tid))
self.min_tid = self.next_tid = min_tid
self.max_tid = max_tid
self.next_oid = None
self.partition = partition
self.source = source
args = partition, CHECK_COUNT, min_tid, max_tid
p = Packets.AskCheckTIDRange(*args)
for conn, identified in self.conn_dict.items():
self.conn_dict[conn] = conn.ask(p) if identified else None
self.conn_dict[None] = app.dm.checkTIDRange(*args)
def connected(self, node):
conn = node.getConnection()
if self.conn_dict.get(conn, self) is None:
self.conn_dict[conn] = conn.ask(Packets.AskCheckTIDRange(
self.partition, CHECK_COUNT, self.next_tid, self.max_tid))
def connectionLost(self, conn):
try:
del self.conn_dict[conn]
except KeyError:
return
if self.source is not None and self.source.getConnection() is conn:
del self.source
elif len(self.conn_dict) > 1:
neo.lib.logging.warning("node lost but keep up checking partition"
" %u", self.partition)
return
neo.lib.logging.warning("check of partition %u aborted", self.partition)
self._nextPartition()
def _nextRange(self):
if self.next_oid:
args = self.partition, CHECK_COUNT, self.next_tid, self.max_tid, \
self.next_oid
p = Packets.AskCheckSerialRange(*args)
check = self.app.dm.checkSerialRange
else:
args = self.partition, CHECK_COUNT, self.next_tid, self.max_tid
p = Packets.AskCheckTIDRange(*args)
check = self.app.dm.checkTIDRange
for conn in self.conn_dict.keys():
self.conn_dict[conn] = check(*args) if conn is None else conn.ask(p)
def checkRange(self, conn, *args):
if self.conn_dict.get(conn, self) != conn.getPeerId():
# Ignore answers to old requests,
# because we did nothing to cancel them.
neo.lib.logging.info("ignored AnswerCheck*Range%r", args)
return
self.conn_dict[conn] = args
answer_set = set(self.conn_dict.itervalues())
if len(answer_set) > 1:
for answer in answer_set:
if type(answer) is not tuple:
return
# TODO: Automatically tell corrupted cells to fix their data
# if we know a good source.
# For the moment, tell master to put them in CORRUPTED state
# and keep up checking if useful.
uuid = self.app.uuid
args = None if self.source is None else self.conn_dict[
None if self.source.getUUID() == uuid
else self.source.getConnection()]
uuid_list = []
for conn, answer in self.conn_dict.items():
if answer != args:
del self.conn_dict[conn]
if conn is None:
uuid_list.append(uuid)
else:
uuid_list.append(conn.getUUID())
self.app.closeClient(conn)
p = Packets.NotifyPartitionCorrupted(self.partition, uuid_list)
self.app.master_conn.notify(p)
if len(self.conn_dict) <= 1:
neo.lib.logging.warning("check of partition %u aborted",
self.partition)
self.queue.clear()
self._nextPartition()
return
try:
count, _, max_tid = args
except ValueError:
count, _, self.next_tid, _, max_oid = args
if count < CHECK_COUNT:
neo.lib.logging.debug("partition %u checked from %s to %s",
self.partition, dump(self.min_tid), dump(self.max_tid))
self._nextPartition()
return
self.next_oid = add64(max_oid, 1)
else:
(count, _, max_tid), = answer_set
if count < CHECK_COUNT:
self.next_tid = self.min_tid
self.next_oid = ZERO_OID
else:
self.next_tid = add64(max_tid, 1)
self._nextRange()
...@@ -532,7 +532,7 @@ class DatabaseManager(object): ...@@ -532,7 +532,7 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def checkTIDRange(self, min_tid, max_tid, length, partition): def checkTIDRange(self, partition, length, min_tid, max_tid):
""" """
Generate a diggest from transaction list. Generate a diggest from transaction list.
min_tid (packed) min_tid (packed)
...@@ -549,12 +549,12 @@ class DatabaseManager(object): ...@@ -549,12 +549,12 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition): def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
""" """
Generate a diggest from object list. Generate a diggest from object list.
min_oid (packed) min_oid (packed)
OID at which verification starts. OID at which verification starts.
min_serial (packed) min_tid (packed)
Serial of min_oid object at which search should start. Serial of min_oid object at which search should start.
length length
Maximum number of records to include in result. Maximum number of records to include in result.
......
...@@ -702,7 +702,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -702,7 +702,7 @@ class MySQLDatabaseManager(DatabaseManager):
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
def checkTIDRange(self, min_tid, max_tid, length, partition): def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tid_checksum, max_tid = self.query( count, tid_checksum, max_tid = self.query(
"""SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid) """SELECT COUNT(*), SHA1(GROUP_CONCAT(tid SEPARATOR ",")), MAX(tid)
FROM (SELECT tid FROM trans FROM (SELECT tid FROM trans
...@@ -713,30 +713,30 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -713,30 +713,30 @@ class MySQLDatabaseManager(DatabaseManager):
'partition': partition, 'partition': partition,
'min_tid': util.u64(min_tid), 'min_tid': util.u64(min_tid),
'max_tid': util.u64(max_tid), 'max_tid': util.u64(max_tid),
'limit': '' if length is None else 'LIMIT %(length)d' % length, 'limit': '' if length is None else 'LIMIT %u' % length,
})[0] })[0]
if count: if count:
return count, a2b_hex(tid_checksum), util.p64(max_tid) return count, a2b_hex(tid_checksum), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition): def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
u64 = util.u64 u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange) # We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_. # because it's difficult to get the last serial _for the last oid_.
# We would need a function (that could be named 'LAST') that returns the # We would need a function (that could be named 'LAST') that returns the
# last grouped value, instead of the greatest one. # last grouped value, instead of the greatest one.
r = self.query( r = self.query(
"""SELECT oid, tid """SELECT tid, oid
FROM obj FROM obj
WHERE partition = %(partition)s WHERE partition = %(partition)s
AND tid <= %(max_tid)d AND tid <= %(max_tid)d
AND (oid > %(min_oid)d OR AND (tid > %(min_tid)d OR
oid = %(min_oid)d AND tid >= %(min_tid)d) tid = %(min_tid)d AND oid >= %(min_oid)d)
ORDER BY oid ASC, tid ASC %(limit)s""" % { ORDER BY tid, oid %(limit)s""" % {
'min_oid': u64(min_oid), 'min_oid': u64(min_oid),
'min_tid': u64(min_serial), 'min_tid': u64(min_tid),
'max_tid': u64(max_tid), 'max_tid': u64(max_tid),
'limit': '' if length is None else 'LIMIT %(length)d' % length, 'limit': '' if length is None else 'LIMIT %u' % length,
'partition': partition, 'partition': partition,
}) })
if r: if r:
...@@ -746,4 +746,4 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -746,4 +746,4 @@ class MySQLDatabaseManager(DatabaseManager):
p64(r[-1][0]), p64(r[-1][0]),
sha1(','.join(str(x[1]) for x in r)).digest(), sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1])) p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
...@@ -595,7 +595,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -595,7 +595,7 @@ class SQLiteDatabaseManager(DatabaseManager):
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
def checkTIDRange(self, min_tid, max_tid, length, partition): def checkTIDRange(self, partition, length, min_tid, max_tid):
count, tids, max_tid = self.query("""\ count, tids, max_tid = self.query("""\
SELECT COUNT(*), GROUP_CONCAT(tid), MAX(tid) SELECT COUNT(*), GROUP_CONCAT(tid), MAX(tid)
FROM (SELECT tid FROM trans FROM (SELECT tid FROM trans
...@@ -607,20 +607,19 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -607,20 +607,19 @@ class SQLiteDatabaseManager(DatabaseManager):
return count, sha1(tids).digest(), util.p64(max_tid) return count, sha1(tids).digest(), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition): def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
u64 = util.u64 u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange) # We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_. # because it's difficult to get the last serial _for the last oid_.
# We would need a function (that could be named 'LAST') that returns the # We would need a function (that could be named 'LAST') that returns the
# last grouped value, instead of the greatest one. # last grouped value, instead of the greatest one.
min_oid = u64(min_oid) min_tid = u64(min_tid)
r = self.query("""\ r = self.query("""\
SELECT oid, tid SELECT tid, oid
FROM obj FROM obj
WHERE partition=? AND tid<=? WHERE partition=? AND tid<=? AND (tid>? OR tid=? AND oid>=?)
AND (oid>? OR oid=? AND tid>=?) ORDER BY tid, oid LIMIT ?""",
ORDER BY oid ASC, tid ASC LIMIT ?""", (partition, u64(max_tid), min_tid, min_tid, u64(min_oid),
(partition, u64(max_tid), min_oid, min_oid, u64(min_serial),
-1 if length is None else length)).fetchall() -1 if length is None else length)).fetchall()
if r: if r:
p64 = util.p64 p64 = util.p64
...@@ -629,4 +628,4 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -629,4 +628,4 @@ class SQLiteDatabaseManager(DatabaseManager):
p64(r[-1][0]), p64(r[-1][0]),
sha1(','.join(str(x[1]) for x in r)).digest(), sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1])) p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID return 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
...@@ -72,3 +72,6 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -72,3 +72,6 @@ class MasterOperationHandler(BaseMasterHandler):
def askTruncate(self, conn, tid): def askTruncate(self, conn, tid):
self.app.dm.truncate(tid) self.app.dm.truncate(tid)
conn.answer(Packets.AnswerTruncate()) conn.answer(Packets.AnswerTruncate())
def checkPartition(self, conn, *args):
self.app.checker(*args)
...@@ -25,26 +25,42 @@ from neo.lib.protocol import Errors, NodeStates, Packets, \ ...@@ -25,26 +25,42 @@ from neo.lib.protocol import Errors, NodeStates, Packets, \
from neo.lib.util import add64 from neo.lib.util import add64
def checkConnectionIsReplicatorConnection(func): def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw): def wrapper(self, conn, *args, **kw):
assert self.app.replicator.getCurrentConnection() is conn assert self.app.replicator.getCurrentConnection() is conn
return func(self, conn, *args, **kw) return func(self, conn, *args, **kw)
return wraps(func)(decorator) return wraps(func)(wrapper)
def checkFeedingConnection(check):
def decorator(func):
def wrapper(self, conn, partition, *args, **kw):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or (cell.isOutOfDate() if check else
not cell.isReadable()):
p = Errors.CheckingError if check else Errors.ReplicationError
return conn.answer(p("partition %u not readable" % partition))
conn.asServer()
return func(self, conn, partition, *args, **kw)
return wraps(func)(wrapper)
return decorator
class StorageOperationHandler(EventHandler): class StorageOperationHandler(EventHandler):
"""This class handles events for replications.""" """This class handles events for replications."""
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
if self.app.listening_conn and conn.isClient(): app = self.app
if app.listening_conn and conn.isClient():
# XXX: Connection and Node should merged. # XXX: Connection and Node should merged.
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid: if uuid:
node = self.app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
else: else:
node = self.app.nm.getByAddress(conn.getAddress()) node = app.nm.getByAddress(conn.getAddress())
node.setState(NodeStates.DOWN) node.setState(NodeStates.DOWN)
replicator = self.app.replicator replicator = app.replicator
if replicator.current_node is node: if replicator.current_node is node:
replicator.abort() replicator.abort()
app.checker.connectionLost(conn)
# Client # Client
...@@ -52,10 +68,9 @@ class StorageOperationHandler(EventHandler): ...@@ -52,10 +68,9 @@ class StorageOperationHandler(EventHandler):
if self.app.listening_conn: if self.app.listening_conn:
self.app.replicator.abort() self.app.replicator.abort()
@checkConnectionIsReplicatorConnection def _acceptIdentification(self, node, *args):
def acceptIdentification(self, conn, node_type, self.app.replicator.connected(node)
uuid, num_partitions, num_replicas, your_uuid): self.app.checker.connected(node)
self.app.replicator.fetchTransactions()
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerFetchTransactions(self, conn, pack_tid, next_tid, tid_list): def answerFetchTransactions(self, conn, pack_tid, next_tid, tid_list):
...@@ -105,33 +120,53 @@ class StorageOperationHandler(EventHandler): ...@@ -105,33 +120,53 @@ class StorageOperationHandler(EventHandler):
def replicationError(self, conn, message): def replicationError(self, conn, message):
self.app.replicator.abort('source message: ' + message) self.app.replicator.abort('source message: ' + message)
def checkingError(self, conn, message):
try:
self.app.checker.connectionLost(conn)
finally:
self.app.closeClient(conn)
@property
def answerCheckTIDRange(self):
return self.app.checker.checkRange
@property
def answerCheckSerialRange(self):
return self.app.checker.checkRange
# Server (all methods must set connection as server so that it isn't closed # Server (all methods must set connection as server so that it isn't closed
# if client tasks are finished) # if client tasks are finished)
def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition): @checkFeedingConnection(check=True)
conn.asServer() def askCheckTIDRange(self, conn, *args):
count, tid_checksum, max_tid = self.app.dm.checkTIDRange(min_tid, msg_id = conn.getPeerId()
max_tid, length, partition) conn = weakref.proxy(conn)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, def check():
count, tid_checksum, max_tid)) r = self.app.dm.checkTIDRange(*args)
try:
conn.answer(Packets.AnswerCheckTIDRange(*r), msg_id)
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
yield
self.app.newTask(check())
def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length, @checkFeedingConnection(check=True)
partition): def askCheckSerialRange(self, conn, *args):
conn.asServer() msg_id = conn.getPeerId()
count, oid_checksum, max_oid, serial_checksum, max_serial = \ conn = weakref.proxy(conn)
self.app.dm.checkSerialRange(min_oid, min_serial, max_tid, length, def check():
partition) r = self.app.dm.checkSerialRange(*args)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length, try:
count, oid_checksum, max_oid, serial_checksum, max_serial)) conn.answer(Packets.AnswerCheckSerialRange(*r), msg_id)
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
yield
self.app.newTask(check())
@checkFeedingConnection(check=False)
def askFetchTransactions(self, conn, partition, length, min_tid, max_tid, def askFetchTransactions(self, conn, partition, length, min_tid, max_tid,
tid_list): tid_list):
app = self.app app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId() msg_id = conn.getPeerId()
conn = weakref.proxy(conn) conn = weakref.proxy(conn)
peer_tid_set = set(tid_list) peer_tid_set = set(tid_list)
...@@ -162,14 +197,10 @@ class StorageOperationHandler(EventHandler): ...@@ -162,14 +197,10 @@ class StorageOperationHandler(EventHandler):
pass pass
app.newTask(push()) app.newTask(push())
@checkFeedingConnection(check=False)
def askFetchObjects(self, conn, partition, length, min_tid, max_tid, def askFetchObjects(self, conn, partition, length, min_tid, max_tid,
min_oid, object_dict): min_oid, object_dict):
app = self.app app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId() msg_id = conn.getPeerId()
conn = weakref.proxy(conn) conn = weakref.proxy(conn)
dm = app.dm dm = app.dm
......
...@@ -132,7 +132,7 @@ class Replicator(object): ...@@ -132,7 +132,7 @@ class Replicator(object):
outdated_list = [] outdated_list = []
for offset in xrange(pt.getPartitions()): for offset in xrange(pt.getPartitions()):
for cell in pt.getCellList(offset): for cell in pt.getCellList(offset):
if cell.getUUID() == uuid: if cell.getUUID() == uuid and not cell.isCorrupted():
self.partition_dict[offset] = p = Partition() self.partition_dict[offset] = p = Partition()
if cell.isOutOfDate(): if cell.isOutOfDate():
outdated_list.append(offset) outdated_list.append(offset)
...@@ -154,17 +154,25 @@ class Replicator(object): ...@@ -154,17 +154,25 @@ class Replicator(object):
abort = False abort = False
added_list = [] added_list = []
app = self.app app = self.app
last_tid, last_trans_dict, last_obj_dict = app.dm.getLastTIDs()
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
if uuid == app.uuid: if uuid == app.uuid:
if state == CellStates.DISCARDED: if state in (CellStates.DISCARDED, CellStates.CORRUPTED):
try:
del self.partition_dict[offset] del self.partition_dict[offset]
except KeyError:
continue
self.replicate_dict.pop(offset, None) self.replicate_dict.pop(offset, None)
self.source_dict.pop(offset, None) self.source_dict.pop(offset, None)
abort = abort or self.current_partition == offset abort = abort or self.current_partition == offset
elif state == CellStates.OUT_OF_DATE: elif state == CellStates.OUT_OF_DATE:
assert offset not in self.partition_dict assert offset not in self.partition_dict
self.partition_dict[offset] = p = Partition() self.partition_dict[offset] = p = Partition()
p.next_trans = p.next_obj = ZERO_TID try:
p.next_trans = add64(last_trans_dict[offset], 1)
except KeyError:
p.next_trans = ZERO_TID
p.next_obj = last_obj_dict.get(offset, ZERO_TID)
p.max_ttid = INVALID_TID p.max_ttid = INVALID_TID
added_list.append(offset) added_list.append(offset)
if added_list: if added_list:
...@@ -212,11 +220,10 @@ class Replicator(object): ...@@ -212,11 +220,10 @@ class Replicator(object):
self.current_partition = offset self.current_partition = offset
previous_node = self.current_node previous_node = self.current_node
self.current_node = node self.current_node = node
if node.isConnected(): if node.isConnected(connecting=True):
if node.isIdentified():
node.getConnection().asClient() node.getConnection().asClient()
self.fetchTransactions() self.fetchTransactions()
if node is previous_node:
return
else: else:
assert name or node.getUUID() != app.uuid, "loopback connection" assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app), conn = ClientConnection(app.em, StorageOperationHandler(app),
...@@ -224,7 +231,11 @@ class Replicator(object): ...@@ -224,7 +231,11 @@ class Replicator(object):
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name)) None if name else app.uuid, app.server, name or app.name))
if previous_node is not None and previous_node.isConnected(): if previous_node is not None and previous_node.isConnected():
previous_node.getConnection().closeClient() app.closeClient(previous_node.getConnection())
def connected(self, node):
if self.current_node is node and self.current_partition is not None:
self.fetchTransactions()
def fetchTransactions(self, min_tid=None): def fetchTransactions(self, min_tid=None):
offset = self.current_partition offset = self.current_partition
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from binascii import a2b_hex
import unittest import unittest
from mock import Mock from mock import Mock
from neo.lib.util import dump, p64, u64 from neo.lib.util import dump, p64, u64
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
...@@ -499,10 +500,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -499,10 +500,6 @@ class StorageDBTests(NeoUnitTestBase):
def test_getReplicationTIDList(self): def test_getReplicationTIDList(self):
self.setNumPartitions(2, True) self.setNumPartitions(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
# - all
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3])
# - one partition # - one partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
self.checkSet(result, [tid1, tid3]) self.checkSet(result, [tid1, tid3])
...@@ -519,6 +516,37 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -519,6 +516,37 @@ class StorageDBTests(NeoUnitTestBase):
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 1, 0)
self.checkSet(result, [tid1]) self.checkSet(result, [tid1])
def test_checkRange(self):
def check(trans, obj, *args):
self.assertEqual(trans, self.db.checkTIDRange(*args))
self.assertEqual(obj, self.db.checkSerialRange(*(args+(ZERO_OID,))))
self.setNumPartitions(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
z = 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
# - one partition
check((2, a2b_hex('84320eb8dbbe583f67055c15155ab6794f11654d'), tid3),
z,
0, 10, ZERO_TID, MAX_TID)
# - another partition
check((2, a2b_hex('1f02f98cf775a9e0ce9252ff5972dce728c4ddb0'), tid4),
(4, a2b_hex('e5b47bddeae2096220298df686737d939a27d736'), tid4,
a2b_hex('1e9093698424b5370e19acd2d5fc20dcd56a32cd'), p64(1)),
1, 10, ZERO_TID, MAX_TID)
self.assertEqual(
(3, a2b_hex('b85e2d4914e22b5ad3b82b312b3dc405dc17dcb8'), tid4,
a2b_hex('1b6d73ecdc064595fe915a5c26da06b195caccaa'), p64(1)),
self.db.checkSerialRange(1, 10, ZERO_TID, MAX_TID, p64(2)))
# - min_tid is inclusive
check((1, a2b_hex('da4b9237bacccdf19c0760cab7aec4a8359010b0'), tid3),
z,
0, 10, tid3, MAX_TID)
# - max tid is inclusive
x = 1, a2b_hex('b6589fc6ab0dc82cf12099d1c2d40ab994e8410c'), tid1
check(x, z, 0, 10, ZERO_TID, tid2)
# - limit
y = 1, a2b_hex('356a192b7913b04c54574d18c28d46e6395428ab'), tid2
check(y, x + y[1:], 1, 1, ZERO_TID, MAX_TID)
def test_findUndoTID(self): def test_findUndoTID(self):
self.setNumPartitions(4, True) self.setNumPartitions(4, True)
db = self.db db = self.db
......
...@@ -638,71 +638,6 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -638,71 +638,6 @@ class ProtocolTests(NeoUnitTestBase):
def test_AnswerTIDsFrom(self): def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom) self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
def test_AskCheckTIDRange(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckTIDRange(min_tid, max_tid, length, partition)
p_min_tid, p_max_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_max_tid, max_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
count = 1
tid_checksum = "3" * 20
max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid)
p_min_tid, p_length, p_count, p_tid_checksum, p_max_tid = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_tid, max_tid)
def test_AskCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
max_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckSerialRange(min_oid, min_serial, max_tid, length,
partition)
p_min_oid, p_min_serial, p_max_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_max_tid, max_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
count = 1
oid_checksum = "4" * 20
max_oid = self.getOID(5)
tid_checksum = "5" * 20
max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial)
p_min_oid, p_min_serial, p_length, p_count, p_oid_checksum, \
p_max_oid, p_tid_checksum, p_max_serial = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_oid_checksum, oid_checksum)
self.assertEqual(p_max_oid, max_oid)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_serial, max_serial)
def test_AskPack(self): def test_AskPack(self):
tid = self.getNextTID() tid = self.getNextTID()
p = Packets.AskPack(tid) p = Packets.AskPack(tid)
......
...@@ -715,7 +715,7 @@ class NEOCluster(object): ...@@ -715,7 +715,7 @@ class NEOCluster(object):
txn = transaction.Transaction() txn = transaction.Transaction()
storage.tpc_begin(txn, tid(i)) storage.tpc_begin(txn, tid(i))
for o in oid_list: for o in oid_list:
storage.store(p64(o), tid_dict.get(o), repr((i, o)), '', txn) storage.store(oid(o), tid_dict.get(o), repr((i, o)), '', txn)
storage.tpc_vote(txn) storage.tpc_vote(txn)
i = storage.tpc_finish(txn) i = storage.tpc_finish(txn)
for o in oid_list: for o in oid_list:
...@@ -788,6 +788,7 @@ def predictable_random(seed=None): ...@@ -788,6 +788,7 @@ def predictable_random(seed=None):
# Because we have 2 running threads when client works, we can't # Because we have 2 running threads when client works, we can't
# patch neo.client.pool (and cluster should have 1 storage). # patch neo.client.pool (and cluster should have 1 storage).
from neo.master import backup_app from neo.master import backup_app
from neo.master.handlers import administration
from neo.storage import replicator from neo.storage import replicator
def decorator(wrapped): def decorator(wrapped):
def wrapper(*args, **kw): def wrapper(*args, **kw):
...@@ -800,10 +801,12 @@ def predictable_random(seed=None): ...@@ -800,10 +801,12 @@ def predictable_random(seed=None):
if node_type == NodeTypes.CLIENT else if node_type == NodeTypes.CLIENT else
UUID_NAMESPACES[node_type] + ''.join( UUID_NAMESPACES[node_type] + ''.join(
chr(r.randrange(256)) for _ in xrange(15))) chr(r.randrange(256)) for _ in xrange(15)))
backup_app.random = replicator.random = r administration.random = backup_app.random = replicator.random \
= r
return wrapped(*args, **kw) return wrapped(*args, **kw)
finally: finally:
del MasterApplication.getNewUUID del MasterApplication.getNewUUID
backup_app.random = replicator.random = random administration.random = backup_app.random = replicator.random \
= random
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
return decorator return decorator
...@@ -23,11 +23,13 @@ import threading ...@@ -23,11 +23,13 @@ import threading
import transaction import transaction
import unittest import unittest
import neo.lib import neo.lib
from neo.storage.checker import CHECK_COUNT
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError DelayedError, ConflictError
from neo.lib.connection import MTClientConnection from neo.lib.connection import MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID ZERO_OID, ZERO_TID, MAX_TID
from neo.lib.util import p64
from . import NEOCluster, NEOThreadedTest, Patch, predictable_random from . import NEOCluster, NEOThreadedTest, Patch, predictable_random
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
...@@ -36,8 +38,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -36,8 +38,9 @@ class ReplicationTests(NEOThreadedTest):
def checksumPartition(self, storage, partition): def checksumPartition(self, storage, partition):
dm = storage.dm dm = storage.dm
args = ZERO_TID, MAX_TID, None, partition args = partition, None, ZERO_TID, MAX_TID
return dm.checkTIDRange(*args), dm.checkSerialRange(ZERO_TID, *args) return dm.checkTIDRange(*args), \
dm.checkSerialRange(min_oid=ZERO_OID, *args)
def checkPartitionReplicated(self, source, destination, partition): def checkPartitionReplicated(self, source, destination, partition):
self.assertEqual(self.checksumPartition(source, partition), self.assertEqual(self.checksumPartition(source, partition),
...@@ -60,25 +63,28 @@ class ReplicationTests(NEOThreadedTest): ...@@ -60,25 +63,28 @@ class ReplicationTests(NEOThreadedTest):
return checked return checked
def testBackupNormalCase(self): def testBackupNormalCase(self):
upstream = NEOCluster(partitions=7, replicas=1, storage_count=3) np = 7
nr = 2
check_dict = dict.fromkeys(xrange(np))
upstream = NEOCluster(partitions=np, replicas=nr-1, storage_count=3)
try: try:
upstream.start() upstream.start()
importZODB = upstream.importZODB() importZODB = upstream.importZODB()
importZODB(3) importZODB(3)
upstream.client.setPoll(0) upstream.client.setPoll(0)
backup = NEOCluster(partitions=7, replicas=1, storage_count=5, backup = NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream) upstream=upstream)
try: try:
backup.start() backup.start()
# Initialize & catch up. # Initialize & catch up.
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
backup.tic() backup.tic()
self.assertEqual(14, self.checkBackup(backup)) self.assertEqual(np*nr, self.checkBackup(backup))
# Normal case, following upstream cluster closely. # Normal case, following upstream cluster closely.
importZODB(17) importZODB(17)
upstream.client.setPoll(0) upstream.client.setPoll(0)
backup.tic() backup.tic()
self.assertEqual(14, self.checkBackup(backup)) self.assertEqual(np*nr, self.checkBackup(backup))
# Check that a backup cluster can be restarted. # Check that a backup cluster can be restarted.
finally: finally:
backup.stop() backup.stop()
...@@ -90,11 +96,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -90,11 +96,13 @@ class ReplicationTests(NEOThreadedTest):
importZODB(17) importZODB(17)
upstream.client.setPoll(0) upstream.client.setPoll(0)
backup.tic() backup.tic()
self.assertEqual(14, self.checkBackup(backup)) self.assertEqual(np*nr, self.checkBackup(backup))
backup.neoctl.checkReplicas(check_dict, ZERO_TID, None)
backup.tic()
# Stop backing up, nothing truncated. # Stop backing up, nothing truncated.
backup.neoctl.setClusterState(ClusterStates.STOPPING_BACKUP) backup.neoctl.setClusterState(ClusterStates.STOPPING_BACKUP)
backup.tic() backup.tic()
self.assertEqual(14, self.checkBackup(backup)) self.assertEqual(np*nr, self.checkBackup(backup))
self.assertEqual(backup.neoctl.getClusterState(), self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.RUNNING) ClusterStates.RUNNING)
finally: finally:
...@@ -110,6 +118,8 @@ class ReplicationTests(NEOThreadedTest): ...@@ -110,6 +118,8 @@ class ReplicationTests(NEOThreadedTest):
- primary storage disconnected from backup master - primary storage disconnected from backup master
- non-primary storage disconnected from backup master - non-primary storage disconnected from backup master
""" """
np = 4
check_dict = dict.fromkeys(xrange(np))
from neo.master.backup_app import random from neo.master.backup_app import random
def fetchObjects(orig, min_tid=None, min_oid=ZERO_OID): def fetchObjects(orig, min_tid=None, min_oid=ZERO_OID):
if min_tid is None: if min_tid is None:
...@@ -124,11 +134,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -124,11 +134,11 @@ class ReplicationTests(NEOThreadedTest):
node_list.remove(txn.getNode()) node_list.remove(txn.getNode())
node_list[0].getConnection().close() node_list[0].getConnection().close()
return orig(txn) return orig(txn)
upstream = NEOCluster(partitions=4, replicas=0, storage_count=1) upstream = NEOCluster(partitions=np, replicas=0, storage_count=1)
try: try:
upstream.start() upstream.start()
importZODB = upstream.importZODB(random=random) importZODB = upstream.importZODB(random=random)
backup = NEOCluster(partitions=4, replicas=2, storage_count=4, backup = NEOCluster(partitions=np, replicas=2, storage_count=4,
upstream=upstream) upstream=upstream)
try: try:
backup.start() backup.start()
...@@ -160,8 +170,10 @@ class ReplicationTests(NEOThreadedTest): ...@@ -160,8 +170,10 @@ class ReplicationTests(NEOThreadedTest):
finally: finally:
del p del p
upstream.client.setPoll(0) upstream.client.setPoll(0)
if event > 5:
backup.neoctl.checkReplicas(check_dict, ZERO_TID, None)
backup.tic() backup.tic()
self.assertEqual(12, self.checkBackup(backup)) self.assertEqual(np*3, self.checkBackup(backup))
finally: finally:
backup.stop() backup.stop()
finally: finally:
...@@ -196,12 +208,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -196,12 +208,13 @@ class ReplicationTests(NEOThreadedTest):
# default for performance reason # default for performance reason
orig.im_self.dropPartitions((offset,)) orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list) return orig(ptid, cell_list)
cluster = NEOCluster(partitions=3, replicas=1, storage_count=3) np = 3
cluster = NEOCluster(partitions=np, replicas=1, storage_count=3)
s0, s1, s2 = cluster.storage_list s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects: for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
try: try:
cluster.start([s0]) cluster.start([s0])
cluster.populate([range(6)] * 3) cluster.populate([range(np*2)] * np)
cluster.client.setPoll(0) cluster.client.setPoll(0)
s1.start() s1.start()
s2.start() s2.start()
...@@ -223,6 +236,50 @@ class ReplicationTests(NEOThreadedTest): ...@@ -223,6 +236,50 @@ class ReplicationTests(NEOThreadedTest):
cluster.stop() cluster.stop()
cluster.reset(True) cluster.reset(True)
def testCheckReplicas(self):
from neo.storage import checker
def corrupt(offset):
s0, s1, s2 = (storage_dict[cell.getUUID()]
for cell in cluster.master.pt.getCellList(offset, True))
s1.dm.deleteObject(p64(np+offset), p64(corrupt_tid))
return s0.uuid
def check(expected_state, expected_count):
self.assertEqual(expected_count, len([None
for row in cluster.neoctl.getPartitionRowList()[1]
for cell in row[1]
if cell[1] == CellStates.CORRUPTED]))
self.assertEqual(expected_state, cluster.neoctl.getClusterState())
np = 5
tid_count = np * 3
corrupt_tid = tid_count // 2
check_dict = dict.fromkeys(xrange(np))
cluster = NEOCluster(partitions=np, replicas=2, storage_count=3)
try:
checker.CHECK_COUNT = 2
cluster.start()
cluster.populate([range(np*2)] * tid_count)
cluster.client.setPoll(0)
storage_dict = dict((x.uuid, x) for x in cluster.storage_list)
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
cluster.tic()
check(ClusterStates.RUNNING, 0)
source = corrupt(0)
cluster.neoctl.checkReplicas(check_dict, p64(corrupt_tid+1), None)
cluster.tic()
check(ClusterStates.RUNNING, 0)
cluster.neoctl.checkReplicas({0: source}, ZERO_TID, None)
cluster.tic()
check(ClusterStates.RUNNING, 1)
corrupt(1)
cluster.neoctl.checkReplicas(check_dict, p64(corrupt_tid+1), None)
cluster.tic()
check(ClusterStates.RUNNING, 1)
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
cluster.tic()
check(ClusterStates.VERIFYING, 4)
finally:
checker.CHECK_COUNT = CHECK_COUNT
cluster.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment