Commit 8e3c7b01 authored by Julien Muchembled's avatar Julien Muchembled

Implements backup using specialised storage nodes and relying on replication

Replication is also fully reimplemented:
- It is not done anymore on whole partitions.
- It runs at lowest priority not to degrades performance for client nodes.

Schema of MySQL table is changed to optimize storage layout: rows are now
grouped by age, for good partial replication performance.
This certainly also speeds up simple loads/stores.
parent 75d83690
......@@ -111,42 +111,17 @@ RC - Review output of pylint (CODE)
consider using query(request, args) instead of query(request % args)
- Make listening address and port optionnal, and if they are not provided
listen on all interfaces on any available port.
- Replication throttling (HIGH AVAILABILITY)
In its current implementation, replication runs at full speed, which
degrades performance for client nodes. Replication should allow
throttling, and that throttling should be configurable.
See "Replication pipelining".
- Make replication speed configurable (HIGH AVAILABILITY)
In its current implementation, replication runs at lowest priority, not to
degrades performance for client nodes. But when there's only 1 storage
left for a partition, it may be wanted to guarantee a minimum speed to
avoid complete data loss if another failure happens too early.
- Pack segmentation & throttling (HIGH AVAILABILITY)
In its current implementation, pack runs in one call on all storage nodes
at the same time, which lcoks down the whole cluster. This task should
be split in chunks and processed in "background" on storage nodes.
Packing throttling should probably be at the lowest possible priority
(below interactive use and below replication).
- Replication pipelining (SPEED)
Replication work currently with too many exchanges between replicating
storage, and network latency can become a significant limit.
This should be changed to have just one initial request from
replicating storage, and multiple packets from reference storage with
database range checksums. When receiving these checksums, replicating
storage must compare with what it has, and ask row lists (might not even
be required) and data when there are differences. Quick fetching from
network with asynchronous checking (=queueing) + congestion control
(asking reference storage's to pause its packet flow) will probably be
required.
This should make it easier to throttle replication workload on reference
storage node, as it can decide to postpone replication-related packets on
its own.
- Partial replication (SPEED)
In its current implementation, replication always happens on a whole
partition. In typical use, only a few last transactions will have been
missed, so replicating only past a given TID would be much faster.
To achieve this, storage nodes must store 2 values:
- a pack identifier, which must be different each time a pack occurs
(increasing number sequence, TID-ish, etc) to trigger a
whole-partition replication when a pack happened (this could be
improved too, later)
- the latest (-ish) transaction committed locally, to use as a lower
replication boundary
- tpc_finish failures propagation to master (FUNCTIONALITY)
When asked to lock transaction data, if something goes wrong the master
node must be informed.
......
......@@ -9,7 +9,7 @@ SQL commands to migrate each storage from NEO 0.10.x::
CREATE TABLE new_data (id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, hash BINARY(20) NOT NULL UNIQUE, compression TINYINT UNSIGNED NULL, value LONGBLOB NULL) ENGINE = InnoDB SELECT DISTINCT obj.hash as hash, compression, value FROM obj, data WHERE obj.hash=data.hash ORDER BY serial;
DROP TABLE data;
RENAME TABLE new_data TO data;
CREATE TABLE new_obj (partition SMALLINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, data_id BIGINT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL, PRIMARY KEY (partition, oid, serial), KEY (data_id)) ENGINE = InnoDB SELECT partition, oid, serial, data.id as data_id, value_serial FROM obj LEFT JOIN data ON (obj.hash=data.hash);
CREATE TABLE new_obj (partition SMALLINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, data_id BIGINT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL, PRIMARY KEY (partition, serial, oid), KEY (partition, oid, serial), KEY (data_id)) ENGINE = InnoDB SELECT partition, oid, serial, data.id as data_id, value_serial FROM obj LEFT JOIN data ON (obj.hash=data.hash);
DROP TABLE obj;
RENAME TABLE new_obj TO obj;
ALTER TABLE tobj CHANGE hash data_id BIGINT UNSIGNED NULL;
......
......@@ -959,7 +959,7 @@ class Application(object):
tid_list = []
# request a tid list for each partition
for offset in xrange(self.pt.getPartitions()):
p = Packets.AskTIDsFrom(start, stop, limit, [offset])
p = Packets.AskTIDsFrom(start, stop, limit, offset)
for node, conn in self.cp.iterateForObject(offset, readable=True):
try:
r = self._askStorage(conn, p)
......
......@@ -90,3 +90,8 @@ class ConfigurationManager(object):
# only from command line
return util.bin(self.argument_list.get('uuid', None))
def getUpstreamCluster(self):
return self.__get('upstream_cluster', True)
def getUpstreamMasters(self):
return util.parseMasterList(self.__get('upstream_masters'))
......@@ -79,6 +79,9 @@ class EpollEventManager(object):
self.epoll.unregister(fd)
del self.connection_dict[fd]
def isIdle(self):
return not (self._pending_processing or self.writer_set)
def _addPendingConnection(self, conn):
pending_processing = self._pending_processing
if conn not in pending_processing:
......
......@@ -48,6 +48,7 @@ class ErrorCodes(Enum):
PROTOCOL_ERROR = Enum.Item(4)
BROKEN_NODE = Enum.Item(5)
ALREADY_PENDING = Enum.Item(7)
REPLICATION_ERROR = Enum.Item(8)
ErrorCodes = ErrorCodes()
class ClusterStates(Enum):
......@@ -55,6 +56,9 @@ class ClusterStates(Enum):
VERIFYING = Enum.Item(2)
RUNNING = Enum.Item(3)
STOPPING = Enum.Item(4)
STARTING_BACKUP = Enum.Item(5)
BACKINGUP = Enum.Item(6)
STOPPING_BACKUP = Enum.Item(7)
ClusterStates = ClusterStates()
class NodeTypes(Enum):
......@@ -117,6 +121,7 @@ ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID)
MAX_TID = '\x7f' + '\xff' * 7 # SQLite does not accept numbers above 2^63-1
UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S',
......@@ -723,6 +728,7 @@ class LastIDs(Packet):
POID('last_oid'),
PTID('last_tid'),
PPTID('last_ptid'),
PTID('backup_tid'),
)
class PartitionTable(Packet):
......@@ -760,16 +766,6 @@ class PartitionChanges(Packet):
),
)
class ReplicationDone(Packet):
"""
Notify the master node that a partition has been successully replicated from
a storage to another.
S -> M
"""
_fmt = PStruct('notify_replication_done',
PNumber('offset'),
)
class StartOperation(Packet):
"""
Tell a storage nodes to start an operation. Until a storage node receives
......@@ -965,7 +961,7 @@ class GetObject(Packet):
"""
Ask a stored object by its OID and a serial or a TID if given. If a serial
is specified, the specified revision of an object will be returned. If
a TID is specified, an object right before the TID will be returned. S,C -> S.
a TID is specified, an object right before the TID will be returned. C -> S.
Answer the requested object. S -> C.
"""
_fmt = PStruct('ask_object',
......@@ -1003,16 +999,14 @@ class TIDList(Packet):
class TIDListFrom(Packet):
"""
Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
S -> S.
Answer the requested TIDs. S -> S
C -> S.
Answer the requested TIDs. S -> C
"""
_fmt = PStruct('tid_list_from',
PTID('min_tid'),
PTID('max_tid'),
PNumber('length'),
PList('partition_list',
PNumber('partition'),
),
PNumber('partition'),
)
_answer = PStruct('answer_tids',
......@@ -1054,27 +1048,6 @@ class ObjectHistory(Packet):
PFHistoryList,
)
class ObjectHistoryFrom(Packet):
"""
Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial for min_oid. S -> S.
Answer the requested serials. S -> S.
"""
_fmt = PStruct('ask_object_history',
POID('min_oid'),
PTID('min_serial'),
PTID('max_serial'),
PNumber('length'),
PNumber('partition'),
)
_answer = PStruct('ask_finish_transaction',
PDict('object_dict',
POID('oid'),
PFTidList,
),
)
class PartitionList(Packet):
"""
All the following messages are for neoctl to admin node
......@@ -1341,6 +1314,110 @@ class NotifyReady(Packet):
"""
pass
# replication
class FetchTransactions(Packet):
"""
S -> S
"""
_fmt = PStruct('ask_transaction_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
PFTidList, # already known transactions
)
_answer = PStruct('answer_transaction_list',
PTID('pack_tid'),
PTID('next_tid'),
PFTidList, # transactions to delete
)
class AddTransaction(Packet):
"""
S -> S
"""
_fmt = PStruct('add_transaction',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PFOidList,
)
class FetchObjects(Packet):
"""
S -> S
"""
_fmt = PStruct('ask_object_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
PDict('object_dict', # already known objects
PTID('serial'),
PFOidList,
),
)
_answer = PStruct('answer_object_list',
PTID('pack_tid'),
PTID('next_tid'),
POID('next_oid'),
PDict('object_dict', # objects to delete
PTID('serial'),
PFOidList,
),
)
class AddObject(Packet):
"""
S -> S
"""
_fmt = PStruct('add_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class Replicate(Packet):
"""
M -> S
"""
_fmt = PStruct('replicate',
PTID('tid'),
PString('upstream_name'),
PDict('source_dict',
PNumber('partition'),
PAddress('address'),
)
)
class ReplicationDone(Packet):
"""
Notify the master node that a partition has been successully replicated from
a storage to another.
S -> M
"""
_fmt = PStruct('notify_replication_done',
PNumber('offset'),
PTID('tid'),
)
class Truncate(Packet):
"""
M -> S
"""
_fmt = PStruct('ask_truncate',
PTID('tid'),
)
_answer = PFEmpty
StaticRegistry = {}
def register(request, ignore_when_closed=None):
""" Register a packet in the packet registry """
......@@ -1516,16 +1593,12 @@ class Packets(dict):
ClusterState)
NotifyLastOID = register(
NotifyLastOID)
NotifyReplicationDone = register(
ReplicationDone)
AskObjectUndoSerial, AnswerObjectUndoSerial = register(
ObjectUndoSerial)
AskHasLock, AnswerHasLock = register(
HasLock)
AskTIDsFrom, AnswerTIDsFrom = register(
TIDListFrom)
AskObjectHistoryFrom, AnswerObjectHistoryFrom = register(
ObjectHistoryFrom)
AskPack, AnswerPack = register(
Pack, ignore_when_closed=False)
AskCheckTIDRange, AnswerCheckTIDRange = register(
......@@ -1540,6 +1613,20 @@ class Packets(dict):
CheckCurrentSerial)
NotifyTransactionFinished = register(
NotifyTransactionFinished)
Replicate = register(
Replicate)
NotifyReplicationDone = register(
ReplicationDone)
AskFetchTransactions, AnswerFetchTransactions = register(
FetchTransactions)
AskFetchObjects, AnswerFetchObjects = register(
FetchObjects)
AddTransaction = register(
AddTransaction)
AddObject = register(
AddObject)
AskTruncate, AnswerTruncate = register(
Truncate)
def Errors():
registry_dict = {}
......
......@@ -150,6 +150,11 @@ class PartitionTable(object):
return True
return False
def getCell(self, offset, uuid):
for cell in self.partition_list[offset]:
if cell.getUUID() == uuid:
return cell
def setCell(self, offset, node, state):
if state == CellStates.DISCARDED:
return self.removeCell(offset, node)
......@@ -157,28 +162,19 @@ class PartitionTable(object):
raise PartitionTableException('Invalid node state')
self.count_dict.setdefault(node, 0)
row = self.partition_list[offset]
if len(row) == 0:
# Create a new row.
row = [Cell(node, state), ]
if state != CellStates.FEEDING:
self.count_dict[node] += 1
self.partition_list[offset] = row
self.num_filled_rows += 1
for cell in self.partition_list[offset]:
if cell.getNode() is node:
if not cell.isFeeding():
self.count_dict[node] -= 1
cell.setState(state)
break
else:
# XXX this can be slow, but it is necessary to remove a duplicate,
# if any.
for cell in row:
if cell.getNode() == node:
row.remove(cell)
if not cell.isFeeding():
self.count_dict[node] -= 1
break
row = self.partition_list[offset]
self.num_filled_rows += not row
row.append(Cell(node, state))
if state != CellStates.FEEDING:
self.count_dict[node] += 1
return (offset, node.getUUID(), state)
if state != CellStates.FEEDING:
self.count_dict[node] += 1
return offset, node.getUUID(), state
def removeCell(self, offset, node):
row = self.partition_list[offset]
......
......@@ -28,6 +28,10 @@ from neo.lib.event import EventManager
from neo.lib.connection import ListeningConnection, ClientConnection
from neo.lib.exception import ElectionFailure, PrimaryFailure, OperationFailure
from neo.lib.util import dump
class StateChangedException(Exception): pass
from .backup_app import BackupApplication
from .handlers import election, identification, secondary
from .handlers import administration, client, storage, shutdown
from .pt import PartitionTable
......@@ -41,6 +45,8 @@ class Application(object):
packing = None
# Latest completely commited TID
last_transaction = ZERO_TID
backup_tid = None
backup_app = None
def __init__(self, config):
# Internal attributes.
......@@ -90,16 +96,29 @@ class Application(object):
self._current_manager = None
# backup
upstream_cluster = config.getUpstreamCluster()
if upstream_cluster:
if upstream_cluster == self.name:
raise ValueError("upstream cluster name must be"
" different from cluster name")
self.backup_app = BackupApplication(self, upstream_cluster,
*config.getUpstreamMasters())
registerLiveDebugger(on_log=self.log)
def close(self):
self.listening_conn = None
if self.backup_app is not None:
self.backup_app.close()
self.nm.close()
self.em.close()
del self.__dict__
def log(self):
self.em.log()
if self.backup_app is not None:
self.backup_app.log()
self.nm.log()
self.tm.log()
if self.pt is not None:
......@@ -257,27 +276,29 @@ class Application(object):
a shutdown.
"""
neo.lib.logging.info('provide service')
em = self.em
poll = self.em.poll
self.tm.reset()
self.changeClusterState(ClusterStates.RUNNING)
# Now everything is passive.
while True:
try:
em.poll(1)
except OperationFailure:
# If not operational, send Stop Operation packets to storage
# nodes and client nodes. Abort connections to client nodes.
neo.lib.logging.critical('No longer operational')
for node in self.nm.getIdentifiedList():
if node.isStorage() or node.isClient():
node.notify(Packets.StopOperation())
if node.isClient():
node.getConnection().abort()
# Then, go back, and restart.
return
try:
while True:
poll(1)
except OperationFailure:
# If not operational, send Stop Operation packets to storage
# nodes and client nodes. Abort connections to client nodes.
neo.lib.logging.critical('No longer operational')
except StateChangedException, e:
assert e.args[0] == ClusterStates.STARTING_BACKUP
self.backup_tid = tid = self.getLastTransaction()
self.pt.setBackupTidDict(dict((node.getUUID(), tid)
for node in self.nm.getStorageList(only_identified=True)))
for node in self.nm.getIdentifiedList():
if node.isStorage() or node.isClient():
node.notify(Packets.StopOperation())
if node.isClient():
node.getConnection().abort()
def playPrimaryRole(self):
neo.lib.logging.info(
......@@ -314,7 +335,13 @@ class Application(object):
self.runManager(RecoveryManager)
while True:
self.runManager(VerificationManager)
self.provideService()
if self.backup_tid:
if self.backup_app is None:
raise RuntimeError("No upstream cluster to backup"
" defined in configuration")
self.backup_app.provideService()
else:
self.provideService()
def playSecondaryRole(self):
"""
......@@ -364,7 +391,8 @@ class Application(object):
# select the storage handler
client_handler = client.ClientServiceHandler(self)
if state == ClusterStates.RUNNING:
if state in (ClusterStates.RUNNING, ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP, ClusterStates.STOPPING_BACKUP):
storage_handler = storage.StorageServiceHandler(self)
elif self._current_manager is not None:
storage_handler = self._current_manager.getHandler()
......@@ -389,8 +417,9 @@ class Application(object):
handler = storage_handler
else:
continue # keep handler
conn.setHandler(handler)
handler.connectionCompleted(conn)
if type(handler) is not type(conn.getLastHandler()):
conn.setHandler(handler)
handler.connectionCompleted(conn)
self.cluster_state = state
def getNewUUID(self, node_type):
......@@ -437,19 +466,13 @@ class Application(object):
sys.exit()
def identifyStorageNode(self, uuid, node):
state = NodeStates.RUNNING
handler = None
if self.cluster_state == ClusterStates.RUNNING:
if uuid is None or node is None:
# same as for verification
state = NodeStates.PENDING
handler = storage.StorageServiceHandler(self)
elif self.cluster_state == ClusterStates.STOPPING:
if self.cluster_state == ClusterStates.STOPPING:
raise NotReadyError
else:
raise RuntimeError('unhandled cluster state: %s' %
(self.cluster_state, ))
return (uuid, state, handler)
state = NodeStates.RUNNING
if uuid is None or node is None:
# same as for verification
state = NodeStates.PENDING
return uuid, state, storage.StorageServiceHandler(self)
def identifyNode(self, node_type, uuid, node):
......
##############################################################################
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsibility of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# guarantees and support are strongly advised to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
##############################################################################
import random, weakref
from bisect import bisect
import neo.lib
from neo.lib.bootstrap import BootstrapManager
from neo.lib.connector import getConnectorHandler
from neo.lib.exception import PrimaryFailure
from neo.lib.node import NodeManager
from neo.lib.protocol import CellStates, ClusterStates, NodeTypes, Packets
from neo.lib.protocol import INVALID_TID, ZERO_TID
from neo.lib.util import add64, u64, dump
from .app import StateChangedException
from .pt import PartitionTable
from .handlers.backup import BackupHandler
"""
Backup algorithm
This implementation relies on normal storage replication.
Storage nodes that are specialised for backup are not in the same NEO cluster,
but are managed by another master in a different cluster.
When the cluster is in BACKINGUP state, its master acts like a client to the
master of the main cluster. It gets notified of new data thanks to invalidation,
and notifies in turn its storage nodes what/when to replicate.
Storages stay in UP_TO_DATE state, even if partitions are synchronized up to
different tids. Storage nodes remember they are in such state and when
switching into RUNNING state, the cluster cuts the DB at the last TID for which
we have all data.
Out of backup storage nodes assigned to a partition, one is chosen as primary
for that partition. It means only this node will fetch data from the upstream
cluster, to minimize bandwidth between clusters. Other replicas will
synchronize from the primary node.
There is no UUID conflict between the 2 clusters:
- Storage nodes connect anonymously to upstream.
- Master node receives a new from upstream master and uses it only when
communicating with it.
"""
class BackupApplication(object):
pt = None
def __init__(self, app, name, master_addresses, connector_name):
self.app = weakref.proxy(app)
self.name = name
self.nm = NodeManager()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
self.nm.createMaster(address=master_address)
em = property(lambda self: self.app.em)
def close(self):
self.nm.close()
del self.__dict__
def log(self):
self.nm.log()
if self.pt is not None:
self.pt.log()
def provideService(self):
neo.lib.logging.info('provide backup')
poll = self.em.poll
app = self.app
pt = app.pt
while True:
app.changeClusterState(ClusterStates.STARTING_BACKUP)
bootstrap = BootstrapManager(self, self.name, NodeTypes.CLIENT)
# {offset -> node}
self.primary_partition_dict = {}
# [[tid]]
self.tid_list = tuple([] for _ in xrange(pt.getPartitions()))
try:
node, conn, uuid, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection(self.connector_handler)
try:
app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node
if num_partitions != pt.getPartitions():
raise RuntimeError("inconsistent number of partitions")
self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskNodeInformation())
conn.ask(Packets.AskPartitionTable())
conn.ask(Packets.AskLastTransaction())
# debug variable to log how big 'tid_list' can be.
self.debug_tid_count = 0
while True:
poll(1)
except PrimaryFailure, msg:
neo.lib.logging.error('upstream master is down: %s', msg)
finally:
app.backup_tid = pt.getBackupTid()
try:
conn.close()
except PrimaryFailure:
pass
try:
del self.pt
except AttributeError:
pass
except StateChangedException, e:
app.changeClusterState(*e.args)
last_tid = app.getLastTransaction()
if last_tid < app.backup_tid:
neo.lib.logging.warning(
"Truncating at %s (last_tid was %s)",
dump(app.backup_tid), dump(last_tid))
p = Packets.AskTruncate(app.backup_tid)
connection_list = []
for node in app.nm.getStorageList(only_identified=True):
conn = node.getConnection()
conn.ask(p)
connection_list.append(conn)
for conn in connection_list:
while conn.isPending():
poll(1)
app.setLastTransaction(app.backup_tid)
del app.backup_tid
break
finally:
del self.primary_partition_dict, self.tid_list
def nodeLost(self, node):
getCellList = self.app.pt.getCellList
trigger_set = set()
for offset, primary_node in self.primary_partition_dict.items():
if primary_node is not node:
continue
cell_list = getCellList(offset, readable=True)
cell = max(cell_list, key=lambda cell: cell.backup_tid)
tid = cell.backup_tid
self.primary_partition_dict[offset] = primary_node = cell.getNode()
p = Packets.Replicate(tid, '', {offset: primary_node.getAddress()})
for cell in cell_list:
cell.replicating = tid
if cell.backup_tid < tid:
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from %r",
dump(cell.getUUID()), offset, u64(tid),
dump(primary_node.getUUID()))
cell.getNode().getConnection().notify(p)
trigger_set.add(primary_node)
for node in trigger_set:
self.triggerBackup(node)
def invalidatePartitions(self, tid, partition_set):
app = self.app
prev_tid = app.getLastTransaction()
app.setLastTransaction(tid)
pt = app.pt
getByUUID = app.nm.getByUUID
trigger_set = set()
for offset in xrange(pt.getPartitions()):
try:
last_max_tid = self.tid_list[offset][-1]
except IndexError:
last_max_tid = INVALID_TID
if offset in partition_set:
self.tid_list[offset].append(tid)
node_list = []
for cell in pt.getCellList(offset, readable=True):
node = cell.getNode()
assert node.isConnected()
node_list.append(node)
if last_max_tid <= cell.backup_tid:
# This is the last time we can increase
# 'backup_tid' without replication.
neo.lib.logging.debug(
"partition %u: updating backup_tid of %r to %u",
offset, cell, u64(prev_tid))
cell.backup_tid = prev_tid
assert node_list
trigger_set.update(node_list)
# Make sure we have a primary storage for this partition.
if offset not in self.primary_partition_dict:
self.primary_partition_dict[offset] = \
random.choice(node_list)
else:
# Partition not touched, so increase 'backup_tid' of all
# "up-to-date" replicas, without having to replicate.
for cell in pt.getCellList(offset, readable=True):
if last_max_tid <= cell.backup_tid:
cell.backup_tid = tid
neo.lib.logging.debug(
"partition %u: updating backup_tid of %r to %u",
offset, cell, u64(tid))
for node in trigger_set:
self.triggerBackup(node)
count = sum(map(len, self.tid_list))
if self.debug_tid_count < count:
neo.lib.logging.debug("Maximum number of tracked tids: %u", count)
self.debug_tid_count = count
def triggerBackup(self, node):
tid_list = self.tid_list
tid = self.app.getLastTransaction()
replicate_list = []
for offset, cell in self.app.pt.iterNodeCell(node):
max_tid = tid_list[offset]
if max_tid and self.primary_partition_dict[offset] is node and \
max(cell.backup_tid, cell.replicating) < max_tid[-1]:
cell.replicating = tid
replicate_list.append(offset)
if not replicate_list:
return
getByUUID = self.nm.getByUUID
getCellList = self.pt.getCellList
source_dict = {}
address_set = set()
for offset in replicate_list:
cell_list = getCellList(offset, readable=True)
random.shuffle(cell_list)
assert cell_list, offset
for cell in cell_list:
addr = cell.getAddress()
if addr in address_set:
break
else:
address_set.add(addr)
source_dict[offset] = addr
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from %r",
dump(node.getUUID()), offset, u64(tid), addr)
node.getConnection().notify(Packets.Replicate(
tid, self.name, source_dict))
def notifyReplicationDone(self, node, offset, tid):
app = self.app
cell = app.pt.getCell(offset, node.getUUID())
tid_list = self.tid_list[offset]
if tid_list: # may be empty if the cell is out-of-date
# or if we're not fully initialized
if tid < tid_list[0]:
cell.replicating = tid
else:
try:
tid = add64(tid_list[bisect(tid_list, tid)], -1)
except IndexError:
tid = app.getLastTransaction()
neo.lib.logging.debug("partition %u: updating backup_tid of %r to %u",
offset, cell, u64(tid))
cell.backup_tid = tid
# Forget tids we won't need anymore.
cell_list = app.pt.getCellList(offset, readable=True)
del tid_list[:bisect(tid_list, min(x.backup_tid for x in cell_list))]
primary_node = self.primary_partition_dict.get(offset)
primary = primary_node is node
result = None if primary else app.pt.setUpToDate(node, offset)
if app.getClusterState() == ClusterStates.BACKINGUP:
assert not cell.isOutOfDate()
if result: # was out-of-date
max_tid, = [x.backup_tid for x in cell_list
if x.getNode() is primary_node]
if tid < max_tid:
cell.replicating = max_tid
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from %r",
dump(node.getUUID()), offset, u64(max_tid),
dump(primary_node.getUUID()))
node.getConnection().notify(Packets.Replicate(max_tid,
'', {offset: primary_node.getAddress()}))
else:
self.triggerBackup(node)
if primary:
# Notify secondary storages that they can replicate from
# primary ones, even if they are already replicating.
p = Packets.Replicate(tid, '', {offset: node.getAddress()})
for cell in cell_list:
if max(cell.backup_tid, cell.replicating) < tid:
cell.replicating = tid
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from"
" %r", dump(cell.getUUID()), offset, u64(tid),
dump(node.getUUID()))
cell.getNode().getConnection().notify(p)
return result
......@@ -18,15 +18,18 @@
import neo
from . import MasterHandler
from ..app import StateChangedException
from neo.lib.protocol import ClusterStates, NodeStates, Packets, ProtocolError
from neo.lib.protocol import Errors
from neo.lib.util import dump
CLUSTER_STATE_WORKFLOW = {
# destination: sources
ClusterStates.VERIFYING: set([ClusterStates.RECOVERING]),
ClusterStates.STOPPING: set([ClusterStates.RECOVERING,
ClusterStates.VERIFYING, ClusterStates.RUNNING]),
ClusterStates.VERIFYING: (ClusterStates.RECOVERING,),
ClusterStates.STARTING_BACKUP: (ClusterStates.RUNNING,
ClusterStates.STOPPING_BACKUP),
ClusterStates.STOPPING_BACKUP: (ClusterStates.BACKINGUP,
ClusterStates.STARTING_BACKUP),
}
class AdministrationHandler(MasterHandler):
......@@ -42,16 +45,17 @@ class AdministrationHandler(MasterHandler):
conn.answer(Packets.AnswerPrimary(app.uuid, []))
def setClusterState(self, conn, state):
app = self.app
# check request
if state not in CLUSTER_STATE_WORKFLOW:
try:
if app.cluster_state not in CLUSTER_STATE_WORKFLOW[state]:
raise ProtocolError('Can not switch to this state')
except KeyError:
raise ProtocolError('Invalid state requested')
valid_current_states = CLUSTER_STATE_WORKFLOW[state]
if self.app.cluster_state not in valid_current_states:
raise ProtocolError('Cannot switch to this state')
# change state
if state == ClusterStates.VERIFYING:
storage_list = self.app.nm.getStorageList(only_identified=True)
storage_list = app.nm.getStorageList(only_identified=True)
if not storage_list:
raise ProtocolError('Cannot exit recovery without any '
'storage node')
......@@ -60,15 +64,18 @@ class AdministrationHandler(MasterHandler):
if node.getConnection().isPending():
raise ProtocolError('Cannot exit recovery now: node %r is '
'entering cluster' % (node, ))
self.app._startup_allowed = True
else:
self.app.changeClusterState(state)
app._startup_allowed = True
state = app.cluster_state
elif state == ClusterStates.STARTING_BACKUP:
if app.tm.hasPending() or app.nm.getClientList(True):
raise ProtocolError("Can not switch to %s state with pending"
" transactions or connected clients" % state)
elif state != ClusterStates.STOPPING_BACKUP:
app.changeClusterState(state)
# answer
conn.answer(Errors.Ack('Cluster state changed'))
if state == ClusterStates.STOPPING:
self.app.cluster_state = state
self.app.shutdown()
if state != app.cluster_state:
raise StateChangedException(state)
def setNodeState(self, conn, uuid, state, modify_partition_table):
neo.lib.logging.info("set node state for %s-%s : %s" %
......
##############################################################################
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsibility of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# guarantees and support are strongly advised to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
##############################################################################
from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler
from neo.lib.protocol import CellStates
class BackupHandler(EventHandler):
"""Handler dedicated to upstream master during BACKINGUP state"""
def connectionLost(self, conn, new_state):
if self.app.app.listening_conn: # if running
raise PrimaryFailure('connection lost')
def answerPartitionTable(self, conn, ptid, row_list):
self.app.pt.load(ptid, row_list, self.app.nm)
def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
def answerLastTransaction(self, conn, tid):
app = self.app
app.invalidatePartitions(tid, set(xrange(app.pt.getPartitions())))
def invalidateObjects(self, conn, tid, oid_list):
app = self.app
getPartition = app.app.pt.getPartition
partition_set = set(map(getPartition, oid_list))
partition_set.add(getPartition(tid))
app.invalidatePartitions(tid, partition_set)
......@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib
from neo.lib.protocol import Packets, ProtocolError
from neo.lib.protocol import ClusterStates, Packets, ProtocolError
from neo.lib.exception import OperationFailure
from neo.lib.util import dump
from neo.lib.connector import ConnectorConnectionClosedException
......@@ -45,14 +45,18 @@ class StorageServiceHandler(BaseServiceHandler):
if not app.pt.operational():
raise OperationFailure, 'cannot continue operation'
app.tm.forget(conn.getUUID())
if app.getClusterState() == ClusterStates.BACKINGUP:
app.backup_app.nodeLost(node)
if app.packing is not None:
self.answerPack(conn, False)
def askLastIDs(self, conn):
app = self.app
loid = app.tm.getLastOID()
ltid = app.tm.getLastTID()
conn.answer(Packets.AnswerLastIDs(loid, ltid, app.pt.getID()))
conn.answer(Packets.AnswerLastIDs(
app.tm.getLastOID(),
app.tm.getLastTID(),
app.pt.getID(),
app.backup_tid))
def askUnfinishedTransactions(self, conn):
tm = self.app.tm
......@@ -68,15 +72,26 @@ class StorageServiceHandler(BaseServiceHandler):
# transaction locked on this storage node
self.app.tm.lock(ttid, conn.getUUID())
def notifyReplicationDone(self, conn, offset):
node = self.app.nm.getByUUID(conn.getUUID())
neo.lib.logging.debug("%s is up for offset %s" % (node, offset))
try:
cell_list = self.app.pt.setUpToDate(node, offset)
except PartitionTableException, e:
raise ProtocolError(str(e))
def notifyReplicationDone(self, conn, offset, tid):
app = self.app
node = app.nm.getByUUID(conn.getUUID())
if app.backup_tid:
cell_list = app.backup_app.notifyReplicationDone(node, offset, tid)
if not cell_list:
return
else:
try:
cell_list = self.app.pt.setUpToDate(node, offset)
if not cell_list:
raise ProtocolError('Non-oudated partition')
except PartitionTableException, e:
raise ProtocolError(str(e))
neo.lib.logging.debug("%s is up for offset %s", node, offset)
self.app.broadcastPartitionChanges(cell_list)
def answerTruncate(self, conn):
pass
def answerPack(self, conn, status):
app = self.app
if app.packing is not None:
......
......@@ -17,11 +17,25 @@
import neo.lib.pt
from struct import pack, unpack
from neo.lib.protocol import CellStates
from neo.lib.pt import PartitionTableException
from neo.lib.pt import PartitionTable
from neo.lib.protocol import CellStates, ZERO_TID
class PartitionTable(PartitionTable):
class Cell(neo.lib.pt.Cell):
replicating = ZERO_TID
def setState(self, state):
try:
if CellStates.OUT_OF_DATE == state != self.state:
del self.backup_tid, self.replicating
except AttributeError:
pass
return super(Cell, self).setState(state)
neo.lib.pt.Cell = Cell
class PartitionTable(neo.lib.pt.PartitionTable):
"""This class manages a partition table for the primary master node"""
def setID(self, id):
......@@ -54,7 +68,7 @@ class PartitionTable(PartitionTable):
row = []
for _ in xrange(repeats):
node = node_list[index]
row.append(neo.lib.pt.Cell(node))
row.append(Cell(node))
self.count_dict[node] = self.count_dict.get(node, 0) + 1
index += 1
if index == len(node_list):
......@@ -88,7 +102,7 @@ class PartitionTable(PartitionTable):
node_list = [c.getNode() for c in row]
n = self.findLeastUsedNode(node_list)
if n is not None:
row.append(neo.lib.pt.Cell(n,
row.append(Cell(n,
CellStates.OUT_OF_DATE))
self.count_dict[n] += 1
cell_list.append((offset, n.getUUID(),
......@@ -132,11 +146,11 @@ class PartitionTable(PartitionTable):
# check the partition is assigned and known as outdated
for cell in self.getCellList(offset):
if cell.getUUID() == uuid:
if not cell.isOutOfDate():
raise PartitionTableException('Non-oudated partition')
break
if cell.isOutOfDate():
break
return
else:
raise PartitionTableException('Non-assigned partition')
raise neo.lib.pt.PartitionTableException('Non-assigned partition')
# update the partition table
cell_list = [self.setCell(offset, node, CellStates.UP_TO_DATE)]
......@@ -177,7 +191,7 @@ class PartitionTable(PartitionTable):
else:
if num_cells <= self.nr:
row.append(neo.lib.pt.Cell(node, CellStates.OUT_OF_DATE))
row.append(Cell(node, CellStates.OUT_OF_DATE))
cell_list.append((offset, node.getUUID(),
CellStates.OUT_OF_DATE))
node_count += 1
......@@ -196,7 +210,7 @@ class PartitionTable(PartitionTable):
CellStates.FEEDING))
# Don't count a feeding cell.
self.count_dict[max_cell.getNode()] -= 1
row.append(neo.lib.pt.Cell(node, CellStates.OUT_OF_DATE))
row.append(Cell(node, CellStates.OUT_OF_DATE))
cell_list.append((offset, node.getUUID(),
CellStates.OUT_OF_DATE))
node_count += 1
......@@ -277,7 +291,7 @@ class PartitionTable(PartitionTable):
node = self.findLeastUsedNode([cell.getNode() for cell in row])
if node is None:
break
row.append(neo.lib.pt.Cell(node, CellStates.OUT_OF_DATE))
row.append(Cell(node, CellStates.OUT_OF_DATE))
changed_cell_list.append((offset, node.getUUID(),
CellStates.OUT_OF_DATE))
self.count_dict[node] += 1
......@@ -309,6 +323,13 @@ class PartitionTable(PartitionTable):
CellStates.OUT_OF_DATE))
return change_list
def iterNodeCell(self, node):
for offset, row in enumerate(self.partition_list):
for cell in row:
if cell.getNode() is node:
yield offset, cell
break
def getUpToDateCellNodeSet(self):
"""
Return a set of all nodes which are part of at least one UP TO DATE
......@@ -329,3 +350,16 @@ class PartitionTable(PartitionTable):
for cell in row
if cell.isOutOfDate())
def setBackupTidDict(self, backup_tid_dict):
for row in self.partition_list:
for cell in row:
cell.backup_tid = backup_tid_dict.get(cell.getUUID(),
ZERO_TID)
def getBackupTid(self):
try:
return min(max(cell.backup_tid for cell in row
if not cell.isOutOfDate())
for row in self.partition_list)
except ValueError:
return ZERO_TID
......@@ -33,6 +33,7 @@ class RecoveryManager(MasterHandler):
super(RecoveryManager, self).__init__(app)
# The target node's uuid to request next.
self.target_ptid = None
self.backup_tid_dict = {}
def getHandler(self):
return self
......@@ -98,6 +99,9 @@ class RecoveryManager(MasterHandler):
app.tm.setLastOID(ZERO_OID)
pt.make(allowed_node_set)
self._broadcastPartitionTable(pt.getID(), pt.getRowList())
elif app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict)
app.backup_tid = pt.getBackupTid()
app.setLastTransaction(app.tm.getLastTID())
neo.lib.logging.debug(
......@@ -118,7 +122,7 @@ class RecoveryManager(MasterHandler):
# ask the last IDs to perform the recovery
conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, conn, loid, ltid, lptid):
def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
# Get max values.
if loid is not None:
self.app.tm.setLastOID(loid)
......@@ -128,6 +132,7 @@ class RecoveryManager(MasterHandler):
# something newer
self.target_ptid = lptid
conn.ask(Packets.AskPartitionTable())
self.backup_tid_dict[conn.getUUID()] = backup_tid
def answerPartitionTable(self, conn, ptid, row_list):
if ptid != self.target_ptid:
......@@ -136,6 +141,7 @@ class RecoveryManager(MasterHandler):
dump(self.target_ptid))
else:
self._broadcastPartitionTable(ptid, row_list)
self.app.backup_tid = self.backup_tid_dict[conn.getUUID()]
def _broadcastPartitionTable(self, ptid, row_list):
try:
......
......@@ -113,19 +113,21 @@ class VerificationManager(BaseServiceHandler):
def verifyData(self):
"""Verify the data in storage nodes and clean them up, if necessary."""
em, nm = self.app.em, self.app.nm
app = self.app
# wait for any missing node
neo.lib.logging.debug('waiting for the cluster to be operational')
while not self.app.pt.operational():
em.poll(1)
while not app.pt.operational():
app.em.poll(1)
if app.backup_tid:
return
neo.lib.logging.info('start to verify data')
getIdentifiedList = app.nm.getIdentifiedList
# Gather all unfinished transactions.
self._askStorageNodesAndWait(Packets.AskUnfinishedTransactions(),
[x for x in self.app.nm.getIdentifiedList() if x.isStorage()])
[x for x in getIdentifiedList() if x.isStorage()])
# Gather OIDs for each unfinished TID, and verify whether the
# transaction can be finished or must be aborted. This could be
......@@ -136,17 +138,16 @@ class VerificationManager(BaseServiceHandler):
if uuid_set is None:
packet = Packets.DeleteTransaction(tid, self._oid_set or [])
# Make sure that no node has this transaction.
for node in self.app.nm.getIdentifiedList():
for node in getIdentifiedList():
if node.isStorage():
node.notify(packet)
else:
packet = Packets.CommitTransaction(tid)
for node in self.app.nm.getIdentifiedList(pool_set=uuid_set):
for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet)
self._oid_set = set()
# If possible, send the packets now.
em.poll(0)
app.em.poll(0)
def verifyTransaction(self, tid):
em = self.app.em
......@@ -189,11 +190,11 @@ class VerificationManager(BaseServiceHandler):
return uuid_set
def answerLastIDs(self, conn, loid, ltid, lptid):
def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
# FIXME: this packet should not allowed here, the master already
# accepted the current partition table end IDs. As there were manually
# approved during recovery, there is no need to check them here.
pass
raise RuntimeError
def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
uuid = conn.getUUID()
......
......@@ -54,15 +54,10 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testInitializationHandler',
'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorageHandler',
'neo.tests.storage.testStorageMySQLdb',
'neo.tests.storage.testStorageBTree',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions',
'neo.tests.storage.testReplicationHandler',
'neo.tests.storage.testReplicator',
'neo.tests.storage.testReplication',
# client application
'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler',
......@@ -70,6 +65,7 @@ UNIT_TEST_MODULES = [
'neo.tests.client.testConnectionPool',
# light functional tests
'neo.tests.threaded.test',
'neo.tests.threaded.testReplication',
]
FUNC_TEST_MODULES = [
......
......@@ -113,28 +113,21 @@ class Application(object):
"""Load persistent configuration data from the database.
If data is not present, generate it."""
def NoneOnKeyError(getter):
try:
return getter()
except KeyError:
return None
dm = self.dm
# check cluster name
try:
dm_name = dm.getName()
except KeyError:
name = dm.getName()
if name is None:
dm.setName(self.name)
else:
if dm_name != self.name:
raise RuntimeError('name %r does not match with the '
'database: %r' % (self.name, dm_name))
elif name != self.name:
raise RuntimeError('name %r does not match with the database: %r'
% (self.name, dm_name))
# load configuration
self.uuid = NoneOnKeyError(dm.getUUID)
num_partitions = NoneOnKeyError(dm.getNumPartitions)
num_replicas = NoneOnKeyError(dm.getNumReplicas)
ptid = NoneOnKeyError(dm.getPTID)
self.uuid = dm.getUUID()
num_partitions = dm.getNumPartitions()
num_replicas = dm.getNumReplicas()
ptid = dm.getPTID()
# check partition table configuration
if num_partitions is not None and num_replicas is not None:
......@@ -152,10 +145,7 @@ class Application(object):
def loadPartitionTable(self):
"""Load a partition table from the database."""
try:
ptid = self.dm.getPTID()
except KeyError:
ptid = None
ptid = self.dm.getPTID()
cell_list = self.dm.getPartitionTable()
new_cell_list = []
for offset, uuid, state in cell_list:
......@@ -216,9 +206,7 @@ class Application(object):
except OperationFailure, msg:
neo.lib.logging.error('operation stopped: %s', msg)
except PrimaryFailure, msg:
self.replicator.masterLost()
neo.lib.logging.error('primary master is down: %s', msg)
self.master_node = None
def connectToPrimary(self):
"""Find a primary master node, and connect to it.
......@@ -296,6 +284,7 @@ class Application(object):
neo.lib.logging.info('doing operation')
_poll = self._poll
isIdle = self.em.isIdle
handler = master.MasterOperationHandler(self)
self.master_conn.setHandler(handler)
......@@ -304,16 +293,21 @@ class Application(object):
self.dm.dropUnfinishedData()
self.tm.reset()
while True:
_poll()
if self.replicator.pending():
# Call processDelayedTasks before act, so tasks added in the
# act call are executed after one poll call, so that sent
# packets are already on the network and delayed task
# processing happens in parallel with the same task on the
# other storage node.
self.replicator.processDelayedTasks()
self.replicator.act()
self.task_queue = task_queue = deque()
try:
while True:
while task_queue and isIdle():
try:
task_queue[-1].next()
task_queue.rotate()
except StopIteration:
task_queue.pop()
_poll()
finally:
del self.task_queue
# Abort any replication, whether we are feeding or out-of-date.
for node in self.nm.getStorageList(only_identified=True):
node.getConnection().close()
def wait(self):
# change handler
......@@ -368,6 +362,13 @@ class Application(object):
neo.lib.logging.info(' %r:%r: %r:%r %r %r', key, event.__name__,
_msg_id, _conn, args)
def newTask(self, iterator):
try:
iterator.next()
except StopIteration:
return
self.task_queue.appendleft(iterator)
def shutdown(self, erase=False):
"""Close all connections and exit"""
for c in self.em.getConnectionList():
......
......@@ -15,10 +15,13 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
LOG_QUERIES = False
from neo.lib.exception import DatabaseFailure
from .manager import DatabaseManager
from .sqlite import SQLiteDatabaseManager
DATABASE_MANAGER_DICT = {}
DATABASE_MANAGER_DICT = {'SQLite': SQLiteDatabaseManager}
try:
from .mysqldb import MySQLDatabaseManager
......@@ -27,17 +30,6 @@ except ImportError:
else:
DATABASE_MANAGER_DICT['MySQL'] = MySQLDatabaseManager
try:
from .btree import BTreeDatabaseManager
except ImportError:
pass
else:
# XXX: warning: name might change in the future.
DATABASE_MANAGER_DICT['BTree'] = BTreeDatabaseManager
if not DATABASE_MANAGER_DICT:
raise ImportError('No database back-end available.')
def buildDatabaseManager(name, args=(), kw={}):
if name is None:
name = DATABASE_MANAGER_DICT.keys()[0]
......
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""
Naive b-tree implementation.
Simple, though not so well tested.
Not persistent ! (no data retained after process exit)
"""
from BTrees.OOBTree import OOBTree as _OOBTree
import neo.lib
from hashlib import sha1
from . import DatabaseManager
from .manager import CreationUndone
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID
from neo.lib import util
# Keep dropped trees in memory to avoid instanciating when not needed.
TREE_POOL = []
# How many empty BTree istance to keep in ram
MAX_TREE_POOL_SIZE = 100
def batchDelete(tree, tester_callback=None, deleter_callback=None, **kw):
"""
Iter over given BTree and delete found entries.
tree BTree
Tree to delete entries from.
tester_callback function(key, value) -> boolean
Called with each key, value pair found in tree.
If return value is true, delete entry. Otherwise, skip to next key.
deleter_callback function(tree, key_list) -> None (None)
Custom function to delete items
**kw
Keyword arguments for tree.items .
"""
if tester_callback is None:
key_list = list(safeIter(tree.iterkeys, **kw))
else:
key_list = [key for key, value in safeIter(tree.iteritems, **kw)
if tester_callback(key, value)]
if deleter_callback is None:
for key in key_list:
del tree[key]
else:
deleter_callback(tree, key_list)
def OOBTree():
try:
result = TREE_POOL.pop()
except IndexError:
result = _OOBTree()
# Next btree we prune will have room, restore prune method
global prune
prune = _prune
return result
def _prune(tree):
tree.clear()
TREE_POOL.append(tree)
if len(TREE_POOL) >= MAX_TREE_POOL_SIZE:
# Already at/above max pool size, disable ourselve.
global prune
prune = _noPrune
def _noPrune(_):
pass
prune = _prune
def iterObjSerials(obj):
for tserial in obj.values():
for serial in tserial.keys():
yield serial
def descItems(tree):
try:
key = tree.maxKey()
except ValueError:
pass
else:
while True:
yield (key, tree[key])
try:
key = tree.maxKey(key - 1)
except ValueError:
break
def descKeys(tree):
try:
key = tree.maxKey()
except ValueError:
pass
else:
while True:
yield key
try:
key = tree.maxKey(key - 1)
except ValueError:
break
def safeIter(func, *args, **kw):
try:
some_list = func(*args, **kw)
except ValueError:
some_list = []
return some_list
class BTreeDatabaseManager(DatabaseManager):
def __init__(self, database, wait):
super(BTreeDatabaseManager, self).__init__(database, wait)
self.setup(reset=1)
@property
def _num_partitions(self):
return self._config['partitions']
def setup(self, reset=0):
if reset:
self._data = OOBTree()
self._obj = OOBTree()
self._trans = OOBTree()
self._tobj = OOBTree()
self._ttrans = OOBTree()
self._pt = {}
self._config = {}
self._uncommitted_data = {}
def _begin(self):
pass
def _commit(self):
pass
def _rollback(self):
pass
def getConfiguration(self, key):
return self._config[key]
def _setConfiguration(self, key, value):
self._config[key] = value
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
def _getPackTID(self):
try:
result = int(self.getConfiguration('_pack_tid'))
except KeyError:
result = -1
return result
def getPartitionTable(self):
pt = []
append = pt.append
for (offset, uuid), state in self._pt.iteritems():
append((offset, uuid, state))
return pt
def getLastTID(self, all=True):
try:
ltid = self._trans.maxKey()
except ValueError:
ltid = None
if all:
try:
tmp_ltid = self._ttrans.maxKey()
except ValueError:
tmp_ltid = None
tmp_serial = None
for tserial in self._tobj.values():
try:
max_tmp_serial = tserial.maxKey()
except ValueError:
pass
else:
tmp_serial = max(tmp_serial, max_tmp_serial)
ltid = max(ltid, tmp_ltid, tmp_serial)
if ltid is not None:
ltid = util.p64(ltid)
return ltid
def getUnfinishedTIDList(self):
p64 = util.p64
tid_set = set(p64(x) for x in self._ttrans.keys())
tid_set.update(p64(x) for x in iterObjSerials(self._tobj))
return list(tid_set)
def objectPresent(self, oid, tid, all=True):
u64 = util.u64
oid = u64(oid)
tid = u64(tid)
try:
result = self._obj[oid].has_key(tid)
except KeyError:
if all:
try:
result = self._tobj[oid].has_key(tid)
except KeyError:
result = False
else:
result = False
return result
def _getObject(self, oid, tid=None, before_tid=None):
tserial = self._obj.get(oid)
if tserial is not None:
if tid is None:
try:
if before_tid is None:
tid = tserial.maxKey()
else:
tid = tserial.maxKey(before_tid - 1)
except ValueError:
return False
try:
checksum, value_serial = tserial[tid]
except KeyError:
return False
try:
next_serial = tserial.minKey(tid + 1)
except ValueError:
next_serial = None
if checksum is None:
compression = data = None
else:
compression, data, _ = self._data[checksum]
return tid, next_serial, compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset):
pt = self._pt
if reset:
pt.clear()
for offset, uuid, state in cell_list:
# TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query
key = (offset, uuid)
if state == CellStates.DISCARDED:
pt.pop(key, None)
else:
pt[key] = int(state)
self.setPTID(ptid)
def changePartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, False)
def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True)
def _oidDeleterCallback(self, oid):
data = self._data
uncommitted_data = self._uncommitted_data
def deleter_callback(tree, key_list):
for tid in key_list:
checksum = tree.pop(tid)[0]
if checksum:
index = data[checksum][2]
index.remove((oid, tid))
if not index and checksum not in uncommitted_data:
del data[checksum]
return deleter_callback
def _objDeleterCallback(self, tree, key_list):
data = self._data
checksum_list = []
checksum_set = set()
for oid in key_list:
tserial = tree.pop(oid)
for tid, (checksum, _) in tserial.items():
if checksum:
index = data[checksum][2]
try:
index.remove((oid, tid))
except KeyError: # _tobj
checksum_list.append(checksum)
checksum_set.add(checksum)
prune(tserial)
self.unlockData(checksum_list)
self._pruneData(checksum_set)
def dropPartitions(self, offset_list):
offset_list = frozenset(offset_list)
num_partitions = self._num_partitions
def same_partition(key, _):
return key % num_partitions in offset_list
batchDelete(self._obj, same_partition, self._objDeleterCallback)
batchDelete(self._trans, same_partition)
def dropUnfinishedData(self):
batchDelete(self._tobj, deleter_callback=self._objDeleterCallback)
self._ttrans.clear()
def storeTransaction(self, tid, object_list, transaction, temporary=True):
u64 = util.u64
tid = u64(tid)
if temporary:
obj = self._tobj
trans = self._ttrans
else:
obj = self._obj
trans = self._trans
data = self._data
for oid, checksum, value_serial in object_list:
oid = u64(oid)
if value_serial:
value_serial = u64(value_serial)
checksum = self._obj[oid][value_serial][0]
if temporary:
self.storeData(checksum)
if checksum:
if not temporary:
data[checksum][2].add((oid, tid))
try:
tserial = obj[oid]
except KeyError:
tserial = obj[oid] = OOBTree()
tserial[tid] = checksum, value_serial
if transaction is not None:
oid_list, user, desc, ext, packed = transaction
trans[tid] = (tuple(oid_list), user, desc, ext, packed)
def _pruneData(self, checksum_list):
data = self._data
for checksum in set(checksum_list).difference(self._uncommitted_data):
if not data[checksum][2]:
del data[checksum]
def _storeData(self, checksum, data, compression):
try:
if self._data[checksum][:2] != (compression, data):
raise AssertionError("hash collision")
except KeyError:
self._data[checksum] = compression, data, set()
return checksum
def finishTransaction(self, tid):
tid = util.u64(tid)
self._popTransactionFromTObj(tid, True)
ttrans = self._ttrans
try:
data = ttrans[tid]
except KeyError:
pass
else:
del ttrans[tid]
self._trans[tid] = data
def _popTransactionFromTObj(self, tid, to_obj):
checksum_list = []
if to_obj:
deleter_callback = None
obj = self._obj
def callback(oid, data):
try:
tserial = obj[oid]
except KeyError:
tserial = obj[oid] = OOBTree()
tserial[tid] = data
checksum = data[0]
if checksum:
self._data[checksum][2].add((oid, tid))
checksum_list.append(checksum)
else:
deleter_callback = self._objDeleterCallback
callback = lambda oid, data: None
def tester_callback(oid, tserial):
try:
data = tserial[tid]
except KeyError:
pass
else:
del tserial[tid]
callback(oid, data)
return not tserial
batchDelete(self._tobj, tester_callback, deleter_callback)
self.unlockData(checksum_list)
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
self._popTransactionFromTObj(tid, False)
try:
del self._ttrans[tid]
except KeyError:
pass
for oid in oid_list:
self._deleteObject(u64(oid), tid)
try:
del self._trans[tid]
except KeyError:
pass
def deleteTransactionsAbove(self, partition, tid, max_tid):
num_partitions = self._num_partitions
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(self._trans, same_partition,
min=util.u64(tid), max=util.u64(max_tid))
def deleteObject(self, oid, serial=None):
u64 = util.u64
self._deleteObject(u64(oid), serial and u64(serial))
def _deleteObject(self, oid, serial=None):
obj = self._obj
try:
tserial = obj[oid]
except KeyError:
return
batchDelete(tserial, deleter_callback=self._oidDeleterCallback(oid),
min=serial, max=serial)
if not tserial:
del obj[oid]
def deleteObjectsAbove(self, partition, oid, serial, max_tid):
obj = self._obj
u64 = util.u64
oid = u64(oid)
serial = u64(serial)
max_tid = u64(max_tid)
num_partitions = self._num_partitions
if oid % num_partitions == partition:
try:
tserial = obj[oid]
except KeyError:
pass
else:
batchDelete(tserial, min=serial, max=max_tid,
deleter_callback=self._oidDeleterCallback(oid))
if not tserial:
del tserial[oid]
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(obj, same_partition, self._objDeleterCallback,
min=oid, excludemin=True, max=max_tid)
def getTransaction(self, tid, all=False):
tid = util.u64(tid)
try:
result = self._trans[tid]
except KeyError:
if all:
try:
result = self._ttrans[tid]
except KeyError:
result = None
else:
result = None
if result is not None:
oid_list, user, desc, ext, packed = result
result = (list(oid_list), user, desc, ext, packed)
return result
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
checksum, value_serial = self._obj[oid][value_serial]
if checksum is None:
neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
return self._getObjectLength(oid, value_serial)
return len(self._data[checksum][1])
def getObjectHistory(self, oid, offset=0, length=1):
# FIXME: This method doesn't take client's current ransaction id as
# parameter, which means it can return transactions in the future of
# client's transaction.
oid = util.u64(oid)
p64 = util.p64
pack_tid = self._getPackTID()
try:
tserial = self._obj[oid]
except KeyError:
result = None
else:
result = []
append = result.append
tserial_iter = descItems(tserial)
while offset > 0:
tserial_iter.next()
offset -= 1
data = self._data
for serial, (checksum, value_serial) in tserial_iter:
if length == 0 or serial < pack_tid:
break
length -= 1
if checksum is None:
try:
data_length = self._getObjectLength(oid, value_serial)
except CreationUndone:
data_length = 0
else:
data_length = len(data[checksum][1])
append((p64(serial), data_length))
if not result:
result = None
return result
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition):
u64 = util.u64
p64 = util.p64
min_oid = u64(min_oid)
min_serial = u64(min_serial)
max_serial = u64(max_serial)
result = {}
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition:
if length == 0:
break
if oid == min_oid:
try:
tid_seq = tserial.keys(min=min_serial, max=max_serial)
except ValueError:
continue
else:
tid_seq = tserial.keys(max=max_serial)
if not tid_seq:
continue
result[p64(oid)] = tid_list = []
append = tid_list.append
for tid in tid_seq:
if length == 0:
break
length -= 1
append(p64(tid))
else:
continue
break
return result
def getTIDList(self, offset, length, partition_list):
p64 = util.p64
partition_list = frozenset(partition_list)
result = []
append = result.append
trans_iter = descKeys(self._trans)
num_partitions = self._num_partitions
while offset > 0:
tid = trans_iter.next()
if tid % num_partitions in partition_list:
offset -= 1
for tid in trans_iter:
if tid % num_partitions in partition_list:
if length == 0:
break
length -= 1
append(p64(tid))
return result
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
p64 = util.p64
u64 = util.u64
result = []
append = result.append
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=u64(min_tid), max=u64(max_tid)):
if tid % num_partitions == partition:
if length == 0:
break
length -= 1
append(p64(tid))
return result
def _updatePackFuture(self, oid, orig_serial, max_serial):
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
new_serial = None
obj = self._obj
for tree in (obj, self._tobj):
try:
tserial = tree[oid]
except KeyError:
continue
for serial, (checksum, value_serial) in tserial.iteritems(
min=max_serial):
if value_serial == orig_serial:
tserial[serial] = checksum, new_serial
if not new_serial:
new_serial = serial
return new_serial
def pack(self, tid, updateObjectDataForPack):
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
self._setPackTID(tid)
def obj_callback(oid, tserial):
try:
max_serial = tserial.maxKey(tid)
except ValueError:
# No entry before pack TID, nothing to pack on this object.
pass
else:
if tserial[max_serial][0] is None:
# Last version before/at pack TID is a creation undo, drop
# it too.
max_serial += 1
def serial_callback(serial, value):
new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial:
new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial),
new_serial, value[0])
batchDelete(tserial, serial_callback,
self._oidDeleterCallback(oid),
max=max_serial, excludemax=True)
return not tserial
batchDelete(self._obj, obj_callback, self._objDeleterCallback)
def checkTIDRange(self, min_tid, max_tid, length, partition):
if length:
tid_list = []
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=util.u64(min_tid),
max=util.u64(max_tid)):
if tid % num_partitions == partition:
tid_list.append(tid)
if len(tid_list) >= length:
break
if tid_list:
return (len(tid_list),
sha1(','.join(map(str, tid_list))).digest(),
util.p64(tid_list[-1]))
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
if length:
u64 = util.u64
min_oid = u64(min_oid)
max_tid = u64(max_tid)
oid_list = []
serial_list = []
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition:
try:
if oid == min_oid:
tserial = tserial.keys(min=u64(min_serial),
max=max_tid)
else:
tserial = tserial.keys(max=max_tid)
except ValueError:
continue
for serial in tserial:
oid_list.append(oid)
serial_list.append(serial)
if len(oid_list) >= length:
break
else:
continue
break
if oid_list:
p64 = util.p64
return (len(oid_list),
sha1(','.join(map(str, oid_list))).digest(),
p64(oid_list[-1]),
sha1(','.join(map(str, serial_list))).digest(),
p64(serial_list[-1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
......@@ -18,6 +18,7 @@
import neo.lib
from neo.lib import util
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import ZERO_TID
class CreationUndone(Exception):
pass
......@@ -37,34 +38,6 @@ class DatabaseManager(object):
"""Called during instanciation, to process database parameter."""
pass
def isUnderTransaction(self):
return self._under_transaction
def begin(self):
"""
Begin a transaction
"""
if self._under_transaction:
raise DatabaseFailure('A transaction has already begun')
self._begin()
self._under_transaction = True
def commit(self):
"""
Commit the current transaction
"""
if not self._under_transaction:
raise DatabaseFailure('The transaction has not begun')
self._commit()
self._under_transaction = False
def rollback(self):
"""
Rollback the current transaction
"""
self._rollback()
self._under_transaction = False
def setup(self, reset = 0):
"""Set up a database
......@@ -79,14 +52,33 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def _begin(self):
raise NotImplementedError
def __enter__(self):
"""
Begin a transaction
"""
if self._under_transaction:
raise DatabaseFailure('A transaction has already begun')
r = self.begin()
self._under_transaction = True
return r
def _commit(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_value, tb):
if not self._under_transaction:
raise DatabaseFailure('The transaction has not begun')
self._under_transaction = False
if exc_type is None:
self.commit()
else:
self.rollback()
def _rollback(self):
raise NotImplementedError
def begin(self):
pass
def commit(self):
pass
def rollback(self):
pass
def _getPartition(self, oid_or_tid):
return oid_or_tid % self.getNumPartitions()
......@@ -104,13 +96,8 @@ class DatabaseManager(object):
if self._under_transaction:
self._setConfiguration(key, value)
else:
self.begin()
try:
with self:
self._setConfiguration(key, value)
except:
self.rollback()
raise
self.commit()
def _setConfiguration(self, key, value):
raise NotImplementedError
......@@ -171,7 +158,9 @@ class DatabaseManager(object):
"""
Load a Partition Table ID from a database.
"""
return long(self.getConfiguration('ptid'))
ptid = self.getConfiguration('ptid')
if ptid is not None:
return long(ptid)
def setPTID(self, ptid):
"""
......@@ -194,18 +183,31 @@ class DatabaseManager(object):
"""
self.setConfiguration('loid', util.dump(loid))
def getBackupTID(self):
return util.bin(self.getConfiguration('backup_tid'))
def setBackupTID(self, backup_tid):
return self.setConfiguration('backup_tid', util.dump(backup_tid))
def getPartitionTable(self):
"""Return a whole partition table as a tuple of rows. Each row
is again a tuple of an offset (row ID), an UUID of a storage
node, and a cell state."""
raise NotImplementedError
def getLastTID(self, all = True):
"""Return the last TID in a database. If all is true,
unfinished transactions must be taken account into. If there
is no TID in the database, return None."""
def _getLastTIDs(self, all=True):
raise NotImplementedError
def getLastTIDs(self, all=True):
trans, obj = self._getLastTIDs()
if trans:
tid = max(trans.itervalues())
if obj:
tid = max(tid, max(obj.itervalues()))
else:
tid = max(obj.itervalues()) if obj else None
return tid, trans, obj
def getUnfinishedTIDList(self):
"""Return a list of unfinished transaction's IDs."""
raise NotImplementedError
......@@ -352,13 +354,8 @@ class DatabaseManager(object):
else:
del refcount[data_id]
if prune:
self.begin()
try:
with self:
self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
__getDataTID = set()
def _getDataTID(self, oid, tid=None, before_tid=None):
......@@ -466,23 +463,24 @@ class DatabaseManager(object):
an oid list"""
raise NotImplementedError
def deleteTransactionsAbove(self, partition, tid, max_tid):
"""Delete all transactions above given TID (inclued) in given
partition, but never above max_tid (in case transactions are committed
during replication)."""
raise NotImplementedError
def deleteObject(self, oid, serial=None):
"""Delete given object. If serial is given, only delete that serial for
given oid."""
raise NotImplementedError
def deleteObjectsAbove(self, partition, oid, serial, max_tid):
"""Delete all objects above given OID and serial (inclued) in given
partition, but never above max_tid (in case objects are stored during
replication)"""
def _deleteRange(self, partition, min_tid=None, max_tid=None):
"""Delete all objects and transactions between given min_tid (excluded)
and max_tid (included)"""
raise NotImplementedError
def truncate(self, tid):
assert tid not in (None, ZERO_TID), tid
with self:
assert self.getBackupTID()
self.setBackupTID(tid)
for partition in xrange(self.getNumPartitions()):
self._deleteRange(partition, tid)
def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction
......@@ -498,10 +496,10 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None."""
raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, max_serial, length,
partition):
"""Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial and below max_serial, for given partition,
def getReplicationObjectList(self, min_tid, max_tid, length, partition,
min_oid):
"""Return a dict of length oids grouped by serial at (or above)
min_tid and min_oid and below max_tid, for given partition,
sorted in ascending order."""
raise NotImplementedError
......
......@@ -27,14 +27,12 @@ import re
import string
import time
from . import DatabaseManager
from . import DatabaseManager, LOG_QUERIES
from .manager import CreationUndone
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib import util
LOG_QUERIES = False
def splitOIDField(tid, oids):
if (len(oids) % 8) != 0 or len(oids) == 0:
raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid,
......@@ -99,18 +97,22 @@ class MySQLDatabaseManager(DatabaseManager):
self.conn.query("SET SESSION group_concat_max_len = -1")
self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
def _begin(self):
self.query("""BEGIN""")
def begin(self):
q = self.query
q("BEGIN")
return q
def _commit(self):
if LOG_QUERIES:
if LOG_QUERIES:
def commit(self):
neo.lib.logging.debug('committing...')
self.conn.commit()
self.conn.commit()
def _rollback(self):
if LOG_QUERIES:
def rollback(self):
neo.lib.logging.debug('aborting...')
self.conn.rollback()
self.conn.rollback()
else:
commit = property(lambda self: self.conn.commit)
rollback = property(lambda self: self.conn.rollback)
def query(self, query):
"""Query data from a database."""
......@@ -194,7 +196,8 @@ class MySQLDatabaseManager(DatabaseManager):
serial BIGINT UNSIGNED NOT NULL,
data_id BIGINT UNSIGNED NULL,
value_serial BIGINT UNSIGNED NULL,
PRIMARY KEY (partition, oid, serial),
PRIMARY KEY (partition, serial, oid),
KEY (partition, oid, serial),
KEY (data_id)
) ENGINE = InnoDB""" + p)
......@@ -233,17 +236,17 @@ class MySQLDatabaseManager(DatabaseManager):
" FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id") or ())
def getConfiguration(self, key):
if key in self._config:
return self._config[key]
q = self.query
e = self.escape
sql_key = e(str(key))
try:
r = q("SELECT value FROM config WHERE name = '%s'" % sql_key)[0][0]
except IndexError:
raise KeyError, key
self._config[key] = r
return r
return self._config[key]
except KeyError:
sql_key = self.escape(str(key))
try:
r = self.query("SELECT value FROM config WHERE name = '%s'"
% sql_key)[0][0]
except IndexError:
r = None
self._config[key] = r
return r
def _setConfiguration(self, key, value):
q = self.query
......@@ -251,20 +254,19 @@ class MySQLDatabaseManager(DatabaseManager):
self._config[key] = value
key = e(str(key))
if value is None:
value = 'NULL'
q("DELETE FROM config WHERE name = '%s'" % key)
else:
value = "'%s'" % (e(str(value)), )
q("""REPLACE INTO config VALUES ('%s', %s)""" % (key, value))
value = e(str(value))
q("REPLACE INTO config VALUES ('%s', '%s')" % (key, value))
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
def _getPackTID(self):
try:
result = int(self.getConfiguration('_pack_tid'))
except KeyError:
result = -1
return result
return int(self.getConfiguration('_pack_tid'))
except TypeError:
return -1
def getPartitionTable(self):
q = self.query
......@@ -275,58 +277,42 @@ class MySQLDatabaseManager(DatabaseManager):
pt.append((offset, uuid, state))
return pt
def getLastTID(self, all = True):
# XXX this does not consider serials in obj.
# I am not sure if this is really harmful. For safety,
# check for tobj only at the moment. The reason why obj is
# not tested is that it is too slow to get the max serial
# from obj when it has a huge number of objects, because
# serial is the second part of the primary key, so the index
# is not used in this case. If doing it, it is better to
# make another index for serial, but I doubt the cost increase
# is worth.
q = self.query
self.begin()
ltid = q("SELECT MAX(value) FROM (SELECT MAX(tid) AS value FROM trans "
"GROUP BY partition) AS foo")[0][0]
if all:
tmp_ltid = q("""SELECT MAX(tid) FROM ttrans""")[0][0]
if ltid is None or (tmp_ltid is not None and ltid < tmp_ltid):
ltid = tmp_ltid
tmp_serial = q("""SELECT MAX(serial) FROM tobj""")[0][0]
if ltid is None or (tmp_serial is not None and ltid < tmp_serial):
ltid = tmp_serial
self.commit()
if ltid is not None:
ltid = util.p64(ltid)
return ltid
def _getLastTIDs(self, all=True):
p64 = util.p64
with self as q:
trans = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)"
" FROM trans GROUP BY partition"))
obj = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(serial)"
" FROM obj GROUP BY partition"))
if all:
tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if tid is not None:
trans[None] = p64(tid)
tid = q("SELECT MAX(serial) FROM tobj")[0][0]
if tid is not None:
obj[None] = p64(tid)
return trans, obj
def getUnfinishedTIDList(self):
q = self.query
tid_set = set()
self.begin()
r = q("""SELECT tid FROM ttrans""")
tid_set.update((util.p64(t[0]) for t in r))
r = q("""SELECT serial FROM tobj""")
self.commit()
with self as q:
r = q("""SELECT tid FROM ttrans""")
tid_set.update((util.p64(t[0]) for t in r))
r = q("""SELECT serial FROM tobj""")
tid_set.update((util.p64(t[0]) for t in r))
return list(tid_set)
def objectPresent(self, oid, tid, all = True):
q = self.query
oid = util.u64(oid)
tid = util.u64(tid)
partition = self._getPartition(oid)
self.begin()
r = q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND "
"serial=%d" % (partition, oid, tid))
if not r and all:
r = q("SELECT oid FROM tobj WHERE serial=%d AND oid=%d"
% (tid, oid))
self.commit()
if r:
return True
return False
with self as q:
return q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND "
"serial=%d" % (partition, oid, tid)) or all and \
q("SELECT oid FROM tobj WHERE serial=%d AND oid=%d"
% (tid, oid))
def _getObject(self, oid, tid=None, before_tid=None):
q = self.query
......@@ -357,11 +343,9 @@ class MySQLDatabaseManager(DatabaseManager):
return serial, next_serial, compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset):
q = self.query
e = self.escape
offset_list = []
self.begin()
try:
with self as q:
if reset:
q("""TRUNCATE pt""")
for offset, uuid, state in cell_list:
......@@ -377,10 +361,6 @@ class MySQLDatabaseManager(DatabaseManager):
ON DUPLICATE KEY UPDATE state = %d""" \
% (offset, uuid, state, state))
self.setPTID(ptid)
except:
self.rollback()
raise
self.commit()
if self._use_partition:
for offset in offset_list:
add = """ALTER TABLE %%s ADD PARTITION (
......@@ -399,9 +379,7 @@ class MySQLDatabaseManager(DatabaseManager):
self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, offset_list):
q = self.query
self.begin()
try:
with self as q:
# XXX: these queries are inefficient (execution time increase with
# row count, although we use indexes) when there are rows to
# delete. It should be done as an idle task, by chunks.
......@@ -413,10 +391,6 @@ class MySQLDatabaseManager(DatabaseManager):
q("DELETE FROM obj" + where)
q("DELETE FROM trans" + where)
self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
if self._use_partition:
drop = "ALTER TABLE %s DROP PARTITION" + \
','.join(' p%u' % i for i in offset_list)
......@@ -428,20 +402,13 @@ class MySQLDatabaseManager(DatabaseManager):
raise
def dropUnfinishedData(self):
q = self.query
self.begin()
try:
with self as q:
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("""TRUNCATE tobj""")
q("""TRUNCATE ttrans""")
except:
self.rollback()
raise
self.commit()
self.unlockData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary = True):
q = self.query
e = self.escape
u64 = util.u64
tid = u64(tid)
......@@ -453,8 +420,7 @@ class MySQLDatabaseManager(DatabaseManager):
obj_table = 'obj'
trans_table = 'trans'
self.begin()
try:
with self as q:
for oid, data_id, value_serial in object_list:
oid = u64(oid)
partition = self._getPartition(oid)
......@@ -481,10 +447,6 @@ class MySQLDatabaseManager(DatabaseManager):
q("REPLACE INTO %s VALUES (%d, %d, %i, '%s', '%s', '%s', '%s')"
% (trans_table, partition, tid, packed, oids, user, desc,
ext))
except:
self.rollback()
raise
self.commit()
def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data)
......@@ -497,24 +459,19 @@ class MySQLDatabaseManager(DatabaseManager):
def _storeData(self, checksum, data, compression):
e = self.escape
checksum = e(checksum)
self.begin()
try:
with self as q:
try:
self.query("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" %
q("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" %
(checksum, compression, e(data)))
except IntegrityError, (code, _):
if code != DUP_ENTRY:
raise
(r, c, d), = self.query("SELECT id, compression, value"
" FROM data WHERE hash='%s'" % checksum)
(r, c, d), = q("SELECT id, compression, value"
" FROM data WHERE hash='%s'" % checksum)
if c != compression or d != data:
raise
else:
r = self.conn.insert_id()
except:
self.rollback()
raise
self.commit()
return r
def _getDataTID(self, oid, tid=None, before_tid=None):
......@@ -540,27 +497,20 @@ class MySQLDatabaseManager(DatabaseManager):
def finishTransaction(self, tid):
q = self.query
tid = util.u64(tid)
self.begin()
try:
with self as q:
sql = " FROM tobj WHERE serial=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("INSERT INTO obj SELECT *" + sql)
q("DELETE FROM tobj WHERE serial=%d" % tid)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid)
except:
self.rollback()
raise
self.commit()
self.unlockData(data_id_list)
def deleteTransaction(self, tid, oid_list=()):
q = self.query
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
self.begin()
try:
with self as q:
sql = " FROM tobj WHERE serial=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.unlockData(data_id_list)
......@@ -578,77 +528,45 @@ class MySQLDatabaseManager(DatabaseManager):
q("DELETE" + sql)
data_id_set.discard(None)
self._pruneData(data_id_set)
except:
self.rollback()
raise
self.commit()
def deleteTransactionsAbove(self, partition, tid, max_tid):
self.begin()
try:
self.query('DELETE FROM trans WHERE partition=%(partition)d AND '
'%(tid)d <= tid AND tid <= %(max_tid)d' % {
'partition': partition,
'tid': util.u64(tid),
'max_tid': util.u64(max_tid),
})
except:
self.rollback()
raise
self.commit()
def deleteObject(self, oid, serial=None):
q = self.query
u64 = util.u64
oid = u64(oid)
sql = " FROM obj WHERE partition=%d AND oid=%d" \
% (self._getPartition(oid), oid)
if serial:
sql += ' AND serial=%d' % u64(serial)
self.begin()
try:
with self as q:
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x]
q("DELETE" + sql)
self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
def deleteObjectsAbove(self, partition, oid, serial, max_tid):
def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE partition=%d" % partition
if min_tid:
sql += " AND %d < tid" % util.u64(min_tid)
if max_tid:
sql += " AND tid <= %d" % util.u64(max_tid)
q = self.query
u64 = util.u64
oid = u64(oid)
sql = (" FROM obj WHERE partition=%d AND serial <= %d"
" AND (oid > %d OR (oid = %d AND serial >= %d))" %
(partition, u64(max_tid), oid, oid, u64(serial)))
self.begin()
try:
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x]
q("DELETE" + sql)
self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
q("DELETE FROM trans" + sql)
sql = " FROM obj" + sql.replace('tid', 'serial')
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x]
q("DELETE" + sql)
self._pruneData(data_id_list)
def getTransaction(self, tid, all = False):
q = self.query
tid = util.u64(tid)
self.begin()
r = q("""SELECT oids, user, description, ext, packed FROM trans
WHERE partition = %d AND tid = %d""" \
% (self._getPartition(tid), tid))
if not r and all:
r = q("""SELECT oids, user, description, ext, packed FROM ttrans
WHERE tid = %d""" \
% tid)
self.commit()
with self as q:
r = q("SELECT oids, user, description, ext, packed FROM trans"
" WHERE partition = %d AND tid = %d"
% (self._getPartition(tid), tid))
if not r and all:
r = q("SELECT oids, user, description, ext, packed FROM ttrans"
" WHERE tid = %d" % tid)
if r:
oids, user, desc, ext, packed = r[0]
oid_list = splitOIDField(tid, oids)
return oid_list, user, desc, ext, bool(packed)
return None
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
......@@ -690,34 +608,17 @@ class MySQLDatabaseManager(DatabaseManager):
return result
return None
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition):
q = self.query
def getReplicationObjectList(self, min_tid, max_tid, length, partition,
min_oid):
u64 = util.u64
p64 = util.p64
min_oid = u64(min_oid)
min_serial = u64(min_serial)
max_serial = u64(max_serial)
r = q('SELECT oid, serial FROM obj '
'WHERE partition = %(partition)s '
'AND serial <= %(max_serial)d '
'AND ((oid = %(min_oid)d AND serial >= %(min_serial)d) '
'OR oid > %(min_oid)d) '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid,
'min_serial': min_serial,
'max_serial': max_serial,
'length': length,
'partition': partition,
})
result = {}
for oid, serial in r:
try:
serial_list = result[oid]
except KeyError:
serial_list = result[oid] = []
serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems())
min_tid = u64(min_tid)
r = self.query('SELECT serial, oid FROM obj'
' WHERE partition = %d AND serial <= %d'
' AND (serial = %d AND %d <= oid OR %d < serial)'
' ORDER BY serial ASC, oid ASC LIMIT %d' % (
partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))
return [(p64(serial), p64(oid)) for serial, oid in r]
def getTIDList(self, offset, length, partition_list):
q = self.query
......@@ -727,12 +628,11 @@ class MySQLDatabaseManager(DatabaseManager):
return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
q = self.query
u64 = util.u64
p64 = util.p64
min_tid = u64(min_tid)
max_tid = u64(max_tid)
r = q("""SELECT tid FROM trans
r = self.query("""SELECT tid FROM trans
WHERE partition = %(partition)d
AND tid >= %(min_tid)d AND tid <= %(max_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % {
......@@ -772,13 +672,11 @@ class MySQLDatabaseManager(DatabaseManager):
def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture)
q = self.query
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
getPartition = self._getPartition
self.begin()
try:
with self as q:
self._setPackTID(tid)
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, '
'MAX(serial) FROM obj WHERE serial <= %d GROUP BY oid'
......@@ -804,10 +702,6 @@ class MySQLDatabaseManager(DatabaseManager):
q('DELETE' + sql)
data_id_set.discard(None)
self._pruneData(data_id_set)
except:
self.rollback()
raise
self.commit()
def checkTIDRange(self, min_tid, max_tid, length, partition):
count, tid_checksum, max_tid = self.query(
......@@ -816,11 +710,11 @@ class MySQLDatabaseManager(DatabaseManager):
WHERE partition = %(partition)s
AND tid >= %(min_tid)d
AND tid <= %(max_tid)d
ORDER BY tid ASC LIMIT %(length)d) AS t""" % {
ORDER BY tid ASC %(limit)s) AS t""" % {
'partition': partition,
'min_tid': util.u64(min_tid),
'max_tid': util.u64(max_tid),
'length': length,
'limit': '' if length is None else 'LIMIT %(length)d' % length,
})[0]
if count:
return count, a2b_hex(tid_checksum), util.p64(max_tid)
......@@ -839,11 +733,11 @@ class MySQLDatabaseManager(DatabaseManager):
AND serial <= %(max_tid)d
AND (oid > %(min_oid)d OR
oid = %(min_oid)d AND serial >= %(min_serial)d)
ORDER BY oid ASC, serial ASC LIMIT %(length)d""" % {
ORDER BY oid ASC, serial ASC %(limit)s""" % {
'min_oid': u64(min_oid),
'min_serial': u64(min_serial),
'max_tid': u64(max_tid),
'length': length,
'limit': '' if length is None else 'LIMIT %(length)d' % length,
'partition': partition,
})
if r:
......
#
# Copyright (C) 2012 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import sqlite3
import neo.lib
from array import array
from hashlib import sha1
import re
import string
from . import DatabaseManager, LOG_QUERIES
from .manager import CreationUndone
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib import util
def splitOIDField(tid, oids):
if (len(oids) % 8) != 0 or len(oids) == 0:
raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid,
len(oids)))
oid_list = []
append = oid_list.append
for i in xrange(0, len(oids), 8):
append(oids[i:i+8])
return oid_list
class SQLiteDatabaseManager(DatabaseManager):
"""This class manages a database on SQLite.
CAUTION: Make sure we never use statement journal files, as explained at
http://www.sqlite.org/tempfiles.html for more information.
In other words, temporary files (by default in /var/tmp !) must
never be used for small requests.
"""
def __init__(self, *args, **kw):
super(SQLiteDatabaseManager, self).__init__(*args, **kw)
self._config = {}
self._connect()
def _parse(self, database):
self.db = database
def close(self):
self.conn.close()
def _connect(self):
neo.lib.logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, isolation_level=None,
check_same_thread=False)
def begin(self):
q = self.query
q("BEGIN IMMEDIATE")
return q
if LOG_QUERIES:
def commit(self):
neo.lib.logging.debug('committing...')
self.conn.commit()
def rollback(self):
neo.lib.logging.debug('aborting...')
self.conn.rollback()
def query(self, query):
printable_char_list = []
for c in query.split('\n', 1)[0][:70]:
if c not in string.printable or c in '\t\x0b\x0c\r':
c = '\\x%02x' % ord(c)
printable_char_list.append(c)
neo.lib.logging.debug('querying %s...',
''.join(printable_char_list))
return self.conn.execute(query)
else:
commit = property(lambda self: self.conn.commit)
rollback = property(lambda self: self.conn.rollback)
query = property(lambda self: self.conn.execute)
def setup(self, reset = 0):
self._config.clear()
q = self.query
if reset:
for t in 'config', 'pt', 'trans', 'obj', 'data', 'ttrans', 'tobj':
q('DROP TABLE IF EXISTS ' + t)
# The table "config" stores configuration parameters which affect the
# persistent data.
q("""CREATE TABLE IF NOT EXISTS config (
name TEXT NOT NULL PRIMARY KEY,
value BLOB)
""")
# The table "pt" stores a partition table.
q("""CREATE TABLE IF NOT EXISTS pt (
rid INTEGER NOT NULL,
uuid BLOB NOT NULL,
state INTEGER NOT NULL,
PRIMARY KEY (rid, uuid))
""")
# The table "trans" stores information on committed transactions.
q("""CREATE TABLE IF NOT EXISTS trans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
PRIMARY KEY (partition, tid))
""")
# The table "obj" stores committed object metadata.
q("""CREATE TABLE IF NOT EXISTS obj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
serial INTEGER NOT NULL,
data_id INTEGER,
value_serial INTEGER,
PRIMARY KEY (partition, serial, oid))
""")
q("""CREATE INDEX IF NOT EXISTS _obj_i1 ON
obj(partition, oid, serial)
""")
q("""CREATE INDEX IF NOT EXISTS _obj_i2 ON
obj(data_id)
""")
# The table "data" stores object data.
q("""CREATE TABLE IF NOT EXISTS data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
hash BLOB NOT NULL UNIQUE,
compression INTEGER,
value BLOB)
""")
# The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL)
""")
# The table "tobj" stores uncommitted object metadata.
q("""CREATE TABLE IF NOT EXISTS tobj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
serial INTEGER NOT NULL,
data_id INTEGER,
value_serial INTEGER,
PRIMARY KEY (serial, oid))
""")
self._uncommitted_data = dict(q("SELECT data_id, count(*)"
" FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id"))
def getConfiguration(self, key):
try:
return self._config[key]
except KeyError:
try:
r = str(self.query("SELECT value FROM config WHERE name=?",
(key,)).fetchone()[0])
except TypeError:
r = None
self._config[key] = r
return r
def _setConfiguration(self, key, value):
q = self.query
self._config[key] = value
if value is None:
q("DELETE FROM config WHERE name=?", (key,))
else:
q("REPLACE INTO config VALUES (?,?)", (key, buffer(str(value))))
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
def _getPackTID(self):
try:
return int(self.getConfiguration('_pack_tid'))
except TypeError:
return -1
def getPartitionTable(self):
return [(offset, util.bin(uuid), state)
for offset, uuid, state in self.query(
"SELECT rid, uuid, state FROM pt")]
def _getLastTIDs(self, all=True):
p64 = util.p64
with self as q:
trans = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)"
" FROM trans GROUP BY partition"))
obj = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(serial)"
" FROM obj GROUP BY partition"))
if all:
tid = q("SELECT MAX(tid) FROM ttrans").fetchone()[0]
if tid is not None:
trans[None] = p64(tid)
tid = q("SELECT MAX(serial) FROM tobj").fetchone()[0]
if tid is not None:
obj[None] = p64(tid)
return trans, obj
def getUnfinishedTIDList(self):
p64 = util.p64
tid_set = set()
with self as q:
tid_set.update((p64(t[0]) for t in q("SELECT tid FROM ttrans")))
tid_set.update((p64(t[0]) for t in q("SELECT serial FROM tobj")))
return list(tid_set)
def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid)
tid = util.u64(tid)
with self as q:
r = q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND serial=?",
(self._getPartition(oid), oid, tid)).fetchone()
if not r and all:
r = q("SELECT 1 FROM tobj WHERE serial=? AND oid=?",
(tid, oid)).fetchone()
return bool(r)
def _getObject(self, oid, tid=None, before_tid=None):
q = self.query
partition = self._getPartition(oid)
sql = ('SELECT serial, compression, data.hash, value, value_serial'
' FROM obj LEFT JOIN data ON obj.data_id = data.id'
' WHERE partition=? AND oid=?')
if tid is not None:
r = q(sql + ' AND serial=?', (partition, oid, tid))
elif before_tid is not None:
r = q(sql + ' AND serial<? ORDER BY serial DESC LIMIT 1',
(partition, oid, before_tid))
else:
r = q(sql + ' ORDER BY serial DESC LIMIT 1', (partition, oid))
try:
serial, compression, checksum, data, value_serial = r.fetchone()
except TypeError:
return None
r = q("""SELECT serial FROM obj
WHERE partition=? AND oid=? AND serial>?
ORDER BY serial LIMIT 1""",
(partition, oid, serial)).fetchone()
if checksum:
checksum = str(checksum)
data = str(data)
return serial, r and r[0], compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset):
with self as q:
if reset:
q("DELETE FROM pt")
for offset, uuid, state in cell_list:
uuid = buffer(util.dump(uuid))
# TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query
# WKRD: Why does SQLite need a statement journal file
# whereas we try to replace only 1 value ?
# We don't want to remove the 'NOT NULL' constraint
# so we must simulate a "REPLACE OR FAIL".
q("DELETE FROM pt WHERE rid=? AND uuid=?", (offset, uuid))
if state != CellStates.DISCARDED:
q("INSERT OR FAIL INTO pt VALUES (?,?,?)",
(offset, uuid, int(state)))
self.setPTID(ptid)
def changePartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, False)
def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, offset_list):
where = " WHERE partition=?"
with self as q:
for partition in offset_list:
data_id_list = [x for x, in
q("SELECT DISTINCT data_id FROM obj" + where,
(partition,)) if x]
q("DELETE FROM obj" + where, (partition,))
q("DELETE FROM trans" + where, (partition,))
self._pruneData(data_id_list)
def dropUnfinishedData(self):
with self as q:
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("DELETE FROM tobj")
q("DELETE FROM ttrans")
self.unlockData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary=True):
u64 = util.u64
tid = u64(tid)
T = 't' if temporary else ''
obj_sql = "INSERT OR FAIL INTO %sobj VALUES (?,?,?,?,?)" % T
with self as q:
for oid, data_id, value_serial in object_list:
oid = u64(oid)
partition = self._getPartition(oid)
if value_serial:
value_serial = u64(value_serial)
(data_id,), = q("SELECT data_id FROM obj"
" WHERE partition=? AND oid=? AND serial=?",
(partition, oid, value_serial))
if temporary:
self.storeData(data_id)
try:
q(obj_sql, (partition, oid, tid, data_id, value_serial))
except sqlite3.IntegrityError:
# This may happen if a previous replication of 'obj' was
# interrupted.
if not T:
r, = q("SELECT data_id, value_serial FROM obj"
" WHERE partition=? AND oid=? AND serial=?",
(partition, oid, tid))
if r == (data_id, value_serial):
continue
raise
if transaction is not None:
oid_list, user, desc, ext, packed = transaction
partition = self._getPartition(tid)
assert packed in (0, 1)
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?)" % T,
(partition, tid, packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext)))
def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data)
if data_id_list:
q = self.query
data_id_list.difference_update(x for x, in q(
"SELECT DISTINCT data_id FROM obj WHERE data_id IN (%s)"
% ",".join(map(str, data_id_list))))
q("DELETE FROM data WHERE id IN (%s)"
% ",".join(map(str, data_id_list)))
def _storeData(self, checksum, data, compression):
H = buffer(checksum)
with self as q:
try:
return q("INSERT INTO data VALUES (NULL,?,?,?)",
(H, compression, buffer(data))).lastrowid
except sqlite3.IntegrityError, e:
if e.args[0] == 'column hash is not unique':
(r, c, d), = q("SELECT id, compression, value"
" FROM data WHERE hash=?", (H,))
if c == compression and str(d) == data:
return r
raise
def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getPartition(oid)
sql = 'SELECT serial, data_id, value_serial FROM obj' \
' WHERE partition=? AND oid=?'
if tid is not None:
r = self.query(sql + ' AND serial=?', (partition, oid, tid))
elif before_tid is not None:
r = self.query(sql + ' AND serial<? ORDER BY serial DESC LIMIT 1',
(partition, oid, before_tid))
else:
r = self.query(sql + ' ORDER BY serial DESC LIMIT 1',
(partition, oid))
r = r.fetchone()
if r:
serial, data_id, value_serial = r
if value_serial is None and data_id:
return serial, serial
return serial, value_serial
return None, None
def finishTransaction(self, tid):
args = util.u64(tid),
with self as q:
sql = " FROM tobj WHERE serial=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args)
q("DELETE FROM tobj WHERE serial=?", args)
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?",
args)
q("DELETE FROM ttrans WHERE tid=?", args)
self.unlockData(data_id_list)
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
with self as q:
sql = " FROM tobj WHERE serial=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x]
self.unlockData(data_id_list)
q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?",
(getPartition(tid), tid))
# delete from obj using indexes
data_id_set = set()
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=? AND oid=? AND serial=?"
args = getPartition(oid), oid, tid
data_id_set.update(*q("SELECT data_id" + sql, args))
q("DELETE" + sql, args)
data_id_set.discard(None)
self._pruneData(data_id_set)
def deleteObject(self, oid, serial=None):
oid = util.u64(oid)
sql = " FROM obj WHERE partition=? AND oid=?"
args = [self._getPartition(oid), oid]
if serial:
sql += " AND serial=?"
args.append(util.u64(serial))
with self as q:
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args)
if x]
q("DELETE" + sql, args)
self._pruneData(data_id_list)
def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE partition=?"
args = [partition]
if min_tid:
sql += " AND ? < tid"
args.append(util.u64(min_tid))
if max_tid:
sql += " AND tid <= ?"
args.append(util.u64(max_tid))
q = self.query
q("DELETE FROM trans" + sql, args)
sql = " FROM obj" + sql.replace('tid', 'serial')
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args)
if x]
q("DELETE" + sql, args)
self._pruneData(data_id_list)
def getTransaction(self, tid, all=False):
tid = util.u64(tid)
with self as q:
r = q("SELECT oids, user, description, ext, packed FROM trans"
" WHERE partition=? AND tid=?",
(self._getPartition(tid), tid)).fetchone()
if not r and all:
r = q("SELECT oids, user, description, ext, packed FROM ttrans"
" WHERE tid=?", (tid,)).fetchone()
if r:
oids, user, description, ext, packed = r
return splitOIDField(tid, oids), str(user), \
str(description), str(ext), packed
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
length, value_serial = self.query("""SELECT LENGTH(value), value_serial
FROM obj LEFT JOIN data ON obj.data_id=data.id
WHERE partition=? AND oid=? AND serial=?""",
(self._getPartition(oid), oid, value_serial)).fetchone()
if length is None:
neo.lib.logging.info("Multiple levels of indirection"
" when searching for object data for oid %d at tid %d."
" This causes suboptimal performance.", oid, value_serial)
length = self._getObjectLength(oid, value_serial)
return length
def getObjectHistory(self, oid, offset=0, length=1):
# FIXME: This method doesn't take client's current transaction id as
# parameter, which means it can return transactions in the future of
# client's transaction.
p64 = util.p64
oid = util.u64(oid)
pack_tid = self._getPackTID()
result = []
append = result.append
with self as q:
for serial, length, value_serial in q("""\
SELECT serial, LENGTH(value), value_serial
FROM obj LEFT JOIN data ON obj.data_id = data.id
WHERE partition=? AND oid=? AND serial>=?
ORDER BY serial DESC LIMIT ?,?""",
(self._getPartition(oid), oid, pack_tid, offset, length)):
if length is None:
try:
length = self._getObjectLength(oid, value_serial)
except CreationUndone:
length = 0
append((p64(serial), length))
return result or None
def getReplicationObjectList(self, min_tid, max_tid, length, partition,
min_oid):
u64 = util.u64
p64 = util.p64
min_tid = u64(min_tid)
return [(p64(serial), p64(oid)) for serial, oid in self.query("""\
SELECT serial, oid FROM obj
WHERE partition=? AND serial<=?
AND (serial=? AND ?<=oid OR ?<serial)
ORDER BY serial ASC, oid ASC LIMIT ?""",
(partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))]
def getTIDList(self, offset, length, partition_list):
p64 = util.p64
return [p64(t[0]) for t in self.query("""\
SELECT tid FROM trans WHERE partition in (%s)
ORDER BY tid DESC LIMIT %d,%d"""
% (','.join(map(str, partition_list)), offset, length))]
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
u64 = util.u64
p64 = util.p64
min_tid = u64(min_tid)
max_tid = u64(max_tid)
return [p64(t[0]) for t in self.query("""\
SELECT tid FROM trans
WHERE partition=? AND ?<=tid AND tid<=?
ORDER BY tid ASC LIMIT ?""",
(partition, min_tid, max_tid, length))]
def _updatePackFuture(self, oid, orig_serial, max_serial):
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
partition = self._getPartition(oid)
value_serial = None
q = self.query
for T in '', 't':
update = """UPDATE OR FAIL %sobj SET value_serial=?
WHERE partition=? AND oid=? AND serial=?""" % T
for serial, in q("""SELECT serial FROM %sobj
WHERE partition=? AND oid=? AND serial>=? AND value_serial=?
ORDER BY serial ASC""" % T,
(partition, oid, max_serial, orig_serial)):
q(update, (value_serial, partition, oid, serial))
if value_serial is None:
# First found, mark its serial for future reference.
value_serial = serial
return value_serial
def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture)
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
getPartition = self._getPartition
with self as q:
self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid,"
" MAX(serial) FROM obj WHERE serial<=? GROUP BY oid",
(tid,)):
partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition=?"
" AND oid=? AND serial=? AND data_id IS NULL",
(partition, oid, max_serial)).fetchone():
max_serial += 1
elif not count:
continue
# There are things to delete for this object
data_id_set = set()
sql = " FROM obj WHERE partition=? AND oid=? AND serial<?"
args = partition, oid, max_serial
for serial, data_id in q("SELECT serial, data_id" + sql, args):
data_id_set.add(data_id)
new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial:
new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial),
new_serial, data_id)
q("DELETE" + sql, args)
data_id_set.discard(None)
self._pruneData(data_id_set)
def checkTIDRange(self, min_tid, max_tid, length, partition):
count, tids, max_tid = self.query("""\
SELECT COUNT(*), GROUP_CONCAT(tid), MAX(tid)
FROM (SELECT tid FROM trans
WHERE partition=? AND ?<=tid AND tid<=?
ORDER BY tid ASC LIMIT ?) AS t""",
(partition, util.u64(min_tid), util.u64(max_tid),
-1 if length is None else length)).fetchone()
if count:
return count, sha1(tids).digest(), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_.
# We would need a function (that could be named 'LAST') that returns the
# last grouped value, instead of the greatest one.
min_oid = u64(min_oid)
r = self.query("""\
SELECT oid, serial
FROM obj
WHERE partition=? AND serial<=?
AND (oid>? OR oid=? AND serial>=?)
ORDER BY oid ASC, serial ASC LIMIT ?""",
(partition, u64(max_tid), min_oid, min_oid, u64(min_serial),
-1 if length is None else length)).fetchall()
if r:
p64 = util.p64
return (len(r),
sha1(','.join(str(x[0]) for x in r)).digest(),
p64(r[-1][0]),
sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
......@@ -18,15 +18,15 @@
import neo
from neo.lib.handler import EventHandler
from neo.lib import protocol
from neo.lib.util import dump
from neo.lib.exception import PrimaryFailure, OperationFailure
from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors, ZERO_HASH
from neo.lib.protocol import NodeStates, NodeTypes
class BaseMasterHandler(EventHandler):
def connectionLost(self, conn, new_state):
if self.app.listening_conn: # if running
self.app.master_node = None
raise PrimaryFailure('connection lost')
def stopOperation(self, conn):
......@@ -62,44 +62,5 @@ class BaseMasterHandler(EventHandler):
dump(uuid))
self.app.tm.abortFor(uuid)
class BaseClientAndStorageOperationHandler(EventHandler):
""" Accept requests common to client and storage nodes """
def askTransactionInformation(self, conn, tid):
app = self.app
t = app.dm.getTransaction(tid)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def _askObject(self, oid, serial, tid):
raise NotImplementedError
def askObject(self, conn, oid, serial, tid):
app = self.app
if self.app.tm.loadLocked(oid):
# Delay the response.
app.queueEvent(self.askObject, conn, (oid, serial, tid))
return
o = self._askObject(oid, serial, tid)
if o is None:
neo.lib.logging.debug('oid = %s does not exist', dump(oid))
p = Errors.OidDoesNotExist(dump(oid))
elif o is False:
neo.lib.logging.debug('oid = %s not found', dump(oid))
p = Errors.OidNotFound(dump(oid))
else:
serial, next_serial, compression, checksum, data, data_serial = o
neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial))
if checksum is None:
checksum = ZERO_HASH
data = ''
p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data, data_serial)
conn.answer(p)
def answerUnfinishedTransactions(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(*args, **kw)
......@@ -16,10 +16,10 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib
from neo.lib import protocol
from neo.lib.handler import EventHandler
from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors, ZERO_HASH
from . import BaseClientAndStorageOperationHandler
from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \
ZERO_HASH, INVALID_PARTITION
from ..transactions import ConflictError, DelayedError
from ..exception import AlreadyPendingError
import time
......@@ -28,10 +28,40 @@ import time
# Set to None to disable.
SLOW_STORE = 2
class ClientOperationHandler(BaseClientAndStorageOperationHandler):
class ClientOperationHandler(EventHandler):
def _askObject(self, oid, serial, ttid):
return self.app.dm.getObject(oid, serial, ttid)
def askTransactionInformation(self, conn, tid):
t = self.app.dm.getTransaction(tid)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def askObject(self, conn, oid, serial, tid):
app = self.app
if app.tm.loadLocked(oid):
# Delay the response.
app.queueEvent(self.askObject, conn, (oid, serial, tid))
return
o = app.dm.getObject(oid, serial, tid)
if o is None:
neo.lib.logging.debug('oid = %s does not exist', dump(oid))
p = Errors.OidDoesNotExist(dump(oid))
elif o is False:
neo.lib.logging.debug('oid = %s not found', dump(oid))
p = Errors.OidNotFound(dump(oid))
else:
serial, next_serial, compression, checksum, data, data_serial = o
neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial))
if checksum is None:
checksum = ZERO_HASH
data = ''
p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data, data_serial)
conn.answer(p)
def connectionLost(self, conn, new_state):
uuid = conn.getUUID()
......@@ -96,22 +126,18 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
self._askStoreObject(conn, oid, serial, compression, checksum, data,
data_serial, ttid, unlock, time.time())
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
getReplicationTIDList = self.app.dm.getReplicationTIDList
tid_list = []
extend = tid_list.extend
for partition in partition_list:
extend(getReplicationTIDList(min_tid, max_tid, length, partition))
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition):
conn.answer(Packets.AnswerTIDsFrom(self.app.dm.getReplicationTIDList(
min_tid, max_tid, length, partition)))
def askTIDs(self, conn, first, last, partition):
# This method is complicated, because I must return TIDs only
# about usable partitions assigned to me.
if first >= last:
raise protocol.ProtocolError('invalid offsets')
raise ProtocolError('invalid offsets')
app = self.app
if partition == protocol.INVALID_PARTITION:
if partition == INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
......@@ -149,7 +175,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
def askObjectHistory(self, conn, oid, first, last):
if first >= last:
raise protocol.ProtocolError( 'invalid offsets')
raise ProtocolError('invalid offsets')
app = self.app
history_list = app.dm.getObjectHistory(oid, first, last - first)
......
......@@ -21,6 +21,7 @@ from neo.lib.handler import EventHandler
from neo.lib.protocol import NodeTypes, Packets, NotReadyError
from neo.lib.protocol import ProtocolError, BrokenNodeDisallowedError
from neo.lib.util import dump
from .storage import StorageOperationHandler
class IdentificationHandler(EventHandler):
""" Handler used for incoming connections during operation state """
......@@ -35,37 +36,42 @@ class IdentificationHandler(EventHandler):
if not self.app.ready:
raise NotReadyError
app = self.app
node = app.nm.getByUUID(uuid)
# If this node is broken, reject it.
if node is not None and node.isBroken():
raise BrokenNodeDisallowedError
# choose the handler according to the node type
if node_type == NodeTypes.CLIENT:
from .client import ClientOperationHandler
handler = ClientOperationHandler
if node is None:
node = app.nm.createClient(uuid=uuid)
elif node.isConnected():
# cut previous connection
node.getConnection().close()
assert not node.isConnected()
node.setRunning()
elif node_type == NodeTypes.STORAGE:
from .storage import StorageOperationHandler
handler = StorageOperationHandler
if node is None:
neo.lib.logging.error('reject an unknown storage node %s',
dump(uuid))
raise NotReadyError
if uuid is None:
if node_type != NodeTypes.STORAGE:
raise ProtocolError('reject anonymous non-storage node')
handler = StorageOperationHandler(self.app)
conn.setHandler(handler)
else:
raise ProtocolError('reject non-client-or-storage node')
# apply the handler and set up the connection
handler = handler(self.app)
conn.setHandler(handler)
node.setConnection(conn)
args = (NodeTypes.STORAGE, app.uuid, app.pt.getPartitions(),
app.pt.getReplicas(), uuid)
if uuid == app.uuid:
raise ProtocolError("uuid conflict or loopback connection")
node = app.nm.getByUUID(uuid)
# If this node is broken, reject it.
if node is not None and node.isBroken():
raise BrokenNodeDisallowedError
# choose the handler according to the node type
if node_type == NodeTypes.CLIENT:
from .client import ClientOperationHandler
handler = ClientOperationHandler
if node is None:
node = app.nm.createClient(uuid=uuid)
elif node.isConnected():
# cut previous connection
node.getConnection().close()
assert not node.isConnected()
node.setRunning()
elif node_type == NodeTypes.STORAGE:
if node is None:
neo.lib.logging.error('reject an unknown storage node %s',
dump(uuid))
raise NotReadyError
handler = StorageOperationHandler
else:
raise ProtocolError('reject non-client-or-storage node')
# apply the handler and set up the connection
handler = handler(self.app)
conn.setHandler(handler)
node.setConnection(conn, app.uuid < uuid)
# accept the identification and trigger an event
conn.answer(Packets.AcceptIdentification(*args))
conn.answer(Packets.AcceptIdentification(NodeTypes.STORAGE, uuid and
app.uuid, app.pt.getPartitions(), app.pt.getReplicas(), uuid))
handler.connectionCompleted(conn)
......@@ -25,10 +25,6 @@ class InitializationHandler(BaseMasterHandler):
def answerNodeInformation(self, conn):
pass
def notifyNodeInformation(self, conn, node_list):
# the whole node list is received here
BaseMasterHandler.notifyNodeInformation(self, conn, node_list)
def answerPartitionTable(self, conn, ptid, row_list):
app = self.app
pt = app.pt
......@@ -53,8 +49,9 @@ class InitializationHandler(BaseMasterHandler):
app.dm.setPartitionTable(ptid, cell_list)
def answerLastIDs(self, conn, loid, ltid, lptid):
def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
self.app.dm.setLastOID(loid)
self.app.dm.setBackupTID(backup_tid)
def notifyPartitionChanges(self, conn, ptid, cell_list):
# XXX: This is safe to ignore those notifications because all of the
......
......@@ -24,11 +24,8 @@ from . import BaseMasterHandler
class MasterOperationHandler(BaseMasterHandler):
""" This handler is used for the primary master """
def answerUnfinishedTransactions(self, conn, max_tid, ttid_list):
self.app.replicator.setUnfinishedTIDList(max_tid, ttid_list)
def notifyTransactionFinished(self, conn, ttid, max_tid):
self.app.replicator.transactionFinished(ttid, max_tid)
def notifyTransactionFinished(self, conn, *args, **kw):
self.app.replicator.transactionFinished(*args, **kw)
def notifyPartitionChanges(self, conn, ptid, cell_list):
"""This is very similar to Send Partition Table, except that
......@@ -44,14 +41,7 @@ class MasterOperationHandler(BaseMasterHandler):
app.dm.changePartitionTable(ptid, cell_list)
# Check changes for replications
if app.replicator is not None:
for offset, uuid, state in cell_list:
if uuid == app.uuid:
# If this is for myself, this can affect replications.
if state == CellStates.DISCARDED:
app.replicator.removePartition(offset)
elif state == CellStates.OUT_OF_DATE:
app.replicator.addPartition(offset)
app.replicator.notifyPartitionChanges(cell_list)
def askLockInformation(self, conn, ttid, tid, oid_list):
if not ttid in self.app.tm:
......@@ -74,3 +64,11 @@ class MasterOperationHandler(BaseMasterHandler):
if not conn.isClosed():
conn.answer(Packets.AnswerPack(True))
def replicate(self, conn, tid, upstream_name, source_dict):
self.app.replicator.backup(tid,
dict((p, (a, upstream_name))
for p, a in source_dict.iteritems()))
def askTruncate(self, conn, tid):
self.app.dm.truncate(tid)
conn.answer(Packets.AnswerTruncate())
#
# Copyright (C) 2006-2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from functools import wraps
import neo.lib
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, ZERO_HASH, ZERO_TID, ZERO_OID
from neo.lib.util import add64, u64
# TODO: benchmark how different values behave
RANGE_LENGTH = 4000
MIN_RANGE_LENGTH = 1000
CHECK_CHUNK = 0
CHECK_REPLICATE = 1
CHECK_DONE = 2
"""
Replication algorithm
Purpose: replicate the content of a reference node into a replicating node,
bringing it up-to-date.
This happens both when a new storage is added to en existing cluster, as well
as when a nde was separated from cluster and rejoins it.
Replication happens per partition. Reference node can change between
partitions.
2 parts, done sequentially:
- Transaction (metadata) replication
- Object (data) replication
Both parts follow the same mechanism:
- On both sides (replicating and reference), compute a checksum of a chunk
(RANGE_LENGTH number of entries). If there is a mismatch, chunk size is
reduced, and scan restarts from same row, until it reaches a minimal length
(MIN_RANGE_LENGTH). Then, it replicates all rows in that chunk. If the
content of chunks match, it moves on to the next chunk.
- Replicating a chunk starts with asking for a list of all entries (only their
identifier) and skipping those both side have, deleting those which reference
has and replicating doesn't, and asking individually all entries missing in
replicating.
"""
# TODO: Make object replication get ordered by serial first and oid second, so
# changes are in a big segment at the end, rather than in many segments (one
# per object).
# TODO: To improve performance when a pack happened, the following algorithm
# should be used:
# - If reference node packed, find non-existant oids in reference node (their
# creation was undone, and pack pruned them), and delete them.
# - Run current algorithm, starting at our last pack TID.
# - Pack partition at reference's TID.
def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
if self.app.replicator.isCurrentConnection(conn):
return func(self, conn, *args, **kw)
# Should probably raise & close connection...
return wraps(func)(decorator)
class ReplicationHandler(EventHandler):
"""This class handles events for replications."""
def connectionLost(self, conn, new_state):
replicator = self.app.replicator
if replicator.isCurrentConnection(conn):
if replicator.pending():
neo.lib.logging.warning(
'replication is stopped due to a connection lost')
replicator.storageLost()
def connectionFailed(self, conn):
neo.lib.logging.warning(
'replication is stopped due to connection failure')
self.app.replicator.storageLost()
def acceptIdentification(self, conn, node_type,
uuid, num_partitions, num_replicas, your_uuid):
self.startReplication(conn)
def startReplication(self, conn):
max_tid = self.app.replicator.getCurrentCriticalTID()
conn.ask(self._doAskCheckTIDRange(ZERO_TID, max_tid), timeout=300)
@checkConnectionIsReplicatorConnection
def answerTIDsFrom(self, conn, tid_list):
assert tid_list
app = self.app
ask = conn.ask
# If I have pending TIDs, check which TIDs I don't have, and
# request the data.
tid_set = frozenset(tid_list)
my_tid_set = frozenset(app.replicator.getTIDsFromResult())
extra_tid_set = my_tid_set - tid_set
if extra_tid_set:
deleteTransaction = app.dm.deleteTransaction
for tid in extra_tid_set:
deleteTransaction(tid)
missing_tid_set = tid_set - my_tid_set
for tid in missing_tid_set:
ask(Packets.AskTransactionInformation(tid), timeout=300)
if len(tid_list) == MIN_RANGE_LENGTH:
# If we received fewer, we knew it before sending AskTIDsFrom, and
# we should have finished TID replication at that time.
max_tid = self.app.replicator.getCurrentCriticalTID()
ask(self._doAskCheckTIDRange(add64(tid_list[-1], 1), max_tid,
RANGE_LENGTH))
@checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
app = self.app
# Directly store the transaction.
app.dm.storeTransaction(tid, (), (oid_list, user, desc, ext, packed),
False)
@checkConnectionIsReplicatorConnection
def answerObjectHistoryFrom(self, conn, object_dict):
assert object_dict
app = self.app
ask = conn.ask
deleteObject = app.dm.deleteObject
my_object_dict = app.replicator.getObjectHistoryFromResult()
object_set = set()
max_oid = max(object_dict.iterkeys())
max_serial = max(object_dict[max_oid])
for oid, serial_list in object_dict.iteritems():
for serial in serial_list:
object_set.add((oid, serial))
my_object_set = set()
for oid, serial_list in my_object_dict.iteritems():
filter = lambda x: True
if max_oid is not None:
if oid > max_oid:
continue
elif oid == max_oid:
filter = lambda x: x <= max_serial
for serial in serial_list:
if filter(serial):
my_object_set.add((oid, serial))
extra_object_set = my_object_set - object_set
for oid, serial in extra_object_set:
deleteObject(oid, serial)
missing_object_set = object_set - my_object_set
for oid, serial in missing_object_set:
if not app.dm.objectPresent(oid, serial):
ask(Packets.AskObject(oid, serial, None), timeout=300)
if sum(map(len, object_dict.itervalues())) == MIN_RANGE_LENGTH:
max_tid = self.app.replicator.getCurrentCriticalTID()
ask(self._doAskCheckSerialRange(max_oid, add64(max_serial, 1),
max_tid, RANGE_LENGTH))
@checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial):
dm = self.app.dm
if data or checksum != ZERO_HASH:
data_id = dm.storeData(checksum, data, compression)
else:
data_id = None
# Directly store the transaction.
obj = oid, data_id, data_serial
dm.storeTransaction(serial_start, [obj], None, False)
def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid,
length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.getCurrentOffset()
neo.lib.logging.debug("Check serial range (offset=%s, min_oid=%x,"
" min_tid=%x, max_tid=%x, length=%s)", partition, u64(min_oid),
u64(min_tid), u64(max_tid), length)
check_args = (min_oid, min_tid, max_tid, length, partition)
replicator.checkSerialRange(*check_args)
return Packets.AskCheckSerialRange(*check_args)
def _doAskCheckTIDRange(self, min_tid, max_tid, length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.getCurrentOffset()
neo.lib.logging.debug(
"Check TID range (offset=%s, min_tid=%x, max_tid=%x, length=%s)",
partition, u64(min_tid), u64(max_tid), length)
replicator.checkTIDRange(min_tid, max_tid, length, partition)
return Packets.AskCheckTIDRange(min_tid, max_tid, length, partition)
def _doAskTIDsFrom(self, min_tid, length):
replicator = self.app.replicator
partition_id = replicator.getCurrentOffset()
max_tid = replicator.getCurrentCriticalTID()
replicator.getTIDsFrom(min_tid, max_tid, length, partition_id)
neo.lib.logging.debug("Ask TIDs (offset=%s, min_tid=%x, max_tid=%x,"
"length=%s)", partition_id, u64(min_tid), u64(max_tid), length)
return Packets.AskTIDsFrom(min_tid, max_tid, length, [partition_id])
def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
replicator = self.app.replicator
partition_id = replicator.getCurrentOffset()
max_serial = replicator.getCurrentCriticalTID()
replicator.getObjectHistoryFrom(min_oid, min_serial, max_serial,
length, partition_id)
return Packets.AskObjectHistoryFrom(min_oid, min_serial, max_serial,
length, partition_id)
def _checkRange(self, match, current_boundary, next_boundary, length,
count):
if count == 0:
# Reference storage has no data for this chunk, stop and truncate.
return CHECK_DONE, (current_boundary, )
if match:
# Same data on both sides
if length < RANGE_LENGTH and length == count:
# ...and previous check detected a difference - and we still
# haven't reached the end. This means that we just check the
# first half of a chunk which, as a whole, is different. So
# next test must happen on the next chunk.
recheck_min_boundary = next_boundary
else:
# ...and we just checked a whole chunk, move on to the next
# one.
recheck_min_boundary = None
else:
# Something is different in current chunk
recheck_min_boundary = current_boundary
if recheck_min_boundary is None:
if count == length:
# Go on with next chunk
action = CHECK_CHUNK
params = (next_boundary, RANGE_LENGTH)
else:
# No more chunks.
action = CHECK_DONE
params = (next_boundary, )
else:
# We must recheck current chunk.
if not match and count <= MIN_RANGE_LENGTH:
# We are already at minimum chunk length, replicate.
action = CHECK_REPLICATE
params = (recheck_min_boundary, )
else:
# Check a smaller chunk.
# Note: +1, so we can detect we reached the end when answer
# comes back.
action = CHECK_CHUNK
params = (recheck_min_boundary, max(min(length / 2, count + 1),
MIN_RANGE_LENGTH))
return action, params
@checkConnectionIsReplicatorConnection
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
pkt_min_tid = min_tid
ask = conn.ask
app = self.app
replicator = app.replicator
next_tid = add64(max_tid, 1)
action, params = self._checkRange(
replicator.getTIDCheckResult(min_tid, length) == (
count, tid_checksum, max_tid), min_tid, next_tid, length,
count)
critical_tid = replicator.getCurrentCriticalTID()
if action == CHECK_REPLICATE:
(min_tid, ) = params
ask(self._doAskTIDsFrom(min_tid, count))
if length != count:
action = CHECK_DONE
params = (next_tid, )
if action == CHECK_CHUNK:
(min_tid, count) = params
if min_tid >= critical_tid:
# Stop if past critical TID
action = CHECK_DONE
params = (next_tid, )
else:
ask(self._doAskCheckTIDRange(min_tid, critical_tid, count))
if action == CHECK_DONE:
# Delete all transactions we might have which are beyond what peer
# knows.
(last_tid, ) = params
offset = replicator.getCurrentOffset()
neo.lib.logging.debug("TID range checked (offset=%s, min_tid=%x,"
" length=%s, count=%s, max_tid=%x, last_tid=%x,"
" critical_tid=%x)", offset, u64(pkt_min_tid), length, count,
u64(max_tid), u64(last_tid), u64(critical_tid))
app.dm.deleteTransactionsAbove(offset, last_tid, critical_tid)
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID, critical_tid))
@checkConnectionIsReplicatorConnection
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
ask = conn.ask
app = self.app
replicator = app.replicator
next_params = (max_oid, add64(max_serial, 1))
action, params = self._checkRange(
replicator.getSerialCheckResult(min_oid, min_serial, length) == (
count, oid_checksum, max_oid, serial_checksum, max_serial),
(min_oid, min_serial), next_params, length, count)
if action == CHECK_REPLICATE:
((min_oid, min_serial), ) = params
ask(self._doAskObjectHistoryFrom(min_oid, min_serial, count))
if length != count:
action = CHECK_DONE
params = (next_params, )
if action == CHECK_CHUNK:
((min_oid, min_serial), count) = params
max_tid = replicator.getCurrentCriticalTID()
ask(self._doAskCheckSerialRange(min_oid, min_serial, max_tid, count))
if action == CHECK_DONE:
# Delete all objects we might have which are beyond what peer
# knows.
((last_oid, last_serial), ) = params
offset = replicator.getCurrentOffset()
max_tid = replicator.getCurrentCriticalTID()
neo.lib.logging.debug("Serial range checked (offset=%s, min_oid=%x,"
" min_serial=%x, length=%s, count=%s, max_oid=%x,"
" max_serial=%x, last_oid=%x, last_serial=%x, critical_tid=%x)",
offset, u64(min_oid), u64(min_serial), length, count,
u64(max_oid), u64(max_serial), u64(last_oid), u64(last_serial),
u64(max_tid))
app.dm.deleteObjectsAbove(offset, last_oid, last_serial, max_tid)
# Nothing remains, so the replication for this partition is
# finished.
replicator.setReplicationDone()
......@@ -15,36 +15,101 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from . import BaseClientAndStorageOperationHandler
from neo.lib.protocol import Packets
import weakref
from functools import wraps
import neo.lib
from neo.lib.connector import ConnectorConnectionClosedException
from neo.lib.handler import EventHandler
from neo.lib.protocol import Errors, NodeStates, Packets, \
ZERO_HASH, ZERO_TID, ZERO_OID
from neo.lib.util import add64, u64
class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
assert self.app.replicator.getCurrentConnection() is conn
return func(self, conn, *args, **kw)
return wraps(func)(decorator)
def _askObject(self, oid, serial, tid):
result = self.app.dm.getObject(oid, serial, tid)
if result and result[5]:
return result[:2] + (None, None, None) + result[4:]
return result
class StorageOperationHandler(EventHandler):
"""This class handles events for replications."""
def askLastIDs(self, conn):
app = self.app
oid = app.dm.getLastOID()
tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list):
assert len(partition_list) == 1, partition_list
tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid, length,
partition_list[0])
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, min_oid, min_serial, max_serial,
length, partition):
object_dict = self.app.dm.getObjectHistoryFrom(min_oid, min_serial,
max_serial, length, partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def connectionLost(self, conn, new_state):
if self.app.listening_conn and conn.isClient():
# XXX: Connection and Node should merged.
uuid = conn.getUUID()
if uuid:
node = self.app.nm.getByUUID(uuid)
else:
node = self.app.nm.getByAddress(conn.getAddress())
node.setState(NodeStates.DOWN)
replicator = self.app.replicator
if replicator.current_node is node:
replicator.abort()
# Client
def connectionFailed(self, conn):
if self.app.listening_conn:
self.app.replicator.abort()
@checkConnectionIsReplicatorConnection
def acceptIdentification(self, conn, node_type,
uuid, num_partitions, num_replicas, your_uuid):
self.app.replicator.fetchTransactions()
@checkConnectionIsReplicatorConnection
def answerFetchTransactions(self, conn, pack_tid, next_tid, tid_list):
if tid_list:
deleteTransaction = self.app.dm.deleteTransaction
for tid in tid_list:
deleteTransaction(tid)
assert not pack_tid, "TODO"
if next_tid:
self.app.replicator.fetchTransactions(next_tid)
else:
self.app.replicator.fetchObjects()
@checkConnectionIsReplicatorConnection
def addTransaction(self, conn, tid, user, desc, ext, packed, oid_list):
# Directly store the transaction.
self.app.dm.storeTransaction(tid, (),
(oid_list, user, desc, ext, packed), False)
@checkConnectionIsReplicatorConnection
def answerFetchObjects(self, conn, pack_tid, next_tid,
next_oid, object_dict):
if object_dict:
deleteObject = self.app.dm.deleteObject
for serial, oid_list in object_dict.iteritems():
for oid in oid_list:
delObject(oid, serial)
assert not pack_tid, "TODO"
if next_tid:
self.app.replicator.fetchObjects(next_tid, next_oid)
else:
self.app.replicator.finish()
@checkConnectionIsReplicatorConnection
def addObject(self, conn, oid, serial, compression,
checksum, data, data_serial):
dm = self.app.dm
if data or checksum != ZERO_HASH:
data_id = dm.storeData(checksum, data, compression)
else:
data_id = None
# Directly store the transaction.
obj = oid, data_id, data_serial
dm.storeTransaction(serial, (obj,), None, False)
@checkConnectionIsReplicatorConnection
def replicationError(self, conn, message):
self.app.replicator.abort('source message: ' + message)
# Server (all methods must set connection as server so that it isn't closed
# if client tasks are finished)
def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
conn.asServer()
count, tid_checksum, max_tid = self.app.dm.checkTIDRange(min_tid,
max_tid, length, partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length,
......@@ -52,9 +117,91 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition):
conn.asServer()
count, oid_checksum, max_oid, serial_checksum, max_serial = \
self.app.dm.checkSerialRange(min_oid, min_serial, max_tid, length,
partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial))
def askFetchTransactions(self, conn, partition, length, min_tid, max_tid,
tid_list):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
peer_tid_set = set(tid_list)
dm = app.dm
tid_list = dm.getReplicationTIDList(min_tid, max_tid, length + 1,
partition)
next_tid = tid_list.pop() if length < len(tid_list) else None
def push():
try:
pack_tid = None # TODO
for tid in tid_list:
if tid in peer_tid_set:
peer_tid_set.remove(tid)
else:
t = dm.getTransaction(tid)
if t is None:
conn.answer(Errors.ReplicationError(
"partition %u dropped" % partition))
return
oid_list, user, desc, ext, packed = t
conn.notify(Packets.AddTransaction(
tid, user, desc, ext, packed, oid_list))
yield
conn.answer(Packets.AnswerFetchTransactions(
pack_tid, next_tid, peer_tid_set), msg_id)
yield
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
app.newTask(push())
def askFetchObjects(self, conn, partition, length, min_tid, max_tid,
min_oid, object_dict):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
dm = app.dm
object_list = dm.getReplicationObjectList(min_tid, max_tid, length,
partition, min_oid)
if length < len(object_list):
next_tid, next_oid = object_list.pop()
else:
next_tid = next_oid = None
def push():
try:
pack_tid = None # TODO
for serial, oid in object_list:
oid_set = object_dict.get(serial)
if oid_set:
if type(oid_set) is list:
object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set:
oid_set.remove(oid)
if not oid_set:
del object_dict[serial]
continue
object = dm.getObject(oid, serial)
if object is None:
conn.answer(Errors.ReplicationError(
"partition %u dropped" % partition))
return
conn.notify(Packets.AddObject(oid, serial, *object[2:]))
yield
conn.answer(Packets.AnswerFetchObjects(
pack_tid, next_tid, next_oid, object_dict), msg_id)
yield
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
app.newTask(push())
......@@ -27,15 +27,11 @@ class VerificationHandler(BaseMasterHandler):
def askLastIDs(self, conn):
app = self.app
try:
oid = app.dm.getLastOID()
except KeyError:
oid = None
try:
tid = app.dm.getLastTID()
except KeyError:
tid = None
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
conn.answer(Packets.AnswerLastIDs(
app.dm.getLastOID(),
app.dm.getLastTIDs()[0],
app.pt.getID(),
app.dm.getBackupTID()))
def askPartitionTable(self, conn):
pt = self.app.pt
......
......@@ -15,363 +15,300 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib
from random import choice
"""
Replication algorithm
from .handlers import replication
from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.lib.connection import ClientConnection
from neo.lib.util import dump
Purpose: replicate the content of a reference node into a replicating node,
bringing it up-to-date. This happens in the following cases:
- A new storage is added to en existing cluster.
- A node was separated from cluster and rejoins it.
- In a backup cluster, the master notifies a node that new data exists upstream
(note that in this case, the cell is always marked as UP_TO_DATE).
class Partition(object):
"""This class abstracts the state of a partition."""
Replication happens per partition. Reference node can change between
partitions.
def __init__(self, offset, max_tid, ttid_list):
# Possible optimization:
# _pending_ttid_list & _critical_tid can be shared amongst partitions
# created at the same time (cf Replicator.setUnfinishedTIDList).
# Replicator.transactionFinished would only have to iterate on these
# different sets, instead of all partitions.
self._offset = offset
self._pending_ttid_list = set(ttid_list)
# pending upper bound
self._critical_tid = max_tid
2 parts, done sequentially:
- Transaction (metadata) replication
- Object (data) replication
def getOffset(self):
return self._offset
Both parts follow the same mechanism:
- The range of data to replicate is split into chunks of FETCH_COUNT items
(transaction or object).
- For every chunk, the requesting node sends to seeding node the list of items
it already has.
- Before answering, the seeding node sends 1 packet for every missing item.
- The seeding node finally answers with the list of items to delete (usually
empty).
def getCriticalTID(self):
return self._critical_tid
Replication is partial, starting from the greatest stored tid in the partition:
- For transactions, this tid is excluded from replication.
- For objects, this tid is included unless the storage already knows it has
all oids for it.
def transactionFinished(self, ttid, max_tid):
self._pending_ttid_list.remove(ttid)
assert max_tid is not None
# final upper bound
self._critical_tid = max_tid
There is no check that item values on both nodes matches.
def safe(self):
return not self._pending_ttid_list
TODO: Packing and replication currently fail when then happen at the same time.
"""
class Task(object):
"""
A Task is a callable to execute at another time, with given parameters.
Execution result is kept and can be retrieved later.
"""
import random
_func = None
_args = None
_kw = None
_result = None
_processed = False
import neo.lib
from neo.lib.protocol import CellStates, NodeTypes, NodeStates, Packets, \
INVALID_TID, ZERO_TID, ZERO_OID
from neo.lib.connection import ClientConnection
from neo.lib.util import add64, u64
from .handlers.storage import StorageOperationHandler
FETCH_COUNT = 1000
def __init__(self, func, args=(), kw=None):
self._func = func
self._args = args
if kw is None:
kw = {}
self._kw = kw
def process(self):
if self._processed:
raise ValueError, 'You cannot process a single Task twice'
self._processed = True
self._result = self._func(*self._args, **self._kw)
class Partition(object):
def getResult(self):
# Should we instead execute immediately rather than raising ?
if not self._processed:
raise ValueError, 'You cannot get a result until task is executed'
return self._result
__slots__ = 'next_trans', 'next_obj', 'max_ttid'
def __repr__(self):
fmt = '<%s at %x %r(*%r, **%r)%%s>' % (self.__class__.__name__,
id(self), self._func, self._args, self._kw)
if self._processed:
extra = ' => %r' % (self._result, )
else:
extra = ''
return fmt % (extra, )
return '<%s(%s) at 0x%x>' % (self.__class__.__name__,
', '.join('%s=%r' % (x, getattr(self, x)) for x in self.__slots__
if hasattr(self, x)),
id(self))
class Replicator(object):
"""This class handles replications of objects and transactions.
Assumptions:
- Client nodes recognize partition changes reasonably quickly.
- When an out of date partition is added, next transaction ID
is given after the change is notified and serialized.
Procedures:
- Get the last TID right after a partition is added. This TID
is called a "critical TID", because this and TIDs before this
may not be present in this storage node yet. After a critical
TID, all transactions must exist in this storage node.
- Check if a primary master node still has pending transactions
before and at a critical TID. If so, I must wait for them to be
committed or aborted.
- In order to copy data, first get the list of TIDs. This is done
part by part, because the list can be very huge. When getting
a part of the list, I verify if they are in my database, and
ask data only for non-existing TIDs. This is performed until
the check reaches a critical TID.
- Next, get the list of OIDs. And, for each OID, ask the history,
namely, a list of serials. This is also done part by part, and
I ask only non-existing data. """
# new_partition_set
# outdated partitions for which no pending transactions was asked to
# primary master yet
# partition_dict
# outdated partitions with pending transaction and temporary critical
# tid
# current_partition
# partition being currently synchronised
# current_connection
# connection to a storage node we are replicating from
# waiting_for_unfinished_tids
# unfinished tids have been asked to primary master node, but it
# didn't answer yet.
# replication_done
# False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if
# there is something to replicate
current_node = None
current_partition = None
current_connection = None
waiting_for_unfinished_tids = False
replication_done = True
def __init__(self, app):
self.app = app
self.new_partition_set = set()
self.partition_dict = {}
self.task_list = []
self.task_dict = {}
def masterLost(self):
"""
When connection to primary master is lost, stop waiting for unfinished
transactions.
"""
self.waiting_for_unfinished_tids = False
def getCurrentConnection(self):
node = self.current_node
if node is not None and node.isConnected():
return node.getConnection()
def storageLost(self):
"""
Restart replicating.
"""
self.reset()
def populate(self):
"""
Populate partitions to replicate. Must be called when partition
table is the one accepted by primary master.
Implies a reset.
"""
partition_list = self.app.pt.getOutdatedOffsetListFor(self.app.uuid)
self.new_partition_set = set(partition_list)
self.partition_dict = {}
self.reset()
def reset(self):
"""Reset attributes to restart replicating."""
self.task_list = []
self.task_dict = {}
self.current_partition = None
self.current_connection = None
self.replication_done = True
def pending(self):
"""Return whether there is any pending partition."""
return bool(self.partition_dict or self.new_partition_set)
def getCurrentOffset(self):
assert self.current_partition is not None
return self.current_partition.getOffset()
def getCurrentCriticalTID(self):
assert self.current_partition is not None
return self.current_partition.getCriticalTID()
def setReplicationDone(self):
""" Callback from ReplicationHandler """
self.replication_done = True
def isCurrentConnection(self, conn):
return self.current_connection is conn
def setUnfinishedTIDList(self, max_tid, ttid_list):
def setUnfinishedTIDList(self, max_tid, ttid_list, offset_list):
"""This is a callback from MasterOperationHandler."""
neo.lib.logging.debug('setting unfinished TTIDs %s',
','.join(map(dump, ttid_list)))
# all new outdated partition must wait those ttid
new_partition_set = self.new_partition_set
while new_partition_set:
offset = new_partition_set.pop()
self.partition_dict[offset] = Partition(offset, max_tid, ttid_list)
self.waiting_for_unfinished_tids = False
if ttid_list:
self.ttid_set.update(ttid_list)
max_ttid = max(ttid_list)
else:
max_ttid = None
for offset in offset_list:
self.partition_dict[offset].max_ttid = max_ttid
self.replicate_dict[offset] = max_tid
self._nextPartition()
def transactionFinished(self, ttid, max_tid):
""" Callback from MasterOperationHandler """
for partition in self.partition_dict.itervalues():
partition.transactionFinished(ttid, max_tid)
def _askUnfinishedTIDs(self):
conn = self.app.master_conn
conn.ask(Packets.AskUnfinishedTransactions())
self.waiting_for_unfinished_tids = True
self.ttid_set.remove(ttid)
min_ttid = min(self.ttid_set) if self.ttid_set else INVALID_TID
for offset, p in self.partition_dict.iteritems():
if p.max_ttid and p.max_ttid < min_ttid:
p.max_ttid = None
self.replicate_dict[offset] = max_tid
self._nextPartition()
def getBackupTID(self):
outdated_set = set(self.app.pt.getOutdatedOffsetListFor(self.app.uuid))
tid = INVALID_TID
for offset, p in self.partition_dict.iteritems():
if offset not in outdated_set:
tid = min(tid, p.next_trans, p.next_obj)
if tid not in (ZERO_TID, INVALID_TID):
return add64(tid, -1)
def _startReplication(self):
# Choose a storage node for the source.
def populate(self):
app = self.app
cell_list = app.pt.getCellList(self.current_partition.getOffset(),
readable=True)
node_list = [cell.getNode() for cell in cell_list
if cell.getNodeState() == NodeStates.RUNNING]
try:
node = choice(node_list)
except IndexError:
# Not operational.
neo.lib.logging.error('not operational', exc_info = 1)
self.current_partition = None
return
addr = node.getAddress()
if addr is None:
neo.lib.logging.error("no address known for the selected node %s" %
(dump(node.getUUID()), ))
pt = app.pt
uuid = app.uuid
self.partition_dict = p = {}
self.replicate_dict = {}
self.source_dict = {}
self.ttid_set = set()
last_tid, last_trans_dict, last_obj_dict = app.dm.getLastTIDs()
backup_tid = app.dm.getBackupTID()
if backup_tid and last_tid < backup_tid:
last_tid = backup_tid
outdated_list = []
for offset in xrange(pt.getPartitions()):
for cell in pt.getCellList(offset):
if cell.getUUID() == uuid:
self.partition_dict[offset] = p = Partition()
if cell.isOutOfDate():
outdated_list.append(offset)
try:
p.next_trans = add64(last_trans_dict[offset], 1)
except KeyError:
p.next_trans = ZERO_TID
p.next_obj = last_obj_dict.get(offset, ZERO_TID)
p.max_ttid = INVALID_TID
else:
p.next_trans = p.next_obj = last_tid
p.max_ttid = None
if outdated_list:
self.app.master_conn.ask(Packets.AskUnfinishedTransactions(),
offset_list=outdated_list)
def notifyPartitionChanges(self, cell_list):
"""This is a callback from MasterOperationHandler."""
abort = False
added_list = []
app = self.app
for offset, uuid, state in cell_list:
if uuid == app.uuid:
if state == CellStates.DISCARDED:
del self.partition_dict[offset]
self.replicate_dict.pop(offset, None)
self.source_dict.pop(offset, None)
abort = abort or self.current_partition == offset
elif state == CellStates.OUT_OF_DATE:
assert offset not in self.partition_dict
self.partition_dict[offset] = p = Partition()
p.next_trans = p.next_obj = ZERO_TID
p.max_ttid = INVALID_TID
added_list.append(offset)
if added_list:
self.app.master_conn.ask(Packets.AskUnfinishedTransactions(),
offset_list=added_list)
if abort:
self.abort()
def backup(self, tid, source_dict):
for offset in source_dict:
self.replicate_dict[offset] = tid
self.source_dict.update(source_dict)
self._nextPartition()
def _nextPartition(self):
# XXX: One connection to another storage may remain open forever.
# All other previous connections are automatically closed
# after some time of inactivity.
# This should be improved in several ways:
# - Keeping connections open between 2 clusters (backup case) is
# quite a good thing because establishing a connection costs
# time/bandwidth and replication is actually never finished.
# - When all storages of a non-backup cluster are up-to-date,
# there's no reason to keep any connection open.
if self.current_partition is not None or not self.replicate_dict:
return
connection = self.current_connection
if connection is None or connection.getAddress() != addr:
handler = replication.ReplicationHandler(app)
self.current_connection = ClientConnection(app.em, handler,
node=node, connector=app.connector_handler())
p = Packets.RequestIdentification(NodeTypes.STORAGE,
app.uuid, app.server, app.name)
self.current_connection.ask(p)
if connection is not None:
connection.close()
else:
connection.getHandler().startReplication(connection)
self.replication_done = False
def _finishReplication(self):
# TODO: remove try..except: pass
app = self.app
# Choose a partition with no unfinished transaction if possible.
for offset in self.replicate_dict:
if not self.partition_dict[offset].max_ttid:
break
try:
# Notify to a primary master node that my cell is now up-to-date.
conn = self.app.master_conn
offset = self.current_partition.getOffset()
self.partition_dict.pop(offset)
conn.notify(Packets.NotifyReplicationDone(offset))
addr, name = self.source_dict[offset]
except KeyError:
pass
if self.pending():
self.current_partition = None
assert self.app.pt.getCell(offset, self.app.uuid).isOutOfDate()
node = random.choice([cell.getNode()
for cell in app.pt.getCellList(offset, readable=True)
if cell.getNodeState() == NodeStates.RUNNING])
name = None
else:
self.current_connection.close()
def act(self):
if self.current_partition is not None:
# Don't end replication until we have received all expected
# answers, as we might have asked object data just before the last
# AnswerCheckSerialRange.
if self.replication_done and \
not self.current_connection.isPending():
# finish a replication
neo.lib.logging.info('replication is done for %s' %
(self.current_partition.getOffset(), ))
self._finishReplication()
return
if self.waiting_for_unfinished_tids:
# Still waiting.
neo.lib.logging.debug('waiting for unfinished tids')
return
if self.new_partition_set:
# Ask pending transactions.
neo.lib.logging.debug('asking unfinished tids')
self._askUnfinishedTIDs()
return
# Try to select something.
for partition in self.partition_dict.values():
# XXX: replication could start up to the initial critical tid, that
# is below the pending transactions, then finish when all pending
# transactions are committed.
if partition.safe():
self.current_partition = partition
break
node = app.nm.getByAddress(addr)
if node is None:
assert name, addr
node = app.nm.createStorage(address=addr)
self.current_partition = offset
previous_node = self.current_node
self.current_node = node
if node.isConnected():
node.getConnection().asClient()
self.fetchTransactions()
if node is previous_node:
return
else:
assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler())
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name))
if previous_node is not None and previous_node.isConnected():
previous_node.getConnection().closeClient()
def fetchTransactions(self, min_tid=None):
offset = self.current_partition
p = self.partition_dict[offset]
if min_tid:
p.next_trans = min_tid
else:
# Not yet.
neo.lib.logging.debug('not ready yet')
try:
addr, name = self.source_dict[offset]
except KeyError:
pass
else:
if addr != self.current_node.getAddress():
return self.abort()
min_tid = p.next_trans
self.replicate_tid = self.replicate_dict.pop(offset)
neo.lib.logging.debug("starting replication of <partition=%u"
" min_tid=%u max_tid=%u> from %r", offset, u64(min_tid),
u64(self.replicate_tid), self.current_node)
max_tid = self.replicate_tid
tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid,
FETCH_COUNT, offset)
self.current_node.getConnection().ask(Packets.AskFetchTransactions(
offset, FETCH_COUNT, min_tid, max_tid, tid_list))
def fetchObjects(self, min_tid=None, min_oid=ZERO_OID):
offset = self.current_partition
p = self.partition_dict[offset]
max_tid = self.replicate_tid
if min_tid:
if p.next_obj < self.next_backup_tid:
self.app.dm.setBackupTID(min_tid)
else:
min_tid = p.next_obj
p.next_trans = p.next_obj = add64(max_tid, 1)
if self.app.dm.getBackupTID() is None or \
self.app.pt.getCell(offset, self.app.uuid).isOutOfDate():
self.next_backup_tid = ZERO_TID
else:
self.next_backup_tid = self.getBackupTID()
p.next_obj = min_tid
object_dict = {}
for serial, oid in self.app.dm.getReplicationObjectList(min_tid,
max_tid, FETCH_COUNT, offset, min_oid):
try:
object_dict[serial].append(oid)
except KeyError:
object_dict[serial] = [oid]
self.current_node.getConnection().ask(Packets.AskFetchObjects(
offset, FETCH_COUNT, min_tid, max_tid, min_oid, object_dict))
def finish(self):
offset = self.current_partition
tid = self.replicate_tid
del self.current_partition, self.replicate_tid, self.next_backup_tid
p = self.partition_dict[offset]
p.next_obj = add64(tid, 1)
self.app.dm.setBackupTID(self.getBackupTID())
if not p.max_ttid:
p = Packets.NotifyReplicationDone(offset, tid)
self.app.master_conn.notify(p)
neo.lib.logging.debug("partition %u replicated up to %u from %r",
offset, u64(tid), self.current_node)
self._nextPartition()
def abort(self, message=''):
offset = self.current_partition
if offset is None:
return
self._startReplication()
def removePartition(self, offset):
"""This is a callback from MasterOperationHandler."""
self.partition_dict.pop(offset, None)
self.new_partition_set.discard(offset)
def addPartition(self, offset):
"""This is a callback from MasterOperationHandler."""
if not self.partition_dict.has_key(offset):
self.new_partition_set.add(offset)
def _addTask(self, key, func, args=(), kw=None):
task = Task(func, args, kw)
task_dict = self.task_dict
if key in task_dict:
raise ValueError, 'Task with key %r already exists (%r), cannot ' \
'add %r' % (key, task_dict[key], task)
task_dict[key] = task
self.task_list.append(task)
def processDelayedTasks(self):
task_list = self.task_list
if task_list:
for task in task_list:
task.process()
self.task_list = []
def checkTIDRange(self, min_tid, max_tid, length, partition):
self._addTask(('TID', min_tid, length),
self.app.dm.checkTIDRange, (min_tid, max_tid, length, partition))
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition):
self._addTask(('Serial', min_oid, min_serial, length),
self.app.dm.checkSerialRange, (min_oid, min_serial, max_tid, length,
partition))
def getTIDsFrom(self, min_tid, max_tid, length, partition):
self._addTask('TIDsFrom', self.app.dm.getReplicationTIDList,
(min_tid, max_tid, length, partition))
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition):
self._addTask('ObjectHistoryFrom', self.app.dm.getObjectHistoryFrom,
(min_oid, min_serial, max_serial, length, partition))
def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult()
def getTIDCheckResult(self, min_tid, length):
return self._getCheckResult(('TID', min_tid, length))
def getSerialCheckResult(self, min_oid, min_serial, length):
return self._getCheckResult(('Serial', min_oid, min_serial, length))
def getTIDsFromResult(self):
return self._getCheckResult('TIDsFrom')
def getObjectHistoryFromResult(self):
return self._getCheckResult('ObjectHistoryFrom')
del self.current_partition
neo.lib.logging.warning('replication aborted for partition %u%s',
offset, message and ' (%s)' % message)
if self.app.master_node is None:
return
if offset in self.partition_dict:
# XXX: Try another partition if possible, to increase probability to
# connect to another node. It would be better to explicitely
# search for another node instead.
tid = self.replicate_dict.pop(offset, None) or self.replicate_tid
if self.replicate_dict:
self._nextPartition()
self.replicate_dict[offset] = tid
else:
self.replicate_dict[offset] = tid
self._nextPartition()
else: # partition removed
self._nextPartition()
......@@ -131,6 +131,11 @@ class NeoTestBase(unittest.TestCase):
sys.stdout.write('\n')
sys.stdout.flush()
class failureException(AssertionError):
def __init__(self, msg=None):
neo.lib.logging.error(msg)
AssertionError.__init__(self, msg)
failIfEqual = failUnlessEqual = assertEquals = assertNotEquals = None
def assertNotEqual(self, first, second, msg=None):
......
......@@ -25,6 +25,7 @@ import signal
import random
import weakref
import MySQLdb
import sqlite3
import unittest
import tempfile
import traceback
......@@ -242,9 +243,15 @@ class NEOCluster(object):
self.cleanup_on_delete = cleanup_on_delete
self.verbose = verbose
self.uuid_set = set()
self.db_user = db_user
self.db_password = db_password
self.db_list = db_list
if adapter == 'MySQL':
self.db_user = db_user
self.db_password = db_password
self.db_template = '%s:%s@%%s' % (db_user, db_password)
elif adapter == 'SQLite':
self.db_template = os.path.join(temp_dir, '%s.sqlite')
else:
assert False, adapter
self.address_type = address_type
self.local_ip = local_ip = IP_VERSION_FORMAT_DICT[self.address_type]
self.setupDB(clear_databases)
......@@ -290,7 +297,7 @@ class NEOCluster(object):
self.local_ip),
0 ),
'--masters': self.master_nodes,
'--database': '%s:%s@%s' % (db_user, db_password, db),
'--database': self.db_template % db,
'--adapter': adapter,
})
# create neoctl
......@@ -316,6 +323,17 @@ class NEOCluster(object):
if self.adapter == 'MySQL':
setupMySQLdb(self.db_list, self.db_user, self.db_password,
clear_databases)
elif self.adapter == 'SQLite':
if clear_databases:
for db in self.db_list:
try:
os.remove(self.db_template % db)
except OSError, e:
if e.errno != errno.ENOENT:
raise
else:
neo.lib.logging.debug('%r deleted',
db_template % db)
def run(self, except_storages=()):
""" Start cluster processes except some storage nodes """
......@@ -402,11 +420,14 @@ class NEOCluster(object):
db = ZODB.DB(storage=self.getZODBStorage(**kw))
return (db, db.open())
def getSQLConnection(self, db, autocommit=False):
def getSQLConnection(self, db):
assert db in self.db_list
conn = MySQLdb.Connect(user=self.db_user, passwd=self.db_password,
db=db)
conn.autocommit(autocommit)
if self.adapter == 'MySQL':
conn = MySQLdb.Connect(user=self.db_user, passwd=self.db_password,
db=db)
conn.autocommit(True)
elif self.adapter == 'SQLite':
conn = sqlite3.connect(self.db_template % db, isolation_level=None)
return conn
def _getProcessList(self, type):
......
......@@ -234,6 +234,9 @@ class ClientTests(NEOFunctionalTest):
temp_dir=self.getTempDirectory())
neoctl = self.neo.getNEOCTL()
self.neo.start()
# BUG: The following 2 lines creates 2 app, i.e. 2 TCP connections
# to the storage, so there may be a race condition at network
# level and 'st2.store' may be effective before 'st1.store'.
db1, conn1 = self.neo.getZODBConnection()
db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage
......
......@@ -35,7 +35,7 @@ class ClusterTests(NEOFunctionalTest):
def testClusterStartup(self):
neo = NEOCluster(['test_neo1', 'test_neo2'], replicas=1,
adapter='MySQL', temp_dir=self.getTempDirectory())
temp_dir=self.getTempDirectory())
neoctl = neo.getNEOCTL()
neo.run()
# Runing a new cluster doesn't exit Recovery state.
......
......@@ -23,7 +23,7 @@ from persistent import Persistent
from . import NEOCluster, NEOFunctionalTest
from neo.lib.protocol import ClusterStates, NodeStates
from ZODB.tests.StorageTestBase import zodb_pickle
from MySQLdb import ProgrammingError
import MySQLdb, sqlite3
from MySQLdb.constants.ER import NO_SUCH_TABLE
class PObject(Persistent):
......@@ -46,9 +46,11 @@ class StorageTests(NEOFunctionalTest):
NEOFunctionalTest.tearDown(self)
def queryCount(self, db, query):
db.query(query)
result = db.store_result().fetch_row()[0][0]
return result
try:
db.query(query)
except AttributeError:
return db.execute(query).fetchone()[0]
return db.store_result().fetch_row()[0][0]
def __setup(self, storage_number=2, pending_number=0, replicas=1,
partitions=10, master_count=2):
......@@ -58,7 +60,6 @@ class StorageTests(NEOFunctionalTest):
partitions=partitions, replicas=replicas,
temp_dir=self.getTempDirectory(),
clear_databases=True,
adapter='MySQL',
)
# too many pending storage nodes requested
assert pending_number <= storage_number
......@@ -80,7 +81,7 @@ class StorageTests(NEOFunctionalTest):
db.close()
def __checkDatabase(self, db_name):
db = self.neo.getSQLConnection(db_name, autocommit=True)
db = self.neo.getSQLConnection(db_name)
# wait for the sql transaction to be commited
def callback(last_try):
object_number = self.queryCount(db, 'select count(*) from obj')
......@@ -124,13 +125,16 @@ class StorageTests(NEOFunctionalTest):
def __checkReplicateCount(self, db_name, target_count, timeout=0, delay=1):
db = self.neo.getSQLConnection(db_name, autocommit=True)
def callback(last_try):
replicate_count = 0
try:
replicate_count = self.queryCount(db,
'select count(distinct uuid) from pt')
except ProgrammingError, exc:
if exc[0] != NO_SUCH_TABLE:
except MySQLdb.ProgrammingError, e:
if e[0] != NO_SUCH_TABLE:
raise
except sqlite3.OperationalError, e:
if not e[0].startswith('no such table:'):
raise
replicate_count = 0
if last_try is not None and last_try < replicate_count:
raise AssertionError, 'Regression: %s became %s' % \
(last_try, replicate_count)
......
......@@ -85,7 +85,7 @@ class MasterRecoveryTests(NeoUnitTestBase):
self.assertTrue(ptid2 > self.app.pt.getID())
self.assertTrue(oid2 > self.app.tm.getLastOID())
self.assertTrue(tid2 > self.app.tm.getLastTID())
recovery.answerLastIDs(conn, oid2, tid2, ptid2)
recovery.answerLastIDs(conn, oid2, tid2, ptid2, None)
self.assertEqual(oid2, self.app.tm.getLastOID())
self.assertEqual(tid2, self.app.tm.getLastTID())
self.assertEqual(ptid2, recovery.target_ptid)
......
......@@ -130,10 +130,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.app.tm.setLastTID(tid)
service.askLastIDs(conn)
packet = self.checkAnswerLastIDs(conn)
loid, ltid, lptid = packet.decode()
loid, ltid, lptid, backup_tid = packet.decode()
self.assertEqual(loid, oid)
self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid)
self.assertEqual(backup_tid, None)
def test_13_askUnfinishedTransactions(self):
service = self.service
......
......@@ -3,7 +3,7 @@
import math, os, random, sys, time
from cStringIO import StringIO
from persistent.TimeStamp import TimeStamp
from ZODB.utils import p64, newTid
from ZODB.utils import p64, u64
from ZODB.BaseStorage import TransactionRecord
from ZODB.FileStorage import FileStorage
......@@ -44,6 +44,7 @@ class DummyZODB(object):
self.new_ratio = new_ratio
self.next_oid = 0
self.err_count = 0
self.tid = u64('TID\0\0\0\0\0')
def __call__(self):
variate = self.random.lognormvariate
......@@ -63,9 +64,11 @@ class DummyZODB(object):
yield p64(oid), int(round(variate(self.obj_size_mu,
self.obj_size_sigma))) or 1
def as_storage(self, transaction_count, dummy_data_file=None):
def as_storage(self, stop, dummy_data_file=None):
if dummy_data_file is None:
dummy_data_file = DummyData(self.random)
if isinstance(stop, int):
stop = (lambda x: lambda y: x <= y)(stop)
class dummy_change(object):
data_txn = None
version = ''
......@@ -97,12 +100,14 @@ class DummyZODB(object):
size = 0
def iterator(storage, *args):
args = ' ', '', '', {}
tid = None
for i in xrange(1, transaction_count+1):
tid = newTid(tid)
t = dummy_transaction(tid, *args)
i = 0
variate = self.random.lognormvariate
while not stop(i):
self.tid += max(1, int(variate(10, 3)))
t = dummy_transaction(p64(self.tid), *args)
storage.size += t.size
yield t
i += 1
def getSize(self):
return self.size
return dummy_storage()
......
......@@ -164,19 +164,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
calls[0].checkArgs(tid)
self.checkNoPacketSent(conn)
def test_31_answerUnfinishedTransactions(self):
# set unfinished TID on replicator
conn = self.getFakeConnection()
self.app.replicator = Mock()
self.operation.answerUnfinishedTransactions(
conn=conn,
max_tid=INVALID_TID,
ttid_list=(INVALID_TID, ),
)
calls = self.app.replicator.mockGetNamedCalls('setUnfinishedTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(INVALID_TID, (INVALID_TID, ))
def test_askPack(self):
self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection()
......
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from struct import pack
from collections import deque
from .. import NeoUnitTestBase
from neo.storage.database import buildDatabaseManager
from neo.storage.handlers.replication import ReplicationHandler
from neo.storage.handlers.replication import RANGE_LENGTH
from neo.storage.handlers.storage import StorageOperationHandler
from neo.storage.replicator import Replicator
from neo.lib.protocol import ZERO_OID, ZERO_TID
MAX_TRANSACTIONS = 10000
MAX_OBJECTS = 100000
MAX_TID = '\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE' # != INVALID_TID
class FakeConnection(object):
def __init__(self):
self._msg_id = 0
self._queue = deque()
def allocateId(self):
self._msg_id += 1
return self._msg_id
def _addPacket(self, packet, *args, **kw):
packet.setId(self.allocateId())
self._queue.append(packet)
ask = _addPacket
answer = _addPacket
notify = _addPacket
def setPeerId(self, msg_id):
pass
def process(self, dhandler, dconn):
if not self._queue:
return False
while self._queue:
dhandler.dispatch(dconn, self._queue.popleft())
return True
class ReplicationTests(NeoUnitTestBase):
def checkReplicationProcess(self, reference, outdated):
pt = Mock({'getPartitions': 1})
# reference application
rapp = Mock({})
rapp.pt = pt
rapp.dm = reference
rapp.tm = Mock({'loadLocked': False})
mconn = FakeConnection()
rapp.master_conn = mconn
# outdated application
oapp = Mock({})
oapp.dm = outdated
oapp.pt = pt
oapp.master_conn = mconn
oapp.replicator = Replicator(oapp)
oapp.replicator.getCurrentOffset = lambda: 0
oapp.replicator.isCurrentConnection = lambda c: True
oapp.replicator.getCurrentCriticalTID = lambda: MAX_TID
# handlers and connections
rhandler = StorageOperationHandler(rapp)
rconn = FakeConnection()
ohandler = ReplicationHandler(oapp)
oconn = FakeConnection()
# run replication
ohandler.startReplication(oconn)
process = True
while process:
process = oconn.process(rhandler, rconn)
oapp.replicator.processDelayedTasks()
process |= rconn.process(ohandler, oconn)
# check transactions
for tid in reference.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual(
reference.getTransaction(tid),
outdated.getTransaction(tid),
)
for tid in outdated.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual(
outdated.getTransaction(tid),
reference.getTransaction(tid),
)
# check transactions
params = ZERO_TID, '\xFF' * 8, MAX_TRANSACTIONS, 0
self.assertEqual(
reference.getReplicationTIDList(*params),
outdated.getReplicationTIDList(*params),
)
# check objects
params = ZERO_OID, ZERO_TID, '\xFF' * 8, MAX_OBJECTS, 0
self.assertEqual(
reference.getObjectHistoryFrom(*params),
outdated.getObjectHistoryFrom(*params),
)
def buildStorage(self, transactions, objects, name='BTree', database=None):
def makeid(oid_or_tid):
return pack('!Q', oid_or_tid)
storage = buildDatabaseManager(name, (database, 0))
storage.setup(reset=True)
storage.setNumPartitions(1)
storage._transactions = transactions
storage._objects = objects
# store transactions
for tid in transactions:
transaction = ([ZERO_OID], 'user', 'desc', '', False)
storage.storeTransaction(makeid(tid), [], transaction, False)
# store object history
H = "0" * 20
storage.storeData(H, '', 0)
storage.unlockData((H,))
for tid, oid_list in objects.iteritems():
object_list = [(makeid(oid), H, None) for oid in oid_list]
storage.storeTransaction(makeid(tid), object_list, None, False)
return storage
def testReplication0(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1], 2: [1], 3: [1]},
),
outdated=self.buildStorage(
transactions=[],
objects={},
),
)
def testReplication1(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1], 2: [1], 3: [1]},
),
outdated=self.buildStorage(
transactions=[1],
objects={1: [1]},
),
)
def testReplication2(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1, 2, 3]},
),
outdated=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1, 2, 3]},
),
)
def testChunkBeginning(self):
ref_number = range(RANGE_LENGTH + 1)
out_number = range(RANGE_LENGTH)
obj_list = [1, 2, 3]
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list),
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list),
),
)
def testChunkEnd(self):
ref_number = range(RANGE_LENGTH)
out_number = range(RANGE_LENGTH - 1)
obj_list = [1, 2, 3]
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testChunkMiddle(self):
obj_list = [1, 2, 3]
ref_number = range(RANGE_LENGTH)
out_number = range(4000)
out_number.remove(3000)
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testFullChunkPart(self):
obj_list = [1, 2, 3]
ref_number = range(1001)
out_number = {}
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testSameData(self):
obj_list = [1, 2, 3]
number = range(RANGE_LENGTH * 2)
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=number,
objects=dict.fromkeys(number, obj_list)
),
outdated=self.buildStorage(
transactions=number,
objects=dict.fromkeys(number, obj_list)
),
)
def testTooManyData(self):
obj_list = [0, 1]
ref_number = range(RANGE_LENGTH)
out_number = range(RANGE_LENGTH + 2)
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testMissingObject(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2],
objects=dict.fromkeys([1, 2], [1, 2]),
),
outdated=self.buildStorage(
transactions=[1, 2],
objects=dict.fromkeys([1], [1]),
),
)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from neo.lib.util import add64
from .. import NeoUnitTestBase
from neo.lib.protocol import Packets, ZERO_OID, ZERO_TID
from neo.storage.handlers.replication import ReplicationHandler
from neo.storage.handlers.replication import RANGE_LENGTH, MIN_RANGE_LENGTH
class FakeDict(object):
def __init__(self, items):
self._items = items
self._dict = dict(items)
assert len(self._dict) == len(items), self._dict
def iteritems(self):
for item in self._items:
yield item
def iterkeys(self):
for key, value in self.iteritems():
yield key
def itervalues(self):
for key, value in self.iteritems():
yield value
def items(self):
return self._items[:]
def keys(self):
return [x for x, y in self._items]
def values(self):
return [y for x, y in self._items]
def __getitem__(self, key):
return self._dict[key]
def __getattr__(self, key):
return getattr(self._dict, key)
def __len__(self):
return len(self._dict)
class StorageReplicationHandlerTests(NeoUnitTestBase):
def setup(self):
pass
def teardown(self):
pass
def getApp(self, conn=None, tid_check_result=(0, 0, ZERO_TID),
serial_check_result=(0, 0, ZERO_OID, 0, ZERO_TID),
tid_result=(),
history_result=None,
rid=0, critical_tid=ZERO_TID,
num_partitions=1,
):
if history_result is None:
history_result = {}
replicator = Mock({
'__repr__': 'Fake replicator',
'reset': None,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDCheckResult': tid_check_result,
'getSerialCheckResult': serial_check_result,
'getTIDsFromResult': tid_result,
'getObjectHistoryFromResult': history_result,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDsFrom': None,
'getObjectHistoryFrom': None,
'getCurrentOffset': rid,
'getCurrentCriticalTID': critical_tid,
})
def isCurrentConnection(other_conn):
return other_conn is conn
replicator.isCurrentConnection = isCurrentConnection
real_replicator = replicator
class FakeApp(object):
replicator = real_replicator
dm = Mock({
'storeTransaction': None,
'deleteObject': None,
})
pt = Mock({
'getPartitions': num_partitions,
})
return FakeApp
def _checkReplicationStarted(self, conn, rid, replicator):
min_tid, max_tid, length, partition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(min_tid, ZERO_TID)
self.assertEqual(length, RANGE_LENGTH)
self.assertEqual(partition, rid)
calls = replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, max_tid, length, partition)
def _checkPacketTIDList(self, conn, tid_list, next_tid, app):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(type(next_range), Packets.AskCheckTIDRange)
pmin_tid, plength, ppartition = next_range.decode()
self.assertEqual(pmin_tid, add64(next_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, app.replicator.getCurrentOffset())
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
self.assertEqual(len(packet_list), len(tid_list))
for packet in packet_list:
self.assertEqual(type(packet),
Packets.AskTransactionInformation)
ptid = packet.decode()[0]
for tid in tid_list:
if ptid == tid:
tid_list.remove(tid)
break
else:
raise AssertionFailed('%s not found in %r'
% (dump(ptid), map(dump, tid_list)))
def _checkPacketSerialList(self, conn, object_list, next_oid, next_serial, app):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(type(next_range), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_range.decode()
self.assertEqual(pmin_oid, next_oid)
self.assertEqual(pmin_serial, add64(next_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, app.replicator.getCurrentOffset())
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
self.assertEqual(len(packet_list), len(object_list),
([x.decode() for x in packet_list], object_list))
reference_set = set((x + (None, ) for x in object_list))
packet_set = set((x.decode() for x in packet_list))
assert len(packet_list) == len(reference_set) == len(packet_set)
self.assertEqual(reference_set, packet_set)
def test_connectionLost(self):
app = self.getApp()
ReplicationHandler(app).connectionLost(None, None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('storageLost')), 1)
def test_connectionFailed(self):
app = self.getApp()
ReplicationHandler(app).connectionFailed(None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('storageLost')), 1)
def test_acceptIdentification(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
replication = ReplicationHandler(app)
replication.acceptIdentification(conn, None, None, None,
None, None)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_startReplication(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
ReplicationHandler(app).startReplication(conn)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_answerTIDsFrom(self):
conn = self.getFakeConnection()
tid_list = [self.getOID(0), self.getOID(1), self.getOID(2)]
app = self.getApp(conn=conn, tid_result=[])
# With no known TID
ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
# With some TIDs known
conn = self.getFakeConnection()
known_tid_list = [tid_list[0], tid_list[1]]
unknown_tid_list = [tid_list[2], ]
app = self.getApp(conn=conn, tid_result=known_tid_list)
ReplicationHandler(app).answerTIDsFrom(conn, tid_list[1:])
calls = app.dm.mockGetNamedCalls('deleteTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid_list[0])
def test_answerTransactionInformation(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
tid = self.getNextTID()
user = 'foo'
desc = 'bar'
ext = 'baz'
packed = True
oid_list = [self.getOID(1), self.getOID(2)]
ReplicationHandler(app).answerTransactionInformation(conn, tid, user,
desc, ext, packed, oid_list)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, (), (oid_list, user, desc, ext, packed), False)
def test_answerObjectHistoryFrom(self):
conn = self.getFakeConnection()
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
oid_3 = self.getOID(3)
oid_4 = self.getOID(4)
oid_5 = self.getOID(5)
tid_list = map(self.getOID, xrange(7))
oid_dict = FakeDict((
(oid_1, [tid_list[0], tid_list[1]]),
(oid_2, [tid_list[2], tid_list[3]]),
(oid_4, [tid_list[5]]),
))
flat_oid_list = []
for oid, serial_list in oid_dict.iteritems():
for serial in serial_list:
flat_oid_list.append((oid, serial))
app = self.getApp(conn=conn, history_result={})
# With no known OID/Serial
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
# With some known OID/Serials
# For test to be realist, history_result should contain the same
# number of serials as oid_dict, otherise it just tests the special
# case of the last check in a partition.
conn = self.getFakeConnection()
app = self.getApp(conn=conn, history_result={
oid_1: [oid_dict[oid_1][0], ],
oid_3: [tid_list[2]],
oid_4: [tid_list[4], oid_dict[oid_4][0], tid_list[6]],
oid_5: [tid_list[6]],
})
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
calls = app.dm.mockGetNamedCalls('deleteObject')
actual_deletes = set(((x.getParam(0), x.getParam(1)) for x in calls))
expected_deletes = set((
(oid_3, tid_list[2]),
(oid_4, tid_list[4]),
))
self.assertEqual(actual_deletes, expected_deletes)
def test_answerObject(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
oid = self.getOID(1)
serial_start = self.getNextTID()
serial_end = self.getNextTID()
compression = 1
checksum = "0" * 20
data = 'foo'
data_serial = None
app.dm.mockAddReturnValues(storeData=checksum)
ReplicationHandler(app).answerObject(conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(serial_start, [(oid, checksum, data_serial)],
None, False)
# CheckTIDRange
def test_answerCheckTIDFullRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
assert max_tid < critical_tid
length = RANGE_LENGTH
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk
pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDSmallRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
assert max_tid < critical_tid
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk
pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkAboveCriticalTID(self):
critical_tid = self.getNextTID()
min_tid = self.getNextTID()
max_tid = self.getNextTID()
assert critical_tid < max_tid
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with object range checks
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
num_partitions = 13
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 1, 0, max_tid), rid=rid,
conn=conn, num_partitions=num_partitions)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: go on with object range checks
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(rid, add64(max_tid, 1), ZERO_TID)
def test_answerCheckTIDRangeDifferentBigChunk(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
assert min_tid < max_tid < critical_tid, (min_tid, max_tid,
critical_tid)
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask again, length halved
pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask tid list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
tid_packet = calls[0].getParam(0)
self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length)
self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: ask tid list, and start replicating object range
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
tid_packet = calls[0].getParam(0)
self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
# CheckSerialRange
def test_answerCheckSerialFullRangeIdenticalChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialSmallRangeIdenticalChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialRangeIdenticalChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
num_partitions = 13
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length - 1, 0, max_oid, 1,
max_serial), rid=rid, conn=conn, num_partitions=num_partitions)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: mark replication as done
self.checkNoPacketSent(conn)
self.assertTrue(app.replicator.replication_done)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), ZERO_TID)
def test_answerCheckSerialRangeDifferentBigChunk(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask again, length halved
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask serial list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
serial_packet = calls[0].getParam(0)
self.assertEqual(type(serial_packet), Packets.AskObjectHistoryFrom)
pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
serial_packet.decode()
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(pmax_serial, critical_tid)
self.assertEqual(plength, length)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength,
ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
num_partitions = 13
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid,
1, max_serial), rid=rid, conn=conn, critical_tid=critical_tid,
num_partitions=num_partitions,
)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: ask tid list, and mark replication as done
pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskObjectHistoryFrom,
decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(pmax_serial, critical_tid)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength,
ppartition)
self.assertTrue(app.replicator.replication_done)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), critical_tid)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock, ReturnValues
from .. import NeoUnitTestBase
from neo.storage.replicator import Replicator, Partition, Task
from neo.lib.protocol import CellStates, NodeStates, Packets
class StorageReplicatorTests(NeoUnitTestBase):
def setup(self):
pass
def teardown(self):
pass
def test_populate(self):
my_uuid = self.getNewUUID()
other_uuid = self.getNewUUID()
app = Mock()
app.uuid = my_uuid
app.pt = Mock({
'getPartitions': 2,
'getOutdatedOffsetListFor': [0],
})
replicator = Replicator(app)
self.assertEqual(replicator.new_partition_set, set())
replicator.populate()
self.assertEqual(replicator.new_partition_set, set([0]))
def test_reset(self):
replicator = Replicator(None)
replicator.task_list = ['foo']
replicator.task_dict = {'foo': 'bar'}
replicator.current_partition = 'foo'
replicator.current_connection = 'foo'
replicator.replication_done = 'foo'
replicator.reset()
self.assertEqual(replicator.task_list, [])
self.assertEqual(replicator.task_dict, {})
self.assertEqual(replicator.current_partition, None)
self.assertEqual(replicator.current_connection, None)
self.assertTrue(replicator.replication_done)
def test_setCriticalTID(self):
critical_tid = self.getNextTID()
partition = Partition(0, critical_tid, [])
self.assertEqual(partition.getCriticalTID(), critical_tid)
self.assertEqual(partition.getOffset(), 0)
def test_act(self):
# Also tests "pending"
uuid = self.getNewUUID()
master_uuid = self.getNewUUID()
critical_tid_0 = self.getNextTID()
critical_tid_1 = self.getNextTID()
critical_tid_2 = self.getNextTID()
unfinished_ttid_1 = self.getOID(1)
unfinished_ttid_2 = self.getOID(2)
app = Mock()
app.server = ('127.0.0.1', 10000)
app.name = 'fake cluster'
app.em = Mock({
'register': None,
})
def connectorGenerator():
return Mock()
app.connector_handler = connectorGenerator
app.uuid = uuid
node_addr = ('127.0.0.1', 1234)
node = Mock({
'getAddress': node_addr,
})
running_cell = Mock({
'getNodeState': NodeStates.RUNNING,
'getNode': node,
})
unknown_cell = Mock({
'getNodeState': NodeStates.UNKNOWN,
})
app.pt = Mock({
'getCellList': [running_cell, unknown_cell],
'getOutdatedOffsetListFor': [0],
'getPartition': 0,
})
node_conn_handler = Mock({
'startReplication': None,
})
node_conn = Mock({
'getAddress': node_addr,
'getHandler': node_conn_handler,
})
replicator = Replicator(app)
replicator.populate()
def act():
app.master_conn = self.getFakeConnection(uuid=master_uuid)
self.assertTrue(replicator.pending())
replicator.act()
# ask unfinished tids
act()
unfinished_tids = app.master_conn.mockGetNamedCalls('ask')[0].getParam(0)
self.assertTrue(replicator.new_partition_set)
self.assertEqual(type(unfinished_tids),
Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# nothing happens until waiting_for_unfinished_tids becomes False
act()
self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# first time, there is an unfinished tid before critical tid,
# replication cannot start, and unfinished TIDs are asked again
replicator.setUnfinishedTIDList(critical_tid_0,
[unfinished_ttid_1, unfinished_ttid_2])
self.assertFalse(replicator.waiting_for_unfinished_tids)
# Note: detection that nothing can be replicated happens on first call
# and unfinished tids are asked again on second call. This is ok, but
# might change, so just call twice.
act()
replicator.transactionFinished(unfinished_ttid_1, critical_tid_1)
act()
replicator.transactionFinished(unfinished_ttid_2, critical_tid_2)
replicator.current_connection = node_conn
act()
self.assertEqual(replicator.current_partition,
replicator.partition_dict[0])
self.assertEqual(len(node_conn_handler.mockGetNamedCalls(
'startReplication')), 1)
self.assertFalse(replicator.replication_done)
# Other calls should do nothing
replicator.current_connection = Mock()
act()
self.checkNoPacketSent(app.master_conn)
self.checkNoPacketSent(replicator.current_connection)
# Mark replication over for this partition
replicator.replication_done = True
# Don't finish while there are pending answers
replicator.current_connection = Mock({
'isPending': True,
})
act()
self.assertTrue(replicator.pending())
replicator.current_connection = Mock({
'isPending': False,
})
act()
# also, replication is over
self.assertFalse(replicator.pending())
def test_removePartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None, 2: None}
replicator.new_partition_set = set([1])
replicator.removePartition(0)
self.assertEqual(replicator.partition_dict, {2: None})
self.assertEqual(replicator.new_partition_set, set([1]))
replicator.removePartition(1)
replicator.removePartition(2)
self.assertEqual(replicator.partition_dict, {})
self.assertEqual(replicator.new_partition_set, set())
# Must not raise
replicator.removePartition(3)
def test_addPartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None}
replicator.new_partition_set = set([1])
replicator.addPartition(0)
replicator.addPartition(1)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(replicator.new_partition_set, set([1]))
replicator.addPartition(2)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(len(replicator.new_partition_set), 2)
self.assertEqual(replicator.new_partition_set, set([1, 2]))
def test_processDelayedTasks(self):
replicator = Replicator(None)
replicator.reset()
marker = []
def someCallable(foo, bar=None):
return (foo, bar)
replicator._addTask(1, someCallable, args=('foo', ))
self.assertRaises(ValueError, replicator._addTask, 1, None)
replicator._addTask(2, someCallable, args=('foo', ), kw={'bar': 'bar'})
replicator.processDelayedTasks()
self.assertEqual(replicator._getCheckResult(1), ('foo', None))
self.assertEqual(replicator._getCheckResult(2), ('foo', 'bar'))
# Also test Task
task = Task(someCallable, args=('foo', ))
self.assertRaises(ValueError, task.getResult)
task.process()
self.assertRaises(ValueError, task.process)
self.assertEqual(task.getResult(), ('foo', None))
if __name__ == "__main__":
unittest.main()
......@@ -18,11 +18,10 @@
import unittest
from mock import Mock
from neo.lib.util import dump, p64, u64
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, MAX_TID
from .. import NeoUnitTestBase
from neo.lib.exception import DatabaseFailure
MAX_TID = '\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE' # != INVALID_TID
class StorageDBTests(NeoUnitTestBase):
......@@ -74,7 +73,7 @@ class StorageDBTests(NeoUnitTestBase):
def checkConfigEntry(self, get_call, set_call, value):
# generic test for all configuration entries accessors
self.assertRaises(KeyError, get_call)
self.assertEqual(get_call(), None)
set_call(value)
self.assertEqual(get_call(), value)
set_call(value * 2)
......@@ -92,6 +91,29 @@ class StorageDBTests(NeoUnitTestBase):
db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1))
def test_transaction(self):
db = self.getDB()
x = []
class DB(db.__class__):
begin = lambda self: x.append('begin')
commit = lambda self: x.append('commit')
rollback = lambda self: x.append('rollback')
db.__class__ = DB
with db:
self.assertEqual(x.pop(), 'begin')
self.assertEqual(x.pop(), 'commit')
try:
with db:
self.assertEqual(x.pop(), 'begin')
with db:
self.fail()
self.fail()
except DatabaseFailure:
pass
self.assertEqual(x.pop(), 'rollback')
self.assertRaises(DatabaseFailure, db.__exit__, None, None, None)
self.assertFalse(x)
def test_getPartitionTable(self):
db = self.getDB()
ptid = self.getPTID(1)
......@@ -128,21 +150,22 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2))
def test_getLastTID(self):
def test_getLastTIDs(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2)
txn, objs = self.getTransaction([oid1, oid2])
# max TID is in obj table
self.db.storeTransaction(tid1, objs, txn, False)
self.db.storeTransaction(tid2, objs, txn, False)
self.assertEqual(self.db.getLastTID(), tid2)
# max tid is in ttrans table
self.assertEqual(self.db.getLastTIDs(), (tid2, {0: tid2}, {0: tid2}))
self.db.storeTransaction(tid3, objs, txn)
result = self.db.getLastTID()
self.assertEqual(self.db.getLastTID(), tid3)
# max tid is in tobj (serial)
tids = {0: tid2, None: tid3}
self.assertEqual(self.db.getLastTIDs(), (tid3, tids, tids))
self.db.storeTransaction(tid4, objs, None)
self.assertEqual(self.db.getLastTID(), tid4)
self.assertEqual(self.db.getLastTIDs(),
(tid4, tids, {0: tid2, None: tid4}))
self.db.finishTransaction(tid3)
self.assertEqual(self.db.getLastTIDs(),
(tid4, {0: tid3}, {0: tid3, None: tid4}))
def test_getUnfinishedTIDList(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4)
......@@ -294,7 +317,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# nothing in database
self.assertEqual(self.db.getLastTID(), None)
self.assertEqual(self.db.getLastTIDs(), (None, {}, {}))
self.assertEqual(self.db.getUnfinishedTIDList(), [])
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None)
......@@ -362,24 +385,6 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteTransactionsAbove(self):
self.setNumPartitions(2)
tid1 = self.getOID(0)
tid2 = self.getOID(1)
tid3 = self.getOID(2)
oid1 = self.getOID(1)
for tid in (tid1, tid2, tid3):
txn, objs = self.getTransaction([oid1])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.deleteTransactionsAbove(0, tid2, tid3)
# Right partition, below cutoff
self.assertNotEqual(self.db.getTransaction(tid1, True), None)
# Wrong partition, above cutoff
self.assertNotEqual(self.db.getTransaction(tid2, True), None)
# Right partition, above cutoff
self.assertEqual(self.db.getTransaction(tid3, True), None)
def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
......@@ -397,34 +402,28 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid2, tid=tid2),
(tid2, None, 1, "0" * 20, '', None))
def test_deleteObjectsAbove(self):
self.setNumPartitions(2)
tid1 = self.getOID(1)
tid2 = self.getOID(2)
tid3 = self.getOID(3)
oid1 = self.getOID(0)
oid2 = self.getOID(1)
oid3 = self.getOID(2)
for tid in (tid1, tid2, tid3):
txn, objs = self.getTransaction([oid1, oid2, oid3])
def test_deleteRange(self):
np = 4
self.setNumPartitions(np)
t1, t2, t3 = map(self.getOID, (1, 2, 3))
oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list)
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.deleteObjectsAbove(0, oid1, tid2, tid3)
# Check getObjectHistoryFrom because MySQL adapter use two tables
# that must be synchronized
self.assertEqual(self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID,
MAX_TID, 10, 0), {oid1: [tid1]})
# Right partition, below cutoff
self.assertNotEqual(self.db.getObject(oid1, tid=tid1), None)
# Right partition, above tid cutoff
self.assertFalse(self.db.getObject(oid1, tid=tid2))
self.assertFalse(self.db.getObject(oid1, tid=tid3))
# Wrong partition, above cutoff
self.assertNotEqual(self.db.getObject(oid2, tid=tid1), None)
self.assertNotEqual(self.db.getObject(oid2, tid=tid2), None)
self.assertNotEqual(self.db.getObject(oid2, tid=tid3), None)
# Right partition, above cutoff
self.assertEqual(self.db.getObject(oid3), None)
def check(offset, tid_list, *tids):
self.assertEqual(self.db.getReplicationTIDList(ZERO_TID,
MAX_TID, len(tid_list) + 1, offset), tid_list)
expected = [(t, oid_list[offset+i]) for t in tids for i in 0, np]
self.assertEqual(self.db.getReplicationObjectList(ZERO_TID,
MAX_TID, len(expected) + 1, offset, ZERO_OID), expected)
self.db._deleteRange(0, MAX_TID)
self.db._deleteRange(0, max_tid=ZERO_TID)
check(0, [], t1, t2, t3)
self.db._deleteRange(0); check(0, [])
self.db._deleteRange(1, t2); check(1, [t1], t1, t2)
self.db._deleteRange(2, max_tid=t2); check(2, [], t3)
self.db._deleteRange(3, t1, t2); check(3, [t3], t1, t3)
def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2)
......@@ -467,59 +466,6 @@ class StorageDBTests(NeoUnitTestBase):
result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None)
def test_getObjectHistoryFrom(self):
self.setNumPartitions(2)
oid1 = self.getOID(0)
oid2 = self.getOID(2)
oid3 = self.getOID(1)
tid1, tid2, tid3, tid4, tid5 = self.getTIDs(5)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
txn3, objs3 = self.getTransaction([oid1])
txn4, objs4 = self.getTransaction([oid2])
txn5, objs5 = self.getTransaction([oid3])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.storeTransaction(tid3, objs3, txn3)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.storeTransaction(tid5, objs5, txn5)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
self.db.finishTransaction(tid5)
# Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, MAX_TID, 10, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Upper bound is inclusive
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, tid3, 10, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 3, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
1)
self.assertEqual(result, {
oid3: [tid5],
})
def _storeTransactions(self, count):
# use OID generator to know result of tid % N
tid_list = self.getOIDs(count)
......
#
# Copyright (C) 2009-2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from collections import deque
from .. import NeoUnitTestBase
from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler
from neo.lib.protocol import INVALID_PARTITION, Packets
from neo.lib.protocol import INVALID_TID, INVALID_OID
class StorageStorageHandlerTests(NeoUnitTestBase):
def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs):
conn = self.getFakeConnection(address=("127.0.0.1", self.master_port),
is_server=_listening)
# hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.app.transaction_dict = {}
self.app.store_lock_dict = {}
self.app.load_lock_dict = {}
self.app.event_queue = deque()
self.app.event_queue_dict = {}
# handler
self.operation = StorageOperationHandler(self.app)
# set pmn
self.master_uuid = self.getNewUUID()
pmn = self.app.nm.getMasterList()[0]
pmn.setUUID(self.master_uuid)
self.app.primary_master_node = pmn
self.master_port = 10010
def test_18_askTransactionInformation1(self):
# transaction does not exists
conn = self.getFakeConnection()
self.app.dm = Mock({'getNumPartitions': 1})
self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self):
# answer
conn = self.getFakeConnection()
tid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)]
dm = Mock({"getTransaction": (oid_list, 'user', 'desc', '', False), })
self.app.dm = dm
self.operation.askTransactionInformation(conn, tid)
self.checkAnswerTransactionInformation(conn)
def test_24_askObject1(self):
# delayed response
conn = self.getFakeConnection()
oid = self.getOID(1)
tid = self.getNextTID()
serial = self.getNextTID()
self.app.dm = Mock()
self.app.tm = Mock({'loadLocked': True})
self.app.load_lock_dict[oid] = object()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
self.assertEqual(len(self.app.event_queue), 1)
self.checkNoPacketSent(conn)
self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
def test_24_askObject2(self):
# invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None})
conn = self.getFakeConnection()
oid = self.getOID(1)
tid = self.getNextTID()
serial = self.getNextTID()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEqual(len(self.app.event_queue), 0)
self.assertEqual(len(calls), 1)
calls[0].checkArgs(oid, serial, tid)
self.checkErrorPacket(conn)
def test_24_askObject3(self):
oid = self.getOID(1)
tid = self.getNextTID()
serial = self.getNextTID()
next_serial = self.getNextTID()
H = "0" * 20
# object found => answer
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self.getFakeConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
self.assertEqual(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
def test_25_askTIDsFrom(self):
# well case => answer
conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getPartitions': 1})
tid = self.getNextTID()
tid2 = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, tid2, 2, [1])
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, tid2, 2, 1)
self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self):
min_oid = self.getOID(2)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = 4
partition = 8
num_partitions = 16
tid = self.getNextTID()
conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},})
self.app.pt = Mock({
'getPartitions': num_partitions,
})
self.operation.askObjectHistoryFrom(conn, min_oid, min_serial,
max_serial, length, partition)
self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
def test_askCheckTIDRange(self):
count = 1
tid_checksum = "1" * 20
min_tid = self.getNextTID()
num_partitions = 4
length = 5
partition = 6
max_tid = self.getNextTID()
self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
self.operation.askCheckTIDRange(conn, min_tid, max_tid, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, max_tid, length, partition)
pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
decode=True)
self.assertEqual(min_tid, pmin_tid)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(tid_checksum, ptid_checksum)
self.assertEqual(max_tid, pmax_tid)
def test_askCheckSerialRange(self):
count = 1
oid_checksum = "2" * 20
min_oid = self.getOID(1)
num_partitions = 4
length = 5
partition = 6
serial_checksum = "3" * 20
min_serial = self.getNextTID()
max_serial = self.getNextTID()
max_oid = self.getOID(2)
self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid,
serial_checksum, max_serial)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
self.operation.askCheckSerialRange(conn, min_oid, min_serial,
max_serial, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
Packets.AnswerCheckSerialRange, decode=True)
self.assertEqual(min_oid, pmin_oid)
self.assertEqual(min_serial, pmin_serial)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(oid_checksum, poid_checksum)
self.assertEqual(max_oid, pmax_oid)
self.assertEqual(serial_checksum, pserial_checksum)
self.assertEqual(max_serial, pmax_serial)
if __name__ == "__main__":
unittest.main()
......@@ -35,11 +35,6 @@ class StorageMySQSLdbTests(StorageDBTests):
db.setup(reset)
return db
def checkCalledQuery(self, query=None, call=0):
self.assertTrue(len(self.db.conn.mockGetNamedCalls('query')) > call)
call = self.db.conn.mockGetNamedCalls('query')[call]
call.checkArgs('BEGIN')
def test_MySQLDatabaseManagerInit(self):
db = MySQLDatabaseManager('%s@%s' % (NEO_SQL_USER, NEO_SQL_DATABASE),
0)
......@@ -48,30 +43,6 @@ class StorageMySQSLdbTests(StorageDBTests):
self.assertEqual(db.user, NEO_SQL_USER)
# & connect
self.assertTrue(isinstance(db.conn, MySQLdb.connection))
self.assertFalse(db.isUnderTransaction())
def test_begin(self):
# no current transaction
self.db.conn = Mock({ })
self.assertFalse(self.db.isUnderTransaction())
self.db.begin()
self.checkCalledQuery(query='COMMIT')
self.assertTrue(self.db.isUnderTransaction())
def test_commit(self):
self.db.conn = Mock()
self.db.begin()
self.db.commit()
self.assertEqual(len(self.db.conn.mockGetNamedCalls('commit')), 1)
self.assertFalse(self.db.isUnderTransaction())
def test_rollback(self):
# rollback called and no current transaction
self.db.conn = Mock({ })
self.db.under_transaction = True
self.db.rollback()
self.assertEqual(len(self.db.conn.mockGetNamedCalls('rollback')), 1)
self.assertFalse(self.db.isUnderTransaction())
def test_query1(self):
# fake result object
......
......@@ -16,15 +16,13 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from .testStorageDBTests import StorageDBTests
from neo.storage.database.btree import BTreeDatabaseManager
from neo.storage.database.sqlite import SQLiteDatabaseManager
class StorageBTreeTests(StorageDBTests):
class StorageSQLiteTests(StorageDBTests):
def getDB(self, reset=0):
# db manager
db = BTreeDatabaseManager('', 0)
db = SQLiteDatabaseManager(':memory:', 0)
db.setup(reset)
return db
......
......@@ -68,40 +68,6 @@ class StorageVerificationHandlerTests(NeoUnitTestBase):
# nothing happens
self.checkNoPacketSent(conn)
def test_07_askLastIDs(self):
conn = self.getClientConnection()
last_ptid = self.getPTID(1)
last_oid = self.getOID(2)
self.app.pt = Mock({'getID': last_ptid})
class DummyDM(object):
def getLastOID(self):
raise KeyError
getLastTID = getLastOID
self.app.dm = DummyDM()
self.verification.askLastIDs(conn)
oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
self.assertEqual(oid, None)
self.assertEqual(tid, None)
self.assertEqual(ptid, last_ptid)
# return value stored in db
conn = self.getClientConnection()
self.app.dm = Mock({
'getLastOID': last_oid,
'getLastTID': p64(4),
})
self.verification.askLastIDs(conn)
oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
self.assertEqual(oid, last_oid)
self.assertEqual(u64(tid), 4)
self.assertEqual(ptid, self.app.pt.getID())
call_list = self.app.dm.mockGetNamedCalls('getLastOID')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
call_list = self.app.dm.mockGetNamedCalls('getLastTID')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
def test_08_askPartitionTable(self):
node = self.app.nm.createStorage(
address=("127.7.9.9", 1),
......
......@@ -160,16 +160,6 @@ class ProtocolTests(NeoUnitTestBase):
p = Packets.AskLastIDs()
self.assertEqual(p.decode(), ())
def test_19_answerLastIDs(self):
oid = self.getNextTID()
tid = self.getNextTID()
ptid = self.getPTID()
p = Packets.AnswerLastIDs(oid, tid, ptid)
loid, ltid, lptid = p.decode()
self.assertEqual(loid, oid)
self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid)
def test_20_askPartitionTable(self):
self.assertEqual(Packets.AskPartitionTable().decode(), ())
......@@ -638,40 +628,16 @@ class ProtocolTests(NeoUnitTestBase):
def test_AskTIDsFrom(self):
tid = self.getNextTID()
tid2 = self.getNextTID()
p = Packets.AskTIDsFrom(tid, tid2, 1000, [5])
p = Packets.AskTIDsFrom(tid, tid2, 1000, 5)
min_tid, max_tid, length, partition = p.decode()
self.assertEqual(min_tid, tid)
self.assertEqual(max_tid, tid2)
self.assertEqual(length, 1000)
self.assertEqual(partition, [5])
self.assertEqual(partition, 5)
def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
def test_AskObjectHistoryFrom(self):
oid = self.getOID(1)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = 5
partition = 4
p = Packets.AskObjectHistoryFrom(oid, min_serial, max_serial, length,
partition)
p_oid, p_min_serial, p_max_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_max_serial, max_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerObjectHistoryFrom(self):
object_dict = {}
for int_oid in xrange(4):
object_dict[self.getOID(int_oid)] = [self.getNextTID() \
for _ in xrange(5)]
p = Packets.AnswerObjectHistoryFrom(object_dict)
p_object_dict = p.decode()[0]
self.assertEqual(object_dict, p_object_dict)
def test_AskCheckTIDRange(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
......
......@@ -32,8 +32,9 @@ from neo.lib.connection import BaseConnection, Connection
from neo.lib.connector import SocketConnector, \
ConnectorConnectionRefusedException, ConnectorTryAgainException
from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes, \
UUID_NAMESPACES, INVALID_UUID
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64
from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
......@@ -293,38 +294,18 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
pass
def switchTables(self):
adapter = self._init_args['getAdapter']
dm = self.dm
if adapter == 'BTree':
dm._obj, dm._tobj = dm._tobj, dm._obj
dm._trans, dm._ttrans = dm._ttrans, dm._trans
uncommitted_data = dm._uncommitted_data
for checksum, (_, _, index) in dm._data.iteritems():
uncommitted_data[checksum] = len(index)
index.clear()
elif adapter == 'MySQL':
q = dm.query
dm.begin()
with self.dm as q:
for table in ('trans', 'obj'):
q('RENAME TABLE %s to tmp' % table)
q('RENAME TABLE t%s to %s' % (table, table))
q('RENAME TABLE tmp to t%s' % table)
dm.commit()
else:
assert False
q('ALTER TABLE %s RENAME TO tmp' % table)
q('ALTER TABLE t%s RENAME TO %s' % (table, table))
q('ALTER TABLE tmp RENAME TO t%s' % table)
def getDataLockInfo(self):
adapter = self._init_args['getAdapter']
dm = self.dm
if adapter == 'BTree':
checksum_dict = dict((x, x) for x in dm._data)
elif adapter == 'MySQL':
checksum_dict = dict(dm.query("SELECT id, hash FROM data"))
else:
assert False
checksum_dict = dict(dm.query("SELECT id, hash FROM data"))
assert set(dm._uncommitted_data).issubset(checksum_dict)
get = dm._uncommitted_data.get
return dict((v, get(k, 0)) for k, v in checksum_dict.iteritems())
return dict((str(v), get(k, 0)) for k, v in checksum_dict.iteritems())
class ClientApplication(Node, neo.client.app.Application):
......@@ -406,13 +387,15 @@ class Patch(object):
class ConnectionFilter(object):
filtered_count = 0
def __init__(self, *conns):
self.filter_dict = {}
self.lock = threading.Lock()
self.conn_list = [(conn, self._patch(conn)) for conn in conns]
def _patch(self, conn):
assert '_addPacket' not in conn.__dict__
assert '_addPacket' not in conn.__dict__, "already patched"
lock = self.lock
filter_dict = self.filter_dict
orig = conn.__class__._addPacket
......@@ -423,6 +406,7 @@ class ConnectionFilter(object):
if not queue:
for filter in filter_dict:
if filter(conn, packet):
self.filtered_count += 1
break
else:
return orig(conn, packet)
......@@ -551,8 +535,8 @@ class NEOCluster(object):
SocketConnector.send = cls.SocketConnector_send
Storage.setupLog = setupLog
def __init__(self, master_count=1, partitions=1, replicas=0,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'),
def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
storage_count=None, db_list=None, clear_databases=True,
db_user=DB_USER, db_password='', verbose=None):
if verbose is not None:
......@@ -570,6 +554,10 @@ class NEOCluster(object):
weak_self = weakref.proxy(self)
kw = dict(cluster=weak_self, getReplicas=replicas, getAdapter=adapter,
getPartitions=partitions, getReset=clear_databases)
if upstream is not None:
self.upstream = weakref.proxy(upstream)
kw.update(getUpstreamCluster=upstream.name,
getUpstreamMasters=parseMasterList(upstream.master_nodes))
self.master_list = [MasterApplication(address=x, **kw)
for x in master_list]
if db_list is None:
......@@ -581,8 +569,8 @@ class NEOCluster(object):
if adapter == 'MySQL':
setupMySQLdb(db_list, db_user, db_password, clear_databases)
db = '%s:%s@%%s' % (db_user, db_password)
elif adapter == 'BTree':
db = '%s'
elif adapter == 'SQLite':
db = os.path.join(getTempDirectory(), '%s.sqlite')
else:
assert False, adapter
self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
......@@ -607,6 +595,11 @@ class NEOCluster(object):
return admin
###
@property
def primary_master(self):
master, = [master for master in self.master_list if master.primary]
return master
def reset(self, clear_database=False):
for node_type in 'master', 'storage', 'admin':
kw = {}
......@@ -635,7 +628,7 @@ class NEOCluster(object):
self._startCluster()
self.tic()
state = self.neoctl.getClusterState()
assert state == ClusterStates.RUNNING, state
assert state in (ClusterStates.RUNNING, ClusterStates.BACKINGUP), state
self.enableStorageList(storage_list)
def _startCluster(self):
......@@ -644,6 +637,7 @@ class NEOCluster(object):
except RuntimeError:
self.tic()
if self.neoctl.getClusterState() not in (
ClusterStates.BACKINGUP,
ClusterStates.RUNNING,
ClusterStates.VERIFYING,
):
......@@ -704,7 +698,7 @@ class NEOCluster(object):
self.client.setPoll(True)
return Storage.Storage(None, self.name, _app=self.client, **kw)
def populate(self, dummy_zodb=None, random=random):
def importZODB(self, dummy_zodb=None, random=random):
if dummy_zodb is None:
from ..stat_zodb import PROD1
dummy_zodb = PROD1(random)
......@@ -713,6 +707,20 @@ class NEOCluster(object):
return lambda count: self.getZODBStorage().importFrom(
as_storage(count), preindex=preindex)
def populate(self, transaction_list, tid=lambda i: p64(i+1),
oid=lambda i: p64(i+1)):
storage = self.getZODBStorage()
tid_dict = {}
for i, oid_list in enumerate(transaction_list):
txn = transaction.Transaction()
storage.tpc_begin(txn, tid(i))
for o in oid_list:
storage.store(p64(o), tid_dict.get(o), repr((i, o)), '', txn)
storage.tpc_vote(txn)
i = storage.tpc_finish(txn)
for o in oid_list:
tid_dict[o] = i
def getTransaction(self):
txn = transaction.TransactionManager()
return txn, self.db.open(transaction_manager=txn)
......@@ -774,3 +782,28 @@ class NEOThreadedTest(NeoTestBase):
etype, value, tb = self.__exc_info
del self.__exc_info
raise etype, value, tb
def predictable_random(seed=None):
# Because we have 2 running threads when client works, we can't
# patch neo.client.pool (and cluster should have 1 storage).
from neo.master import backup_app
from neo.storage import replicator
def decorator(wrapped):
def wrapper(*args, **kw):
s = repr(time.time()) if seed is None else seed
neo.lib.logging.info("using seed %r", s)
r = random.Random(s)
try:
MasterApplication.getNewUUID = lambda self, node_type: (
super(MasterApplication, self).getNewUUID(node_type)
if node_type == NodeTypes.CLIENT else
UUID_NAMESPACES[node_type] + ''.join(
chr(r.randrange(256)) for _ in xrange(15)))
backup_app.random = replicator.random = r
return wrapped(*args, **kw)
finally:
del MasterApplication.getNewUUID
backup_app.random = replicator.random = random
return wraps(wrapped)(wrapper)
return decorator
......@@ -26,8 +26,7 @@ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError
from neo.lib.connection import MTClientConnection
from neo.lib.protocol import NodeStates, Packets, ZERO_TID
from . import NEOCluster, NEOThreadedTest, \
Patch, ConnectionFilter
from . import NEOCluster, NEOThreadedTest, Patch
from neo.lib.util import makeChecksum
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
......
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import random
import sys
import time
import threading
import transaction
import unittest
import neo.lib
from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError
from neo.lib.connection import MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID
from . import NEOCluster, NEOThreadedTest, Patch, predictable_random
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
class ReplicationTests(NEOThreadedTest):
def checksumPartition(self, storage, partition):
dm = storage.dm
args = ZERO_TID, MAX_TID, None, partition
return dm.checkTIDRange(*args), dm.checkSerialRange(ZERO_TID, *args)
def checkPartitionReplicated(self, source, destination, partition):
self.assertEqual(self.checksumPartition(source, partition),
self.checksumPartition(destination, partition))
def checkBackup(self, cluster):
upstream_pt = cluster.upstream.primary_master.pt
pt = cluster.primary_master.pt
np = pt.getPartitions()
self.assertEqual(np, upstream_pt.getPartitions())
checked = 0
source_dict = dict((x.uuid, x) for x in cluster.upstream.storage_list)
for storage in cluster.storage_list:
self.assertEqual(np, storage.pt.getPartitions())
for partition in pt.getAssignedPartitionList(storage.uuid):
cell_list = upstream_pt.getCellList(partition, readable=True)
source = source_dict[random.choice(cell_list).getUUID()]
self.checkPartitionReplicated(source, storage, partition)
checked += 1
return checked
def testBackupNormalCase(self):
upstream = NEOCluster(partitions=7, replicas=1, storage_count=3)
try:
upstream.start()
importZODB = upstream.importZODB()
importZODB(3)
upstream.client.setPoll(0)
backup = NEOCluster(partitions=7, replicas=1, storage_count=5,
upstream=upstream)
try:
backup.start()
# Initialize & catch up.
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
# Normal case, following upstream cluster closely.
importZODB(17)
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
# Check that a backup cluster can be restarted.
finally:
backup.stop()
backup.reset()
try:
backup.start()
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.BACKINGUP)
importZODB(17)
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
# Stop backing up, nothing truncated.
backup.neoctl.setClusterState(ClusterStates.STOPPING_BACKUP)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally:
backup.stop()
finally:
upstream.stop()
@predictable_random()
def testBackupNodeLost(self):
"""Check backup cluster can recover after random connection loss
- backup master disconnected from upstream master
- primary storage disconnected from backup master
- non-primary storage disconnected from backup master
"""
from neo.master.backup_app import random
def fetchObjects(orig, min_tid=None, min_oid=ZERO_OID):
if min_tid is None:
counts[0] += 1
if counts[0] > 1:
orig.im_self.app.master_conn.close()
return orig(min_tid, min_oid)
def onTransactionCommitted(orig, txn):
counts[0] += 1
if counts[0] > 1:
node_list = orig.im_self.nm.getClientList(only_identified=True)
node_list.remove(txn.getNode())
node_list[0].getConnection().close()
return orig(txn)
upstream = NEOCluster(partitions=4, replicas=0, storage_count=1)
try:
upstream.start()
importZODB = upstream.importZODB(random=random)
backup = NEOCluster(partitions=4, replicas=2, storage_count=4,
upstream=upstream)
try:
backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
backup.tic()
storage_list = [x.uuid for x in backup.storage_list]
slave = set(xrange(len(storage_list))).difference
for event in xrange(10):
counts = [0]
if event == 5:
p = Patch(upstream.master.tm,
_on_commit=onTransactionCommitted)
else:
primary_dict = {}
for k, v in sorted(backup.master.backup_app
.primary_partition_dict.iteritems()):
primary_dict.setdefault(storage_list.index(v._uuid),
[]).append(k)
if event % 2:
storage = slave(primary_dict).pop()
else:
storage, partition_list = primary_dict.popitem()
# Populate until the found storage performs
# a second replication partially and aborts.
p = Patch(backup.storage_list[storage].replicator,
fetchObjects=fetchObjects)
try:
importZODB(lambda x: counts[0] > 1)
finally:
del p
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(12, self.checkBackup(backup))
finally:
backup.stop()
finally:
upstream.stop()
def testReplicationAbortedBySource(self):
"""
Check that a feeding node aborts replication when its partition is
dropped, and that the out-of-date node finishes to replicate from
another source.
Here are the different states of partitions over time:
pt: 0: U|U|U
pt: 0: UO|UO|UO
pt: 0: FOO|UO.|U.O # node 1 replicates from node 0
pt: 0: .OU|UO.|U.O # here node 0 lost partition 0
# and node 1 must switch to node 2
pt: 0: .OU|UO.|U.U
pt: 0: .OU|UU.|U.U
pt: 0: .UU|UU.|U.U
"""
def connected(orig, *args, **kw):
patch[0] = s1.filterConnection(s0)
patch[0].add(delayAskFetch,
Patch(s0.dm, changePartitionTable=changePartitionTable))
return orig(*args, **kw)
def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and packet.decode()[0] == offset
def changePartitionTable(orig, ptid, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
patch[0].remove(delayAskFetch)
# XXX: this is currently not done by
# default for performance reason
orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list)
cluster = NEOCluster(partitions=3, replicas=1, storage_count=3)
s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
try:
cluster.start([s0])
cluster.populate([range(6)] * 3)
cluster.client.setPoll(0)
s1.start()
s2.start()
cluster.tic()
cluster.neoctl.enableStorageList([s1.uuid, s2.uuid])
offset, = [offset for offset, row in enumerate(
cluster.master.pt.partition_list)
for cell in row if cell.isFeeding()]
patch = [Patch(s1.replicator, fetchTransactions=connected)]
try:
cluster.tic()
self.assertEqual(1, patch[0].filtered_count)
patch[0]()
finally:
del patch[:]
cluster.tic()
self.checkPartitionReplicated(s1, s2, offset)
finally:
cluster.stop()
cluster.reset(True)
if __name__ == "__main__":
unittest.main()
......@@ -29,7 +29,7 @@ extras_require = {
'client': ['ZODB3'], # ZODB3 >= 3.10
'ctl': [],
'master': [],
'storage-btree': ['ZODB3'],
'storage-sqlite': [],
'storage-mysqldb': ['MySQL-python'],
}
extras_require['tests'] = ['zope.testing', 'psutil',
......
......@@ -78,21 +78,22 @@ def main():
if subprocess.call((os.path.join(bin, 'buildout'), '-v'),
cwd=test_home):
continue
title = '[%s:%s-g%s:%s]' % (branch,
git('rev-list', '--topo-order', '--count', revision),
revision[:7], os.path.basename(test_home))
if tests:
subprocess.call([os.path.join(bin, 'neotestrunner'),
'-' + tests, '--title',
'NEO tests ' + title,
] + sys.argv[1:arg_count])
if 'm' in tasks:
subprocess.call([os.path.join(bin, 'python'),
'tools/matrix', '--repeat=2',
'--min-storages=1', '--max-storages=24',
'--min-replicas=0', '--max-replicas=3',
'--title', 'Matrix ' + title,
] + sys.argv[1:arg_count])
for backend in 'SQLite', 'MySQL':
os.environ['NEO_TESTS_ADAPTER'] = backend
title = '[%s:%s-g%s:%s:%s]' % (branch,
git('rev-list', '--topo-order', '--count', revision),
revision[:7], os.path.basename(test_home), backend)
if tests:
subprocess.call([os.path.join(bin, 'neotestrunner'),
'-' + tests, '--title', 'NEO tests ' + title,
] + sys.argv[1:arg_count])
if 'm' in tasks:
subprocess.call([os.path.join(bin, 'python'),
'tools/matrix', '--repeat=2',
'--min-storages=1', '--max-storages=24',
'--min-replicas=0', '--max-replicas=3',
'--title', 'Matrix ' + title,
] + sys.argv[1:arg_count])
finally:
s.close()
clean()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment