Commit ca2caf87 authored by Julien Muchembled's avatar Julien Muchembled

Bump protocol version and upgrade storages automatically

parents cff279af d3c8b76d
...@@ -58,8 +58,6 @@ ...@@ -58,8 +58,6 @@
committed by future transactions. committed by future transactions.
- Add a 'devid' storage configuration so that master do not distribute - Add a 'devid' storage configuration so that master do not distribute
replicated partitions on storages with same 'devid'. replicated partitions on storages with same 'devid'.
- Make tpc_finish safer as described in its __doc__: moving work to
tpc_vote and recover from master failure when possible.
Storage Storage
- Use libmysqld instead of a stand-alone MySQL server. - Use libmysqld instead of a stand-alone MySQL server.
...@@ -143,9 +141,7 @@ ...@@ -143,9 +141,7 @@
Admin Admin
- Make admin node able to monitor multiple clusters simultaneously - Make admin node able to monitor multiple clusters simultaneously
- Send notifications (ie: mail) when a storage or master node is lost - Send notifications (ie: mail) when a storage or master node is lost
- Add ctl command to truncate DB at arbitrary TID. 'Truncate' message - Add ctl command to list last transactions, like fstail for FileStorage.
can be reused. There should also be a way to list last transactions,
like fstail for FileStorage.
Tests Tests
- Use another mock library: Python 3.3+ has unittest.mock, which is - Use another mock library: Python 3.3+ has unittest.mock, which is
......
...@@ -65,10 +65,12 @@ class AdminEventHandler(EventHandler): ...@@ -65,10 +65,12 @@ class AdminEventHandler(EventHandler):
askLastIDs = forward_ask(Packets.AskLastIDs) askLastIDs = forward_ask(Packets.AskLastIDs)
askLastTransaction = forward_ask(Packets.AskLastTransaction) askLastTransaction = forward_ask(Packets.AskLastTransaction)
addPendingNodes = forward_ask(Packets.AddPendingNodes) addPendingNodes = forward_ask(Packets.AddPendingNodes)
askRecovery = forward_ask(Packets.AskRecovery)
tweakPartitionTable = forward_ask(Packets.TweakPartitionTable) tweakPartitionTable = forward_ask(Packets.TweakPartitionTable)
setClusterState = forward_ask(Packets.SetClusterState) setClusterState = forward_ask(Packets.SetClusterState)
setNodeState = forward_ask(Packets.SetNodeState) setNodeState = forward_ask(Packets.SetNodeState)
checkReplicas = forward_ask(Packets.CheckReplicas) checkReplicas = forward_ask(Packets.CheckReplicas)
truncate = forward_ask(Packets.Truncate)
class MasterEventHandler(EventHandler): class MasterEventHandler(EventHandler):
......
...@@ -612,18 +612,29 @@ class Application(ThreadedApplication): ...@@ -612,18 +612,29 @@ class Application(ThreadedApplication):
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension), str(transaction.description), dumps(transaction._extension),
txn_context['cache_dict']) txn_context['cache_dict'])
add_involved_nodes = txn_context['involved_nodes'].add queue = txn_context['queue']
trans_nodes = []
for node, conn in self.cp.iterateForObject(ttid): for node, conn in self.cp.iterateForObject(ttid):
logging.debug("voting transaction %s on %s", dump(ttid), logging.debug("voting transaction %s on %s", dump(ttid),
dump(conn.getUUID())) dump(conn.getUUID()))
try: try:
self._askStorage(conn, packet) conn.ask(packet, queue=queue)
except ConnectionClosed: except ConnectionClosed:
continue continue
add_involved_nodes(node) trans_nodes.append(node)
# check at least one storage node accepted # check at least one storage node accepted
if txn_context['involved_nodes']: if trans_nodes:
involved_nodes = txn_context['involved_nodes']
packet = Packets.AskVoteTransaction(ttid)
for node in involved_nodes.difference(trans_nodes):
conn = self.cp.getConnForNode(node)
if conn is not None:
try:
conn.ask(packet, queue=queue)
except ConnectionClosed:
pass
involved_nodes.update(trans_nodes)
self.waitResponses(queue)
txn_context['voted'] = None txn_context['voted'] = None
# We must not go further if connection to master was lost since # We must not go further if connection to master was lost since
# tpc_begin, to lower the probability of failing during tpc_finish. # tpc_begin, to lower the probability of failing during tpc_finish.
...@@ -667,27 +678,14 @@ class Application(ThreadedApplication): ...@@ -667,27 +678,14 @@ class Application(ThreadedApplication):
fail in tpc_finish. In particular, making a transaction permanent fail in tpc_finish. In particular, making a transaction permanent
should ideally be as simple as switching a bit permanently. should ideally be as simple as switching a bit permanently.
In NEO, tpc_finish breaks this promise by not ensuring earlier that all In NEO, all the data (with the exception of the tid, simply because
data and metadata are written, and it is for example vulnerable to it is not known yet) is already flushed on disk at the end on the vote.
ENOSPC errors. In other words, some work should be moved to tpc_vote. During tpc_finish, all nodes storing the transaction metadata are asked
to commit by saving the new tid and flushing again: for SQL backends,
TODO: - In tpc_vote, all involved storage nodes must be asked to write it's just an UPDATE of 1 cell. At last, the metadata is moved to
all metadata to ttrans/tobj and _commit_. AskStoreTransaction a final place so that the new transaction is readable, but this is
can be extended for this: for nodes that don't store anything something that can always be replayed (during the verification phase)
in ttrans, it can just contain the ttid. The final tid is not if any failure happens.
known yet, so ttrans/tobj would contain the ttid.
- In tpc_finish, AskLockInformation is still required for read
locking, ttrans.tid must be updated with the final value and
ttrans _committed_.
- The Verification phase would need some change because
ttrans/tobj may contain data for which tpc_finish was not
called. The ttid is also in trans so a mapping ttid<->tid is
always possible and can be forwarded via the master so that all
storage are still able to update the tid column with the final
value when moving rows from tobj to obj.
The resulting cost is:
- additional RPCs in tpc_vote
- 1 updated row in ttrans + commit
TODO: We should recover from master failures when the transaction got TODO: We should recover from master failures when the transaction got
successfully committed. More precisely, we should not raise: successfully committed. More precisely, we should not raise:
......
...@@ -102,11 +102,17 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -102,11 +102,17 @@ class PrimaryNotificationsHandler(MTEventHandler):
if app.master_conn is None: if app.master_conn is None:
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
oid_list = app._cache.clear_current()
db = app.getDB() db = app.getDB()
if db is not None: if app.last_tid < ltid:
db.invalidate(app.last_tid and oid_list = app._cache.clear_current()
add64(app.last_tid, 1), oid_list) db is None or db.invalidate(
app.last_tid and add64(app.last_tid, 1),
oid_list)
else:
# The DB was truncated. It happens so
# rarely that we don't need to optimize.
app._cache.clear()
db is None or db.invalidateCache()
finally: finally:
app._cache_lock_release() app._cache_lock_release()
app.last_tid = ltid app.last_tid = ltid
......
...@@ -112,9 +112,11 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -112,9 +112,11 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, _): def answerStoreTransaction(self, conn):
pass pass
answerVoteTransaction = answerStoreTransaction
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
logging.debug('Get %u TIDs from %r', len(tid_list), conn) logging.debug('Get %u TIDs from %r', len(tid_list), conn)
self.app.setHandlerData(tid_list) self.app.setHandlerData(tid_list)
......
...@@ -41,9 +41,6 @@ class BootstrapManager(EventHandler): ...@@ -41,9 +41,6 @@ class BootstrapManager(EventHandler):
self.num_partitions = None self.num_partitions = None
self.current = None self.current = None
def notifyNodeInformation(self, conn, node_list):
pass
def announcePrimary(self, conn): def announcePrimary(self, conn):
# We found the primary master early enough to be notified of election # We found the primary master early enough to be notified of election
# end. Lucky. Anyway, we must carry on with identification request, so # end. Lucky. Anyway, we must carry on with identification request, so
......
...@@ -23,7 +23,7 @@ class ElectionFailure(NeoException): ...@@ -23,7 +23,7 @@ class ElectionFailure(NeoException):
class PrimaryFailure(NeoException): class PrimaryFailure(NeoException):
pass pass
class OperationFailure(NeoException): class StoppedOperation(NeoException):
pass pass
class DatabaseFailure(NeoException): class DatabaseFailure(NeoException):
......
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
PROTOCOL_VERSION = 4 PROTOCOL_VERSION = 5
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -722,16 +722,24 @@ class ReelectPrimary(Packet): ...@@ -722,16 +722,24 @@ class ReelectPrimary(Packet):
Force a re-election of a primary master node. M -> M. Force a re-election of a primary master node. M -> M.
""" """
class Recovery(Packet):
"""
Ask all data needed by master to recover. PM -> S, S -> PM.
"""
_answer = PStruct('answer_recovery',
PPTID('ptid'),
PTID('backup_tid'),
PTID('truncate_tid'),
)
class LastIDs(Packet): class LastIDs(Packet):
""" """
Ask the last OID, the last TID and the last Partition Table ID so that Ask the last OID/TID so that a master can initialize its TransactionManager.
a master recover. PM -> S, S -> PM. PM -> S, S -> PM.
""" """
_answer = PStruct('answer_last_ids', _answer = PStruct('answer_last_ids',
POID('last_oid'), POID('last_oid'),
PTID('last_tid'), PTID('last_tid'),
PPTID('last_ptid'),
PTID('backup_tid'),
) )
class PartitionTable(Packet): class PartitionTable(Packet):
...@@ -775,6 +783,8 @@ class StartOperation(Packet): ...@@ -775,6 +783,8 @@ class StartOperation(Packet):
this message, it must not serve client nodes. PM -> S. this message, it must not serve client nodes. PM -> S.
""" """
_fmt = PStruct('start_operation', _fmt = PStruct('start_operation',
# XXX: Is this boolean needed ? Maybe this
# can be deduced from cluster state.
PBoolean('backup'), PBoolean('backup'),
) )
...@@ -786,8 +796,8 @@ class StopOperation(Packet): ...@@ -786,8 +796,8 @@ class StopOperation(Packet):
class UnfinishedTransactions(Packet): class UnfinishedTransactions(Packet):
""" """
Ask unfinished transactions PM -> S. Ask unfinished transactions S -> PM.
Answer unfinished transactions S -> PM. Answer unfinished transactions PM -> S.
""" """
_answer = PStruct('answer_unfinished_transactions', _answer = PStruct('answer_unfinished_transactions',
PTID('max_tid'), PTID('max_tid'),
...@@ -796,36 +806,36 @@ class UnfinishedTransactions(Packet): ...@@ -796,36 +806,36 @@ class UnfinishedTransactions(Packet):
), ),
) )
class ObjectPresent(Packet): class LockedTransactions(Packet):
""" """
Ask if an object is present. If not present, OID_NOT_FOUND should be Ask locked transactions PM -> S.
returned. PM -> S. Answer locked transactions S -> PM.
Answer that an object is present. PM -> S.
""" """
_fmt = PStruct('object_present', _answer = PStruct('answer_locked_transactions',
POID('oid'), PDict('tid_dict',
PTID('tid'), PTID('ttid'),
)
_answer = PStruct('object_present',
POID('oid'),
PTID('tid'), PTID('tid'),
),
) )
class DeleteTransaction(Packet): class FinalTID(Packet):
""" """
Delete a transaction. PM -> S. Return final tid if ttid has been committed. * -> S.
""" """
_fmt = PStruct('delete_transaction', _fmt = PStruct('final_tid',
PTID('ttid'),
)
_answer = PStruct('final_tid',
PTID('tid'), PTID('tid'),
PFOidList,
) )
class CommitTransaction(Packet): class ValidateTransaction(Packet):
""" """
Commit a transaction. PM -> S. Commit a transaction. PM -> S.
""" """
_fmt = PStruct('commit_transaction', _fmt = PStruct('validate_transaction',
PTID('ttid'),
PTID('tid'), PTID('tid'),
) )
...@@ -878,11 +888,10 @@ class LockInformation(Packet): ...@@ -878,11 +888,10 @@ class LockInformation(Packet):
_fmt = PStruct('ask_lock_informations', _fmt = PStruct('ask_lock_informations',
PTID('ttid'), PTID('ttid'),
PTID('tid'), PTID('tid'),
PFOidList,
) )
_answer = PStruct('answer_information_locked', _answer = PStruct('answer_information_locked',
PTID('tid'), PTID('ttid'),
) )
class InvalidateObjects(Packet): class InvalidateObjects(Packet):
...@@ -899,7 +908,7 @@ class UnlockInformation(Packet): ...@@ -899,7 +908,7 @@ class UnlockInformation(Packet):
Unlock information on a transaction. PM -> S. Unlock information on a transaction. PM -> S.
""" """
_fmt = PStruct('notify_unlock_information', _fmt = PStruct('notify_unlock_information',
PTID('tid'), PTID('ttid'),
) )
class GenerateOIDs(Packet): class GenerateOIDs(Packet):
...@@ -961,10 +970,17 @@ class StoreTransaction(Packet): ...@@ -961,10 +970,17 @@ class StoreTransaction(Packet):
PString('extension'), PString('extension'),
PFOidList, PFOidList,
) )
_answer = PFEmpty
_answer = PStruct('answer_store_transaction', class VoteTransaction(Packet):
"""
Ask to store a transaction. C -> S.
Answer if transaction has been stored. S -> C.
"""
_fmt = PStruct('ask_vote_transaction',
PTID('tid'), PTID('tid'),
) )
_answer = PFEmpty
class GetObject(Packet): class GetObject(Packet):
""" """
...@@ -1462,13 +1478,14 @@ class ReplicationDone(Packet): ...@@ -1462,13 +1478,14 @@ class ReplicationDone(Packet):
class Truncate(Packet): class Truncate(Packet):
""" """
XXX: Used for both make storage consistent and leave backup mode Request DB to be truncated. Also used to leave backup mode.
M -> S
""" """
_fmt = PStruct('truncate', _fmt = PStruct('truncate',
PTID('tid'), PTID('tid'),
) )
_answer = Error
StaticRegistry = {} StaticRegistry = {}
def register(request, ignore_when_closed=None): def register(request, ignore_when_closed=None):
...@@ -1586,6 +1603,8 @@ class Packets(dict): ...@@ -1586,6 +1603,8 @@ class Packets(dict):
ReelectPrimary) ReelectPrimary)
NotifyNodeInformation = register( NotifyNodeInformation = register(
NotifyNodeInformation) NotifyNodeInformation)
AskRecovery, AnswerRecovery = register(
Recovery)
AskLastIDs, AnswerLastIDs = register( AskLastIDs, AnswerLastIDs = register(
LastIDs) LastIDs)
AskPartitionTable, AnswerPartitionTable = register( AskPartitionTable, AnswerPartitionTable = register(
...@@ -1600,12 +1619,12 @@ class Packets(dict): ...@@ -1600,12 +1619,12 @@ class Packets(dict):
StopOperation) StopOperation)
AskUnfinishedTransactions, AnswerUnfinishedTransactions = register( AskUnfinishedTransactions, AnswerUnfinishedTransactions = register(
UnfinishedTransactions) UnfinishedTransactions)
AskObjectPresent, AnswerObjectPresent = register( AskLockedTransactions, AnswerLockedTransactions = register(
ObjectPresent) LockedTransactions)
DeleteTransaction = register( AskFinalTID, AnswerFinalTID = register(
DeleteTransaction) FinalTID)
CommitTransaction = register( ValidateTransaction = register(
CommitTransaction) ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register( AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction) BeginTransaction)
AskFinishTransaction, AnswerTransactionFinished = register( AskFinishTransaction, AnswerTransactionFinished = register(
...@@ -1624,6 +1643,8 @@ class Packets(dict): ...@@ -1624,6 +1643,8 @@ class Packets(dict):
AbortTransaction) AbortTransaction)
AskStoreTransaction, AnswerStoreTransaction = register( AskStoreTransaction, AnswerStoreTransaction = register(
StoreTransaction) StoreTransaction)
AskVoteTransaction, AnswerVoteTransaction = register(
VoteTransaction)
AskObject, AnswerObject = register( AskObject, AnswerObject = register(
GetObject) GetObject)
AskTIDs, AnswerTIDs = register( AskTIDs, AnswerTIDs = register(
......
...@@ -24,7 +24,7 @@ from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID ...@@ -24,7 +24,7 @@ from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.connection import ListeningConnection, ClientConnection from neo.lib.connection import ListeningConnection, ClientConnection
from neo.lib.exception import ElectionFailure, PrimaryFailure, OperationFailure from neo.lib.exception import ElectionFailure, PrimaryFailure, StoppedOperation
class StateChangedException(Exception): pass class StateChangedException(Exception): pass
...@@ -45,6 +45,7 @@ class Application(BaseApplication): ...@@ -45,6 +45,7 @@ class Application(BaseApplication):
backup_tid = None backup_tid = None
backup_app = None backup_app = None
uuid = None uuid = None
truncate_tid = None
def __init__(self, config): def __init__(self, config):
super(Application, self).__init__( super(Application, self).__init__(
...@@ -77,7 +78,6 @@ class Application(BaseApplication): ...@@ -77,7 +78,6 @@ class Application(BaseApplication):
self.primary = None self.primary = None
self.primary_master_node = None self.primary_master_node = None
self.cluster_state = None self.cluster_state = None
self._startup_allowed = False
uuid = config.getUUID() uuid = config.getUUID()
if uuid: if uuid:
...@@ -221,7 +221,7 @@ class Application(BaseApplication): ...@@ -221,7 +221,7 @@ class Application(BaseApplication):
self.primary = self.primary is None self.primary = self.primary is None
break break
def broadcastNodesInformation(self, node_list): def broadcastNodesInformation(self, node_list, exclude=None):
""" """
Broadcast changes for a set a nodes Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth Send only one packet per connection to reduce bandwidth
...@@ -243,7 +243,7 @@ class Application(BaseApplication): ...@@ -243,7 +243,7 @@ class Application(BaseApplication):
# send at most one non-empty notification packet per node # send at most one non-empty notification packet per node
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
node_list = node_dict.get(node.getType()) node_list = node_dict.get(node.getType())
if node_list and node.isRunning(): if node_list and node.isRunning() and node is not exclude:
node.notify(Packets.NotifyNodeInformation(node_list)) node.notify(Packets.NotifyNodeInformation(node_list))
def broadcastPartitionChanges(self, cell_list): def broadcastPartitionChanges(self, cell_list):
...@@ -254,7 +254,6 @@ class Application(BaseApplication): ...@@ -254,7 +254,6 @@ class Application(BaseApplication):
ptid = self.pt.setNextID() ptid = self.pt.setNextID()
packet = Packets.NotifyPartitionChanges(ptid, cell_list) packet = Packets.NotifyPartitionChanges(ptid, cell_list)
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
# TODO: notify masters
if node.isRunning() and not node.isMaster(): if node.isRunning() and not node.isMaster():
node.notify(packet) node.notify(packet)
...@@ -266,8 +265,6 @@ class Application(BaseApplication): ...@@ -266,8 +265,6 @@ class Application(BaseApplication):
""" """
logging.info('provide service') logging.info('provide service')
poll = self.em.poll poll = self.em.poll
self.tm.reset()
self.changeClusterState(ClusterStates.RUNNING) self.changeClusterState(ClusterStates.RUNNING)
# Now everything is passive. # Now everything is passive.
...@@ -278,8 +275,13 @@ class Application(BaseApplication): ...@@ -278,8 +275,13 @@ class Application(BaseApplication):
if e.args[0] != ClusterStates.STARTING_BACKUP: if e.args[0] != ClusterStates.STARTING_BACKUP:
raise raise
self.backup_tid = tid = self.getLastTransaction() self.backup_tid = tid = self.getLastTransaction()
self.pt.setBackupTidDict({node.getUUID(): tid packet = Packets.StartOperation(True)
for node in self.nm.getStorageList(only_identified=True)}) tid_dict = {}
for node in self.nm.getStorageList(only_identified=True):
tid_dict[node.getUUID()] = tid
if node.isRunning():
node.notify(packet)
self.pt.setBackupTidDict(tid_dict)
def playPrimaryRole(self): def playPrimaryRole(self):
logging.info('play the primary role with %r', self.listening_conn) logging.info('play the primary role with %r', self.listening_conn)
...@@ -323,30 +325,46 @@ class Application(BaseApplication): ...@@ -323,30 +325,46 @@ class Application(BaseApplication):
in_conflict) in_conflict)
in_conflict.setUUID(None) in_conflict.setUUID(None)
# recover the cluster status at startup # Do not restart automatically if ElectionFailure is raised, in order
# to avoid a split of the database. For example, with 2 machines with
# a master and a storage on each one and replicas=1, the secondary
# master becomes primary in case of network failure between the 2
# machines but must not start automatically: otherwise, each storage
# node would diverge.
self._startup_allowed = False
try: try:
self.runManager(RecoveryManager)
while True: while True:
self.runManager(VerificationManager) self.runManager(RecoveryManager)
try: try:
if self.backup_tid: self.runManager(VerificationManager)
if not self.backup_tid:
self.provideService()
# self.provideService only returns without raising
# when switching to backup mode.
if self.backup_app is None: if self.backup_app is None:
raise RuntimeError("No upstream cluster to backup" raise RuntimeError("No upstream cluster to backup"
" defined in configuration") " defined in configuration")
self.backup_app.provideService() truncate = Packets.Truncate(
# Reset connection with storages (and go through a self.backup_app.provideService())
# recovery phase) when leaving backup mode in order except StoppedOperation, e:
# to get correct last oid/tid.
self.runManager(RecoveryManager)
continue
self.provideService()
except OperationFailure:
logging.critical('No longer operational') logging.critical('No longer operational')
truncate = Packets.Truncate(*e.args) if e.args else None
# Automatic restart except if we truncate or retry to.
self._startup_allowed = not (self.truncate_tid or truncate)
node_list = []
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
if node.isStorage() or node.isClient(): if node.isStorage() or node.isClient():
node.notify(Packets.StopOperation()) conn = node.getConnection()
conn.notify(Packets.StopOperation())
if node.isClient(): if node.isClient():
node.getConnection().abort() conn.abort()
continue
if truncate:
conn.notify(truncate)
if node.isRunning():
node.setPending()
node_list.append(node)
self.broadcastNodesInformation(node_list)
except StateChangedException, e: except StateChangedException, e:
assert e.args[0] == ClusterStates.STOPPING assert e.args[0] == ClusterStates.STOPPING
self.shutdown() self.shutdown()
...@@ -427,7 +445,7 @@ class Application(BaseApplication): ...@@ -427,7 +445,7 @@ class Application(BaseApplication):
continue # keep handler continue # keep handler
if type(handler) is not type(conn.getLastHandler()): if type(handler) is not type(conn.getLastHandler()):
conn.setHandler(handler) conn.setHandler(handler)
handler.connectionCompleted(conn) handler.connectionCompleted(conn, new=False)
self.cluster_state = state self.cluster_state = state
def getNewUUID(self, uuid, address, node_type): def getNewUUID(self, uuid, address, node_type):
...@@ -461,7 +479,7 @@ class Application(BaseApplication): ...@@ -461,7 +479,7 @@ class Application(BaseApplication):
# wait for all transaction to be finished # wait for all transaction to be finished
while self.tm.hasPending(): while self.tm.hasPending():
self.em.poll(1) self.em.poll(1)
except OperationFailure: except StoppedOperation:
logging.critical('No longer operational') logging.critical('No longer operational')
logging.info("asking remaining nodes to shutdown") logging.info("asking remaining nodes to shutdown")
......
...@@ -152,24 +152,20 @@ class BackupApplication(object): ...@@ -152,24 +152,20 @@ class BackupApplication(object):
assert tid != ZERO_TID assert tid != ZERO_TID
logging.warning("Truncating at %s (last_tid was %s)", logging.warning("Truncating at %s (last_tid was %s)",
dump(app.backup_tid), dump(last_tid)) dump(app.backup_tid), dump(last_tid))
# XXX: We want to go through a recovery phase in order to else:
# initialize the transaction manager, but this is only # We will do a dummy truncation, just to leave backup mode,
# possible if storages already know that we left backup # so it's fine to start automatically if there's any
# mode. To that purpose, we always send a Truncate packet, # missing storage.
# even if there's nothing to truncate. # XXX: Consider using another method to leave backup mode,
p = Packets.Truncate(tid) # at least when there's nothing to truncate. Because
for node in app.nm.getStorageList(only_identified=True): # in case of StoppedOperation during VERIFYING state,
conn = node.getConnection() # this flag will be wrongly set to False.
conn.setHandler(handler) app._startup_allowed = True
node.setState(NodeStates.TEMPORARILY_DOWN)
# Packets will be sent at the beginning of the recovery
# phase.
conn.notify(p)
conn.abort()
# If any error happened before reaching this line, we'd go back # If any error happened before reaching this line, we'd go back
# to backup mode, which is the right mode to recover. # to backup mode, which is the right mode to recover.
del app.backup_tid del app.backup_tid
break # Now back to RECOVERY...
return tid
finally: finally:
del self.primary_partition_dict, self.tid_list del self.primary_partition_dict, self.tid_list
pt.clearReplicating() pt.clearReplicating()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets, from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets,
BrokenNodeDisallowedError, BrokenNodeDisallowedError,
...@@ -23,6 +24,10 @@ from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets, ...@@ -23,6 +24,10 @@ from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets,
class MasterHandler(EventHandler): class MasterHandler(EventHandler):
"""This class implements a generic part of the event handlers.""" """This class implements a generic part of the event handlers."""
def connectionCompleted(self, conn, new=None):
if new is None:
super(MasterHandler, self).connectionCompleted(conn)
def requestIdentification(self, conn, node_type, uuid, address, name): def requestIdentification(self, conn, node_type, uuid, address, name):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
...@@ -61,25 +66,31 @@ class MasterHandler(EventHandler): ...@@ -61,25 +66,31 @@ class MasterHandler(EventHandler):
state = self.app.getClusterState() state = self.app.getClusterState()
conn.answer(Packets.AnswerClusterState(state)) conn.answer(Packets.AnswerClusterState(state))
def askLastIDs(self, conn): def askRecovery(self, conn):
app = self.app app = self.app
conn.answer(Packets.AnswerLastIDs( conn.answer(Packets.AnswerRecovery(
app.tm.getLastOID(),
app.tm.getLastTID(),
app.pt.getID(), app.pt.getID(),
app.backup_tid)) app.backup_tid and app.pt.getBackupTid(),
app.truncate_tid))
def askLastIDs(self, conn):
tm = self.app.tm
conn.answer(Packets.AnswerLastIDs(tm.getLastOID(), tm.getLastTID()))
def askLastTransaction(self, conn): def askLastTransaction(self, conn):
conn.answer(Packets.AnswerLastTransaction( conn.answer(Packets.AnswerLastTransaction(
self.app.getLastTransaction())) self.app.getLastTransaction()))
def askNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
nm = self.app.nm nm = self.app.nm
node_list = [] node_list = []
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getClientList()) node_list.extend(n.asTuple() for n in nm.getClientList())
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
def askNodeInformation(self, conn):
self._notifyNodeInformation(conn)
conn.answer(Packets.AnswerNodeInformation()) conn.answer(Packets.AnswerNodeInformation())
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
...@@ -94,15 +105,18 @@ DISCONNECTED_STATE_DICT = { ...@@ -94,15 +105,18 @@ DISCONNECTED_STATE_DICT = {
class BaseServiceHandler(MasterHandler): class BaseServiceHandler(MasterHandler):
"""This class deals with events for a service phase.""" """This class deals with events for a service phase."""
def nodeLost(self, conn, node): def connectionCompleted(self, conn, new):
# This method provides a hook point overridable by service classes. self._notifyNodeInformation(conn)
# It is triggered when a connection to a node gets lost. pt = self.app.pt
pass conn.notify(Packets.SendPartitionTable(pt.getID(), pt.getRowList()))
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
node = self.app.nm.getByUUID(conn.getUUID()) app = self.app
node = app.nm.getByUUID(conn.getUUID())
if node is None: if node is None:
return # for example, when a storage is removed by an admin return # for example, when a storage is removed by an admin
assert node.isStorage(), node
logging.info('storage node lost')
if new_state != NodeStates.BROKEN: if new_state != NodeStates.BROKEN:
new_state = DISCONNECTED_STATE_DICT.get(node.getType(), new_state = DISCONNECTED_STATE_DICT.get(node.getType(),
NodeStates.DOWN) NodeStates.DOWN)
...@@ -117,10 +131,13 @@ class BaseServiceHandler(MasterHandler): ...@@ -117,10 +131,13 @@ class BaseServiceHandler(MasterHandler):
# was in pending state, so drop it from the node manager to forget # was in pending state, so drop it from the node manager to forget
# it and do not set in running state when it comes back # it and do not set in running state when it comes back
logging.info('drop a pending node from the node manager') logging.info('drop a pending node from the node manager')
self.app.nm.remove(node) app.nm.remove(node)
self.app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
# clean node related data in specialized handlers if app.truncate_tid:
self.nodeLost(conn, node) raise StoppedOperation
app.broadcastPartitionChanges(app.pt.outdate(node))
if not app.pt.operational():
raise StoppedOperation
def notifyReady(self, conn): def notifyReady(self, conn):
self.app.setStorageReady(conn.getUUID()) self.app.setStorageReady(conn.getUUID())
......
...@@ -19,6 +19,7 @@ import random ...@@ -19,6 +19,7 @@ import random
from . import MasterHandler from . import MasterHandler
from ..app import StateChangedException from ..app import StateChangedException
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.lib.protocol import ClusterStates, Errors, \ from neo.lib.protocol import ClusterStates, Errors, \
NodeStates, NodeTypes, Packets, ProtocolError, uuid_str NodeStates, NodeTypes, Packets, ProtocolError, uuid_str
...@@ -159,6 +160,13 @@ class AdministrationHandler(MasterHandler): ...@@ -159,6 +160,13 @@ class AdministrationHandler(MasterHandler):
map(app.nm.getByUUID, uuid_list))) map(app.nm.getByUUID, uuid_list)))
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
def truncate(self, conn, tid):
app = self.app
if app.cluster_state != ClusterStates.RUNNING:
raise ProtocolError('Can not truncate in this state')
conn.answer(Errors.Ack(''))
raise StoppedOperation(tid)
def checkReplicas(self, conn, partition_dict, min_tid, max_tid): def checkReplicas(self, conn, partition_dict, min_tid, max_tid):
app = self.app app = self.app
pt = app.pt pt = app.pt
......
...@@ -20,9 +20,6 @@ from . import MasterHandler ...@@ -20,9 +20,6 @@ from . import MasterHandler
class ClientServiceHandler(MasterHandler): class ClientServiceHandler(MasterHandler):
""" Handler dedicated to client during service state """ """ Handler dedicated to client during service state """
def connectionCompleted(self, conn):
pass
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
# cancel its transactions and forgot the node # cancel its transactions and forgot the node
app = self.app app = self.app
...@@ -59,9 +56,10 @@ class ClientServiceHandler(MasterHandler): ...@@ -59,9 +56,10 @@ class ClientServiceHandler(MasterHandler):
pt = app.pt pt = app.pt
# Collect partitions related to this transaction. # Collect partitions related to this transaction.
lock_oid_list = oid_list + checked_list getPartition = pt.getPartition
partition_set = set(map(pt.getPartition, lock_oid_list)) partition_set = set(map(getPartition, oid_list))
partition_set.add(pt.getPartition(ttid)) partition_set.update(map(getPartition, checked_list))
partition_set.add(getPartition(ttid))
# Collect the UUIDs of nodes related to this transaction. # Collect the UUIDs of nodes related to this transaction.
uuid_list = filter(app.isStorageReady, {cell.getUUID() uuid_list = filter(app.isStorageReady, {cell.getUUID()
...@@ -85,7 +83,6 @@ class ClientServiceHandler(MasterHandler): ...@@ -85,7 +83,6 @@ class ClientServiceHandler(MasterHandler):
{x.getUUID() for x in identified_node_list}, {x.getUUID() for x in identified_node_list},
conn.getPeerId(), conn.getPeerId(),
), ),
lock_oid_list,
) )
for node in identified_node_list: for node in identified_node_list:
node.ask(p, timeout=60) node.ask(p, timeout=60)
......
...@@ -26,7 +26,7 @@ class IdentificationHandler(MasterHandler): ...@@ -26,7 +26,7 @@ class IdentificationHandler(MasterHandler):
**kw) **kw)
handler = conn.getHandler() handler = conn.getHandler()
assert not isinstance(handler, IdentificationHandler), handler assert not isinstance(handler, IdentificationHandler), handler
handler.connectionCompleted(conn) handler.connectionCompleted(conn, True)
def _setupNode(self, conn, node_type, uuid, address, node): def _setupNode(self, conn, node_type, uuid, address, node):
app = self.app app = self.app
...@@ -72,7 +72,7 @@ class IdentificationHandler(MasterHandler): ...@@ -72,7 +72,7 @@ class IdentificationHandler(MasterHandler):
node.setState(state) node.setState(state)
node.setConnection(conn) node.setConnection(conn)
conn.setHandler(handler) conn.setHandler(handler)
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node], node)
return uuid return uuid
class SecondaryIdentificationHandler(MasterHandler): class SecondaryIdentificationHandler(MasterHandler):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import CellStates, ClusterStates, Packets, ProtocolError from neo.lib.protocol import CellStates, ClusterStates, Packets, ProtocolError
from neo.lib.exception import OperationFailure from neo.lib.exception import StoppedOperation
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from . import BaseServiceHandler from . import BaseServiceHandler
...@@ -24,25 +24,27 @@ from . import BaseServiceHandler ...@@ -24,25 +24,27 @@ from . import BaseServiceHandler
class StorageServiceHandler(BaseServiceHandler): class StorageServiceHandler(BaseServiceHandler):
""" Handler dedicated to storages during service state """ """ Handler dedicated to storages during service state """
def connectionCompleted(self, conn): def connectionCompleted(self, conn, new):
# TODO: unit test
app = self.app app = self.app
uuid = conn.getUUID() uuid = conn.getUUID()
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
app.setStorageNotReady(uuid) app.setStorageNotReady(uuid)
if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new)
# XXX: what other values could happen ? # XXX: what other values could happen ?
if node.isRunning(): if node.isRunning():
conn.notify(Packets.StartOperation(bool(app.backup_tid))) conn.notify(Packets.StartOperation(bool(app.backup_tid)))
def nodeLost(self, conn, node): def connectionLost(self, conn, new_state):
logging.info('storage node lost')
assert not node.isRunning(), node.getState()
app = self.app app = self.app
app.broadcastPartitionChanges(app.pt.outdate(node)) node = app.nm.getByUUID(conn.getUUID())
if not app.pt.operational(): super(StorageServiceHandler, self).connectionLost(conn, new_state)
raise OperationFailure, 'cannot continue operation'
app.tm.forget(conn.getUUID()) app.tm.forget(conn.getUUID())
if app.getClusterState() == ClusterStates.BACKINGUP: if (app.getClusterState() == ClusterStates.BACKINGUP
# Also check if we're exiting, because backup_app is not usable
# in this case. Maybe cluster state should be set to something
# else, like STOPPING, during cleanup (__del__/close).
and app.listening_conn):
app.backup_app.nodeLost(node) app.backup_app.nodeLost(node)
if app.packing is not None: if app.packing is not None:
self.answerPack(conn, False) self.answerPack(conn, False)
...@@ -74,7 +76,7 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -74,7 +76,7 @@ class StorageServiceHandler(BaseServiceHandler):
CellStates.CORRUPTED)) CellStates.CORRUPTED))
self.app.broadcastPartitionChanges(change_list) self.app.broadcastPartitionChanges(change_list)
if not self.app.pt.operational(): if not self.app.pt.operational():
raise OperationFailure('cannot continue operation') raise StoppedOperation
def notifyReplicationDone(self, conn, offset, tid): def notifyReplicationDone(self, conn, offset, tid):
app = self.app app = self.app
......
...@@ -299,15 +299,19 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -299,15 +299,19 @@ class PartitionTable(neo.lib.pt.PartitionTable):
yield offset, cell yield offset, cell
break break
def getReadableCellNodeSet(self): def getOperationalNodeSet(self):
""" """
Return a set of all nodes which are part of at least one UP TO DATE Return a set of all nodes which are part of at least one UP TO DATE
partition. partition. An empty list is returned if these nodes aren't enough to
become operational.
""" """
return {cell.getNode() node_set = set()
for row in self.partition_list for row in self.partition_list:
for cell in row if not any(cell.isReadable() and cell.getNode().isPending()
if cell.isReadable()} for cell in row):
return () # not operational
node_set.update(cell.getNode() for cell in row if cell.isReadable())
return node_set
def clearReplicating(self): def clearReplicating(self):
for row in self.partition_list: for row in self.partition_list:
......
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib import logging from neo.lib import logging
from neo.lib.util import dump
from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates from neo.lib.protocol import Packets, ProtocolError, ClusterStates, NodeStates
from neo.lib.protocol import ZERO_OID
from .handlers import MasterHandler from .handlers import MasterHandler
...@@ -29,7 +27,9 @@ class RecoveryManager(MasterHandler): ...@@ -29,7 +27,9 @@ class RecoveryManager(MasterHandler):
def __init__(self, app): def __init__(self, app):
# The target node's uuid to request next. # The target node's uuid to request next.
self.target_ptid = None self.target_ptid = None
self.ask_pt = []
self.backup_tid_dict = {} self.backup_tid_dict = {}
self.truncate_dict = {}
def getHandler(self): def getHandler(self):
return self return self
...@@ -51,7 +51,7 @@ class RecoveryManager(MasterHandler): ...@@ -51,7 +51,7 @@ class RecoveryManager(MasterHandler):
app = self.app app = self.app
pt = app.pt pt = app.pt
app.changeClusterState(ClusterStates.RECOVERING) app.changeClusterState(ClusterStates.RECOVERING)
pt.setID(None) pt.clear()
# collect the last partition table available # collect the last partition table available
poll = app.em.poll poll = app.em.poll
...@@ -60,10 +60,14 @@ class RecoveryManager(MasterHandler): ...@@ -60,10 +60,14 @@ class RecoveryManager(MasterHandler):
if pt.filled(): if pt.filled():
# A partition table exists, we are starting an existing # A partition table exists, we are starting an existing
# cluster. # cluster.
node_list = pt.getReadableCellNodeSet() node_list = pt.getOperationalNodeSet()
if app._startup_allowed: if app._startup_allowed:
node_list = [node for node in node_list if node.isPending()] node_list = [node for node in node_list if node.isPending()]
elif not all(node.isPending() for node in node_list): elif node_list:
# we want all nodes to be there if we're going to truncate
if app.truncate_tid:
node_list = pt.getNodeSet()
if not all(node.isPending() for node in node_list):
continue continue
elif app._startup_allowed or app.autostart: elif app._startup_allowed or app.autostart:
# No partition table and admin allowed startup, we are # No partition table and admin allowed startup, we are
...@@ -76,6 +80,17 @@ class RecoveryManager(MasterHandler): ...@@ -76,6 +80,17 @@ class RecoveryManager(MasterHandler):
if node_list and not any(node.getConnection().isPending() if node_list and not any(node.getConnection().isPending()
for node in node_list): for node in node_list):
if pt.filled(): if pt.filled():
if app.truncate_tid:
node_list = app.nm.getIdentifiedList(pool_set={uuid
for uuid, tid in self.truncate_dict.iteritems()
if not tid or app.truncate_tid < tid})
if node_list:
truncate = Packets.Truncate(app.truncate_tid)
for node in node_list:
conn = node.getConnection()
conn.notify(truncate)
self.connectionCompleted(conn, False)
continue
node_list = pt.getConnectedNodeList() node_list = pt.getConnectedNodeList()
break break
...@@ -88,64 +103,81 @@ class RecoveryManager(MasterHandler): ...@@ -88,64 +103,81 @@ class RecoveryManager(MasterHandler):
if pt.getID() is None: if pt.getID() is None:
logging.info('creating a new partition table') logging.info('creating a new partition table')
# reset IDs generators & build new partition with running nodes
app.tm.setLastOID(ZERO_OID)
pt.make(node_list) pt.make(node_list)
self._broadcastPartitionTable(pt.getID(), pt.getRowList()) self._notifyAdmins(Packets.SendPartitionTable(
elif app.backup_tid: pt.getID(), pt.getRowList()))
else:
cell_list = pt.outdate()
if cell_list:
self._notifyAdmins(Packets.NotifyPartitionChanges(
pt.setNextID(), cell_list))
if app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict) pt.setBackupTidDict(self.backup_tid_dict)
app.backup_tid = pt.getBackupTid() app.backup_tid = pt.getBackupTid()
app.setLastTransaction(app.tm.getLastTID()) logging.debug('cluster starts this partition table:')
logging.debug('cluster starts with loid=%s and this partition table :',
dump(app.tm.getLastOID()))
pt.log() pt.log()
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
node = self.app.nm.getByUUID(conn.getUUID()) uuid = conn.getUUID()
assert node is not None self.backup_tid_dict.pop(uuid, None)
self.truncate_dict.pop(uuid, None)
node = self.app.nm.getByUUID(uuid)
try:
i = self.ask_pt.index(uuid)
except ValueError:
pass
else:
del self.ask_pt[i]
if not i:
if self.ask_pt:
self.app.nm.getByUUID(self.ask_pt[0]) \
.ask(Packets.AskPartitionTable())
else:
logging.warning("Waiting for %r to come back."
" No other node has version %s of the partition table.",
node, self.target_ptid)
if node.getState() == new_state: if node.getState() == new_state:
return return
node.setState(new_state) node.setState(new_state)
# broadcast to all so that admin nodes gets informed # broadcast to all so that admin nodes gets informed
self.app.broadcastNodesInformation([node]) self.app.broadcastNodesInformation([node])
def connectionCompleted(self, conn): def connectionCompleted(self, conn, new):
# ask the last IDs to perform the recovery # ask the last IDs to perform the recovery
conn.ask(Packets.AskLastIDs()) conn.ask(Packets.AskRecovery())
def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid): def answerRecovery(self, conn, ptid, backup_tid, truncate_tid):
# Get max values. uuid = conn.getUUID()
if loid is not None: if self.target_ptid <= ptid:
self.app.tm.setLastOID(loid) # Maybe a newer partition table.
if ltid is not None: if self.target_ptid == ptid and self.ask_pt:
self.app.tm.setLastTID(ltid) # Another node is already asked.
if lptid > self.target_ptid: self.ask_pt.append(uuid)
# something newer elif self.target_ptid < ptid or self.ask_pt is not ():
self.target_ptid = lptid # No node asked yet for the newest partition table.
self.target_ptid = ptid
self.ask_pt = [uuid]
conn.ask(Packets.AskPartitionTable()) conn.ask(Packets.AskPartitionTable())
self.backup_tid_dict[conn.getUUID()] = backup_tid self.backup_tid_dict[uuid] = backup_tid
self.truncate_dict[uuid] = truncate_tid
def answerPartitionTable(self, conn, ptid, row_list): def answerPartitionTable(self, conn, ptid, row_list):
if ptid != self.target_ptid:
# If this is not from a target node, ignore it. # If this is not from a target node, ignore it.
logging.warn('Got %s while waiting %s', dump(ptid), if ptid == self.target_ptid:
dump(self.target_ptid)) app = self.app
else:
self._broadcastPartitionTable(ptid, row_list)
self.app.backup_tid = self.backup_tid_dict[conn.getUUID()]
def _broadcastPartitionTable(self, ptid, row_list):
try: try:
new_nodes = self.app.pt.load(ptid, row_list, self.app.nm) new_nodes = app.pt.load(ptid, row_list, app.nm)
except IndexError: except IndexError:
raise ProtocolError('Invalid offset') raise ProtocolError('Invalid offset')
else: self._notifyAdmins(Packets.NotifyNodeInformation(new_nodes),
notification = Packets.NotifyNodeInformation(new_nodes) Packets.SendPartitionTable(ptid, row_list))
ptid = self.app.pt.getID() self.ask_pt = ()
row_list = self.app.pt.getRowList() uuid = conn.getUUID()
partition_table = Packets.SendPartitionTable(ptid, row_list) app.backup_tid = self.backup_tid_dict[uuid]
# notify the admin nodes app.truncate_tid = self.truncate_dict[uuid]
def _notifyAdmins(self, *packets):
for node in self.app.nm.getAdminList(only_identified=True): for node in self.app.nm.getAdminList(only_identified=True):
node.notify(notification) for packet in packets:
node.notify(partition_table) node.notify(packet)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ProtocolError, uuid_str, ZERO_TID from neo.lib.protocol import ProtocolError, uuid_str, ZERO_OID, ZERO_TID
from neo.lib.util import dump, u64, addTID, tidFromTime from neo.lib.util import dump, u64, addTID, tidFromTime
class DelayedError(Exception): class DelayedError(Exception):
...@@ -155,15 +155,18 @@ class TransactionManager(object): ...@@ -155,15 +155,18 @@ class TransactionManager(object):
""" """
Manage current transactions Manage current transactions
""" """
_last_tid = ZERO_TID
def __init__(self, on_commit): def __init__(self, on_commit):
self._on_commit = on_commit
self.reset()
def reset(self):
# ttid -> transaction # ttid -> transaction
self._ttid_dict = {} self._ttid_dict = {}
# node -> transactions mapping # node -> transactions mapping
self._node_dict = {} self._node_dict = {}
self._last_oid = None self._last_oid = ZERO_OID
self._on_commit = on_commit self._last_tid = ZERO_TID
# queue filled with ttids pointing to transactions with increasing tids # queue filled with ttids pointing to transactions with increasing tids
self._queue = [] self._queue = []
...@@ -182,8 +185,6 @@ class TransactionManager(object): ...@@ -182,8 +185,6 @@ class TransactionManager(object):
def getNextOIDList(self, num_oids): def getNextOIDList(self, num_oids):
""" Generate a new OID list """ """ Generate a new OID list """
if self._last_oid is None:
raise RuntimeError, 'I do not know the last OID'
oid = unpack('!Q', self._last_oid)[0] + 1 oid = unpack('!Q', self._last_oid)[0] + 1
oid_list = [pack('!Q', oid + i) for i in xrange(num_oids)] oid_list = [pack('!Q', oid + i) for i in xrange(num_oids)]
self._last_oid = oid_list[-1] self._last_oid = oid_list[-1]
...@@ -249,14 +250,6 @@ class TransactionManager(object): ...@@ -249,14 +250,6 @@ class TransactionManager(object):
if self._last_tid < tid: if self._last_tid < tid:
self._last_tid = tid self._last_tid = tid
def reset(self):
"""
Discard all manager content
This doesn't reset the last TID.
"""
self._ttid_dict = {}
self._node_dict = {}
def hasPending(self): def hasPending(self):
""" """
Returns True if some transactions are pending Returns True if some transactions are pending
...@@ -359,7 +352,12 @@ class TransactionManager(object): ...@@ -359,7 +352,12 @@ class TransactionManager(object):
self._unlockPending() self._unlockPending()
def _unlockPending(self): def _unlockPending(self):
# unlock pending transactions """Serialize transaction unlocks
This should rarely delay unlocks since the time needed to lock a
transaction is roughly constant. The most common case where reordering
is required is when some storages are already busy by other tasks.
"""
queue = self._queue queue = self._queue
pop = queue.pop pop = queue.pop
insert = queue.insert insert = queue.insert
......
...@@ -14,48 +14,29 @@ ...@@ -14,48 +14,29 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import defaultdict
from neo.lib import logging from neo.lib import logging
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.protocol import ClusterStates, Packets, NodeStates from neo.lib.protocol import ClusterStates, Packets, NodeStates
from .handlers import BaseServiceHandler from .handlers import BaseServiceHandler
class VerificationFailure(Exception):
"""
Exception raised each time the cluster integrity failed.
- An required storage node is missing
- A transaction or an object is missing on a node
"""
pass
class VerificationManager(BaseServiceHandler): class VerificationManager(BaseServiceHandler):
"""
Manager for verification step of a NEO cluster:
- Wait for at least one available storage per partition
- Check if all expected content is present
"""
def __init__(self, app): def __init__(self, app):
self._oid_set = set() self._locked_dict = {}
self._tid_set = set() self._voted_dict = defaultdict(set)
self._uuid_set = set() self._uuid_set = set()
self._object_present = False
def _askStorageNodesAndWait(self, packet, node_list): def _askStorageNodesAndWait(self, packet, node_list):
poll = self.app.em.poll poll = self.app.em.poll
operational = self.app.pt.operational
uuid_set = self._uuid_set uuid_set = self._uuid_set
uuid_set.clear() uuid_set.clear()
for node in node_list: for node in node_list:
uuid_set.add(node.getUUID()) uuid_set.add(node.getUUID())
node.ask(packet) node.ask(packet)
while True: while uuid_set:
poll(1) poll(1)
if not operational():
raise VerificationFailure
if not uuid_set:
break
def getHandler(self): def getHandler(self):
return self return self
...@@ -76,135 +57,80 @@ class VerificationManager(BaseServiceHandler): ...@@ -76,135 +57,80 @@ class VerificationManager(BaseServiceHandler):
return state, self return state, self
def run(self): def run(self):
self.app.changeClusterState(ClusterStates.VERIFYING) app = self.app
while True: app.changeClusterState(ClusterStates.VERIFYING)
try: app.tm.reset()
if not app.backup_tid:
self.verifyData() self.verifyData()
except VerificationFailure: # This is where storages truncate if requested:
continue # - we make sure all nodes are running with a truncate_tid value saved
break # - there's no unfinished data
# At this stage, all non-working nodes are out-of-date. # - just before they return the last tid/oid
self.app.broadcastPartitionChanges(self.app.pt.outdate()) self._askStorageNodesAndWait(Packets.AskLastIDs(),
[x for x in app.nm.getIdentifiedList() if x.isStorage()])
app.setLastTransaction(app.tm.getLastTID())
# Just to not return meaningless information in AnswerRecovery.
app.truncate_tid = None
def verifyData(self): def verifyData(self):
"""Verify the data in storage nodes and clean them up, if necessary."""
app = self.app app = self.app
# wait for any missing node
logging.debug('waiting for the cluster to be operational')
while not app.pt.operational():
app.em.poll(1)
if app.backup_tid:
return
logging.info('start to verify data') logging.info('start to verify data')
getIdentifiedList = app.nm.getIdentifiedList getIdentifiedList = app.nm.getIdentifiedList
# Gather all unfinished transactions. # Gather all transactions that may have been partially finished.
self._askStorageNodesAndWait(Packets.AskUnfinishedTransactions(), self._askStorageNodesAndWait(Packets.AskLockedTransactions(),
[x for x in getIdentifiedList() if x.isStorage()]) [x for x in getIdentifiedList() if x.isStorage()])
# Gather OIDs for each unfinished TID, and verify whether the # Some nodes may have already unlocked these transactions and
# transaction can be finished or must be aborted. This could be # _locked_dict is incomplete, but we can ask them the final tid.
# in parallel in theory, but not so easy. Thus do it one-by-one for ttid, voted_set in self._voted_dict.iteritems():
# at the moment. if ttid in self._locked_dict:
for tid in self._tid_set: continue
uuid_set = self.verifyTransaction(tid) partition = app.pt.getPartition(ttid)
if uuid_set is None: for node in getIdentifiedList(pool_set={cell.getUUID()
packet = Packets.DeleteTransaction(tid, self._oid_set or []) # If an outdated cell had unlocked ttid, then either
# Make sure that no node has this transaction. # it is already in _locked_dict or a readable cell also
for node in getIdentifiedList(): # unlocked it.
if node.isStorage(): for cell in app.pt.getCellList(partition, readable=True)
node.notify(packet) } - voted_set):
self._askStorageNodesAndWait(Packets.AskFinalTID(ttid), (node,))
if self._tid is not None:
self._locked_dict[ttid] = self._tid
break
else: else:
if app.getLastTransaction() < tid: # XXX: refactoring needed # Transaction not locked. No need to tell nodes to delete it,
app.setLastTransaction(tid) # since they drop any unfinished data just before being
app.tm.setLastTID(tid) # operational.
packet = Packets.CommitTransaction(tid) pass
# Finish all transactions for which we know that tpc_finish was called
# but not fully processed. This may include replicas with transactions
# that were not even locked.
for ttid, tid in self._locked_dict.iteritems():
uuid_set = self._voted_dict.get(ttid)
if uuid_set:
packet = Packets.ValidateTransaction(ttid, tid)
for node in getIdentifiedList(pool_set=uuid_set): for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet) node.notify(packet)
self._oid_set = set()
# If possible, send the packets now.
app.em.poll(0)
def verifyTransaction(self, tid):
nm = self.app.nm
uuid_set = set()
# Determine to which nodes I should ask.
partition = self.app.pt.getPartition(tid)
uuid_list = [cell.getUUID() for cell \
in self.app.pt.getCellList(partition, readable=True)]
if len(uuid_list) == 0:
raise VerificationFailure
uuid_set.update(uuid_list)
# Gather OIDs.
node_list = self.app.nm.getIdentifiedList(pool_set=uuid_list)
if len(node_list) == 0:
raise VerificationFailure
self._askStorageNodesAndWait(Packets.AskTransactionInformation(tid),
node_list)
if self._oid_set is None or len(self._oid_set) == 0:
# Not commitable.
return None
# Verify that all objects are present.
for oid in self._oid_set:
partition = self.app.pt.getPartition(oid)
object_uuid_list = [cell.getUUID() for cell \
in self.app.pt.getCellList(partition, readable=True)]
if len(object_uuid_list) == 0:
raise VerificationFailure
uuid_set.update(object_uuid_list)
self._object_present = True
self._askStorageNodesAndWait(Packets.AskObjectPresent(oid, tid),
nm.getIdentifiedList(pool_set=object_uuid_list))
if not self._object_present:
# Not commitable.
return None
return uuid_set
def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
logging.info('got unfinished transactions %s from %r',
map(dump, tid_list), conn)
self._uuid_set.remove(conn.getUUID())
self._tid_set.update(tid_list)
def answerTransactionInformation(self, conn, tid, def answerLastIDs(self, conn, loid, ltid):
user, desc, ext, packed, oid_list):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
oid_set = set(oid_list) tm = self.app.tm
if self._oid_set is None: tm.setLastOID(loid)
# Someone does not agree. tm.setLastTID(ltid)
pass
elif len(self._oid_set) == 0: def answerLockedTransactions(self, conn, tid_dict):
# This is the first answer. uuid = conn.getUUID()
self._oid_set.update(oid_set) self._uuid_set.remove(uuid)
elif self._oid_set != oid_set: for ttid, tid in tid_dict.iteritems():
raise ValueError, "Inconsistent transaction %s" % \ if tid:
(dump(tid, )) self._locked_dict[ttid] = tid
self._voted_dict[ttid].add(uuid)
def tidNotFound(self, conn, message):
logging.info('TID not found: %s', message) def answerFinalTID(self, conn, tid):
self._uuid_set.remove(conn.getUUID()) self._uuid_set.remove(conn.getUUID())
self._oid_set = None self._tid = tid
def answerObjectPresent(self, conn, oid, tid):
logging.info('object %s:%s found', dump(oid), dump(tid))
self._uuid_set.remove(conn.getUUID())
def oidNotFound(self, conn, message):
logging.info('OID not found: %s', message)
self._uuid_set.remove(conn.getUUID())
self._object_present = False
def connectionCompleted(self, conn):
pass
def nodeLost(self, conn, node):
if not self.app.pt.operational():
raise VerificationFailure, 'cannot continue verification'
def connectionLost(self, conn, new_state):
self._uuid_set.discard(conn.getUUID())
super(VerificationManager, self).connectionLost(conn, new_state)
...@@ -37,6 +37,7 @@ action_dict = { ...@@ -37,6 +37,7 @@ action_dict = {
'tweak': 'tweakPartitionTable', 'tweak': 'tweakPartitionTable',
'drop': 'dropNode', 'drop': 'dropNode',
'kill': 'killNode', 'kill': 'killNode',
'truncate': 'truncate',
} }
uuid_int = (lambda ns: lambda uuid: uuid_int = (lambda ns: lambda uuid:
...@@ -85,11 +86,14 @@ class TerminalNeoCTL(object): ...@@ -85,11 +86,14 @@ class TerminalNeoCTL(object):
Get last ids. Get last ids.
""" """
assert not params assert not params
r = self.neoctl.getLastIds() ptid, backup_tid, truncate_tid = self.neoctl.getRecovery()
if r[3]: if backup_tid:
return "last_tid = 0x%x" % u64(self.neoctl.getLastTransaction()) ltid = self.neoctl.getLastTransaction()
return "last_oid = 0x%x\nlast_tid = 0x%x\nlast_ptid = %u" % ( r = "backup_tid = 0x%x" % u64(backup_tid)
u64(r[0]), u64(r[1]), r[2]) else:
loid, ltid = self.neoctl.getLastIds()
r = "last_oid = 0x%x" % u64(loid)
return r + "\nlast_tid = 0x%x\nlast_ptid = %u" % (u64(ltid), ptid)
def getPartitionRowList(self, params): def getPartitionRowList(self, params):
""" """
...@@ -193,6 +197,19 @@ class TerminalNeoCTL(object): ...@@ -193,6 +197,19 @@ class TerminalNeoCTL(object):
""" """
return uuid_str(self.neoctl.getPrimary()) return uuid_str(self.neoctl.getPrimary())
def truncate(self, params):
"""
Truncate the database at the given tid.
The cluster must be in RUNNING state, without any pending transaction.
This causes the cluster to go back in RECOVERING state, waiting all
nodes to be pending (do not use 'start' command unless you're sure
the missing nodes don't need to be truncated).
Parameters: tid
"""
self.neoctl.truncate(self.asTID(*params))
def checkReplicas(self, params): def checkReplicas(self, params):
""" """
Test whether partitions have corrupted metadata Test whether partitions have corrupted metadata
......
...@@ -61,3 +61,4 @@ class CommandEventHandler(EventHandler): ...@@ -61,3 +61,4 @@ class CommandEventHandler(EventHandler):
answerPrimary = __answer(Packets.AnswerPrimary) answerPrimary = __answer(Packets.AnswerPrimary)
answerLastIDs = __answer(Packets.AnswerLastIDs) answerLastIDs = __answer(Packets.AnswerLastIDs)
answerLastTransaction = __answer(Packets.AnswerLastTransaction) answerLastTransaction = __answer(Packets.AnswerLastTransaction)
answerRecovery = __answer(Packets.AnswerRecovery)
...@@ -120,6 +120,12 @@ class NeoCTL(BaseApplication): ...@@ -120,6 +120,12 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response) raise RuntimeError(response)
return response[1] return response[1]
def getRecovery(self):
response = self.__ask(Packets.AskRecovery())
if response[0] != Packets.AnswerRecovery:
raise RuntimeError(response)
return response[1:]
def getNodeList(self, node_type=None): def getNodeList(self, node_type=None):
""" """
Get a list of nodes, filtering with given type. Get a list of nodes, filtering with given type.
...@@ -163,6 +169,12 @@ class NeoCTL(BaseApplication): ...@@ -163,6 +169,12 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response) raise RuntimeError(response)
return response[1] return response[1]
def truncate(self, tid):
response = self.__ask(Packets.Truncate(tid))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response)
return response[2]
def checkReplicas(self, *args): def checkReplicas(self, *args):
response = self.__ask(Packets.CheckReplicas(*args)) response = self.__ask(Packets.CheckReplicas(*args))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK: if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
......
...@@ -53,7 +53,6 @@ UNIT_TEST_MODULES = [ ...@@ -53,7 +53,6 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp', 'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler', 'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions', 'neo.tests.storage.testTransactions',
# client application # client application
......
...@@ -23,14 +23,14 @@ from neo.lib.protocol import uuid_str, \ ...@@ -23,14 +23,14 @@ from neo.lib.protocol import uuid_str, \
CellStates, ClusterStates, NodeTypes, Packets CellStates, ClusterStates, NodeTypes, Packets
from neo.lib.node import NodeManager from neo.lib.node import NodeManager
from neo.lib.connection import ListeningConnection from neo.lib.connection import ListeningConnection
from neo.lib.exception import OperationFailure, PrimaryFailure from neo.lib.exception import StoppedOperation, PrimaryFailure
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from .checker import Checker from .checker import Checker
from .database import buildDatabaseManager from .database import buildDatabaseManager
from .exception import AlreadyPendingError from .exception import AlreadyPendingError
from .handlers import identification, verification, initialization from .handlers import identification, initialization
from .handlers import master, hidden from .handlers import master, hidden
from .replicator import Replicator from .replicator import Replicator
from .transactions import TransactionManager from .transactions import TransactionManager
...@@ -193,14 +193,11 @@ class Application(BaseApplication): ...@@ -193,14 +193,11 @@ class Application(BaseApplication):
self.event_queue = deque() self.event_queue = deque()
self.event_queue_dict = {} self.event_queue_dict = {}
try: try:
self.verifyData()
self.initialize() self.initialize()
self.doOperation() self.doOperation()
raise RuntimeError, 'should not reach here' raise RuntimeError, 'should not reach here'
except OperationFailure, msg: except StoppedOperation, msg:
logging.error('operation stopped: %s', msg) logging.error('operation stopped: %s', msg)
if self.cluster_state == ClusterStates.STOPPING_BACKUP:
self.dm.setBackupTID(None)
except PrimaryFailure, msg: except PrimaryFailure, msg:
logging.error('primary master is down: %s', msg) logging.error('primary master is down: %s', msg)
finally: finally:
...@@ -247,30 +244,11 @@ class Application(BaseApplication): ...@@ -247,30 +244,11 @@ class Application(BaseApplication):
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
self.loadPartitionTable() self.loadPartitionTable()
def verifyData(self):
"""Verify data under the control by a primary master node.
Connections from client nodes may not be accepted at this stage."""
logging.info('verifying data')
handler = verification.VerificationHandler(self)
self.master_conn.setHandler(handler)
_poll = self._poll
while not self.operational:
_poll()
def initialize(self): def initialize(self):
""" Retreive partition table and node informations from the primary """
logging.debug('initializing...') logging.debug('initializing...')
_poll = self._poll _poll = self._poll
handler = initialization.InitializationHandler(self) self.master_conn.setHandler(initialization.InitializationHandler(self))
self.master_conn.setHandler(handler) while not self.operational:
# ask node list and partition table
self.pt.clear()
self.master_conn.ask(Packets.AskNodeInformation())
self.master_conn.ask(Packets.AskPartitionTable())
while self.master_conn.isPending():
_poll() _poll()
self.ready = True self.ready = True
self.replicator.populate() self.replicator.populate()
......
...@@ -297,8 +297,9 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -297,8 +297,9 @@ class ImporterDatabaseManager(DatabaseManager):
self.db = buildDatabaseManager(main['adapter'], self.db = buildDatabaseManager(main['adapter'],
(main['database'], main.get('engine'), main['wait'])) (main['database'], main.get('engine'), main['wait']))
for x in """query erase getConfiguration _setConfiguration for x in """query erase getConfiguration _setConfiguration
getPartitionTable changePartitionTable getUnfinishedTIDList getPartitionTable changePartitionTable
dropUnfinishedData storeTransaction finishTransaction getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction
storeData _pruneData storeData _pruneData
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
...@@ -421,7 +422,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -421,7 +422,7 @@ class ImporterDatabaseManager(DatabaseManager):
logging.warning("All data are imported. You should change" logging.warning("All data are imported. You should change"
" your configuration to use the native backend and restart.") " your configuration to use the native backend and restart.")
self._import = None self._import = None
for x in """getObject objectPresent getReplicationTIDList for x in """getObject getReplicationTIDList
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
...@@ -434,23 +435,11 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -434,23 +435,11 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1] zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid return zodb, oid - zodb.shift_oid
def getLastIDs(self, all=True): def getLastIDs(self):
tid, _, _, oid = self.db.getLastIDs(all) tid, _, _, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)), None, None, return (max(tid, util.p64(self.zodb_ltid)), None, None,
max(oid, util.p64(self.zodb_loid))) max(oid, util.p64(self.zodb_loid)))
def objectPresent(self, oid, tid, all=True):
r = self.db.objectPresent(oid, tid, all)
if not r:
u_oid = util.u64(oid)
u_tid = util.u64(tid)
if self.inZodb(u_oid, u_tid):
zodb, oid = self.zodbFromOid(u_oid)
try:
return zodb.loadSerial(util.p64(oid), tid)
except POSKeyError:
pass
def getObject(self, oid, tid=None, before_tid=None): def getObject(self, oid, tid=None, before_tid=None):
u64 = util.u64 u64 = util.u64
u_oid = u64(oid) u_oid = u64(oid)
...@@ -511,6 +500,16 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -511,6 +500,16 @@ class ImporterDatabaseManager(DatabaseManager):
else: else:
return self.db.getTransaction(tid, all) return self.db.getTransaction(tid, all)
def getFinalTID(self, ttid):
if u64(ttid) <= self.zodb_ltid and self._import:
raise NotImplementedError
return self.db.getFinalTID(ttid)
def deleteTransaction(self, tid):
if u64(tid) <= self.zodb_ltid and self._import:
raise NotImplementedError
self.db.deleteTransaction(tid)
def getReplicationTIDList(self, min_tid, max_tid, length, partition): def getReplicationTIDList(self, min_tid, max_tid, length, partition):
p64 = util.p64 p64 = util.p64
tid = p64(self.zodb_tid) tid = p64(self.zodb_tid)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import ZERO_TID, BackendNotImplemented from neo.lib.protocol import ZERO_TID, BackendNotImplemented
def lazymethod(func): def lazymethod(func):
...@@ -94,6 +95,22 @@ class DatabaseManager(object): ...@@ -94,6 +95,22 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def nonempty(self, table):
"""Check whether table is empty or return None if it does not exist"""
raise NotImplementedError
def _checkNoUnfinishedTransactions(self, *hint):
if self.nonempty('ttrans') or self.nonempty('tobj'):
raise DatabaseFailure(
"The database can not be upgraded because you have unfinished"
" transactions. Use an older version of NEO to verify them.")
def _getVersion(self):
version = int(self.getConfiguration("version") or 0)
if self.VERSION < version:
raise DatabaseFailure("The database can not be downgraded.")
return version
def doOperation(self, app): def doOperation(self, app):
pass pass
...@@ -194,10 +211,18 @@ class DatabaseManager(object): ...@@ -194,10 +211,18 @@ class DatabaseManager(object):
def getBackupTID(self): def getBackupTID(self):
return util.bin(self.getConfiguration('backup_tid')) return util.bin(self.getConfiguration('backup_tid'))
def setBackupTID(self, backup_tid): def _setBackupTID(self, tid):
tid = util.dump(backup_tid) tid = util.dump(tid)
logging.debug('backup_tid = %s', tid) logging.debug('backup_tid = %s', tid)
return self.setConfiguration('backup_tid', tid) return self._setConfiguration('backup_tid', tid)
def getTruncateTID(self):
return util.bin(self.getConfiguration('truncate_tid'))
def _setTruncateTID(self, tid):
tid = util.dump(tid)
logging.debug('truncate_tid = %s', tid)
return self._setConfiguration('truncate_tid', tid)
def _setPackTID(self, tid): def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid) self._setConfiguration('_pack_tid', tid)
...@@ -222,10 +247,10 @@ class DatabaseManager(object): ...@@ -222,10 +247,10 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def _getLastIDs(self, all=True): def _getLastIDs(self):
raise NotImplementedError raise NotImplementedError
def getLastIDs(self, all=True): def getLastIDs(self):
trans, obj, oid = self._getLastIDs() trans, obj, oid = self._getLastIDs()
if trans: if trans:
tid = max(trans.itervalues()) tid = max(trans.itervalues())
...@@ -241,16 +266,16 @@ class DatabaseManager(object): ...@@ -241,16 +266,16 @@ class DatabaseManager(object):
trans = obj = {} trans = obj = {}
return tid, trans, obj, oid return tid, trans, obj, oid
def getUnfinishedTIDList(self): def _getUnfinishedTIDDict(self):
"""Return a list of unfinished transaction's IDs."""
raise NotImplementedError raise NotImplementedError
def objectPresent(self, oid, tid, all = True): def getUnfinishedTIDDict(self):
"""Return true iff an object specified by a given pair of an trans, obj = self._getUnfinishedTIDDict()
object ID and a transaction ID is present in a database. obj = dict.fromkeys(obj)
Otherwise, return false. If all is true, the object must be obj.update(trans)
searched from unfinished transactions as well.""" p64 = util.p64
raise NotImplementedError return {p64(ttid): None if tid is None else p64(tid)
for ttid, tid in obj.iteritems()}
@fallback @fallback
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
...@@ -478,14 +503,18 @@ class DatabaseManager(object): ...@@ -478,14 +503,18 @@ class DatabaseManager(object):
data_tid = p64(data_tid) data_tid = p64(data_tid)
return p64(current_tid), data_tid, is_current return p64(current_tid), data_tid, is_current
def finishTransaction(self, tid): def lockTransaction(self, tid, ttid):
"""Finish a transaction specified by a given ID, by moving """Mark voted transaction 'ttid' as committed with given 'tid'"""
temporarily data to a finished area.""" raise NotImplementedError
def unlockTransaction(self, tid, ttid):
"""Finalize a transaction by moving data to a finished area."""
raise NotImplementedError
def abortTransaction(self, ttid):
raise NotImplementedError raise NotImplementedError
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid):
"""Delete a transaction and its content specified by a given ID and
an oid list"""
raise NotImplementedError raise NotImplementedError
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
...@@ -498,12 +527,13 @@ class DatabaseManager(object): ...@@ -498,12 +527,13 @@ class DatabaseManager(object):
and max_tid (included)""" and max_tid (included)"""
raise NotImplementedError raise NotImplementedError
def truncate(self, tid): def truncate(self):
assert tid not in (None, ZERO_TID), tid tid = self.getTruncateTID()
assert self.getBackupTID() if tid:
self.setBackupTID(None) # XXX assert tid != ZERO_TID, tid
for partition in xrange(self.getNumPartitions()): for partition in xrange(self.getNumPartitions()):
self._deleteRange(partition, tid) self._deleteRange(partition, tid)
self._setTruncateTID(None)
self.commit() self.commit()
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
......
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
from binascii import a2b_hex from binascii import a2b_hex
import MySQLdb import MySQLdb
from MySQLdb import DataError, IntegrityError, OperationalError from MySQLdb import DataError, IntegrityError, \
OperationalError, ProgrammingError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
from array import array from array import array
from hashlib import sha1 from hashlib import sha1
import os import os
...@@ -42,6 +43,7 @@ def getPrintableQuery(query, max=70): ...@@ -42,6 +43,7 @@ def getPrintableQuery(query, max=70):
class MySQLDatabaseManager(DatabaseManager): class MySQLDatabaseManager(DatabaseManager):
"""This class manages a database on MySQL.""" """This class manages a database on MySQL."""
VERSION = 1
ENGINES = "InnoDB", "TokuDB" ENGINES = "InnoDB", "TokuDB"
_engine = ENGINES[0] # default engine _engine = ENGINES[0] # default engine
...@@ -144,16 +146,33 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -144,16 +146,33 @@ class MySQLDatabaseManager(DatabaseManager):
self.query("DROP TABLE IF EXISTS" self.query("DROP TABLE IF EXISTS"
" config, pt, trans, obj, data, bigdata, ttrans, tobj") " config, pt, trans, obj, data, bigdata, ttrans, tobj")
def nonempty(self, table):
try:
return bool(self.query("SELECT 1 FROM %s LIMIT 1" % table))
except ProgrammingError, (code, _):
if code != NO_SUCH_TABLE:
raise
def _setup(self): def _setup(self):
self._config.clear() self._config.clear()
q = self.query q = self.query
p = engine = self._engine p = engine = self._engine
# The table "config" stores configuration parameters which affect the
# persistent data. if self.nonempty("config") is None:
q("""CREATE TABLE IF NOT EXISTS config ( # The table "config" stores configuration
# parameters which affect the persistent data.
q("""CREATE TABLE config (
name VARBINARY(255) NOT NULL PRIMARY KEY, name VARBINARY(255) NOT NULL PRIMARY KEY,
value VARBINARY(255) NULL value VARBINARY(255) NULL
) ENGINE=""" + engine) ) ENGINE=""" + engine)
else:
# Automatic migration.
version = self._getVersion()
if version < 1:
self._checkNoUnfinishedTransactions()
q("DROP TABLE IF EXISTS ttrans")
self._setConfiguration("version", self.VERSION)
# The table "pt" stores a partition table. # The table "pt" stores a partition table.
q("""CREATE TABLE IF NOT EXISTS pt ( q("""CREATE TABLE IF NOT EXISTS pt (
...@@ -214,7 +233,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -214,7 +233,7 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans ( q("""CREATE TABLE IF NOT EXISTS ttrans (
`partition` SMALLINT UNSIGNED NOT NULL, `partition` SMALLINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED NOT NULL, tid BIGINT UNSIGNED,
packed BOOLEAN NOT NULL, packed BOOLEAN NOT NULL,
oids MEDIUMBLOB NOT NULL, oids MEDIUMBLOB NOT NULL,
user BLOB NOT NULL, user BLOB NOT NULL,
...@@ -274,7 +293,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -274,7 +293,7 @@ class MySQLDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans" return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans"
" WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0] " WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0]
def _getLastIDs(self, all=True): def _getLastIDs(self):
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
trans = {partition: p64(tid) trans = {partition: p64(tid)
...@@ -285,29 +304,21 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -285,29 +304,21 @@ class MySQLDatabaseManager(DatabaseManager):
" FROM obj GROUP BY `partition`")} " FROM obj GROUP BY `partition`")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY `partition`) as t")[0][0] " GROUP BY `partition`) as t")[0][0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj")[0]
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def _getUnfinishedTIDDict(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all = True):
oid = util.u64(oid)
tid = util.u64(tid)
q = self.query q = self.query
return q("SELECT 1 FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d" return q("SELECT ttid, tid FROM ttrans"), (ttid
% (self._getPartition(oid), oid, tid)) or all and \ for ttid, in q("SELECT DISTINCT tid FROM tobj"))
q("SELECT 1 FROM tobj WHERE tid=%d AND oid=%d" % (tid, oid))
def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# MariaDB is smart enough to realize that 'ttid' is constant.
r = self.query("SELECT tid FROM trans"
" WHERE `partition`=%s AND tid>=ttid AND ttid=%s LIMIT 1"
% (self._getPartition(ttid), ttid))
if r:
return util.p64(r[0][0])
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
oid = util.u64(oid) oid = util.u64(oid)
...@@ -450,9 +461,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -450,9 +461,9 @@ class MySQLDatabaseManager(DatabaseManager):
oid_list, user, desc, ext, packed, ttid = transaction oid_list, user, desc, ext, packed, ttid = transaction
partition = self._getPartition(tid) partition = self._getPartition(tid)
assert packed in (0, 1) assert packed in (0, 1)
q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % ( q("REPLACE INTO %s VALUES (%s,%s,%s,'%s','%s','%s','%s',%s)" % (
trans_table, partition, tid, packed, e(''.join(oid_list)), trans_table, partition, 'NULL' if temporary else tid, packed,
e(user), e(desc), e(ext), u64(ttid))) e(''.join(oid_list)), e(user), e(desc), e(ext), u64(ttid)))
if temporary: if temporary:
self.commit() self.commit()
...@@ -544,40 +555,40 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -544,40 +555,40 @@ class MySQLDatabaseManager(DatabaseManager):
r = self.query(sql) r = self.query(sql)
return r[0] if r else (None, None) return r[0] if r else (None, None)
def finishTransaction(self, tid): def lockTransaction(self, tid, ttid):
u64 = util.u64
self.query("UPDATE ttrans SET tid=%d WHERE ttid=%d LIMIT 1"
% (u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query q = self.query
tid = util.u64(tid) u64 = util.u64
sql = " FROM tobj WHERE tid=%d" % tid tid = u64(tid)
sql = " FROM tobj WHERE tid=%d" % u64(ttid)
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("INSERT INTO obj SELECT *" + sql) q("INSERT INTO obj SELECT `partition`, oid, %d, data_id, value_tid %s"
q("DELETE FROM tobj WHERE tid=%d" % tid) % (tid, sql))
q("DELETE" + sql)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid) q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.releaseData(data_id_list) self.releaseData(data_id_list)
self.commit() self.commit()
def deleteTransaction(self, tid, oid_list=()): def abortTransaction(self, ttid):
u64 = util.u64 ttid = util.u64(ttid)
tid = u64(tid)
getPartition = self._getPartition
q = self.query q = self.query
sql = " FROM tobj WHERE tid=%d" % tid sql = " FROM tobj WHERE tid=%s" % ttid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.releaseData(data_id_list)
q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
q("""DELETE FROM trans WHERE `partition` = %d AND tid = %d""" %
(getPartition(tid), tid))
# delete from obj using indexes
data_id_list = set(data_id_list)
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d" \
% (getPartition(oid), oid, tid)
data_id_list.update(*q("SELECT data_id" + sql))
q("DELETE" + sql) q("DELETE" + sql)
data_id_list.discard(None) q("DELETE FROM ttrans WHERE ttid=%s" % ttid)
self._pruneData(data_id_list) self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" %
(self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
u64 = util.u64 u64 = util.u64
......
...@@ -66,6 +66,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -66,6 +66,8 @@ class SQLiteDatabaseManager(DatabaseManager):
never be used for small requests. never be used for small requests.
""" """
VERSION = 1
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
super(SQLiteDatabaseManager, self).__init__(*args, **kw) super(SQLiteDatabaseManager, self).__init__(*args, **kw)
self._config = {} self._config = {}
...@@ -101,15 +103,32 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -101,15 +103,32 @@ class SQLiteDatabaseManager(DatabaseManager):
for t in 'config', 'pt', 'trans', 'obj', 'data', 'ttrans', 'tobj': for t in 'config', 'pt', 'trans', 'obj', 'data', 'ttrans', 'tobj':
self.query('DROP TABLE IF EXISTS ' + t) self.query('DROP TABLE IF EXISTS ' + t)
def nonempty(self, table):
try:
return bool(self.query(
"SELECT 1 FROM %s LIMIT 1" % table).fetchone())
except sqlite3.OperationalError as e:
if not e.args[0].startswith("no such table:"):
raise
def _setup(self): def _setup(self):
self._config.clear() self._config.clear()
q = self.query q = self.query
# The table "config" stores configuration parameters which affect the
# persistent data. if self.nonempty("config") is None:
q("""CREATE TABLE IF NOT EXISTS config ( # The table "config" stores configuration
name TEXT NOT NULL PRIMARY KEY, # parameters which affect the persistent data.
value TEXT) q("CREATE TABLE IF NOT EXISTS config ("
""") " name TEXT NOT NULL PRIMARY KEY,"
" value TEXT)")
else:
# Automatic migration.
version = self._getVersion()
if version < 1:
self._checkNoUnfinishedTransactions()
q("DROP TABLE IF EXISTS ttrans")
self._setConfiguration("version", self.VERSION)
# The table "pt" stores a partition table. # The table "pt" stores a partition table.
q("""CREATE TABLE IF NOT EXISTS pt ( q("""CREATE TABLE IF NOT EXISTS pt (
...@@ -162,7 +181,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -162,7 +181,7 @@ class SQLiteDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans ( q("""CREATE TABLE IF NOT EXISTS ttrans (
partition INTEGER NOT NULL, partition INTEGER NOT NULL,
tid INTEGER NOT NULL, tid INTEGER,
packed BOOLEAN NOT NULL, packed BOOLEAN NOT NULL,
oids BLOB NOT NULL, oids BLOB NOT NULL,
user BLOB NOT NULL, user BLOB NOT NULL,
...@@ -221,7 +240,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -221,7 +240,7 @@ class SQLiteDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?", return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?",
(max_tid,)).next()[0] (max_tid,)).next()[0]
def _getLastIDs(self, all=True): def _getLastIDs(self):
p64 = util.p64 p64 = util.p64
q = self.query q = self.query
trans = {partition: p64(tid) trans = {partition: p64(tid)
...@@ -232,30 +251,21 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -232,30 +251,21 @@ class SQLiteDatabaseManager(DatabaseManager):
" FROM obj GROUP BY partition")} " FROM obj GROUP BY partition")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj" oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t").next()[0] " GROUP BY partition) as t").next()[0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans").next()[0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj").next()
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid) return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self): def _getUnfinishedTIDDict(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid)
tid = util.u64(tid)
q = self.query q = self.query
return q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?", return q("SELECT ttid, tid FROM ttrans"), (ttid
(self._getPartition(oid), oid, tid)).fetchone() or all and \ for ttid, in q("SELECT DISTINCT tid FROM tobj"))
q("SELECT 1 FROM tobj WHERE tid=? AND oid=?",
(tid, oid)).fetchone() def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# As of SQLite 3.8.7.1, 'tid>=ttid' would ignore the index on tid,
# even though ttid is a constant.
for tid, in self.query("SELECT tid FROM trans"
" WHERE partition=? AND tid>=? AND ttid=? LIMIT 1",
(self._getPartition(ttid), ttid, ttid)):
return util.p64(tid)
def getLastObjectTID(self, oid): def getLastObjectTID(self, oid):
oid = util.u64(oid) oid = util.u64(oid)
...@@ -362,7 +372,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -362,7 +372,8 @@ class SQLiteDatabaseManager(DatabaseManager):
partition = self._getPartition(tid) partition = self._getPartition(tid)
assert packed in (0, 1) assert packed in (0, 1)
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T, q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T,
(partition, tid, packed, buffer(''.join(oid_list)), (partition, None if temporary else tid,
packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext), u64(ttid))) buffer(user), buffer(desc), buffer(ext), u64(ttid)))
if temporary: if temporary:
self.commit() self.commit()
...@@ -407,40 +418,41 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -407,40 +418,41 @@ class SQLiteDatabaseManager(DatabaseManager):
r = r.fetchone() r = r.fetchone()
return r or (None, None) return r or (None, None)
def finishTransaction(self, tid): def lockTransaction(self, tid, ttid):
args = util.u64(tid), u64 = util.u64
self.query("UPDATE ttrans SET tid=? WHERE ttid=?",
(u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query q = self.query
u64 = util.u64
tid = u64(tid)
ttid = u64(ttid)
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, (ttid,)) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args) q("INSERT INTO obj SELECT partition, oid, ?, data_id, value_tid" + sql,
q("DELETE FROM tobj WHERE tid=?", args) (tid, ttid))
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args) q("DELETE" + sql, (ttid,))
q("DELETE FROM ttrans WHERE tid=?", args) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
self.releaseData(data_id_list) self.releaseData(data_id_list)
self.commit() self.commit()
def deleteTransaction(self, tid, oid_list=()): def abortTransaction(self, ttid):
u64 = util.u64 args = util.u64(ttid),
tid = u64(tid)
getPartition = self._getPartition
q = self.query q = self.query
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
self.releaseData(data_id_list)
q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?",
(getPartition(tid), tid))
# delete from obj using indexes
data_id_list = set(data_id_list)
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=? AND oid=? AND tid=?"
args = getPartition(oid), oid, tid
data_id_list.update(*q("SELECT data_id" + sql, args))
q("DELETE" + sql, args) q("DELETE" + sql, args)
data_id_list.discard(None) q("DELETE FROM ttrans WHERE ttid=?", args)
self._pruneData(data_id_list) self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
oid = util.u64(oid) oid = util.u64(oid)
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.exception import PrimaryFailure, OperationFailure from neo.lib.exception import PrimaryFailure, StoppedOperation
from neo.lib.protocol import uuid_str, NodeStates, NodeTypes from neo.lib.protocol import uuid_str, NodeStates, NodeTypes, Packets
class BaseMasterHandler(EventHandler): class BaseMasterHandler(EventHandler):
...@@ -27,7 +27,7 @@ class BaseMasterHandler(EventHandler): ...@@ -27,7 +27,7 @@ class BaseMasterHandler(EventHandler):
raise PrimaryFailure('connection lost') raise PrimaryFailure('connection lost')
def stopOperation(self, conn): def stopOperation(self, conn):
raise OperationFailure('operation stopped') raise StoppedOperation
def reelectPrimary(self, conn): def reelectPrimary(self, conn):
raise PrimaryFailure('re-election occurs') raise PrimaryFailure('re-election occurs')
...@@ -48,7 +48,7 @@ class BaseMasterHandler(EventHandler): ...@@ -48,7 +48,7 @@ class BaseMasterHandler(EventHandler):
erase = state == NodeStates.DOWN erase = state == NodeStates.DOWN
self.app.shutdown(erase=erase) self.app.shutdown(erase=erase)
elif state == NodeStates.HIDDEN: elif state == NodeStates.HIDDEN:
raise OperationFailure raise StoppedOperation
elif node_type == NodeTypes.CLIENT and state != NodeStates.RUNNING: elif node_type == NodeTypes.CLIENT and state != NodeStates.RUNNING:
logging.info('Notified of non-running client, abort (%s)', logging.info('Notified of non-running client, abort (%s)',
uuid_str(uuid)) uuid_str(uuid))
...@@ -56,3 +56,6 @@ class BaseMasterHandler(EventHandler): ...@@ -56,3 +56,6 @@ class BaseMasterHandler(EventHandler):
def answerUnfinishedTransactions(self, conn, *args, **kw): def answerUnfinishedTransactions(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(*args, **kw) self.app.replicator.setUnfinishedTIDList(*args, **kw)
def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
...@@ -19,7 +19,7 @@ from neo.lib.handler import EventHandler ...@@ -19,7 +19,7 @@ from neo.lib.handler import EventHandler
from neo.lib.util import dump, makeChecksum from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \ from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \
ZERO_HASH, INVALID_PARTITION ZERO_HASH, INVALID_PARTITION
from ..transactions import ConflictError, DelayedError from ..transactions import ConflictError, DelayedError, NotRegisteredError
from ..exception import AlreadyPendingError from ..exception import AlreadyPendingError
import time import time
...@@ -68,21 +68,17 @@ class ClientOperationHandler(EventHandler): ...@@ -68,21 +68,17 @@ class ClientOperationHandler(EventHandler):
def abortTransaction(self, conn, ttid): def abortTransaction(self, conn, ttid):
self.app.tm.abort(ttid) self.app.tm.abort(ttid)
def askStoreTransaction(self, conn, ttid, user, desc, ext, oid_list): def askStoreTransaction(self, conn, ttid, *txn_info):
self.app.tm.register(conn.getUUID(), ttid) self.app.tm.register(conn.getUUID(), ttid)
self.app.tm.storeTransaction(ttid, oid_list, user, desc, ext, False) self.app.tm.vote(ttid, txn_info)
conn.answer(Packets.AnswerStoreTransaction(ttid)) conn.answer(Packets.AnswerStoreTransaction())
def askVoteTransaction(self, conn, ttid):
self.app.tm.vote(ttid)
conn.answer(Packets.AnswerVoteTransaction())
def _askStoreObject(self, conn, oid, serial, compression, checksum, data, def _askStoreObject(self, conn, oid, serial, compression, checksum, data,
data_serial, ttid, unlock, request_time): data_serial, ttid, unlock, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
return
try: try:
self.app.tm.storeObject(ttid, serial, oid, compression, self.app.tm.storeObject(ttid, serial, oid, compression,
checksum, data, data_serial, unlock) checksum, data, data_serial, unlock)
...@@ -101,6 +97,13 @@ class ClientOperationHandler(EventHandler): ...@@ -101,6 +97,13 @@ class ClientOperationHandler(EventHandler):
raise_on_duplicate=unlock) raise_on_duplicate=unlock)
except AlreadyPendingError: except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid))) conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
else: else:
if SLOW_STORE is not None: if SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
...@@ -189,14 +192,6 @@ class ClientOperationHandler(EventHandler): ...@@ -189,14 +192,6 @@ class ClientOperationHandler(EventHandler):
self._askCheckCurrentSerial(conn, ttid, serial, oid, time.time()) self._askCheckCurrentSerial(conn, ttid, serial, oid, time.time())
def _askCheckCurrentSerial(self, conn, ttid, serial, oid, request_time): def _askCheckCurrentSerial(self, conn, ttid, serial, oid, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
return
try: try:
self.app.tm.checkCurrentSerial(ttid, serial, oid) self.app.tm.checkCurrentSerial(ttid, serial, oid)
except ConflictError, err: except ConflictError, err:
...@@ -210,6 +205,13 @@ class ClientOperationHandler(EventHandler): ...@@ -210,6 +205,13 @@ class ClientOperationHandler(EventHandler):
serial, oid, request_time), key=(oid, ttid)) serial, oid, request_time), key=(oid, ttid))
except AlreadyPendingError: except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid))) conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
else: else:
if SLOW_STORE is not None: if SLOW_STORE is not None:
duration = time.time() - request_time duration = time.time() - request_time
......
...@@ -15,24 +15,23 @@ ...@@ -15,24 +15,23 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import BaseMasterHandler from . import BaseMasterHandler
from neo.lib import logging, protocol from neo.lib import logging
from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
def answerNodeInformation(self, conn): def answerNodeInformation(self, conn):
pass pass
def answerPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, row_list):
app = self.app app = self.app
pt = app.pt pt = app.pt
pt.load(ptid, row_list, self.app.nm) pt.load(ptid, row_list, self.app.nm)
if not pt.filled(): if not pt.filled():
raise protocol.ProtocolError('Partial partition table received') raise ProtocolError('Partial partition table received')
logging.debug('Got the partition table:')
self.app.pt.log()
# Install the partition table into the database for persistency. # Install the partition table into the database for persistency.
cell_list = [] cell_list = []
num_partitions = app.pt.getPartitions() num_partitions = pt.getPartitions()
unassigned_set = set(xrange(num_partitions)) unassigned_set = set(xrange(num_partitions))
for offset in xrange(num_partitions): for offset in xrange(num_partitions):
for cell in pt.getCellList(offset): for cell in pt.getCellList(offset):
...@@ -46,12 +45,47 @@ class InitializationHandler(BaseMasterHandler): ...@@ -46,12 +45,47 @@ class InitializationHandler(BaseMasterHandler):
app.dm.changePartitionTable(ptid, cell_list, reset=True) app.dm.changePartitionTable(ptid, cell_list, reset=True)
def notifyPartitionChanges(self, conn, ptid, cell_list): def truncate(self, conn, tid):
# XXX: This is safe to ignore those notifications because all of the dm = self.app.dm
# following applies: dm._setBackupTID(None)
# - we first ask for node information, and *then* partition dm._setTruncateTID(tid)
# table content, so it is possible to get notifyPartitionChanges dm.commit()
# packets in between (or even before asking for node information).
# - this handler will be changed after receiving answerPartitionTable def askRecovery(self, conn):
# and before handling the next packet app = self.app
logging.debug('ignoring notifyPartitionChanges during initialization') conn.answer(Packets.AnswerRecovery(
app.pt.getID(),
app.dm.getBackupTID(),
app.dm.getTruncateTID()))
def askLastIDs(self, conn):
dm = self.app.dm
dm.truncate()
ltid, _, _, loid = dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs(loid, ltid))
def askPartitionTable(self, conn):
pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
def askLockedTransactions(self, conn):
conn.answer(Packets.AnswerLockedTransactions(
self.app.dm.getUnfinishedTIDDict()))
def validateTransaction(self, conn, ttid, tid):
dm = self.app.dm
dm.lockTransaction(tid, ttid)
dm.unlockTransaction(tid, ttid)
def startOperation(self, conn, backup):
self.app.operational = True
# XXX: see comment in protocol
dm = self.app.dm
if backup:
if dm.getBackupTID():
return
tid = dm.getLastIDs()[0] or ZERO_TID
else:
tid = None
dm._setBackupTID(tid)
dm.commit()
...@@ -16,13 +16,21 @@ ...@@ -16,13 +16,21 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.protocol import Packets, ProtocolError from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
from . import BaseMasterHandler from . import BaseMasterHandler
class MasterOperationHandler(BaseMasterHandler): class MasterOperationHandler(BaseMasterHandler):
""" This handler is used for the primary master """ """ This handler is used for the primary master """
def startOperation(self, conn, backup):
# XXX: see comment in protocol
assert self.app.operational and backup
dm = self.app.dm
if not dm.getBackupTID():
dm._setBackupTID(dm.getLastIDs()[0] or ZERO_TID)
dm.commit()
def notifyTransactionFinished(self, conn, *args, **kw): def notifyTransactionFinished(self, conn, *args, **kw):
self.app.replicator.transactionFinished(*args, **kw) self.app.replicator.transactionFinished(*args, **kw)
...@@ -42,17 +50,11 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -42,17 +50,11 @@ class MasterOperationHandler(BaseMasterHandler):
# Check changes for replications # Check changes for replications
app.replicator.notifyPartitionChanges(cell_list) app.replicator.notifyPartitionChanges(cell_list)
def askLockInformation(self, conn, ttid, tid, oid_list): def askLockInformation(self, conn, ttid, tid):
if not ttid in self.app.tm: self.app.tm.lock(ttid, tid)
raise ProtocolError('Unknown transaction')
self.app.tm.lock(ttid, tid, oid_list)
if not conn.isClosed():
conn.answer(Packets.AnswerInformationLocked(ttid)) conn.answer(Packets.AnswerInformationLocked(ttid))
def notifyUnlockInformation(self, conn, ttid): def notifyUnlockInformation(self, conn, ttid):
if not ttid in self.app.tm:
raise ProtocolError('Unknown transaction')
# TODO: send an answer
self.app.tm.unlock(ttid) self.app.tm.unlock(ttid)
def askPack(self, conn, tid): def askPack(self, conn, tid):
...@@ -60,17 +62,11 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -60,17 +62,11 @@ class MasterOperationHandler(BaseMasterHandler):
logging.info('Pack started, up to %s...', dump(tid)) logging.info('Pack started, up to %s...', dump(tid))
app.dm.pack(tid, app.tm.updateObjectDataForPack) app.dm.pack(tid, app.tm.updateObjectDataForPack)
logging.info('Pack finished.') logging.info('Pack finished.')
if not conn.isClosed():
conn.answer(Packets.AnswerPack(True)) conn.answer(Packets.AnswerPack(True))
def replicate(self, conn, tid, upstream_name, source_dict): def replicate(self, conn, tid, upstream_name, source_dict):
self.app.replicator.backup(tid, {p: a and (a, upstream_name) self.app.replicator.backup(tid, {p: a and (a, upstream_name)
for p, a in source_dict.iteritems()}) for p, a in source_dict.iteritems()})
def truncate(self, conn, tid):
self.app.replicator.cancel()
self.app.dm.truncate(tid)
conn.close()
def checkPartition(self, conn, *args): def checkPartition(self, conn, *args):
self.app.checker(*args) self.app.checker(*args)
#
# Copyright (C) 2006-2015 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import BaseMasterHandler
from neo.lib import logging
from neo.lib.protocol import Packets, Errors, INVALID_TID, ZERO_TID
from neo.lib.util import dump
from neo.lib.exception import OperationFailure
class VerificationHandler(BaseMasterHandler):
"""This class deals with events for a verification phase."""
def askLastIDs(self, conn):
app = self.app
ltid, _, _, loid = app.dm.getLastIDs()
conn.answer(Packets.AnswerLastIDs(
loid,
ltid,
app.pt.getID(),
app.dm.getBackupTID()))
def askPartitionTable(self, conn):
pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
def notifyPartitionChanges(self, conn, ptid, cell_list):
"""This is very similar to Send Partition Table, except that
the information is only about changes from the previous."""
app = self.app
if ptid <= app.pt.getID():
# Ignore this packet.
logging.debug('ignoring older partition changes')
return
# update partition table in memory and the database
app.pt.update(ptid, cell_list, app.nm)
app.dm.changePartitionTable(ptid, cell_list)
def startOperation(self, conn, backup):
self.app.operational = True
dm = self.app.dm
if backup:
if dm.getBackupTID():
return
tid = dm.getLastIDs()[0] or ZERO_TID
else:
tid = None
dm.setBackupTID(tid)
def stopOperation(self, conn):
raise OperationFailure('operation stopped')
def askUnfinishedTransactions(self, conn):
tid_list = self.app.dm.getUnfinishedTIDList()
conn.answer(Packets.AnswerUnfinishedTransactions(INVALID_TID, tid_list))
def askTransactionInformation(self, conn, tid):
app = self.app
t = app.dm.getTransaction(tid, all=True)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def askObjectPresent(self, conn, oid, tid):
if self.app.dm.objectPresent(oid, tid):
p = Packets.AnswerObjectPresent(oid, tid)
else:
p = Errors.OidNotFound(
'%s:%s do not exist' % (dump(oid), dump(tid)))
conn.answer(p)
def deleteTransaction(self, conn, tid, oid_list):
self.app.dm.deleteTransaction(tid, oid_list)
def commitTransaction(self, conn, tid):
self.app.dm.finishTransaction(tid)
...@@ -128,7 +128,8 @@ class Replicator(object): ...@@ -128,7 +128,8 @@ class Replicator(object):
if tid: if tid:
new_tid = self.getBackupTID() new_tid = self.getBackupTID()
if tid != new_tid: if tid != new_tid:
dm.setBackupTID(new_tid) dm._setBackupTID(new_tid)
dm.commit()
def populate(self): def populate(self):
app = self.app app = self.app
......
...@@ -38,18 +38,22 @@ class DelayedError(Exception): ...@@ -38,18 +38,22 @@ class DelayedError(Exception):
Raised when an object is locked by a previous transaction Raised when an object is locked by a previous transaction
""" """
class NotRegisteredError(Exception):
"""
Raised when a ttid is not registered
"""
class Transaction(object): class Transaction(object):
""" """
Container for a pending transaction Container for a pending transaction
""" """
_tid = None _tid = None
has_trans = False
def __init__(self, uuid, ttid): def __init__(self, uuid, ttid):
self._uuid = uuid self._uuid = uuid
self._ttid = ttid self._ttid = ttid
self._object_dict = {} self._object_dict = {}
self._transaction = None
self._locked = False self._locked = False
self._birth = time() self._birth = time()
self._checked_set = set() self._checked_set = set()
...@@ -89,13 +93,6 @@ class Transaction(object): ...@@ -89,13 +93,6 @@ class Transaction(object):
def isLocked(self): def isLocked(self):
return self._locked return self._locked
def prepare(self, oid_list, user, desc, ext, packed):
"""
Set the transaction informations
"""
# assert self._transaction is not None
self._transaction = oid_list, user, desc, ext, packed, self._ttid
def addObject(self, oid, data_id, value_serial): def addObject(self, oid, data_id, value_serial):
""" """
Add an object to the transaction Add an object to the transaction
...@@ -121,9 +118,6 @@ class Transaction(object): ...@@ -121,9 +118,6 @@ class Transaction(object):
def getLockedOIDList(self): def getLockedOIDList(self):
return self._object_dict.keys() + list(self._checked_set) return self._object_dict.keys() + list(self._checked_set)
def getTransactionInformations(self):
return self._transaction
class TransactionManager(object): class TransactionManager(object):
""" """
...@@ -137,12 +131,6 @@ class TransactionManager(object): ...@@ -137,12 +131,6 @@ class TransactionManager(object):
self._load_lock_dict = {} self._load_lock_dict = {}
self._uuid_dict = {} self._uuid_dict = {}
def __contains__(self, ttid):
"""
Returns True if the TID is known by the manager
"""
return ttid in self._transaction_dict
def register(self, uuid, ttid): def register(self, uuid, ttid):
""" """
Register a transaction, it may be already registered Register a transaction, it may be already registered
...@@ -174,7 +162,21 @@ class TransactionManager(object): ...@@ -174,7 +162,21 @@ class TransactionManager(object):
self._load_lock_dict.clear() self._load_lock_dict.clear()
self._uuid_dict.clear() self._uuid_dict.clear()
def lock(self, ttid, tid, oid_list): def vote(self, ttid, txn_info=None):
"""
Store transaction information received from client node
"""
logging.debug('Vote TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid]
object_list = transaction.getObjectList()
if txn_info:
user, desc, ext, oid_list = txn_info
txn_info = oid_list, user, desc, ext, False, ttid
transaction.has_trans = True
# store metadata to temporary table
self._app.dm.storeTransaction(ttid, object_list, txn_info)
def lock(self, ttid, tid):
""" """
Lock a transaction Lock a transaction
""" """
...@@ -182,43 +184,22 @@ class TransactionManager(object): ...@@ -182,43 +184,22 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
# remember that the transaction has been locked # remember that the transaction has been locked
transaction.lock() transaction.lock()
for oid in transaction.getOIDList(): self._load_lock_dict.update(
self._load_lock_dict[oid] = ttid dict.fromkeys(transaction.getOIDList(), ttid))
# check every object that should be locked # commit transaction and remember its definitive TID
uuid = transaction.getUUID() if transaction.has_trans:
is_assigned = self._app.pt.isAssigned self._app.dm.lockTransaction(tid, ttid)
for oid in oid_list:
if is_assigned(oid, uuid) and \
self._load_lock_dict.get(oid) != ttid:
raise ValueError, 'Some locks are not held'
object_list = transaction.getObjectList()
# txn_info is None is the transaction information is not stored on
# this storage.
txn_info = transaction.getTransactionInformations()
# store data from memory to temporary table
self._app.dm.storeTransaction(tid, object_list, txn_info)
# ...and remember its definitive TID
transaction.setTID(tid) transaction.setTID(tid)
def getTIDFromTTID(self, ttid):
return self._transaction_dict[ttid].getTID()
def unlock(self, ttid): def unlock(self, ttid):
""" """
Unlock transaction Unlock transaction
""" """
logging.debug('Unlock TXN %s', dump(ttid)) tid = self._transaction_dict[ttid].getTID()
self._app.dm.finishTransaction(self.getTIDFromTTID(ttid)) logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid))
self._app.dm.unlockTransaction(tid, ttid)
self.abort(ttid, even_if_locked=True) self.abort(ttid, even_if_locked=True)
def storeTransaction(self, ttid, oid_list, user, desc, ext, packed):
"""
Store transaction information received from client node
"""
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid]
transaction.prepare(oid_list, user, desc, ext, packed)
def getLockingTID(self, oid): def getLockingTID(self, oid):
return self._store_lock_dict.get(oid) return self._store_lock_dict.get(oid)
...@@ -283,9 +264,11 @@ class TransactionManager(object): ...@@ -283,9 +264,11 @@ class TransactionManager(object):
self._store_lock_dict[oid] = ttid self._store_lock_dict[oid] = ttid
def checkCurrentSerial(self, ttid, serial, oid): def checkCurrentSerial(self, ttid, serial, oid):
self.lockObject(ttid, serial, oid, unlock=True) try:
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid] transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=True)
transaction.addCheckedObject(oid) transaction.addCheckedObject(oid)
def storeObject(self, ttid, serial, oid, compression, checksum, data, def storeObject(self, ttid, serial, oid, compression, checksum, data,
...@@ -293,14 +276,17 @@ class TransactionManager(object): ...@@ -293,14 +276,17 @@ class TransactionManager(object):
""" """
Store an object received from client node Store an object received from client node
""" """
try:
transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=unlock) self.lockObject(ttid, serial, oid, unlock=unlock)
# store object # store object
assert ttid in self, "Transaction not registered"
if data is None: if data is None:
data_id = None data_id = None
else: else:
data_id = self._app.dm.holdData(checksum, data, compression) data_id = self._app.dm.holdData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, data_id, value_serial) transaction.addObject(oid, data_id, value_serial)
def abort(self, ttid, even_if_locked=False): def abort(self, ttid, even_if_locked=False):
""" """
...@@ -322,9 +308,7 @@ class TransactionManager(object): ...@@ -322,9 +308,7 @@ class TransactionManager(object):
if not even_if_locked: if not even_if_locked:
return return
else: else:
self._app.dm.releaseData([data_id self._app.dm.abortTransaction(ttid)
for oid, data_id, value_serial in transaction.getObjectList()
if data_id], True)
# unlock any object # unlock any object
for oid in transaction.getLockedOIDList(): for oid in transaction.getLockedOIDList():
if has_load_lock: if has_load_lock:
......
...@@ -463,9 +463,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -463,9 +463,6 @@ class NeoUnitTestBase(NeoTestBase):
def checkAskTransactionInformation(self, conn, **kw): def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw) return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObjectPresent(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObjectPresent, **kw)
def checkAskObject(self, conn, **kw): def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw) return self.checkAskPacket(conn, Packets.AskObject, **kw)
...@@ -514,18 +511,12 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -514,18 +511,12 @@ class NeoUnitTestBase(NeoTestBase):
def checkAnswerObjectHistory(self, conn, **kw): def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw) return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
def checkAnswerStoreObject(self, conn, **kw): def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw) return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw): def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw) return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
def checkAnswerObjectPresent(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
class Patch(object): class Patch(object):
......
...@@ -77,7 +77,7 @@ class ClusterTests(NEOFunctionalTest): ...@@ -77,7 +77,7 @@ class ClusterTests(NEOFunctionalTest):
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
self.neo.killStorage() self.neo.killStorage()
self.neo.expectClusterVerifying() self.neo.expectClusterRecovering()
def testClusterBreaksWithTwoNodes(self): def testClusterBreaksWithTwoNodes(self):
self.neo = NEOCluster(['test_neo1', 'test_neo2'], self.neo = NEOCluster(['test_neo1', 'test_neo2'],
...@@ -88,7 +88,7 @@ class ClusterTests(NEOFunctionalTest): ...@@ -88,7 +88,7 @@ class ClusterTests(NEOFunctionalTest):
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
self.neo.killStorage() self.neo.killStorage()
self.neo.expectClusterVerifying() self.neo.expectClusterRecovering()
def testClusterDoesntBreakWithTwoNodesOneReplica(self): def testClusterDoesntBreakWithTwoNodesOneReplica(self):
self.neo = NEOCluster(['test_neo1', 'test_neo2'], self.neo = NEOCluster(['test_neo1', 'test_neo2'],
...@@ -127,7 +127,7 @@ class ClusterTests(NEOFunctionalTest): ...@@ -127,7 +127,7 @@ class ClusterTests(NEOFunctionalTest):
self.assertEqual(len(self.neo.getClientlist()), 1) self.assertEqual(len(self.neo.getClientlist()), 1)
# drop the storage, the cluster is no more operational... # drop the storage, the cluster is no more operational...
self.neo.getStorageProcessList()[0].stop() self.neo.getStorageProcessList()[0].stop()
self.neo.expectClusterVerifying() self.neo.expectClusterRecovering()
# ...and the client gets disconnected # ...and the client gets disconnected
self.assertEqual(len(self.neo.getClientlist()), 0) self.assertEqual(len(self.neo.getClientlist()), 0)
# restart storage so that the cluster is operational again # restart storage so that the cluster is operational again
......
...@@ -179,7 +179,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -179,7 +179,7 @@ class StorageTests(NEOFunctionalTest):
# Cluster not operational anymore. Only cells of second storage that # Cluster not operational anymore. Only cells of second storage that
# were shared with the third one should become outdated. # were shared with the third one should become outdated.
self.neo.expectUnavailable(started[1]) self.neo.expectUnavailable(started[1])
self.neo.expectClusterVerifying() self.neo.expectClusterRecovering()
self.neo.expectOudatedCells(3) self.neo.expectOudatedCells(3)
def testVerificationTriggered(self): def testVerificationTriggered(self):
...@@ -200,7 +200,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -200,7 +200,7 @@ class StorageTests(NEOFunctionalTest):
# stop it, the cluster must switch to verification # stop it, the cluster must switch to verification
started[0].stop() started[0].stop()
self.neo.expectUnavailable(started[0]) self.neo.expectUnavailable(started[0])
self.neo.expectClusterVerifying() self.neo.expectClusterRecovering()
# client must have been disconnected # client must have been disconnected
self.assertEqual(len(self.neo.getClientlist()), 0) self.assertEqual(len(self.neo.getClientlist()), 0)
conn.close() conn.close()
...@@ -245,7 +245,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -245,7 +245,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectUnavailable(started[1]) self.neo.expectUnavailable(started[1])
self.neo.expectUnavailable(started[2]) self.neo.expectUnavailable(started[2])
self.neo.expectOudatedCells(number=20) self.neo.expectOudatedCells(number=20)
self.neo.expectClusterVerifying() self.neo.expectClusterRecovering()
def testConflictingStorageRejected(self): def testConflictingStorageRejected(self):
""" Check that a storage coming after the recovery process with the same """ Check that a storage coming after the recovery process with the same
...@@ -403,7 +403,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -403,7 +403,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectUnavailable(started[0]) self.neo.expectUnavailable(started[0])
self.neo.expectUnavailable(started[1]) self.neo.expectUnavailable(started[1])
self.neo.expectOudatedCells(number=10) self.neo.expectOudatedCells(number=10)
self.neo.expectClusterVerifying() self.neo.expectClusterRecovering()
# XXX: need to sync with storages first # XXX: need to sync with storages first
self.neo.stop() self.neo.stop()
......
...@@ -67,29 +67,6 @@ class MasterRecoveryTests(NeoUnitTestBase): ...@@ -67,29 +67,6 @@ class MasterRecoveryTests(NeoUnitTestBase):
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(), self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.TEMPORARILY_DOWN) NodeStates.TEMPORARILY_DOWN)
def test_09_answerLastIDs(self):
recovery = self.recovery
uuid = self.identifyToMasterNode()
oid1 = self.getOID(1)
oid2 = self.getOID(2)
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
ptid1 = self.getPTID(1)
ptid2 = self.getPTID(2)
self.app.tm.setLastOID(oid1)
self.app.tm.setLastTID(tid1)
self.app.pt.setID(ptid1)
# send information which are later to what PMN knows, this must update target node
conn = self.getFakeConnection(uuid, self.storage_port)
self.assertTrue(ptid2 > self.app.pt.getID())
self.assertTrue(oid2 > self.app.tm.getLastOID())
self.assertTrue(tid2 > self.app.tm.getLastTID())
recovery.answerLastIDs(conn, oid2, tid2, ptid2, None)
self.assertEqual(oid2, self.app.tm.getLastOID())
self.assertEqual(tid2, self.app.tm.getLastTID())
self.assertEqual(ptid2, recovery.target_ptid)
def test_10_answerPartitionTable(self): def test_10_answerPartitionTable(self):
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port) uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
......
...@@ -21,7 +21,7 @@ from neo.lib.protocol import NodeTypes, NodeStates, Packets ...@@ -21,7 +21,7 @@ from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.storage import StorageServiceHandler from neo.master.handlers.storage import StorageServiceHandler
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
from neo.lib.exception import OperationFailure from neo.lib.exception import StoppedOperation
class MasterStorageHandlerTests(NeoUnitTestBase): class MasterStorageHandlerTests(NeoUnitTestBase):
...@@ -114,24 +114,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -114,24 +114,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.checkNotifyUnlockInformation(storage_conn_1) self.checkNotifyUnlockInformation(storage_conn_1)
self.checkNotifyUnlockInformation(storage_conn_2) self.checkNotifyUnlockInformation(storage_conn_2)
def test_12_askLastIDs(self):
service = self.service
node, conn = self.identifyToMasterNode()
# give a uuid
conn = self.getFakeConnection(node.getUUID(), self.storage_address)
ptid = self.app.pt.getID()
oid = self.getOID(1)
tid = self.getNextTID()
self.app.tm.setLastOID(oid)
self.app.tm.setLastTID(tid)
service.askLastIDs(conn)
packet = self.checkAnswerLastIDs(conn)
loid, ltid, lptid, backup_tid = packet.decode()
self.assertEqual(loid, oid)
self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid)
self.assertEqual(backup_tid, None)
def test_13_askUnfinishedTransactions(self): def test_13_askUnfinishedTransactions(self):
service = self.service service = self.service
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
...@@ -173,64 +155,10 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -173,64 +155,10 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
# drop the second, no storage node left # drop the second, no storage node left
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
self.assertEqual(node2.getState(), NodeStates.RUNNING) self.assertEqual(node2.getState(), NodeStates.RUNNING)
self.assertRaises(OperationFailure, method, conn2) self.assertRaises(StoppedOperation, method, conn2)
self.assertEqual(node2.getState(), state) self.assertEqual(node2.getState(), state)
self.assertEqual(lptid, self.app.pt.getID()) self.assertEqual(lptid, self.app.pt.getID())
def test_nodeLostAfterAskLockInformation(self):
# 2 storage nodes, one will die
node1, conn1 = self._getStorage()
node2, conn2 = self._getStorage()
# client nodes, to distinguish answers for the sample transactions
client1, cconn1 = self._getClient()
client2, cconn2 = self._getClient()
client3, cconn3 = self._getClient()
oid_list = [self.getOID(), ]
# Some shortcuts to simplify test code
self.app.pt = Mock({'operational': True})
# Register some transactions
tm = self.app.tm
# Transaction 1: 2 storage nodes involved, one will die and the other
# already answered node lock
msg_id_1 = 1
ttid1 = tm.begin(client1)
tid1 = tm.prepare(ttid1, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_1)
tm.lock(ttid1, node2.getUUID())
# storage 1 request a notification at commit
tm. registerForNotification(node1.getUUID())
self.checkNoPacketSent(cconn1)
# Storage 1 dies
node1.setTemporarilyDown()
self.service.nodeLost(conn1, node1)
# T1: last locking node lost, client receives AnswerTransactionFinished
self.checkAnswerTransactionFinished(cconn1)
self.checkNotifyTransactionFinished(conn1)
self.checkNotifyUnlockInformation(conn2)
# ...and notifications are sent to other clients
self.checkInvalidateObjects(cconn2)
self.checkInvalidateObjects(cconn3)
# Transaction 2: 2 storage nodes involved, one will die
msg_id_2 = 2
ttid2 = tm.begin(node1)
tid2 = tm.prepare(ttid2, 1, oid_list,
[node1.getUUID(), node2.getUUID()], msg_id_2)
# T2: pending locking answer, client keeps waiting
self.checkNoPacketSent(cconn2, check_notify=False)
tm.remove(node1.getUUID(), ttid2)
# Transaction 3: 1 storage node involved, which won't die
msg_id_3 = 3
ttid3 = tm.begin(node1)
tid3 = tm.prepare(ttid3, 1, oid_list,
[node2.getUUID(), ], msg_id_3)
# T3: action not significant to this transacion, so no response
self.checkNoPacketSent(cconn3, check_notify=False)
tm.remove(node1.getUUID(), ttid3)
def test_answerPack(self): def test_answerPack(self):
# Note: incomming status has no meaning here, so it's left to False. # Note: incomming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage() node1, conn1 = self._getStorage()
......
...@@ -112,19 +112,6 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -112,19 +112,6 @@ class testTransactionManager(NeoUnitTestBase):
# ...and the lock is available # ...and the lock is available
txnman.begin(client, self.getNextTID()) txnman.begin(client, self.getNextTID())
def test_getNextOIDList(self):
txnman = TransactionManager(lambda tid, txn: None)
# must raise as we don"t have one
self.assertEqual(txnman.getLastOID(), None)
self.assertRaises(RuntimeError, txnman.getNextOIDList, 1)
# ask list
txnman.setLastOID(self.getOID(1))
oid_list = txnman.getNextOIDList(15)
self.assertEqual(len(oid_list), 15)
# begin from 1, so generated oid from 2 to 16
for i, oid in zip(xrange(len(oid_list)), oid_list):
self.assertEqual(oid, self.getOID(i+2))
def test_forget(self): def test_forget(self):
client1 = Mock({'__hash__': 1}) client1 = Mock({'__hash__': 1})
client2 = Mock({'__hash__': 2}) client2 = Mock({'__hash__': 2})
......
...@@ -191,18 +191,6 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -191,18 +191,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askObjectHistory(conn, oid2, 1, 2) self.operation.askObjectHistory(conn, oid2, 1, 2)
self.checkAnswerObjectHistory(conn) self.checkAnswerObjectHistory(conn)
def test_askStoreTransaction(self):
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
user = 'USER'
desc = 'DESC'
ext = 'EXT'
oid_list = (self.getOID(1), self.getOID(2))
self.operation.askStoreTransaction(conn, tid, user, desc, ext, oid_list)
calls = self.app.tm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
self.checkAnswerStoreTransaction(conn)
def _getObject(self): def _getObject(self):
oid = self.getOID(0) oid = self.getOID(0)
serial = self.getNextTID() serial = self.getNextTID()
......
...@@ -76,7 +76,7 @@ class StorageInitializationHandlerTests(NeoUnitTestBase): ...@@ -76,7 +76,7 @@ class StorageInitializationHandlerTests(NeoUnitTestBase):
(2, ((node_2, CellStates.UP_TO_DATE), (node_3, CellStates.UP_TO_DATE)))] (2, ((node_2, CellStates.UP_TO_DATE), (node_3, CellStates.UP_TO_DATE)))]
self.assertFalse(self.app.pt.filled()) self.assertFalse(self.app.pt.filled())
# send a complete new table and ack # send a complete new table and ack
self.verification.answerPartitionTable(conn, 2, row_list) self.verification.sendPartitionTable(conn, 2, row_list)
self.assertTrue(self.app.pt.filled()) self.assertTrue(self.app.pt.filled())
self.assertEqual(self.app.pt.getID(), 2) self.assertEqual(self.app.pt.getID(), 2)
self.assertTrue(list(self.app.dm.getPartitionTable())) self.assertTrue(list(self.app.dm.getPartitionTable()))
......
...@@ -20,7 +20,7 @@ from collections import deque ...@@ -20,7 +20,7 @@ from collections import deque
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler from neo.storage.handlers.master import MasterOperationHandler
from neo.lib.exception import PrimaryFailure, OperationFailure from neo.lib.exception import PrimaryFailure
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.protocol import CellStates, ProtocolError, Packets from neo.lib.protocol import CellStates, ProtocolError, Packets
...@@ -104,58 +104,9 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -104,58 +104,9 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(ptid2, cells) calls[0].checkArgs(ptid2, cells)
def test_16_stopOperation1(self):
# OperationFailure
conn = self.getFakeConnection(is_server=False)
self.assertRaises(OperationFailure, self.operation.stopOperation, conn)
def _getConnection(self): def _getConnection(self):
return self.getFakeConnection() return self.getFakeConnection()
def test_askLockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = self.getNextTID()
ttid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.askLockInformation, conn,
ttid, tid, oid_list)
def test_askLockInformation2(self):
""" Lock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
ttid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)]
self.operation.askLockInformation(conn, ttid, tid, oid_list)
calls = self.app.tm.mockGetNamedCalls('lock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ttid, tid, oid_list)
self.checkAnswerInformationLocked(conn)
def test_notifyUnlockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
tid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.notifyUnlockInformation,
conn, tid)
def test_notifyUnlockInformation2(self):
""" Unlock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
self.operation.notifyUnlockInformation(conn, tid)
calls = self.app.tm.mockGetNamedCalls('unlock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
self.checkNoPacketSent(conn)
def test_askPack(self): def test_askPack(self):
self.app.dm = Mock({'pack': None}) self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection() conn = self.getFakeConnection()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from binascii import a2b_hex from binascii import a2b_hex
from contextlib import contextmanager
import unittest import unittest
from neo.lib.util import add64, p64, u64 from neo.lib.util import add64, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
...@@ -80,6 +81,17 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -80,6 +81,17 @@ class StorageDBTests(NeoUnitTestBase):
set_call(value * 2) set_call(value * 2)
self.assertEqual(get_call(), value * 2) self.assertEqual(get_call(), value * 2)
@contextmanager
def commitTransaction(self, tid, objs, txn, commit=True):
ttid = txn[-1]
self.db.storeTransaction(ttid, objs, txn)
self.db.lockTransaction(tid, ttid)
yield
if commit:
self.db.unlockTransaction(tid, ttid)
elif commit is not None:
self.db.abortTransaction(ttid)
def test_UUID(self): def test_UUID(self):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getUUID, db.setUUID, 123) self.checkConfigEntry(db.getUUID, db.setUUID, 123)
...@@ -122,38 +134,24 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -122,38 +134,24 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2)) self.assertEqual(set(list1), set(list2))
def test_getUnfinishedTIDList(self): def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
txn, objs = self.getTransaction([oid1, oid2]) txn, objs = self.getTransaction([oid1, oid2])
# nothing pending
self.db.storeTransaction(tid1, objs, txn, False)
self.checkSet(self.db.getUnfinishedTIDList(), [])
# one unfinished txn # one unfinished txn
self.db.storeTransaction(tid2, objs, txn) with self.commitTransaction(tid2, objs, txn):
self.checkSet(self.db.getUnfinishedTIDList(), [tid2]) expected = {txn[-1]: tid2}
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# no changes # no changes
self.db.storeTransaction(tid3, objs, None, False) self.db.storeTransaction(tid3, objs, None, False)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2]) self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# a second txn known by objs only # a second txn known by objs only
expected[tid4] = None
self.db.storeTransaction(tid4, objs, None) self.db.storeTransaction(tid4, objs, None)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2, tid4]) self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
self.db.abortTransaction(tid4)
def test_objectPresent(self): # nothing pending
tid = self.getNextTID() self.assertEqual(self.db.getUnfinishedTIDDict(), {})
oid = self.getOID(1)
txn, objs = self.getTransaction([oid])
# not present
self.assertFalse(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in temp table
self.db.storeTransaction(tid, objs, txn)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in both tables
self.db.finishTransaction(tid)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertTrue(self.db.objectPresent(oid, tid, all=False))
def test_getObject(self): def test_getObject(self):
oid1, = self.getOIDs(1) oid1, = self.getOIDs(1)
...@@ -169,27 +167,26 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -169,27 +167,26 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, tid1), None) self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None) self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one non-commited version # one non-commited version
self.db.storeTransaction(tid1, objs1, txn1) with self.commitTransaction(tid1, objs1, txn1):
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid1, tid1), None) self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None) self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one commited version # one commited version
self.db.finishTransaction(tid1)
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE) FOUND_BUT_NOT_VISIBLE)
# two version available, one non-commited # two version available, one non-commited
self.db.storeTransaction(tid2, objs2, txn2) with self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE) FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, tid2), FOUND_BUT_NOT_VISIBLE) self.assertEqual(self.db.getObject(oid1, tid2),
FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, before_tid=tid2), self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NO_NEXT) OBJECT_T1_NO_NEXT)
# two commited versions # two commited versions
self.db.finishTransaction(tid2)
self.assertEqual(self.db.getObject(oid1), OBJECT_T2) self.assertEqual(self.db.getObject(oid1), OBJECT_T2)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
...@@ -242,82 +239,28 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -242,82 +239,28 @@ class StorageDBTests(NeoUnitTestBase):
result = db.getPartitionTable() result = db.getPartitionTable()
self.assertEqual(list(result), [cell1]) self.assertEqual(list(result), [cell1])
def test_dropUnfinishedData(self): def test_commitTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
# nothing
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [])
# one is still pending
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [tid2])
# drop it
self.db.dropUnfinishedData()
self.assertEqual(self.db.getUnfinishedTIDList(), [])
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
def test_storeTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# nothing in database # nothing in database
self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None)) self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None))
self.assertEqual(self.db.getUnfinishedTIDList(), []) self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getTransaction(tid1, True), None) self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None) self.assertEqual(self.db.getTransaction(tid2, True), None)
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
# store in temporary tables with self.commitTransaction(tid1, objs1, txn1), \
self.db.storeTransaction(tid1, objs1, txn1) self.commitTransaction(tid2, objs2, txn2):
self.db.storeTransaction(tid2, objs2, txn2) self.assertEqual(self.db.getTransaction(tid1, True),
result = self.db.getTransaction(tid1, True) ([oid1], 'user', 'desc', 'ext', False, p64(1)))
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) self.assertEqual(self.db.getTransaction(tid2, True),
result = self.db.getTransaction(tid2, True) ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# commit pending transaction
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
result = self.db.getTransaction(tid1, False)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, False)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_askFinishTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# stored but not finished
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None) self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
# stored and finished
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
result = self.db.getTransaction(tid1, True) result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True) result = self.db.getTransaction(tid2, True)
...@@ -328,32 +271,29 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -328,32 +271,29 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2))) self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_deleteTransaction(self): def test_deleteTransaction(self):
oid1, oid2 = self.getOIDs(2) txn, objs = self.getTransaction([])
tid1, tid2 = self.getTIDs(2) tid = txn[-1]
txn1, objs1 = self.getTransaction([oid1]) self.db.storeTransaction(tid, objs, txn, False)
txn2, objs2 = self.getTransaction([oid2]) self.assertEqual(self.db.getTransaction(tid), txn)
self.db.storeTransaction(tid1, objs1, txn1) self.db.deleteTransaction(tid)
self.db.storeTransaction(tid2, objs2, txn2) self.assertEqual(self.db.getTransaction(tid), None)
self.db.finishTransaction(tid1)
self.db.deleteTransaction(tid1, [oid1])
self.db.deleteTransaction(tid2, [oid2])
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteObject(self): def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2]) txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2]) txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1) tid1 = txn1[-1]
self.db.storeTransaction(tid2, objs2, txn2) tid2 = txn2[-1]
self.db.finishTransaction(tid1) self.db.storeTransaction(tid1, objs1, txn1, False)
self.db.finishTransaction(tid2) self.db.storeTransaction(tid2, objs2, txn2, False)
self.assertEqual(self.db.getObject(oid1, tid=tid1),
(tid1, tid2, 1, "0" * 20, '', None))
self.db.deleteObject(oid1) self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None) self.assertIs(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None) self.assertIs(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1) self.db.deleteObject(oid2, serial=tid1)
self.assertFalse(self.db.getObject(oid2, tid=tid1)) self.assertIs(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2), self.assertEqual(self.db.getObject(oid2, tid=tid2),
(tid2, None, 1, "0" * 20, '', None)) (tid2, None, 1, "0" * 20, '', None))
...@@ -364,8 +304,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -364,8 +304,7 @@ class StorageDBTests(NeoUnitTestBase):
oid_list = self.getOIDs(np * 2) oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3: for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list) txn, objs = self.getTransaction(oid_list)
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn, False)
self.db.finishTransaction(tid)
def check(offset, tid_list, *tids): def check(offset, tid_list, *tids):
self.assertEqual(self.db.getReplicationTIDList(ZERO_TID, self.assertEqual(self.db.getReplicationTIDList(ZERO_TID,
MAX_TID, len(tid_list) + 1, offset), tid_list) MAX_TID, len(tid_list) + 1, offset), tid_list)
...@@ -386,9 +325,9 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -386,9 +325,9 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# get from temporary table or not # get from temporary table or not
self.db.storeTransaction(tid1, objs1, txn1) with self.commitTransaction(tid1, objs1, txn1), \
self.db.storeTransaction(tid2, objs2, txn2) self.commitTransaction(tid2, objs2, txn2, None):
self.db.finishTransaction(tid1) pass
result = self.db.getTransaction(tid1, True) result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1))) self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True) result = self.db.getTransaction(tid2, True)
...@@ -405,15 +344,13 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -405,15 +344,13 @@ class StorageDBTests(NeoUnitTestBase):
txn2, objs2 = self.getTransaction([oid]) txn2, objs2 = self.getTransaction([oid])
txn3, objs3 = self.getTransaction([oid]) txn3, objs3 = self.getTransaction([oid])
# one revision # one revision
self.db.storeTransaction(tid1, objs1, txn1) self.db.storeTransaction(tid1, objs1, txn1, False)
self.db.finishTransaction(tid1)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid1, 0)]) self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 1) result = self.db.getObjectHistory(oid, 1, 1)
self.assertEqual(result, None) self.assertEqual(result, None)
# two revisions # two revisions
self.db.storeTransaction(tid2, objs2, txn2) self.db.storeTransaction(tid2, objs2, txn2, False)
self.db.finishTransaction(tid2)
result = self.db.getObjectHistory(oid, 0, 3) result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid2, 0), (tid1, 0)]) self.assertEqual(result, [(tid2, 0), (tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 3) result = self.db.getObjectHistory(oid, 1, 3)
...@@ -427,8 +364,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -427,8 +364,7 @@ class StorageDBTests(NeoUnitTestBase):
oid = self.getOID(1) oid = self.getOID(1)
for tid in tid_list: for tid in tid_list:
txn, objs = self.getTransaction([oid]) txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn, False)
self.db.finishTransaction(tid)
return tid_list return tid_list
def test_getTIDList(self): def test_getTIDList(self):
......
...@@ -45,15 +45,6 @@ class TransactionTests(NeoUnitTestBase): ...@@ -45,15 +45,6 @@ class TransactionTests(NeoUnitTestBase):
# disallow lock more than once # disallow lock more than once
self.assertRaises(AssertionError, txn.lock) self.assertRaises(AssertionError, txn.lock)
def testTransaction(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
repr(txn) # check __repr__ does not raise
oid_list = [self.getOID(1), self.getOID(2)]
txn_info = (oid_list, 'USER', 'DESC', 'EXT', False)
txn.prepare(*txn_info)
self.assertEqual(txn.getTransactionInformations(),
txn_info + (txn.getTTID(),))
def testObjects(self): def testObjects(self):
txn = Transaction(self.getClientUUID(), self.getNextTID()) txn = Transaction(self.getClientUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2 = self.getOID(1), self.getOID(2)
...@@ -91,10 +82,10 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -91,10 +82,10 @@ class TransactionManagerTests(NeoUnitTestBase):
def _getTransaction(self): def _getTransaction(self):
tid = self.getNextTID(self.ltid) tid = self.getNextTID(self.ltid)
oid_list = [self.getOID(1), self.getOID(2)] oid_list = [self.getOID(1), self.getOID(2)]
return (tid, (oid_list, 'USER', 'DESC', 'EXT', False)) return (tid, ('USER', 'DESC', 'EXT', oid_list))
def _storeTransactionObjects(self, tid, txn): def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]): for i, oid in enumerate(txn[3]):
self.manager.storeObject(tid, None, self.manager.storeObject(tid, None,
oid, 1, '%020d' % i, '0' + str(i), None) oid, 1, '%020d' % i, '0' + str(i), None)
...@@ -108,15 +99,21 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -108,15 +99,21 @@ class TransactionManagerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args) calls[0].checkArgs(*args)
def _checkTransactionFinished(self, tid): def _checkTransactionFinished(self, *args):
calls = self.app.dm.mockGetNamedCalls('finishTransaction') calls = self.app.dm.mockGetNamedCalls('unlockTransaction')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid) calls[0].checkArgs(*args)
def _checkQueuedEventExecuted(self, number=1): def _checkQueuedEventExecuted(self, number=1):
calls = self.app.mockGetNamedCalls('executeQueuedEvents') calls = self.app.mockGetNamedCalls('executeQueuedEvents')
self.assertEqual(len(calls), number) self.assertEqual(len(calls), number)
def assertRegistered(self, ttid):
self.assertIn(ttid, self.manager._transaction_dict)
def assertNotRegistered(self, ttid):
self.assertNotIn(ttid, self.manager._transaction_dict)
def testSimpleCase(self): def testSimpleCase(self):
""" One node, one transaction, not abort """ """ One node, one transaction, not abort """
data_id_list = random.random(), random.random() data_id_list = random.random(), random.random()
...@@ -127,18 +124,23 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -127,18 +124,23 @@ class TransactionManagerTests(NeoUnitTestBase):
serial1, object1 = self._getObject(1) serial1, object1 = self._getObject(1)
serial2, object2 = self._getObject(2) serial2, object2 = self._getObject(2)
self.manager.register(uuid, ttid) self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self.manager.storeObject(ttid, serial1, *object1) self.manager.storeObject(ttid, serial1, *object1)
self.manager.storeObject(ttid, serial2, *object2) self.manager.storeObject(ttid, serial2, *object2)
self.assertTrue(ttid in self.manager) self.assertRegistered(ttid)
self.manager.lock(ttid, tid, txn[0]) self.manager.vote(ttid, txn)
self._checkTransactionStored(tid, [ user, desc, ext, oid_list = txn
call, = self.app.dm.mockGetNamedCalls('storeTransaction')
call.checkArgs(ttid, [
(object1[0], data_id_list[0], object1[4]), (object1[0], data_id_list[0], object1[4]),
(object2[0], data_id_list[1], object2[4]), (object2[0], data_id_list[1], object2[4]),
], txn + (ttid,)) ], (oid_list, user, desc, ext, False, ttid))
self.manager.lock(ttid, tid)
call, = self.app.dm.mockGetNamedCalls('lockTransaction')
call.checkArgs(tid, ttid)
self.manager.unlock(ttid) self.manager.unlock(ttid)
self.assertFalse(ttid in self.manager) self.assertNotRegistered(ttid)
self._checkTransactionFinished(tid) call, = self.app.dm.mockGetNamedCalls('unlockTransaction')
call.checkArgs(tid, ttid)
def testDelayed(self): def testDelayed(self):
""" Two transactions, the first cause the second to be delayed """ """ Two transactions, the first cause the second to be delayed """
...@@ -150,14 +152,13 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -150,14 +152,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1) serial, obj = self._getObject(1)
# first transaction lock the object # first transaction lock the object
self.manager.register(uuid, ttid1) self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self._storeTransactionObjects(ttid1, txn1) self._storeTransactionObjects(ttid1, txn1)
self.manager.lock(ttid1, tid1, txn1[0]) self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# the second is delayed # the second is delayed
self.manager.register(uuid, ttid2) self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2) self.assertRegistered(ttid2)
self.assertTrue(ttid2 in self.manager)
self.assertRaises(DelayedError, self.manager.storeObject, self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial, *obj) ttid2, serial, *obj)
...@@ -171,14 +172,13 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -171,14 +172,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1) serial, obj = self._getObject(1)
# the (later) transaction lock (change) the object # the (later) transaction lock (change) the object
self.manager.register(uuid, ttid2) self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2) self.assertRegistered(ttid2)
self.assertTrue(ttid2 in self.manager)
self._storeTransactionObjects(ttid2, txn2) self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0]) self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the previous it's not using the latest version # the previous it's not using the latest version
self.manager.register(uuid, ttid1) self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial, *obj) ttid1, serial, *obj)
...@@ -191,7 +191,6 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -191,7 +191,6 @@ class TransactionManagerTests(NeoUnitTestBase):
# try to store without the last revision # try to store without the last revision
self.app.dm = Mock({'getLastObjectTID': next_serial}) self.app.dm = Mock({'getLastObjectTID': next_serial})
self.manager.register(uuid, tid) self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
tid, serial, *obj) tid, serial, *obj)
...@@ -208,15 +207,14 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -208,15 +207,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2) serial2, obj2 = self._getObject(2)
# first transaction lock objects # first transaction lock objects
self.manager.register(uuid1, ttid1) self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self.manager.storeObject(ttid1, serial1, *obj1) self.manager.storeObject(ttid1, serial1, *obj1)
self.manager.storeObject(ttid1, serial1, *obj2) self.manager.storeObject(ttid1, serial1, *obj2)
self.manager.lock(ttid1, tid1, txn1[0]) self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# second transaction is delayed # second transaction is delayed
self.manager.register(uuid2, ttid2) self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2) self.assertRegistered(ttid2)
self.assertTrue(ttid2 in self.manager)
self.assertRaises(DelayedError, self.manager.storeObject, self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial1, *obj1) ttid2, serial1, *obj1)
self.assertRaises(DelayedError, self.manager.storeObject, self.assertRaises(DelayedError, self.manager.storeObject,
...@@ -235,15 +233,14 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -235,15 +233,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2) serial2, obj2 = self._getObject(2)
# the second transaction lock objects # the second transaction lock objects
self.manager.register(uuid2, ttid2) self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeObject(ttid2, serial1, *obj1) self.manager.storeObject(ttid2, serial1, *obj1)
self.manager.storeObject(ttid2, serial2, *obj2) self.manager.storeObject(ttid2, serial2, *obj2)
self.assertTrue(ttid2 in self.manager) self.assertRegistered(ttid2)
self.manager.lock(ttid2, tid2, txn1[0]) self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the first get a conflict # the first get a conflict
self.manager.register(uuid1, ttid1) self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1) self.assertRegistered(ttid1)
self.assertTrue(ttid1 in self.manager)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial1, *obj1) ttid1, serial1, *obj1)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
...@@ -255,12 +252,12 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -255,12 +252,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
serial, obj = self._getObject(1) serial, obj = self._getObject(1)
self.manager.register(uuid, tid) self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.manager.storeObject(tid, serial, *obj) self.manager.storeObject(tid, serial, *obj)
self.assertTrue(tid in self.manager) self.assertRegistered(tid)
self.manager.vote(tid, txn)
# transaction is not locked # transaction is not locked
self.manager.abort(tid) self.manager.abort(tid)
self.assertFalse(tid in self.manager) self.assertNotRegistered(tid)
self.assertFalse(self.manager.loadLocked(obj[0])) self.assertFalse(self.manager.loadLocked(obj[0]))
self._checkQueuedEventExecuted() self._checkQueuedEventExecuted()
...@@ -270,14 +267,14 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -270,14 +267,14 @@ class TransactionManagerTests(NeoUnitTestBase):
ttid = self.getNextTID() ttid = self.getNextTID()
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
self.manager.register(uuid, ttid) self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn) self._storeTransactionObjects(ttid, txn)
self.manager.vote(ttid, txn)
# lock transaction # lock transaction
self.manager.lock(ttid, tid, txn[0]) self.manager.lock(ttid, tid)
self.assertTrue(ttid in self.manager) self.assertRegistered(ttid)
self.manager.abort(ttid) self.manager.abort(ttid)
self.assertTrue(ttid in self.manager) self.assertRegistered(ttid)
for oid in txn[0]: for oid in txn[-1]:
self.assertTrue(self.manager.loadLocked(oid)) self.assertTrue(self.manager.loadLocked(oid))
self._checkQueuedEventExecuted(number=0) self._checkQueuedEventExecuted(number=0)
...@@ -295,20 +292,20 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -295,20 +292,20 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.register(uuid1, ttid1) self.manager.register(uuid1, ttid1)
self.manager.register(uuid2, ttid2) self.manager.register(uuid2, ttid2)
self.manager.register(uuid2, ttid3) self.manager.register(uuid2, ttid3)
self.manager.storeTransaction(ttid1, *txn1) self.manager.vote(ttid1, txn1)
# node 2 owns tid2 & tid3 and lock tid2 only # node 2 owns tid2 & tid3 and lock tid2 only
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeTransaction(ttid3, *txn3)
self._storeTransactionObjects(ttid2, txn2) self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0]) self.manager.vote(ttid2, txn2)
self.assertTrue(ttid1 in self.manager) self.manager.vote(ttid3, txn3)
self.assertTrue(ttid2 in self.manager) self.manager.lock(ttid2, tid2)
self.assertTrue(ttid3 in self.manager) self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertRegistered(ttid3)
self.manager.abortFor(uuid2) self.manager.abortFor(uuid2)
# only tid3 is aborted # only tid3 is aborted
self.assertTrue(ttid1 in self.manager) self.assertRegistered(ttid1)
self.assertTrue(ttid2 in self.manager) self.assertRegistered(ttid2)
self.assertFalse(ttid3 in self.manager) self.assertNotRegistered(ttid3)
self._checkQueuedEventExecuted(number=1) self._checkQueuedEventExecuted(number=1)
def testReset(self): def testReset(self):
...@@ -317,12 +314,12 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -317,12 +314,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
ttid = self.getNextTID() ttid = self.getNextTID()
self.manager.register(uuid, ttid) self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn) self._storeTransactionObjects(ttid, txn)
self.manager.lock(ttid, tid, txn[0]) self.manager.vote(ttid, txn)
self.assertTrue(ttid in self.manager) self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.reset() self.manager.reset()
self.assertFalse(ttid in self.manager) self.assertNotRegistered(ttid)
for oid in txn[0]: for oid in txn[0]:
self.assertFalse(self.manager.loadLocked(oid)) self.assertFalse(self.manager.loadLocked(oid))
......
#
# Copyright (C) 2009-2015 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.lib.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.verification import VerificationHandler
from neo.lib.protocol import CellStates, ErrorCodes
from neo.lib.exception import PrimaryFailure
from neo.lib.util import p64, u64
class StorageVerificationHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.verification = VerificationHandler(self.app)
# define some variable to simulate client and storage node
self.master_port = 10010
self.storage_port = 10020
self.client_port = 11011
self.num_partitions = 1009
self.num_replicas = 2
self.app.operational = False
self.app.load_lock_dict = {}
self.app.pt = PartitionTable(self.num_partitions, self.num_replicas)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageVerificationHandlerTests, self)._tearDown(success)
# Common methods
def getMasterConnection(self):
return self.getFakeConnection(address=("127.0.0.1", self.master_port))
# Tests
def test_03_connectionClosed(self):
conn = self.getMasterConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
# nothing happens
self.checkNoPacketSent(conn)
def test_08_askPartitionTable(self):
node = self.app.nm.createStorage(
address=("127.7.9.9", 1),
uuid=self.getStorageUUID()
)
self.app.pt.setCell(1, node, CellStates.UP_TO_DATE)
self.assertTrue(self.app.pt.hasOffset(1))
conn = self.getMasterConnection()
self.verification.askPartitionTable(conn)
ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True)
self.assertEqual(len(row_list), 1009)
def test_10_notifyPartitionChanges(self):
# old partition change
conn = self.getMasterConnection()
self.verification.notifyPartitionChanges(conn, 1, ())
self.verification.notifyPartitionChanges(conn, 0, ())
self.assertEqual(self.app.pt.getID(), 1)
# new node
conn = self.getMasterConnection()
new_uuid = self.getStorageUUID()
cell = (0, new_uuid, CellStates.UP_TO_DATE)
self.app.nm.createStorage(uuid=new_uuid)
self.app.pt = PartitionTable(1, 1)
self.app.dm = Mock({ })
ptid = self.getPTID()
# pt updated
self.verification.notifyPartitionChanges(conn, ptid, (cell, ))
# check db update
calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].getParam(0), ptid)
self.assertEqual(calls[0].getParam(1), (cell, ))
def test_13_askUnfinishedTransactions(self):
# client connection with no data
self.app.dm = Mock({
'getUnfinishedTIDList': [],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 0)
call_list = self.app.dm.mockGetNamedCalls('getUnfinishedTIDList')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
# client connection with some data
self.app.dm = Mock({
'getUnfinishedTIDList': [p64(4)],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1)
self.assertEqual(u64(tid_list[0]), 4)
def test_14_askTransactionInformation(self):
# ask from client conn with no data
self.app.dm = Mock({
'getTransaction': None,
})
conn = self.getMasterConnection()
tid = p64(1)
self.verification.askTransactionInformation(conn, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('getTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, all=True)
# input some tmp data and ask from client, must find both transaction
self.app.dm = Mock({
'getTransaction': ([p64(2)], 'u2', 'd2', 'e2', False),
})
conn = self.getMasterConnection()
self.verification.askTransactionInformation(conn, p64(1))
tid, user, desc, ext, packed, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1)
self.assertEqual(user, 'u2')
self.assertEqual(desc, 'd2')
self.assertEqual(ext, 'e2')
self.assertFalse(packed)
self.assertEqual(len(oid_list), 1)
self.assertEqual(u64(oid_list[0]), 2)
def test_15_askObjectPresent(self):
# client connection with no data
self.app.dm = Mock({
'objectPresent': False,
})
conn = self.getMasterConnection()
oid, tid = p64(1), p64(2)
self.verification.askObjectPresent(conn, oid, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.OID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('objectPresent')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(oid, tid)
# client connection with some data
self.app.dm = Mock({
'objectPresent': True,
})
conn = self.getMasterConnection()
self.verification.askObjectPresent(conn, oid, tid)
oid, tid = self.checkAnswerObjectPresent(conn, decode=True)
self.assertEqual(u64(tid), 2)
self.assertEqual(u64(oid), 1)
def test_16_deleteTransaction(self):
# client connection with no data
self.app.dm = Mock({
'deleteTransaction': None,
})
conn = self.getMasterConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = p64(1)
self.verification.deleteTransaction(conn, tid, oid_list)
call_list = self.app.dm.mockGetNamedCalls('deleteTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, oid_list)
def test_17_commitTransaction(self):
# commit a transaction
conn = self.getMasterConnection()
dm = Mock()
self.app.dm = dm
self.verification.commitTransaction(conn, p64(1))
self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
call = dm.mockGetNamedCalls("finishTransaction")[0]
tid = call.getParam(0)
self.assertEqual(u64(tid), 1)
if __name__ == "__main__":
unittest.main()
...@@ -33,7 +33,7 @@ from neo.lib import logging ...@@ -33,7 +33,7 @@ from neo.lib import logging
from neo.lib.connection import BaseConnection, Connection from neo.lib.connection import BaseConnection, Connection
from neo.lib.connector import SocketConnector, ConnectorException from neo.lib.connector import SocketConnector, ConnectorException
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
...@@ -478,6 +478,8 @@ class ConnectionFilter(object): ...@@ -478,6 +478,8 @@ class ConnectionFilter(object):
queue.appendleft(packet) queue.appendleft(packet)
break break
else: else:
if conn.isClosed():
return
cls._addPacket(conn, packet) cls._addPacket(conn, packet)
continue continue
break break
...@@ -731,9 +733,12 @@ class NEOCluster(object): ...@@ -731,9 +733,12 @@ class NEOCluster(object):
return node[3] return node[3]
def getOutdatedCells(self): def getOutdatedCells(self):
return [cell for row in self.neoctl.getPartitionRowList()[1] # Ask the admin instead of the primary master to check that it is
for cell in row[1] # notified of every change.
if cell[1] == CellStates.OUT_OF_DATE] return [(i, cell.getUUID())
for i, row in enumerate(self.admin.pt.partition_list)
for cell in row
if not cell.isReadable()]
def getZODBStorage(self, **kw): def getZODBStorage(self, **kw):
kw['_app'] = kw.pop('client', self.client) kw['_app'] = kw.pop('client', self.client)
......
...@@ -21,12 +21,12 @@ import transaction ...@@ -21,12 +21,12 @@ import transaction
import unittest import unittest
from thread import get_ident from thread import get_ident
from zlib import compress from zlib import compress
from persistent import Persistent from persistent import Persistent, GHOST
from ZODB import DB, POSException from ZODB import DB, POSException
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError DelayedError, ConflictError
from neo.lib.connection import ConnectionClosed, MTClientConnection from neo.lib.connection import ConnectionClosed, MTClientConnection
from neo.lib.exception import OperationFailure from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID ZERO_TID
from .. import expectedFailure, _ExpectedFailure, _UnexpectedSuccess, Patch from .. import expectedFailure, _ExpectedFailure, _UnexpectedSuccess, Patch
...@@ -34,6 +34,7 @@ from . import NEOCluster, NEOThreadedTest ...@@ -34,6 +34,7 @@ from . import NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
from neo.storage.handlers.initialization import InitializationHandler
class PCounter(Persistent): class PCounter(Persistent):
value = 0 value = 0
...@@ -394,38 +395,106 @@ class Test(NEOThreadedTest): ...@@ -394,38 +395,106 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testRestartStoragesWithReplicas(self):
"""
Check that the master must discard its partition table when the
cluster is not operational anymore. Which means that it must go back
to RECOVERING state and remain there as long as the partition table
can't be operational.
This also checks that if the master remains the primary one after going
back to recovery, it automatically starts the cluster if possible
(i.e. without manual intervention).
"""
outdated = []
def doOperation(orig):
outdated.append(cluster.getOutdatedCells())
orig()
def stop():
with cluster.master.filterConnection(s0) as m2s0:
m2s0.add(lambda conn, packet:
isinstance(packet, Packets.NotifyPartitionChanges))
s1.stop()
cluster.join((s1,))
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(cluster.getOutdatedCells(),
[(0, s1.uuid), (1, s1.uuid)])
s0.stop()
cluster.join((s0,))
self.assertNotEqual(getClusterState(), ClusterStates.RUNNING)
s0.resetNode()
s1.resetNode()
cluster = NEOCluster(storage_count=2, partitions=2, replicas=1)
try:
cluster.start()
s0, s1 = cluster.storage_list
getClusterState = cluster.neoctl.getClusterState
if 1:
# Scenario 1: When all storage nodes are restarting,
# we want a chance to not restart with outdated cells.
stop()
with Patch(s1, doOperation=doOperation):
s0.start()
s1.start()
self.tic()
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(outdated, [[]])
if 1:
# Scenario 2: When only the first storage node to be stopped
# is started, the cluster must be able to restart.
stop()
s1.start()
self.tic()
# The master doesn't wait for s0 to come back.
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
self.assertEqual(cluster.getOutdatedCells(),
[(0, s0.uuid), (1, s0.uuid)])
finally:
cluster.stop()
def testVerificationCommitUnfinishedTransactions(self): def testVerificationCommitUnfinishedTransactions(self):
""" Verification step should commit locked transactions """ """ Verification step should commit locked transactions """
def delayUnlockInformation(conn, packet): def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation) return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreTransaction(storage, die=False): def onLockTransaction(storage, die=False):
def storeTransaction(orig, *args, **kw): def lock(orig, *args, **kw):
orig(*args, **kw)
if die: if die:
sys.exit() sys.exit()
orig(*args, **kw)
storage.master_conn.close() storage.master_conn.close()
return Patch(storage.dm, storeTransaction=storeTransaction) return Patch(storage.tm, lock=lock)
cluster = NEOCluster(partitions=2, storage_count=2) cluster = NEOCluster(partitions=2, storage_count=2)
try: try:
cluster.start() cluster.start()
s0, s1 = cluster.sortStorageList() s0, s1 = cluster.sortStorageList()
t, c = cluster.getTransaction() t, c = cluster.getTransaction()
r = c.root() r = c.root()
r[0] = x = PCounter() r[0] = PCounter()
tids = [r._p_serial] tids = [r._p_serial]
t.commit() with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit)
self.assertEqual(r._p_state, GHOST)
self.tic()
t.begin()
x = r[0]
self.assertEqual(x.value, 0)
cluster.master.tm._last_oid = x._p_oid
tids.append(r._p_serial) tids.append(r._p_serial)
r[1] = PCounter() r[1] = PCounter()
with onStoreTransaction(s0), onStoreTransaction(s1): c.readCurrent(x)
with cluster.moduloTID(1):
with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit) self.assertRaises(ConnectionClosed, t.commit)
self.tic() self.tic()
t.begin() t.begin()
# The following line checks that s1 moved the transaction
# metadata to final place during the verification phase.
# If it didn't, a NEOStorageError would be raised.
self.assertEqual(3, len(c.db().history(r._p_oid, 4)))
y = r[1] y = r[1]
self.assertEqual(y.value, 0) self.assertEqual(y.value, 0)
assert [u64(o._p_oid) for o in (r, x, y)] == range(3) self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3))
r[2] = 'ok' r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s, \ with cluster.master.filterConnection(s0) as m2s:
cluster.moduloTID(1):
m2s.add(delayUnlockInformation) m2s.add(delayUnlockInformation)
t.commit() t.commit()
x.value = 1 x.value = 1
...@@ -433,12 +502,15 @@ class Test(NEOThreadedTest): ...@@ -433,12 +502,15 @@ class Test(NEOThreadedTest):
# never lock the transaction (packets from master delayed), # never lock the transaction (packets from master delayed),
# so the last transaction will be dropped. # so the last transaction will be dropped.
y.value = 2 y.value = 2
with onStoreTransaction(s1, die=True): di0 = s0.getDataLockInfo()
with onLockTransaction(s1, die=True):
self.assertRaises(ConnectionClosed, t.commit) self.assertRaises(ConnectionClosed, t.commit)
finally: finally:
cluster.stop() cluster.stop()
cluster.reset() cluster.reset()
di0 = s0.getDataLockInfo() (k, v), = set(s0.getDataLockInfo().iteritems()
).difference(di0.iteritems())
self.assertEqual(v, 1)
k, = (k for k, v in di0.iteritems() if v == 1) k, = (k for k, v in di0.iteritems() if v == 1)
di0[k] = 0 # r[2] = 'ok' di0[k] = 0 # r[2] = 'ok'
self.assertEqual(di0.values(), [0, 0, 0, 0, 0]) self.assertEqual(di0.values(), [0, 0, 0, 0, 0])
...@@ -458,6 +530,51 @@ class Test(NEOThreadedTest): ...@@ -458,6 +530,51 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def testDropUnfinishedData(self):
def lock(orig, *args, **kw):
orig(*args, **kw)
storage.master_conn.close()
r = []
def dropUnfinishedData(orig):
r.append(len(orig.__self__.getUnfinishedTIDDict()))
orig()
r.append(len(orig.__self__.getUnfinishedTIDDict()))
cluster = NEOCluster(partitions=2, storage_count=2, replicas=1)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()._p_changed = 1
storage = cluster.storage_list[0]
with Patch(storage.tm, lock=lock), \
Patch(storage.dm, dropUnfinishedData=dropUnfinishedData):
t.commit()
self.tic()
self.assertEqual(r, [1, 0])
finally:
cluster.stop()
def testStorageUpgrade1(self):
cluster = NEOCluster()
try:
cluster.start()
storage = cluster.storage
t, c = cluster.getTransaction()
storage.dm.setConfiguration("version", None)
c.root()._p_changed = 1
t.commit()
storage.stop()
cluster.join((storage,))
storage.resetNode()
storage.start()
t.begin()
storage.dm.setConfiguration("version", None)
c.root()._p_changed = 1
with Patch(storage.tm, lock=lambda *_: sys.exit()):
self.assertRaises(ConnectionClosed, t.commit)
self.assertRaises(DatabaseFailure, storage.resetNode)
finally:
cluster.stop()
def testStorageReconnectDuringStore(self): def testStorageReconnectDuringStore(self):
cluster = NEOCluster(replicas=1) cluster = NEOCluster(replicas=1)
try: try:
...@@ -550,12 +667,17 @@ class Test(NEOThreadedTest): ...@@ -550,12 +667,17 @@ class Test(NEOThreadedTest):
cluster.start() cluster.start()
# prevent storage to reconnect, in order to easily test # prevent storage to reconnect, in order to easily test
# that cluster becomes non-operational # that cluster becomes non-operational
storage.connectToPrimary = sys.exit with Patch(storage, connectToPrimary=sys.exit):
# send an unexpected to master so it aborts connection to storage # send an unexpected to master so it aborts connection to storage
storage.master_conn.answer(Packets.Pong()) storage.master_conn.answer(Packets.Pong())
self.tic() self.tic()
self.assertEqual(cluster.neoctl.getClusterState(), self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.VERIFYING) ClusterStates.RECOVERING)
storage.resetNode()
storage.start()
self.tic()
self.assertEqual(cluster.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally: finally:
cluster.stop() cluster.stop()
...@@ -833,7 +955,7 @@ class Test(NEOThreadedTest): ...@@ -833,7 +955,7 @@ class Test(NEOThreadedTest):
def testStorageFailureDuringTpcFinish(self): def testStorageFailureDuringTpcFinish(self):
def answerTransactionFinished(conn, packet): def answerTransactionFinished(conn, packet):
if isinstance(packet, Packets.AnswerTransactionFinished): if isinstance(packet, Packets.AnswerTransactionFinished):
raise OperationFailure raise StoppedOperation
cluster = NEOCluster() cluster = NEOCluster()
try: try:
cluster.start() cluster.start()
...@@ -849,8 +971,12 @@ class Test(NEOThreadedTest): ...@@ -849,8 +971,12 @@ class Test(NEOThreadedTest):
raise _UnexpectedSuccess raise _UnexpectedSuccess
except ConnectionClosed, e: except ConnectionClosed, e:
e = type(e), None, None e = type(e), None, None
# Also check that the master reset the last oid to a correct value.
self.assertTrue(cluster.client.new_oid_list)
t.begin() t.begin()
self.assertIn('x', c.root()) self.assertEqual(1, u64(c.root()['x']._p_oid))
self.assertFalse(cluster.client.new_oid_list)
self.assertEqual(2, u64(cluster.client.new_oid()))
finally: finally:
cluster.stop() cluster.stop()
raise _ExpectedFailure(e) raise _ExpectedFailure(e)
...@@ -908,5 +1034,111 @@ class Test(NEOThreadedTest): ...@@ -908,5 +1034,111 @@ class Test(NEOThreadedTest):
cluster.stop() cluster.stop()
del cluster.startCluster del cluster.startCluster
def testAbortVotedTransaction(self):
r = []
def tpc_finish(*args, **kw):
for storage in cluster.storage_list:
r.append(len(storage.dm.getUnfinishedTIDDict()))
raise NEOStorageError
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()['x'] = PCounter()
with Patch(cluster.client, tpc_finish=tpc_finish):
self.assertRaises(NEOStorageError, t.commit)
self.tic()
self.assertEqual(r, [1, 1])
for storage in cluster.storage_list:
self.assertFalse(storage.dm.getUnfinishedTIDDict())
t.begin()
self.assertNotIn('x', c.root())
finally:
cluster.stop()
def testStorageLostDuringRecovery(self):
# Initialize a cluster.
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
finally:
cluster.stop()
cluster.reset()
# Restart with a connection failure for the first AskPartitionTable.
# The master must not be stuck in RECOVERING state
# or re-make the partition table.
def make(*args):
sys.exit()
def askPartitionTable(orig, self, conn):
p.revert()
conn.close()
try:
with Patch(cluster.master.pt, make=make), \
Patch(InitializationHandler,
askPartitionTable=askPartitionTable) as p:
cluster.start()
self.assertFalse(p.applied)
finally:
cluster.stop()
def testTruncate(self):
calls = [0, 0]
def dieFirst(i):
def f(orig, *args, **kw):
calls[i] += 1
if calls[i] == 1:
sys.exit()
return orig(*args, **kw)
return f
cluster = NEOCluster(replicas=1)
try:
cluster.start()
t, c = cluster.getTransaction()
r = c.root()
tids = []
for x in xrange(4):
r[x] = None
t.commit()
tids.append(r._p_serial)
truncate_tid = tids[2]
r['x'] = PCounter()
s0, s1 = cluster.storage_list
with Patch(s0.tm, unlock=dieFirst(0)), \
Patch(s1.dm, truncate=dieFirst(1)):
t.commit()
cluster.neoctl.truncate(truncate_tid)
self.tic()
getClusterState = cluster.neoctl.getClusterState
# Unless forced, the cluster waits all nodes to be up,
# so that all nodes are truncated.
self.assertEqual(getClusterState(), ClusterStates.RECOVERING)
self.assertEqual(calls, [1, 0])
s0.resetNode()
s0.start()
# s0 died with unfinished data, and before processing the
# Truncate packet from the master.
self.assertFalse(s0.dm.getTruncateTID())
self.assertEqual(s1.dm.getTruncateTID(), truncate_tid)
self.tic()
self.assertEqual(calls, [1, 1])
self.assertEqual(getClusterState(), ClusterStates.RECOVERING)
s1.resetNode()
with Patch(s1.dm, truncate=dieFirst(1)):
s1.start()
self.assertEqual(s0.dm.getLastIDs()[0], truncate_tid)
self.assertEqual(s1.dm.getLastIDs()[0], r._p_serial)
self.tic()
self.assertEqual(calls, [1, 2])
self.assertEqual(getClusterState(), ClusterStates.RUNNING)
t.begin()
self.assertEqual(r, dict.fromkeys(xrange(3)))
self.assertEqual(r._p_serial, truncate_tid)
self.assertEqual(1, u64(c._storage.new_oid()))
for s in cluster.storage_list:
self.assertEqual(s.dm.getLastIDs()[0], truncate_tid)
finally:
cluster.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -424,7 +424,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -424,7 +424,7 @@ class ReplicationTests(NEOThreadedTest):
check(ClusterStates.RUNNING, 1) check(ClusterStates.RUNNING, 1)
cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None) cluster.neoctl.checkReplicas(check_dict, ZERO_TID, None)
self.tic() self.tic()
check(ClusterStates.VERIFYING, 4) check(ClusterStates.RECOVERING, 4)
finally: finally:
checker.CHECK_COUNT = CHECK_COUNT checker.CHECK_COUNT = CHECK_COUNT
cluster.stop() cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment