Commit 7eb7cf1b authored by Julien Muchembled's avatar Julien Muchembled

Minimize the amount of work during tpc_finish

NEO did not ensure that all data and metadata are written on disk before
tpc_finish, and it was for example vulnerable to ENOSPC errors.
In other words, some work had to be moved to tpc_vote:

- In tpc_vote, all involved storage nodes are now asked to write all metadata
  to ttrans/tobj and _commit_. Because the final tid is not known yet, the tid
  column of ttrans and tobj now contains NULL and the ttid respectively.

- In tpc_finish, AskLockInformation is still required for read locking,
  ttrans.tid is updated with the final value and this change is _committed_.

- The verification phase is greatly simplified, more reliable and faster. For
  all voted transactions, we can know if a tpc_finish was started by getting
  the final tid from the ttid, either from ttrans or from trans. And we know
  that such transactions can't be partial so we don't need to check oids.

So in addition to minimizing the risk of failures during tpc_finish, we also
fix a bug causing the verification phase to discard transactions with objects
for which readCurrent was called.

On performance side:

- Although tpc_vote now asks all involved storages, instead of only those
  storing the transaction metadata, the client has been improved to do this
  in parallel. The additional commits are also all done in parallel.

- A possible improvement to compensate the additional commits is to delay the
  commit done by the unlock.

- By minimizing the time to lock transactions, objects are read-locked for a
  much shorter period. This is even more important that locked transactions
  must be unlocked in the same order.

Transactions with too many modified objects will now timeout inside tpc_vote
instead of tpc_finish. Of course, such transactions may still cause other
transaction to timeout in tpc_finish.
parent 99ac542c
......@@ -58,8 +58,6 @@
committed by future transactions.
- Add a 'devid' storage configuration so that master do not distribute
replicated partitions on storages with same 'devid'.
- Make tpc_finish safer as described in its __doc__: moving work to
tpc_vote and recover from master failure when possible.
Storage
- Use libmysqld instead of a stand-alone MySQL server.
......
......@@ -612,18 +612,29 @@ class Application(ThreadedApplication):
packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension),
txn_context['cache_dict'])
add_involved_nodes = txn_context['involved_nodes'].add
queue = txn_context['queue']
trans_nodes = []
for node, conn in self.cp.iterateForObject(ttid):
logging.debug("voting transaction %s on %s", dump(ttid),
dump(conn.getUUID()))
try:
self._askStorage(conn, packet)
conn.ask(packet, queue=queue)
except ConnectionClosed:
continue
add_involved_nodes(node)
trans_nodes.append(node)
# check at least one storage node accepted
if txn_context['involved_nodes']:
if trans_nodes:
involved_nodes = txn_context['involved_nodes']
packet = Packets.AskVoteTransaction(ttid)
for node in involved_nodes.difference(trans_nodes):
conn = self.cp.getConnForNode(node)
if conn is not None:
try:
conn.ask(packet, queue=queue)
except ConnectionClosed:
pass
involved_nodes.update(trans_nodes)
self.waitResponses(queue)
txn_context['voted'] = None
# We must not go further if connection to master was lost since
# tpc_begin, to lower the probability of failing during tpc_finish.
......@@ -667,27 +678,14 @@ class Application(ThreadedApplication):
fail in tpc_finish. In particular, making a transaction permanent
should ideally be as simple as switching a bit permanently.
In NEO, tpc_finish breaks this promise by not ensuring earlier that all
data and metadata are written, and it is for example vulnerable to
ENOSPC errors. In other words, some work should be moved to tpc_vote.
TODO: - In tpc_vote, all involved storage nodes must be asked to write
all metadata to ttrans/tobj and _commit_. AskStoreTransaction
can be extended for this: for nodes that don't store anything
in ttrans, it can just contain the ttid. The final tid is not
known yet, so ttrans/tobj would contain the ttid.
- In tpc_finish, AskLockInformation is still required for read
locking, ttrans.tid must be updated with the final value and
ttrans _committed_.
- The Verification phase would need some change because
ttrans/tobj may contain data for which tpc_finish was not
called. The ttid is also in trans so a mapping ttid<->tid is
always possible and can be forwarded via the master so that all
storage are still able to update the tid column with the final
value when moving rows from tobj to obj.
The resulting cost is:
- additional RPCs in tpc_vote
- 1 updated row in ttrans + commit
In NEO, all the data (with the exception of the tid, simply because
it is not known yet) is already flushed on disk at the end on the vote.
During tpc_finish, all nodes storing the transaction metadata are asked
to commit by saving the new tid and flushing again: for SQL backends,
it's just an UPDATE of 1 cell. At last, the metadata is moved to
a final place so that the new transaction is readable, but this is
something that can always be replayed (during the verification phase)
if any failure happens.
TODO: We should recover from master failures when the transaction got
successfully committed. More precisely, we should not raise:
......
......@@ -112,9 +112,11 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, _):
def answerStoreTransaction(self, conn):
pass
answerVoteTransaction = answerStoreTransaction
def answerTIDsFrom(self, conn, tid_list):
logging.debug('Get %u TIDs from %r', len(tid_list), conn)
self.app.setHandlerData(tid_list)
......
......@@ -786,8 +786,8 @@ class StopOperation(Packet):
class UnfinishedTransactions(Packet):
"""
Ask unfinished transactions PM -> S.
Answer unfinished transactions S -> PM.
Ask unfinished transactions S -> PM.
Answer unfinished transactions PM -> S.
"""
_answer = PStruct('answer_unfinished_transactions',
PTID('max_tid'),
......@@ -796,36 +796,36 @@ class UnfinishedTransactions(Packet):
),
)
class ObjectPresent(Packet):
class LockedTransactions(Packet):
"""
Ask if an object is present. If not present, OID_NOT_FOUND should be
returned. PM -> S.
Answer that an object is present. PM -> S.
Ask locked transactions PM -> S.
Answer locked transactions S -> PM.
"""
_fmt = PStruct('object_present',
POID('oid'),
PTID('tid'),
)
_answer = PStruct('object_present',
POID('oid'),
_answer = PStruct('answer_locked_transactions',
PDict('tid_dict',
PTID('ttid'),
PTID('tid'),
),
)
class DeleteTransaction(Packet):
class FinalTID(Packet):
"""
Delete a transaction. PM -> S.
Return final tid if ttid has been committed. * -> S.
"""
_fmt = PStruct('delete_transaction',
_fmt = PStruct('final_tid',
PTID('ttid'),
)
_answer = PStruct('final_tid',
PTID('tid'),
PFOidList,
)
class CommitTransaction(Packet):
class ValidateTransaction(Packet):
"""
Commit a transaction. PM -> S.
"""
_fmt = PStruct('commit_transaction',
_fmt = PStruct('validate_transaction',
PTID('ttid'),
PTID('tid'),
)
......@@ -878,11 +878,10 @@ class LockInformation(Packet):
_fmt = PStruct('ask_lock_informations',
PTID('ttid'),
PTID('tid'),
PFOidList,
)
_answer = PStruct('answer_information_locked',
PTID('tid'),
PTID('ttid'),
)
class InvalidateObjects(Packet):
......@@ -899,7 +898,7 @@ class UnlockInformation(Packet):
Unlock information on a transaction. PM -> S.
"""
_fmt = PStruct('notify_unlock_information',
PTID('tid'),
PTID('ttid'),
)
class GenerateOIDs(Packet):
......@@ -961,10 +960,17 @@ class StoreTransaction(Packet):
PString('extension'),
PFOidList,
)
_answer = PFEmpty
_answer = PStruct('answer_store_transaction',
class VoteTransaction(Packet):
"""
Ask to store a transaction. C -> S.
Answer if transaction has been stored. S -> C.
"""
_fmt = PStruct('ask_vote_transaction',
PTID('tid'),
)
_answer = PFEmpty
class GetObject(Packet):
"""
......@@ -1600,12 +1606,12 @@ class Packets(dict):
StopOperation)
AskUnfinishedTransactions, AnswerUnfinishedTransactions = register(
UnfinishedTransactions)
AskObjectPresent, AnswerObjectPresent = register(
ObjectPresent)
DeleteTransaction = register(
DeleteTransaction)
CommitTransaction = register(
CommitTransaction)
AskLockedTransactions, AnswerLockedTransactions = register(
LockedTransactions)
AskFinalTID, AnswerFinalTID = register(
FinalTID)
ValidateTransaction = register(
ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction)
AskFinishTransaction, AnswerTransactionFinished = register(
......@@ -1624,6 +1630,8 @@ class Packets(dict):
AbortTransaction)
AskStoreTransaction, AnswerStoreTransaction = register(
StoreTransaction)
AskVoteTransaction, AnswerVoteTransaction = register(
VoteTransaction)
AskObject, AnswerObject = register(
GetObject)
AskTIDs, AnswerTIDs = register(
......
......@@ -59,9 +59,10 @@ class ClientServiceHandler(MasterHandler):
pt = app.pt
# Collect partitions related to this transaction.
lock_oid_list = oid_list + checked_list
partition_set = set(map(pt.getPartition, lock_oid_list))
partition_set.add(pt.getPartition(ttid))
getPartition = pt.getPartition
partition_set = set(map(getPartition, oid_list))
partition_set.update(map(getPartition, checked_list))
partition_set.add(getPartition(ttid))
# Collect the UUIDs of nodes related to this transaction.
uuid_list = filter(app.isStorageReady, {cell.getUUID()
......@@ -85,7 +86,6 @@ class ClientServiceHandler(MasterHandler):
{x.getUUID() for x in identified_node_list},
conn.getPeerId(),
),
lock_oid_list,
)
for node in identified_node_list:
node.ask(p, timeout=60)
......
......@@ -359,7 +359,12 @@ class TransactionManager(object):
self._unlockPending()
def _unlockPending(self):
# unlock pending transactions
"""Serialize transaction unlocks
This should rarely delay unlocks since the time needed to lock a
transaction is roughly constant. The most common case where reordering
is required is when some storages are already busy by other tasks.
"""
queue = self._queue
pop = queue.pop
insert = queue.insert
......
......@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import defaultdict
from neo.lib import logging
from neo.lib.util import dump
from neo.lib.protocol import ClusterStates, Packets, NodeStates
......@@ -37,10 +38,9 @@ class VerificationManager(BaseServiceHandler):
"""
def __init__(self, app):
self._oid_set = set()
self._tid_set = set()
self._locked_dict = {}
self._voted_dict = defaultdict(set)
self._uuid_set = set()
self._object_present = False
def _askStorageNodesAndWait(self, packet, node_list):
poll = self.app.em.poll
......@@ -87,7 +87,6 @@ class VerificationManager(BaseServiceHandler):
self.app.broadcastPartitionChanges(self.app.pt.outdate())
def verifyData(self):
"""Verify the data in storage nodes and clean them up, if necessary."""
app = self.app
# wait for any missing node
......@@ -100,106 +99,59 @@ class VerificationManager(BaseServiceHandler):
logging.info('start to verify data')
getIdentifiedList = app.nm.getIdentifiedList
# Gather all unfinished transactions.
self._askStorageNodesAndWait(Packets.AskUnfinishedTransactions(),
# Gather all transactions that may have been partially finished.
self._askStorageNodesAndWait(Packets.AskLockedTransactions(),
[x for x in getIdentifiedList() if x.isStorage()])
# Gather OIDs for each unfinished TID, and verify whether the
# transaction can be finished or must be aborted. This could be
# in parallel in theory, but not so easy. Thus do it one-by-one
# at the moment.
for tid in self._tid_set:
uuid_set = self.verifyTransaction(tid)
if uuid_set is None:
packet = Packets.DeleteTransaction(tid, self._oid_set or [])
# Make sure that no node has this transaction.
for node in getIdentifiedList():
if node.isStorage():
node.notify(packet)
# Some nodes may have already unlocked these transactions and
# _locked_dict is incomplete, but we can ask them the final tid.
for ttid, voted_set in self._voted_dict.iteritems():
if ttid in self._locked_dict:
continue
partition = app.pt.getPartition(ttid)
for node in getIdentifiedList(pool_set={cell.getUUID()
# If an outdated cell had unlocked ttid, then either
# it is already in _locked_dict or a readable cell also
# unlocked it.
for cell in app.pt.getCellList(partition, readable=True)
} - voted_set):
self._askStorageNodesAndWait(Packets.AskFinalTID(ttid), (node,))
if self._tid is not None:
self._locked_dict[ttid] = self._tid
break
else:
# Transaction not locked. No need to tell nodes to delete it,
# since they drop any unfinished data just before being
# operational.
pass
# Finish all transactions for which we know that tpc_finish was called
# but not fully processed. This may include replicas with transactions
# that were not even locked.
for ttid, tid in self._locked_dict.iteritems():
uuid_set = self._voted_dict.get(ttid)
if uuid_set:
packet = Packets.ValidateTransaction(ttid, tid)
for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet)
if app.getLastTransaction() < tid: # XXX: refactoring needed
app.setLastTransaction(tid)
app.tm.setLastTID(tid)
packet = Packets.CommitTransaction(tid)
for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet)
self._oid_set = set()
# If possible, send the packets now.
app.em.poll(0)
def verifyTransaction(self, tid):
nm = self.app.nm
uuid_set = set()
# Determine to which nodes I should ask.
partition = self.app.pt.getPartition(tid)
uuid_list = [cell.getUUID() for cell \
in self.app.pt.getCellList(partition, readable=True)]
if len(uuid_list) == 0:
raise VerificationFailure
uuid_set.update(uuid_list)
# Gather OIDs.
node_list = self.app.nm.getIdentifiedList(pool_set=uuid_list)
if len(node_list) == 0:
raise VerificationFailure
self._askStorageNodesAndWait(Packets.AskTransactionInformation(tid),
node_list)
if self._oid_set is None or len(self._oid_set) == 0:
# Not commitable.
return None
# Verify that all objects are present.
for oid in self._oid_set:
partition = self.app.pt.getPartition(oid)
object_uuid_list = [cell.getUUID() for cell \
in self.app.pt.getCellList(partition, readable=True)]
if len(object_uuid_list) == 0:
raise VerificationFailure
uuid_set.update(object_uuid_list)
self._object_present = True
self._askStorageNodesAndWait(Packets.AskObjectPresent(oid, tid),
nm.getIdentifiedList(pool_set=object_uuid_list))
if not self._object_present:
# Not commitable.
return None
return uuid_set
def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
logging.info('got unfinished transactions %s from %r',
map(dump, tid_list), conn)
self._uuid_set.remove(conn.getUUID())
self._tid_set.update(tid_list)
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
self._uuid_set.remove(conn.getUUID())
oid_set = set(oid_list)
if self._oid_set is None:
# Someone does not agree.
pass
elif len(self._oid_set) == 0:
# This is the first answer.
self._oid_set.update(oid_set)
elif self._oid_set != oid_set:
raise ValueError, "Inconsistent transaction %s" % \
(dump(tid, ))
def tidNotFound(self, conn, message):
logging.info('TID not found: %s', message)
self._uuid_set.remove(conn.getUUID())
self._oid_set = None
def answerObjectPresent(self, conn, oid, tid):
logging.info('object %s:%s found', dump(oid), dump(tid))
self._uuid_set.remove(conn.getUUID())
def answerLockedTransactions(self, conn, tid_dict):
uuid = conn.getUUID()
self._uuid_set.remove(uuid)
for ttid, tid in tid_dict.iteritems():
if tid:
self._locked_dict[ttid] = tid
self._voted_dict[ttid].add(uuid)
def oidNotFound(self, conn, message):
logging.info('OID not found: %s', message)
def answerFinalTID(self, conn, tid):
self._uuid_set.remove(conn.getUUID())
self._object_present = False
self._tid = tid
def connectionCompleted(self, conn):
pass
......
......@@ -53,7 +53,6 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions',
# client application
......
......@@ -297,8 +297,9 @@ class ImporterDatabaseManager(DatabaseManager):
self.db = buildDatabaseManager(main['adapter'],
(main['database'], main.get('engine'), main['wait']))
for x in """query erase getConfiguration _setConfiguration
getPartitionTable changePartitionTable getUnfinishedTIDList
dropUnfinishedData storeTransaction finishTransaction
getPartitionTable changePartitionTable
getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction
storeData _pruneData
""".split():
setattr(self, x, getattr(self.db, x))
......@@ -421,7 +422,7 @@ class ImporterDatabaseManager(DatabaseManager):
logging.warning("All data are imported. You should change"
" your configuration to use the native backend and restart.")
self._import = None
for x in """getObject objectPresent getReplicationTIDList
for x in """getObject getReplicationTIDList
""".split():
setattr(self, x, getattr(self.db, x))
......@@ -434,23 +435,11 @@ class ImporterDatabaseManager(DatabaseManager):
zodb = self.zodb[bisect(self.zodb_index, oid) - 1]
return zodb, oid - zodb.shift_oid
def getLastIDs(self, all=True):
tid, _, _, oid = self.db.getLastIDs(all)
def getLastIDs(self):
tid, _, _, oid = self.db.getLastIDs()
return (max(tid, util.p64(self.zodb_ltid)), None, None,
max(oid, util.p64(self.zodb_loid)))
def objectPresent(self, oid, tid, all=True):
r = self.db.objectPresent(oid, tid, all)
if not r:
u_oid = util.u64(oid)
u_tid = util.u64(tid)
if self.inZodb(u_oid, u_tid):
zodb, oid = self.zodbFromOid(u_oid)
try:
return zodb.loadSerial(util.p64(oid), tid)
except POSKeyError:
pass
def getObject(self, oid, tid=None, before_tid=None):
u64 = util.u64
u_oid = u64(oid)
......@@ -511,6 +500,16 @@ class ImporterDatabaseManager(DatabaseManager):
else:
return self.db.getTransaction(tid, all)
def getFinalTID(self, ttid):
if u64(ttid) <= self.zodb_ltid and self._import:
raise NotImplementedError
return self.db.getFinalTID(ttid)
def deleteTransaction(self, tid):
if u64(tid) <= self.zodb_ltid and self._import:
raise NotImplementedError
self.db.deleteTransaction(tid)
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
p64 = util.p64
tid = p64(self.zodb_tid)
......
......@@ -222,10 +222,10 @@ class DatabaseManager(object):
"""
raise NotImplementedError
def _getLastIDs(self, all=True):
def _getLastIDs(self):
raise NotImplementedError
def getLastIDs(self, all=True):
def getLastIDs(self):
trans, obj, oid = self._getLastIDs()
if trans:
tid = max(trans.itervalues())
......@@ -241,16 +241,16 @@ class DatabaseManager(object):
trans = obj = {}
return tid, trans, obj, oid
def getUnfinishedTIDList(self):
"""Return a list of unfinished transaction's IDs."""
def _getUnfinishedTIDDict(self):
raise NotImplementedError
def objectPresent(self, oid, tid, all = True):
"""Return true iff an object specified by a given pair of an
object ID and a transaction ID is present in a database.
Otherwise, return false. If all is true, the object must be
searched from unfinished transactions as well."""
raise NotImplementedError
def getUnfinishedTIDDict(self):
trans, obj = self._getUnfinishedTIDDict()
obj = dict.fromkeys(obj)
obj.update(trans)
p64 = util.p64
return {p64(ttid): None if tid is None else p64(tid)
for ttid, tid in obj.iteritems()}
@fallback
def getLastObjectTID(self, oid):
......@@ -478,14 +478,18 @@ class DatabaseManager(object):
data_tid = p64(data_tid)
return p64(current_tid), data_tid, is_current
def finishTransaction(self, tid):
"""Finish a transaction specified by a given ID, by moving
temporarily data to a finished area."""
def lockTransaction(self, tid, ttid):
"""Mark voted transaction 'ttid' as committed with given 'tid'"""
raise NotImplementedError
def unlockTransaction(self, tid, ttid):
"""Finalize a transaction by moving data to a finished area."""
raise NotImplementedError
def abortTransaction(self, ttid):
raise NotImplementedError
def deleteTransaction(self, tid, oid_list=()):
"""Delete a transaction and its content specified by a given ID and
an oid list"""
def deleteTransaction(self, tid):
raise NotImplementedError
def deleteObject(self, oid, serial=None):
......
......@@ -214,7 +214,7 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans (
`partition` SMALLINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED NOT NULL,
tid BIGINT UNSIGNED,
packed BOOLEAN NOT NULL,
oids MEDIUMBLOB NOT NULL,
user BLOB NOT NULL,
......@@ -274,7 +274,7 @@ class MySQLDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(t) FROM (SELECT MAX(tid) as t FROM trans"
" WHERE tid<=%s GROUP BY `partition`) as t" % max_tid)[0][0]
def _getLastIDs(self, all=True):
def _getLastIDs(self):
p64 = util.p64
q = self.query
trans = {partition: p64(tid)
......@@ -285,29 +285,21 @@ class MySQLDatabaseManager(DatabaseManager):
" FROM obj GROUP BY `partition`")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY `partition`) as t")[0][0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj")[0]
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all = True):
oid = util.u64(oid)
tid = util.u64(tid)
def _getUnfinishedTIDDict(self):
q = self.query
return q("SELECT 1 FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d"
% (self._getPartition(oid), oid, tid)) or all and \
q("SELECT 1 FROM tobj WHERE tid=%d AND oid=%d" % (tid, oid))
return q("SELECT ttid, tid FROM ttrans"), (ttid
for ttid, in q("SELECT DISTINCT tid FROM tobj"))
def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# MariaDB is smart enough to realize that 'ttid' is constant.
r = self.query("SELECT tid FROM trans"
" WHERE `partition`=%s AND tid>=ttid AND ttid=%s LIMIT 1"
% (self._getPartition(ttid), ttid))
if r:
return util.p64(r[0][0])
def getLastObjectTID(self, oid):
oid = util.u64(oid)
......@@ -450,9 +442,9 @@ class MySQLDatabaseManager(DatabaseManager):
oid_list, user, desc, ext, packed, ttid = transaction
partition = self._getPartition(tid)
assert packed in (0, 1)
q("REPLACE INTO %s VALUES (%d,%d,%i,'%s','%s','%s','%s',%d)" % (
trans_table, partition, tid, packed, e(''.join(oid_list)),
e(user), e(desc), e(ext), u64(ttid)))
q("REPLACE INTO %s VALUES (%s,%s,%s,'%s','%s','%s','%s',%s)" % (
trans_table, partition, 'NULL' if temporary else tid, packed,
e(''.join(oid_list)), e(user), e(desc), e(ext), u64(ttid)))
if temporary:
self.commit()
......@@ -544,40 +536,40 @@ class MySQLDatabaseManager(DatabaseManager):
r = self.query(sql)
return r[0] if r else (None, None)
def finishTransaction(self, tid):
def lockTransaction(self, tid, ttid):
u64 = util.u64
self.query("UPDATE ttrans SET tid=%d WHERE ttid=%d LIMIT 1"
% (u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query
tid = util.u64(tid)
sql = " FROM tobj WHERE tid=%d" % tid
u64 = util.u64
tid = u64(tid)
sql = " FROM tobj WHERE tid=%d" % u64(ttid)
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("INSERT INTO obj SELECT *" + sql)
q("DELETE FROM tobj WHERE tid=%d" % tid)
q("INSERT INTO obj SELECT `partition`, oid, %d, data_id, value_tid %s"
% (tid, sql))
q("DELETE" + sql)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.releaseData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
def abortTransaction(self, ttid):
ttid = util.u64(ttid)
q = self.query
sql = " FROM tobj WHERE tid=%d" % tid
sql = " FROM tobj WHERE tid=%s" % ttid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.releaseData(data_id_list)
q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
q("""DELETE FROM trans WHERE `partition` = %d AND tid = %d""" %
(getPartition(tid), tid))
# delete from obj using indexes
data_id_list = set(data_id_list)
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE `partition`=%d AND oid=%d AND tid=%d" \
% (getPartition(oid), oid, tid)
data_id_list.update(*q("SELECT data_id" + sql))
q("DELETE" + sql)
data_id_list.discard(None)
self._pruneData(data_id_list)
q("DELETE FROM ttrans WHERE ttid=%s" % ttid)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE `partition`=%s AND tid=%s" %
(self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None):
u64 = util.u64
......
......@@ -162,7 +162,7 @@ class SQLiteDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
tid INTEGER,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
......@@ -221,7 +221,7 @@ class SQLiteDatabaseManager(DatabaseManager):
return self.query("SELECT MAX(tid) FROM trans WHERE tid<=?",
(max_tid,)).next()[0]
def _getLastIDs(self, all=True):
def _getLastIDs(self):
p64 = util.p64
q = self.query
trans = {partition: p64(tid)
......@@ -232,30 +232,21 @@ class SQLiteDatabaseManager(DatabaseManager):
" FROM obj GROUP BY partition")}
oid = q("SELECT MAX(oid) FROM (SELECT MAX(oid) AS oid FROM obj"
" GROUP BY partition) as t").next()[0]
if all:
tid = q("SELECT MAX(tid) FROM ttrans").next()[0]
if tid is not None:
trans[None] = p64(tid)
tid, toid = q("SELECT MAX(tid), MAX(oid) FROM tobj").next()
if tid is not None:
obj[None] = p64(tid)
if toid is not None and (oid < toid or oid is None):
oid = toid
return trans, obj, None if oid is None else p64(oid)
def getUnfinishedTIDList(self):
p64 = util.p64
return [p64(t[0]) for t in self.query("SELECT tid FROM ttrans"
" UNION SELECT tid FROM tobj")]
def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid)
tid = util.u64(tid)
def _getUnfinishedTIDDict(self):
q = self.query
return q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND tid=?",
(self._getPartition(oid), oid, tid)).fetchone() or all and \
q("SELECT 1 FROM tobj WHERE tid=? AND oid=?",
(tid, oid)).fetchone()
return q("SELECT ttid, tid FROM ttrans"), (ttid
for ttid, in q("SELECT DISTINCT tid FROM tobj"))
def getFinalTID(self, ttid):
ttid = util.u64(ttid)
# As of SQLite 3.8.7.1, 'tid>=ttid' would ignore the index on tid,
# even though ttid is a constant.
for tid, in self.query("SELECT tid FROM trans"
" WHERE partition=? AND tid>=? AND ttid=? LIMIT 1",
(self._getPartition(ttid), ttid, ttid)):
return util.p64(tid)
def getLastObjectTID(self, oid):
oid = util.u64(oid)
......@@ -362,7 +353,8 @@ class SQLiteDatabaseManager(DatabaseManager):
partition = self._getPartition(tid)
assert packed in (0, 1)
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?,?)" % T,
(partition, tid, packed, buffer(''.join(oid_list)),
(partition, None if temporary else tid,
packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext), u64(ttid)))
if temporary:
self.commit()
......@@ -407,40 +399,41 @@ class SQLiteDatabaseManager(DatabaseManager):
r = r.fetchone()
return r or (None, None)
def finishTransaction(self, tid):
args = util.u64(tid),
def lockTransaction(self, tid, ttid):
u64 = util.u64
self.query("UPDATE ttrans SET tid=? WHERE ttid=?",
(u64(tid), u64(ttid)))
self.commit()
def unlockTransaction(self, tid, ttid):
q = self.query
u64 = util.u64
tid = u64(tid)
ttid = u64(ttid)
sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args)
q("DELETE FROM tobj WHERE tid=?", args)
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args)
q("DELETE FROM ttrans WHERE tid=?", args)
data_id_list = [x for x, in q("SELECT data_id" + sql, (ttid,)) if x]
q("INSERT INTO obj SELECT partition, oid, ?, data_id, value_tid" + sql,
(tid, ttid))
q("DELETE" + sql, (ttid,))
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
self.releaseData(data_id_list)
self.commit()
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
def abortTransaction(self, ttid):
args = util.u64(ttid),
q = self.query
sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x]
self.releaseData(data_id_list)
q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?",
(getPartition(tid), tid))
# delete from obj using indexes
data_id_list = set(data_id_list)
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=? AND oid=? AND tid=?"
args = getPartition(oid), oid, tid
data_id_list.update(*q("SELECT data_id" + sql, args))
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("DELETE" + sql, args)
data_id_list.discard(None)
self._pruneData(data_id_list)
q("DELETE FROM ttrans WHERE ttid=?", args)
self.releaseData(data_id_list, True)
def deleteTransaction(self, tid):
tid = util.u64(tid)
getPartition = self._getPartition
self.query("DELETE FROM trans WHERE partition=? AND tid=?",
(self._getPartition(tid), tid))
def deleteObject(self, oid, serial=None):
oid = util.u64(oid)
......
......@@ -56,3 +56,6 @@ class BaseMasterHandler(EventHandler):
def answerUnfinishedTransactions(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(*args, **kw)
def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
......@@ -19,7 +19,7 @@ from neo.lib.handler import EventHandler
from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \
ZERO_HASH, INVALID_PARTITION
from ..transactions import ConflictError, DelayedError
from ..transactions import ConflictError, DelayedError, NotRegisteredError
from ..exception import AlreadyPendingError
import time
......@@ -68,21 +68,17 @@ class ClientOperationHandler(EventHandler):
def abortTransaction(self, conn, ttid):
self.app.tm.abort(ttid)
def askStoreTransaction(self, conn, ttid, user, desc, ext, oid_list):
def askStoreTransaction(self, conn, ttid, *txn_info):
self.app.tm.register(conn.getUUID(), ttid)
self.app.tm.storeTransaction(ttid, oid_list, user, desc, ext, False)
conn.answer(Packets.AnswerStoreTransaction(ttid))
self.app.tm.vote(ttid, txn_info)
conn.answer(Packets.AnswerStoreTransaction())
def askVoteTransaction(self, conn, ttid):
self.app.tm.vote(ttid)
conn.answer(Packets.AnswerVoteTransaction())
def _askStoreObject(self, conn, oid, serial, compression, checksum, data,
data_serial, ttid, unlock, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
return
try:
self.app.tm.storeObject(ttid, serial, oid, compression,
checksum, data, data_serial, unlock)
......@@ -101,6 +97,13 @@ class ClientOperationHandler(EventHandler):
raise_on_duplicate=unlock)
except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerStoreObject(0, oid, serial))
else:
if SLOW_STORE is not None:
duration = time.time() - request_time
......@@ -189,14 +192,6 @@ class ClientOperationHandler(EventHandler):
self._askCheckCurrentSerial(conn, ttid, serial, oid, time.time())
def _askCheckCurrentSerial(self, conn, ttid, serial, oid, request_time):
if ttid not in self.app.tm:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
return
try:
self.app.tm.checkCurrentSerial(ttid, serial, oid)
except ConflictError, err:
......@@ -210,6 +205,13 @@ class ClientOperationHandler(EventHandler):
serial, oid, request_time), key=(oid, ttid))
except AlreadyPendingError:
conn.answer(Errors.AlreadyPending(dump(oid)))
except NotRegisteredError:
# transaction was aborted, cancel this event
logging.info('Forget serial check of %s:%s by %s delayed by %s',
dump(oid), dump(serial), dump(ttid),
dump(self.app.tm.getLockingTID(oid)))
# send an answer as the client side is waiting for it
conn.answer(Packets.AnswerCheckCurrentSerial(0, oid, serial))
else:
if SLOW_STORE is not None:
duration = time.time() - request_time
......
......@@ -42,17 +42,11 @@ class MasterOperationHandler(BaseMasterHandler):
# Check changes for replications
app.replicator.notifyPartitionChanges(cell_list)
def askLockInformation(self, conn, ttid, tid, oid_list):
if not ttid in self.app.tm:
raise ProtocolError('Unknown transaction')
self.app.tm.lock(ttid, tid, oid_list)
if not conn.isClosed():
def askLockInformation(self, conn, ttid, tid):
self.app.tm.lock(ttid, tid)
conn.answer(Packets.AnswerInformationLocked(ttid))
def notifyUnlockInformation(self, conn, ttid):
if not ttid in self.app.tm:
raise ProtocolError('Unknown transaction')
# TODO: send an answer
self.app.tm.unlock(ttid)
def askPack(self, conn, tid):
......@@ -60,7 +54,6 @@ class MasterOperationHandler(BaseMasterHandler):
logging.info('Pack started, up to %s...', dump(tid))
app.dm.pack(tid, app.tm.updateObjectDataForPack)
logging.info('Pack finished.')
if not conn.isClosed():
conn.answer(Packets.AnswerPack(True))
def replicate(self, conn, tid, upstream_name, source_dict):
......
......@@ -16,8 +16,7 @@
from . import BaseMasterHandler
from neo.lib import logging
from neo.lib.protocol import Packets, Errors, INVALID_TID, ZERO_TID
from neo.lib.util import dump
from neo.lib.protocol import Packets, ZERO_TID
from neo.lib.exception import OperationFailure
class VerificationHandler(BaseMasterHandler):
......@@ -62,31 +61,14 @@ class VerificationHandler(BaseMasterHandler):
def stopOperation(self, conn):
raise OperationFailure('operation stopped')
def askUnfinishedTransactions(self, conn):
tid_list = self.app.dm.getUnfinishedTIDList()
conn.answer(Packets.AnswerUnfinishedTransactions(INVALID_TID, tid_list))
def askLockedTransactions(self, conn):
conn.answer(Packets.AnswerLockedTransactions(
self.app.dm.getUnfinishedTIDDict()))
def askTransactionInformation(self, conn, tid):
app = self.app
t = app.dm.getTransaction(tid, all=True)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def askObjectPresent(self, conn, oid, tid):
if self.app.dm.objectPresent(oid, tid):
p = Packets.AnswerObjectPresent(oid, tid)
else:
p = Errors.OidNotFound(
'%s:%s do not exist' % (dump(oid), dump(tid)))
conn.answer(p)
def deleteTransaction(self, conn, tid, oid_list):
self.app.dm.deleteTransaction(tid, oid_list)
def commitTransaction(self, conn, tid):
self.app.dm.finishTransaction(tid)
def askFinalTID(self, conn, ttid):
conn.answer(Packets.AnswerFinalTID(self.app.dm.getFinalTID(ttid)))
def validateTransaction(self, conn, ttid, tid):
dm = self.app.dm
dm.lockTransaction(tid, ttid)
dm.unlockTransaction(tid, ttid)
......@@ -38,18 +38,22 @@ class DelayedError(Exception):
Raised when an object is locked by a previous transaction
"""
class NotRegisteredError(Exception):
"""
Raised when a ttid is not registered
"""
class Transaction(object):
"""
Container for a pending transaction
"""
_tid = None
has_trans = False
def __init__(self, uuid, ttid):
self._uuid = uuid
self._ttid = ttid
self._object_dict = {}
self._transaction = None
self._locked = False
self._birth = time()
self._checked_set = set()
......@@ -89,13 +93,6 @@ class Transaction(object):
def isLocked(self):
return self._locked
def prepare(self, oid_list, user, desc, ext, packed):
"""
Set the transaction informations
"""
# assert self._transaction is not None
self._transaction = oid_list, user, desc, ext, packed, self._ttid
def addObject(self, oid, data_id, value_serial):
"""
Add an object to the transaction
......@@ -121,9 +118,6 @@ class Transaction(object):
def getLockedOIDList(self):
return self._object_dict.keys() + list(self._checked_set)
def getTransactionInformations(self):
return self._transaction
class TransactionManager(object):
"""
......@@ -137,12 +131,6 @@ class TransactionManager(object):
self._load_lock_dict = {}
self._uuid_dict = {}
def __contains__(self, ttid):
"""
Returns True if the TID is known by the manager
"""
return ttid in self._transaction_dict
def register(self, uuid, ttid):
"""
Register a transaction, it may be already registered
......@@ -174,7 +162,21 @@ class TransactionManager(object):
self._load_lock_dict.clear()
self._uuid_dict.clear()
def lock(self, ttid, tid, oid_list):
def vote(self, ttid, txn_info=None):
"""
Store transaction information received from client node
"""
logging.debug('Vote TXN %s', dump(ttid))
transaction = self._transaction_dict[ttid]
object_list = transaction.getObjectList()
if txn_info:
user, desc, ext, oid_list = txn_info
txn_info = oid_list, user, desc, ext, False, ttid
transaction.has_trans = True
# store metadata to temporary table
self._app.dm.storeTransaction(ttid, object_list, txn_info)
def lock(self, ttid, tid):
"""
Lock a transaction
"""
......@@ -182,43 +184,22 @@ class TransactionManager(object):
transaction = self._transaction_dict[ttid]
# remember that the transaction has been locked
transaction.lock()
for oid in transaction.getOIDList():
self._load_lock_dict[oid] = ttid
# check every object that should be locked
uuid = transaction.getUUID()
is_assigned = self._app.pt.isAssigned
for oid in oid_list:
if is_assigned(oid, uuid) and \
self._load_lock_dict.get(oid) != ttid:
raise ValueError, 'Some locks are not held'
object_list = transaction.getObjectList()
# txn_info is None is the transaction information is not stored on
# this storage.
txn_info = transaction.getTransactionInformations()
# store data from memory to temporary table
self._app.dm.storeTransaction(tid, object_list, txn_info)
# ...and remember its definitive TID
self._load_lock_dict.update(
dict.fromkeys(transaction.getOIDList(), ttid))
# commit transaction and remember its definitive TID
if transaction.has_trans:
self._app.dm.lockTransaction(tid, ttid)
transaction.setTID(tid)
def getTIDFromTTID(self, ttid):
return self._transaction_dict[ttid].getTID()
def unlock(self, ttid):
"""
Unlock transaction
"""
logging.debug('Unlock TXN %s', dump(ttid))
self._app.dm.finishTransaction(self.getTIDFromTTID(ttid))
tid = self._transaction_dict[ttid].getTID()
logging.debug('Unlock TXN %s (ttid=%s)', dump(tid), dump(ttid))
self._app.dm.unlockTransaction(tid, ttid)
self.abort(ttid, even_if_locked=True)
def storeTransaction(self, ttid, oid_list, user, desc, ext, packed):
"""
Store transaction information received from client node
"""
assert ttid in self, "Transaction not registered"
transaction = self._transaction_dict[ttid]
transaction.prepare(oid_list, user, desc, ext, packed)
def getLockingTID(self, oid):
return self._store_lock_dict.get(oid)
......@@ -283,9 +264,11 @@ class TransactionManager(object):
self._store_lock_dict[oid] = ttid
def checkCurrentSerial(self, ttid, serial, oid):
self.lockObject(ttid, serial, oid, unlock=True)
assert ttid in self, "Transaction not registered"
try:
transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=True)
transaction.addCheckedObject(oid)
def storeObject(self, ttid, serial, oid, compression, checksum, data,
......@@ -293,14 +276,17 @@ class TransactionManager(object):
"""
Store an object received from client node
"""
try:
transaction = self._transaction_dict[ttid]
except KeyError:
raise NotRegisteredError
self.lockObject(ttid, serial, oid, unlock=unlock)
# store object
assert ttid in self, "Transaction not registered"
if data is None:
data_id = None
else:
data_id = self._app.dm.holdData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, data_id, value_serial)
transaction.addObject(oid, data_id, value_serial)
def abort(self, ttid, even_if_locked=False):
"""
......@@ -322,9 +308,7 @@ class TransactionManager(object):
if not even_if_locked:
return
else:
self._app.dm.releaseData([data_id
for oid, data_id, value_serial in transaction.getObjectList()
if data_id], True)
self._app.dm.abortTransaction(ttid)
# unlock any object
for oid in transaction.getLockedOIDList():
if has_load_lock:
......
......@@ -463,9 +463,6 @@ class NeoUnitTestBase(NeoTestBase):
def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObjectPresent(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObjectPresent, **kw)
def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw)
......@@ -514,18 +511,12 @@ class NeoUnitTestBase(NeoTestBase):
def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
def checkAnswerObjectPresent(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
class Patch(object):
......
......@@ -191,18 +191,6 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.operation.askObjectHistory(conn, oid2, 1, 2)
self.checkAnswerObjectHistory(conn)
def test_askStoreTransaction(self):
conn = self._getConnection(uuid=self.getClientUUID())
tid = self.getNextTID()
user = 'USER'
desc = 'DESC'
ext = 'EXT'
oid_list = (self.getOID(1), self.getOID(2))
self.operation.askStoreTransaction(conn, tid, user, desc, ext, oid_list)
calls = self.app.tm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
self.checkAnswerStoreTransaction(conn)
def _getObject(self):
oid = self.getOID(0)
serial = self.getNextTID()
......
......@@ -112,50 +112,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
def _getConnection(self):
return self.getFakeConnection()
def test_askLockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = self.getNextTID()
ttid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.askLockInformation, conn,
ttid, tid, oid_list)
def test_askLockInformation2(self):
""" Lock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
ttid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)]
self.operation.askLockInformation(conn, ttid, tid, oid_list)
calls = self.app.tm.mockGetNamedCalls('lock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ttid, tid, oid_list)
self.checkAnswerInformationLocked(conn)
def test_notifyUnlockInformation1(self):
""" Unknown transaction """
self.app.tm = Mock({'__contains__': False})
conn = self._getConnection()
tid = self.getNextTID()
handler = self.operation
self.assertRaises(ProtocolError, handler.notifyUnlockInformation,
conn, tid)
def test_notifyUnlockInformation2(self):
""" Unlock transaction """
self.app.tm = Mock({'__contains__': True})
conn = self._getConnection()
tid = self.getNextTID()
self.operation.notifyUnlockInformation(conn, tid)
calls = self.app.tm.mockGetNamedCalls('unlock')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
self.checkNoPacketSent(conn)
def test_askPack(self):
self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection()
......
......@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from binascii import a2b_hex
from contextlib import contextmanager
import unittest
from neo.lib.util import add64, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
......@@ -80,6 +81,17 @@ class StorageDBTests(NeoUnitTestBase):
set_call(value * 2)
self.assertEqual(get_call(), value * 2)
@contextmanager
def commitTransaction(self, tid, objs, txn, commit=True):
ttid = txn[-1]
self.db.storeTransaction(ttid, objs, txn)
self.db.lockTransaction(tid, ttid)
yield
if commit:
self.db.unlockTransaction(tid, ttid)
elif commit is not None:
self.db.abortTransaction(ttid)
def test_UUID(self):
db = self.getDB()
self.checkConfigEntry(db.getUUID, db.setUUID, 123)
......@@ -122,38 +134,24 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2))
def test_getUnfinishedTIDList(self):
def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2)
txn, objs = self.getTransaction([oid1, oid2])
# nothing pending
self.db.storeTransaction(tid1, objs, txn, False)
self.checkSet(self.db.getUnfinishedTIDList(), [])
# one unfinished txn
self.db.storeTransaction(tid2, objs, txn)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2])
with self.commitTransaction(tid2, objs, txn):
expected = {txn[-1]: tid2}
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# no changes
self.db.storeTransaction(tid3, objs, None, False)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2])
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
# a second txn known by objs only
expected[tid4] = None
self.db.storeTransaction(tid4, objs, None)
self.checkSet(self.db.getUnfinishedTIDList(), [tid2, tid4])
def test_objectPresent(self):
tid = self.getNextTID()
oid = self.getOID(1)
txn, objs = self.getTransaction([oid])
# not present
self.assertFalse(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in temp table
self.db.storeTransaction(tid, objs, txn)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertFalse(self.db.objectPresent(oid, tid, all=False))
# available in both tables
self.db.finishTransaction(tid)
self.assertTrue(self.db.objectPresent(oid, tid, all=True))
self.assertTrue(self.db.objectPresent(oid, tid, all=False))
self.assertEqual(self.db.getUnfinishedTIDDict(), expected)
self.db.abortTransaction(tid4)
# nothing pending
self.assertEqual(self.db.getUnfinishedTIDDict(), {})
def test_getObject(self):
oid1, = self.getOIDs(1)
......@@ -169,27 +167,26 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one non-commited version
self.db.storeTransaction(tid1, objs1, txn1)
with self.commitTransaction(tid1, objs1, txn1):
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one commited version
self.db.finishTransaction(tid1)
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE)
# two version available, one non-commited
self.db.storeTransaction(tid2, objs2, txn2)
with self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, tid2), FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, tid2),
FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NO_NEXT)
# two commited versions
self.db.finishTransaction(tid2)
self.assertEqual(self.db.getObject(oid1), OBJECT_T2)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
......@@ -242,82 +239,28 @@ class StorageDBTests(NeoUnitTestBase):
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
def test_dropUnfinishedData(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
# nothing
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [])
# one is still pending
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getUnfinishedTIDList(), [tid2])
# drop it
self.db.dropUnfinishedData()
self.assertEqual(self.db.getUnfinishedTIDList(), [])
result = self.db.getObject(oid1)
self.assertEqual(result, (tid1, None, 1, "0"*20, '', None))
self.assertEqual(self.db.getObject(oid2), None)
def test_storeTransaction(self):
def test_commitTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# nothing in database
self.assertEqual(self.db.getLastIDs(), (None, {}, {}, None))
self.assertEqual(self.db.getUnfinishedTIDList(), [])
self.assertEqual(self.db.getUnfinishedTIDDict(), {})
self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None)
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# store in temporary tables
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# commit pending transaction
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
result = self.db.getTransaction(tid1, False)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, False)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_askFinishTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# stored but not finished
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getTransaction(tid1, True),
([oid1], 'user', 'desc', 'ext', False, p64(1)))
self.assertEqual(self.db.getTransaction(tid2, True),
([oid2], 'user', 'desc', 'ext', False, p64(2)))
self.assertEqual(self.db.getTransaction(tid1, False), None)
self.assertEqual(self.db.getTransaction(tid2, False), None)
# stored and finished
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
......@@ -328,32 +271,29 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(result, ([oid2], 'user', 'desc', 'ext', False, p64(2)))
def test_deleteTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.deleteTransaction(tid1, [oid1])
self.db.deleteTransaction(tid2, [oid2])
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
txn, objs = self.getTransaction([])
tid = txn[-1]
self.db.storeTransaction(tid, objs, txn, False)
self.assertEqual(self.db.getTransaction(tid), txn)
self.db.deleteTransaction(tid)
self.assertEqual(self.db.getTransaction(tid), None)
def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
tid1 = txn1[-1]
tid2 = txn2[-1]
self.db.storeTransaction(tid1, objs1, txn1, False)
self.db.storeTransaction(tid2, objs2, txn2, False)
self.assertEqual(self.db.getObject(oid1, tid=tid1),
(tid1, tid2, 1, "0" * 20, '', None))
self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.assertIs(self.db.getObject(oid1, tid=tid1), None)
self.assertIs(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1)
self.assertFalse(self.db.getObject(oid2, tid=tid1))
self.assertIs(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2),
(tid2, None, 1, "0" * 20, '', None))
......@@ -364,8 +304,7 @@ class StorageDBTests(NeoUnitTestBase):
oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list)
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.storeTransaction(tid, objs, txn, False)
def check(offset, tid_list, *tids):
self.assertEqual(self.db.getReplicationTIDList(ZERO_TID,
MAX_TID, len(tid_list) + 1, offset), tid_list)
......@@ -386,9 +325,9 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
# get from temporary table or not
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
with self.commitTransaction(tid1, objs1, txn1), \
self.commitTransaction(tid2, objs2, txn2, None):
pass
result = self.db.getTransaction(tid1, True)
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False, p64(1)))
result = self.db.getTransaction(tid2, True)
......@@ -405,15 +344,13 @@ class StorageDBTests(NeoUnitTestBase):
txn2, objs2 = self.getTransaction([oid])
txn3, objs3 = self.getTransaction([oid])
# one revision
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid1, objs1, txn1, False)
result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 1)
self.assertEqual(result, None)
# two revisions
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid2)
self.db.storeTransaction(tid2, objs2, txn2, False)
result = self.db.getObjectHistory(oid, 0, 3)
self.assertEqual(result, [(tid2, 0), (tid1, 0)])
result = self.db.getObjectHistory(oid, 1, 3)
......@@ -427,8 +364,7 @@ class StorageDBTests(NeoUnitTestBase):
oid = self.getOID(1)
for tid in tid_list:
txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.storeTransaction(tid, objs, txn, False)
return tid_list
def test_getTIDList(self):
......
......@@ -45,15 +45,6 @@ class TransactionTests(NeoUnitTestBase):
# disallow lock more than once
self.assertRaises(AssertionError, txn.lock)
def testTransaction(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
repr(txn) # check __repr__ does not raise
oid_list = [self.getOID(1), self.getOID(2)]
txn_info = (oid_list, 'USER', 'DESC', 'EXT', False)
txn.prepare(*txn_info)
self.assertEqual(txn.getTransactionInformations(),
txn_info + (txn.getTTID(),))
def testObjects(self):
txn = Transaction(self.getClientUUID(), self.getNextTID())
oid1, oid2 = self.getOID(1), self.getOID(2)
......@@ -91,10 +82,10 @@ class TransactionManagerTests(NeoUnitTestBase):
def _getTransaction(self):
tid = self.getNextTID(self.ltid)
oid_list = [self.getOID(1), self.getOID(2)]
return (tid, (oid_list, 'USER', 'DESC', 'EXT', False))
return (tid, ('USER', 'DESC', 'EXT', oid_list))
def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]):
for i, oid in enumerate(txn[3]):
self.manager.storeObject(tid, None,
oid, 1, '%020d' % i, '0' + str(i), None)
......@@ -108,15 +99,21 @@ class TransactionManagerTests(NeoUnitTestBase):
self.assertEqual(len(calls), 1)
calls[0].checkArgs(*args)
def _checkTransactionFinished(self, tid):
calls = self.app.dm.mockGetNamedCalls('finishTransaction')
def _checkTransactionFinished(self, *args):
calls = self.app.dm.mockGetNamedCalls('unlockTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
calls[0].checkArgs(*args)
def _checkQueuedEventExecuted(self, number=1):
calls = self.app.mockGetNamedCalls('executeQueuedEvents')
self.assertEqual(len(calls), number)
def assertRegistered(self, ttid):
self.assertIn(ttid, self.manager._transaction_dict)
def assertNotRegistered(self, ttid):
self.assertNotIn(ttid, self.manager._transaction_dict)
def testSimpleCase(self):
""" One node, one transaction, not abort """
data_id_list = random.random(), random.random()
......@@ -127,18 +124,23 @@ class TransactionManagerTests(NeoUnitTestBase):
serial1, object1 = self._getObject(1)
serial2, object2 = self._getObject(2)
self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self.manager.storeObject(ttid, serial1, *object1)
self.manager.storeObject(ttid, serial2, *object2)
self.assertTrue(ttid in self.manager)
self.manager.lock(ttid, tid, txn[0])
self._checkTransactionStored(tid, [
self.assertRegistered(ttid)
self.manager.vote(ttid, txn)
user, desc, ext, oid_list = txn
call, = self.app.dm.mockGetNamedCalls('storeTransaction')
call.checkArgs(ttid, [
(object1[0], data_id_list[0], object1[4]),
(object2[0], data_id_list[1], object2[4]),
], txn + (ttid,))
], (oid_list, user, desc, ext, False, ttid))
self.manager.lock(ttid, tid)
call, = self.app.dm.mockGetNamedCalls('lockTransaction')
call.checkArgs(tid, ttid)
self.manager.unlock(ttid)
self.assertFalse(ttid in self.manager)
self._checkTransactionFinished(tid)
self.assertNotRegistered(ttid)
call, = self.app.dm.mockGetNamedCalls('unlockTransaction')
call.checkArgs(tid, ttid)
def testDelayed(self):
""" Two transactions, the first cause the second to be delayed """
......@@ -150,14 +152,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1)
# first transaction lock the object
self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self._storeTransactionObjects(ttid1, txn1)
self.manager.lock(ttid1, tid1, txn1[0])
self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# the second is delayed
self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.assertTrue(ttid2 in self.manager)
self.assertRegistered(ttid2)
self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial, *obj)
......@@ -171,14 +172,13 @@ class TransactionManagerTests(NeoUnitTestBase):
serial, obj = self._getObject(1)
# the (later) transaction lock (change) the object
self.manager.register(uuid, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.assertTrue(ttid2 in self.manager)
self.assertRegistered(ttid2)
self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0])
self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the previous it's not using the latest version
self.manager.register(uuid, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial, *obj)
......@@ -191,7 +191,6 @@ class TransactionManagerTests(NeoUnitTestBase):
# try to store without the last revision
self.app.dm = Mock({'getLastObjectTID': next_serial})
self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.assertRaises(ConflictError, self.manager.storeObject,
tid, serial, *obj)
......@@ -208,15 +207,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2)
# first transaction lock objects
self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self.manager.storeObject(ttid1, serial1, *obj1)
self.manager.storeObject(ttid1, serial1, *obj2)
self.manager.lock(ttid1, tid1, txn1[0])
self.manager.vote(ttid1, txn1)
self.manager.lock(ttid1, tid1)
# second transaction is delayed
self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.assertTrue(ttid2 in self.manager)
self.assertRegistered(ttid2)
self.assertRaises(DelayedError, self.manager.storeObject,
ttid2, serial1, *obj1)
self.assertRaises(DelayedError, self.manager.storeObject,
......@@ -235,15 +233,14 @@ class TransactionManagerTests(NeoUnitTestBase):
serial2, obj2 = self._getObject(2)
# the second transaction lock objects
self.manager.register(uuid2, ttid2)
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeObject(ttid2, serial1, *obj1)
self.manager.storeObject(ttid2, serial2, *obj2)
self.assertTrue(ttid2 in self.manager)
self.manager.lock(ttid2, tid2, txn1[0])
self.assertRegistered(ttid2)
self.manager.vote(ttid2, txn2)
self.manager.lock(ttid2, tid2)
# the first get a conflict
self.manager.register(uuid1, ttid1)
self.manager.storeTransaction(ttid1, *txn1)
self.assertTrue(ttid1 in self.manager)
self.assertRegistered(ttid1)
self.assertRaises(ConflictError, self.manager.storeObject,
ttid1, serial1, *obj1)
self.assertRaises(ConflictError, self.manager.storeObject,
......@@ -255,12 +252,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
self.manager.register(uuid, tid)
self.manager.storeTransaction(tid, *txn)
self.manager.storeObject(tid, serial, *obj)
self.assertTrue(tid in self.manager)
self.assertRegistered(tid)
self.manager.vote(tid, txn)
# transaction is not locked
self.manager.abort(tid)
self.assertFalse(tid in self.manager)
self.assertNotRegistered(tid)
self.assertFalse(self.manager.loadLocked(obj[0]))
self._checkQueuedEventExecuted()
......@@ -270,14 +267,14 @@ class TransactionManagerTests(NeoUnitTestBase):
ttid = self.getNextTID()
tid, txn = self._getTransaction()
self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn)
self.manager.vote(ttid, txn)
# lock transaction
self.manager.lock(ttid, tid, txn[0])
self.assertTrue(ttid in self.manager)
self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.abort(ttid)
self.assertTrue(ttid in self.manager)
for oid in txn[0]:
self.assertRegistered(ttid)
for oid in txn[-1]:
self.assertTrue(self.manager.loadLocked(oid))
self._checkQueuedEventExecuted(number=0)
......@@ -295,20 +292,20 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.register(uuid1, ttid1)
self.manager.register(uuid2, ttid2)
self.manager.register(uuid2, ttid3)
self.manager.storeTransaction(ttid1, *txn1)
self.manager.vote(ttid1, txn1)
# node 2 owns tid2 & tid3 and lock tid2 only
self.manager.storeTransaction(ttid2, *txn2)
self.manager.storeTransaction(ttid3, *txn3)
self._storeTransactionObjects(ttid2, txn2)
self.manager.lock(ttid2, tid2, txn2[0])
self.assertTrue(ttid1 in self.manager)
self.assertTrue(ttid2 in self.manager)
self.assertTrue(ttid3 in self.manager)
self.manager.vote(ttid2, txn2)
self.manager.vote(ttid3, txn3)
self.manager.lock(ttid2, tid2)
self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertRegistered(ttid3)
self.manager.abortFor(uuid2)
# only tid3 is aborted
self.assertTrue(ttid1 in self.manager)
self.assertTrue(ttid2 in self.manager)
self.assertFalse(ttid3 in self.manager)
self.assertRegistered(ttid1)
self.assertRegistered(ttid2)
self.assertNotRegistered(ttid3)
self._checkQueuedEventExecuted(number=1)
def testReset(self):
......@@ -317,12 +314,12 @@ class TransactionManagerTests(NeoUnitTestBase):
tid, txn = self._getTransaction()
ttid = self.getNextTID()
self.manager.register(uuid, ttid)
self.manager.storeTransaction(ttid, *txn)
self._storeTransactionObjects(ttid, txn)
self.manager.lock(ttid, tid, txn[0])
self.assertTrue(ttid in self.manager)
self.manager.vote(ttid, txn)
self.manager.lock(ttid, tid)
self.assertRegistered(ttid)
self.manager.reset()
self.assertFalse(ttid in self.manager)
self.assertNotRegistered(ttid)
for oid in txn[0]:
self.assertFalse(self.manager.loadLocked(oid))
......
#
# Copyright (C) 2009-2015 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from mock import Mock
from .. import NeoUnitTestBase
from neo.lib.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.verification import VerificationHandler
from neo.lib.protocol import CellStates, ErrorCodes
from neo.lib.exception import PrimaryFailure
from neo.lib.util import p64, u64
class StorageVerificationHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.verification = VerificationHandler(self.app)
# define some variable to simulate client and storage node
self.master_port = 10010
self.storage_port = 10020
self.client_port = 11011
self.num_partitions = 1009
self.num_replicas = 2
self.app.operational = False
self.app.load_lock_dict = {}
self.app.pt = PartitionTable(self.num_partitions, self.num_replicas)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageVerificationHandlerTests, self)._tearDown(success)
# Common methods
def getMasterConnection(self):
return self.getFakeConnection(address=("127.0.0.1", self.master_port))
# Tests
def test_03_connectionClosed(self):
conn = self.getMasterConnection()
self.app.listening_conn = object() # mark as running
self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
# nothing happens
self.checkNoPacketSent(conn)
def test_08_askPartitionTable(self):
node = self.app.nm.createStorage(
address=("127.7.9.9", 1),
uuid=self.getStorageUUID()
)
self.app.pt.setCell(1, node, CellStates.UP_TO_DATE)
self.assertTrue(self.app.pt.hasOffset(1))
conn = self.getMasterConnection()
self.verification.askPartitionTable(conn)
ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True)
self.assertEqual(len(row_list), 1009)
def test_10_notifyPartitionChanges(self):
# old partition change
conn = self.getMasterConnection()
self.verification.notifyPartitionChanges(conn, 1, ())
self.verification.notifyPartitionChanges(conn, 0, ())
self.assertEqual(self.app.pt.getID(), 1)
# new node
conn = self.getMasterConnection()
new_uuid = self.getStorageUUID()
cell = (0, new_uuid, CellStates.UP_TO_DATE)
self.app.nm.createStorage(uuid=new_uuid)
self.app.pt = PartitionTable(1, 1)
self.app.dm = Mock({ })
ptid = self.getPTID()
# pt updated
self.verification.notifyPartitionChanges(conn, ptid, (cell, ))
# check db update
calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0].getParam(0), ptid)
self.assertEqual(calls[0].getParam(1), (cell, ))
def test_13_askUnfinishedTransactions(self):
# client connection with no data
self.app.dm = Mock({
'getUnfinishedTIDList': [],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 0)
call_list = self.app.dm.mockGetNamedCalls('getUnfinishedTIDList')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
# client connection with some data
self.app.dm = Mock({
'getUnfinishedTIDList': [p64(4)],
})
conn = self.getMasterConnection()
self.verification.askUnfinishedTransactions(conn)
(max_tid, tid_list) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1)
self.assertEqual(u64(tid_list[0]), 4)
def test_14_askTransactionInformation(self):
# ask from client conn with no data
self.app.dm = Mock({
'getTransaction': None,
})
conn = self.getMasterConnection()
tid = p64(1)
self.verification.askTransactionInformation(conn, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('getTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, all=True)
# input some tmp data and ask from client, must find both transaction
self.app.dm = Mock({
'getTransaction': ([p64(2)], 'u2', 'd2', 'e2', False),
})
conn = self.getMasterConnection()
self.verification.askTransactionInformation(conn, p64(1))
tid, user, desc, ext, packed, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1)
self.assertEqual(user, 'u2')
self.assertEqual(desc, 'd2')
self.assertEqual(ext, 'e2')
self.assertFalse(packed)
self.assertEqual(len(oid_list), 1)
self.assertEqual(u64(oid_list[0]), 2)
def test_15_askObjectPresent(self):
# client connection with no data
self.app.dm = Mock({
'objectPresent': False,
})
conn = self.getMasterConnection()
oid, tid = p64(1), p64(2)
self.verification.askObjectPresent(conn, oid, tid)
code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.OID_NOT_FOUND)
call_list = self.app.dm.mockGetNamedCalls('objectPresent')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(oid, tid)
# client connection with some data
self.app.dm = Mock({
'objectPresent': True,
})
conn = self.getMasterConnection()
self.verification.askObjectPresent(conn, oid, tid)
oid, tid = self.checkAnswerObjectPresent(conn, decode=True)
self.assertEqual(u64(tid), 2)
self.assertEqual(u64(oid), 1)
def test_16_deleteTransaction(self):
# client connection with no data
self.app.dm = Mock({
'deleteTransaction': None,
})
conn = self.getMasterConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = p64(1)
self.verification.deleteTransaction(conn, tid, oid_list)
call_list = self.app.dm.mockGetNamedCalls('deleteTransaction')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs(tid, oid_list)
def test_17_commitTransaction(self):
# commit a transaction
conn = self.getMasterConnection()
dm = Mock()
self.app.dm = dm
self.verification.commitTransaction(conn, p64(1))
self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
call = dm.mockGetNamedCalls("finishTransaction")[0]
tid = call.getParam(0)
self.assertEqual(u64(tid), 1)
if __name__ == "__main__":
unittest.main()
......@@ -21,7 +21,7 @@ import transaction
import unittest
from thread import get_ident
from zlib import compress
from persistent import Persistent
from persistent import Persistent, GHOST
from ZODB import DB, POSException
from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError
......@@ -398,34 +398,46 @@ class Test(NEOThreadedTest):
""" Verification step should commit locked transactions """
def delayUnlockInformation(conn, packet):
return isinstance(packet, Packets.NotifyUnlockInformation)
def onStoreTransaction(storage, die=False):
def storeTransaction(orig, *args, **kw):
orig(*args, **kw)
def onLockTransaction(storage, die=False):
def lock(orig, *args, **kw):
if die:
sys.exit()
orig(*args, **kw)
storage.master_conn.close()
return Patch(storage.dm, storeTransaction=storeTransaction)
return Patch(storage.tm, lock=lock)
cluster = NEOCluster(partitions=2, storage_count=2)
try:
cluster.start()
s0, s1 = cluster.sortStorageList()
t, c = cluster.getTransaction()
r = c.root()
r[0] = x = PCounter()
r[0] = PCounter()
tids = [r._p_serial]
t.commit()
with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit)
self.assertEqual(r._p_state, GHOST)
self.tic()
t.begin()
x = r[0]
self.assertEqual(x.value, 0)
cluster.master.tm._last_oid = x._p_oid
tids.append(r._p_serial)
r[1] = PCounter()
with onStoreTransaction(s0), onStoreTransaction(s1):
c.readCurrent(x)
with cluster.moduloTID(1):
with onLockTransaction(s0), onLockTransaction(s1):
self.assertRaises(ConnectionClosed, t.commit)
self.tic()
t.begin()
# The following line checks that s1 moved the transaction
# metadata to final place during the verification phase.
# If it didn't, a NEOStorageError would be raised.
self.assertEqual(3, len(c.db().history(r._p_oid, 4)))
y = r[1]
self.assertEqual(y.value, 0)
assert [u64(o._p_oid) for o in (r, x, y)] == range(3)
self.assertEqual([u64(o._p_oid) for o in (r, x, y)], range(3))
r[2] = 'ok'
with cluster.master.filterConnection(s0) as m2s, \
cluster.moduloTID(1):
with cluster.master.filterConnection(s0) as m2s:
m2s.add(delayUnlockInformation)
t.commit()
x.value = 1
......@@ -433,12 +445,15 @@ class Test(NEOThreadedTest):
# never lock the transaction (packets from master delayed),
# so the last transaction will be dropped.
y.value = 2
with onStoreTransaction(s1, die=True):
di0 = s0.getDataLockInfo()
with onLockTransaction(s1, die=True):
self.assertRaises(ConnectionClosed, t.commit)
finally:
cluster.stop()
cluster.reset()
di0 = s0.getDataLockInfo()
(k, v), = set(s0.getDataLockInfo().iteritems()
).difference(di0.iteritems())
self.assertEqual(v, 1)
k, = (k for k, v in di0.iteritems() if v == 1)
di0[k] = 0 # r[2] = 'ok'
self.assertEqual(di0.values(), [0, 0, 0, 0, 0])
......@@ -458,6 +473,29 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def testDropUnfinishedData(self):
def lock(orig, *args, **kw):
orig(*args, **kw)
storage.master_conn.close()
r = []
def dropUnfinishedData(orig):
r.append(len(orig.__self__.getUnfinishedTIDDict()))
orig()
r.append(len(orig.__self__.getUnfinishedTIDDict()))
cluster = NEOCluster(partitions=2, storage_count=2, replicas=1)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()._p_changed = 1
storage = cluster.storage_list[0]
with Patch(storage.tm, lock=lock), \
Patch(storage.dm, dropUnfinishedData=dropUnfinishedData):
t.commit()
self.tic()
self.assertEqual(r, [1, 0])
finally:
cluster.stop()
def testStorageReconnectDuringStore(self):
cluster = NEOCluster(replicas=1)
try:
......@@ -908,5 +946,27 @@ class Test(NEOThreadedTest):
cluster.stop()
del cluster.startCluster
def testAbortVotedTransaction(self):
r = []
def tpc_finish(*args, **kw):
for storage in cluster.storage_list:
r.append(len(storage.dm.getUnfinishedTIDDict()))
raise NEOStorageError
cluster = NEOCluster(storage_count=2, partitions=2)
try:
cluster.start()
t, c = cluster.getTransaction()
c.root()['x'] = PCounter()
with Patch(cluster.client, tpc_finish=tpc_finish):
self.assertRaises(NEOStorageError, t.commit)
self.tic()
self.assertEqual(r, [1, 1])
for storage in cluster.storage_list:
self.assertFalse(storage.dm.getUnfinishedTIDDict())
t.begin()
self.assertNotIn('x', c.root())
finally:
cluster.stop()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment