Commit e97427e9 authored by Vincent Pelletier's avatar Vincent Pelletier

Transactions must not be bound to the thread which created them.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2643 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent c5a1efea
......@@ -15,10 +15,9 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from thread import get_ident
from cPickle import dumps, loads
from zlib import compress as real_compress, decompress
from neo.lib.locking import Queue, Empty
from neo.lib.locking import Empty
from random import shuffle
import time
import os
......@@ -49,6 +48,7 @@ from neo.lib.util import u64, parseMasterList
from neo.lib.profiling import profiler_decorator, PROFILING_ENABLED
from neo.lib.live_debug import register as registerLiveDebugger
from neo.client.mq_index import RevisionIndex
from neo.client.container import ThreadContainer, TransactionContainer
if PROFILING_ENABLED:
# Those functions require a "real" python function wrapper before they can
......@@ -65,60 +65,6 @@ else:
compress = real_compress
makeChecksum = real_makeChecksum
class ThreadContext(object):
def __init__(self):
super(ThreadContext, self).__setattr__('_threads_dict', {})
def __getThreadData(self):
thread_id = get_ident()
try:
result = self._threads_dict[thread_id]
except KeyError:
self.clear(thread_id)
result = self._threads_dict[thread_id]
return result
def __getattr__(self, name):
thread_data = self.__getThreadData()
try:
return thread_data[name]
except KeyError:
raise AttributeError, name
def __setattr__(self, name, value):
thread_data = self.__getThreadData()
thread_data[name] = value
def clear(self, thread_id=None):
if thread_id is None:
thread_id = get_ident()
thread_dict = self._threads_dict.get(thread_id)
if thread_dict is None:
queue = Queue(0)
else:
queue = thread_dict['queue']
self._threads_dict[thread_id] = {
'tid': None,
'txn': None,
'data_dict': {},
'data_list': [],
'object_base_serial_dict': {},
'object_serial_dict': {},
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {},
'txn_voted': False,
'queue': queue,
'txn_info': 0,
'history': None,
'node_tids': {},
'asked_object': 0,
'undo_object_tid_dict': {},
'involved_nodes': set(),
'last_transaction': None,
}
class Application(object):
"""The client node application."""
......@@ -158,7 +104,8 @@ class Application(object):
self.primary_bootstrap_handler = master.PrimaryBootstrapHandler(self)
self.notifications_handler = master.PrimaryNotificationsHandler( self)
# Internal attribute distinct between thread
self.local_var = ThreadContext()
self._thread_container = ThreadContainer()
self._txn_container = TransactionContainer()
# Lock definition :
# _load_lock is used to make loading and storing atomic
lock = Lock()
......@@ -185,6 +132,15 @@ class Application(object):
self.compress = compress
registerLiveDebugger(on_log=self.log)
def getHandlerData(self):
return self._thread_container.get()['answer']
def setHandlerData(self, data):
self._thread_container.get()['answer'] = data
def _getThreadQueue(self):
return self._thread_container.get()['queue']
def log(self):
self.em.log()
self.nm.log()
......@@ -222,7 +178,7 @@ class Application(object):
conn.unlock()
@profiler_decorator
def _waitAnyMessage(self, block=True):
def _waitAnyMessage(self, queue, block=True):
"""
Handle all pending packets.
block
......@@ -230,7 +186,6 @@ class Application(object):
received.
"""
pending = self.dispatcher.pending
queue = self.local_var.queue
get = queue.get
_handlePacket = self._handlePacket
while pending(queue):
......@@ -247,39 +202,53 @@ class Application(object):
except ConnectionClosed:
pass
def _waitAnyTransactionMessage(self, txn_context, block=True):
"""
Just like _waitAnyMessage, but for per-transaction exchanges, rather
than per-thread.
"""
queue = txn_context['queue']
self.setHandlerData(txn_context)
try:
self._waitAnyMessage(queue, block=block)
finally:
# Don't leave access to thread context, even if a raise happens.
self.setHandlerData(None)
@profiler_decorator
def _waitMessage(self, target_conn, msg_id, handler=None):
"""Wait for a message returned by the dispatcher in queues."""
get = self.local_var.queue.get
def _ask(self, conn, packet, handler=None):
self.setHandlerData(None)
queue = self._getThreadQueue()
msg_id = conn.ask(packet, queue=queue)
get = queue.get
_handlePacket = self._handlePacket
while True:
conn, packet = get(True)
is_forgotten = isinstance(packet, ForgottenPacket)
if target_conn is conn:
qconn, qpacket = get(True)
is_forgotten = isinstance(qpacket, ForgottenPacket)
if conn is qconn:
# check fake packet
if packet is None:
if qpacket is None:
raise ConnectionClosed
if msg_id == packet.getId():
if msg_id == qpacket.getId():
if is_forgotten:
raise ValueError, 'ForgottenPacket for an ' \
'explicitely expected packet.'
_handlePacket(conn, packet, handler=handler)
_handlePacket(qconn, qpacket, handler=handler)
break
if not is_forgotten and packet is not None:
_handlePacket(conn, packet)
if not is_forgotten and qpacket is not None:
_handlePacket(qconn, qpacket)
return self.getHandlerData()
@profiler_decorator
def _askStorage(self, conn, packet):
""" Send a request to a storage node and process its answer """
msg_id = conn.ask(packet, queue=self.local_var.queue)
self._waitMessage(conn, msg_id, self.storage_handler)
return self._ask(conn, packet, handler=self.storage_handler)
@profiler_decorator
def _askPrimary(self, packet):
""" Send a request to the primary master and process its answer """
conn = self._getMasterConnection()
msg_id = conn.ask(packet, queue=self.local_var.queue)
self._waitMessage(conn, msg_id, self.primary_handler)
return self._ask(self._getMasterConnection(), packet,
handler=self.primary_handler)
@profiler_decorator
def _getMasterConnection(self):
......@@ -311,7 +280,6 @@ class Application(object):
neo.lib.logging.debug('connecting to primary master...')
ready = False
nm = self.nm
queue = self.local_var.queue
packet = Packets.AskPrimary()
while not ready:
# Get network connection to primary master
......@@ -346,8 +314,7 @@ class Application(object):
self.trying_master_node)
continue
try:
msg_id = conn.ask(packet, queue=queue)
self._waitMessage(conn, msg_id,
self._ask(conn, packet,
handler=self.primary_bootstrap_handler)
except ConnectionClosed:
continue
......@@ -373,24 +340,19 @@ class Application(object):
looked-up again.
"""
neo.lib.logging.info('Initializing from master')
queue = self.local_var.queue
ask = self._ask
handler = self.primary_bootstrap_handler
# Identify to primary master and request initial data
p = Packets.RequestIdentification(NodeTypes.CLIENT, self.uuid, None,
self.name)
while conn.getUUID() is None:
self._waitMessage(conn, conn.ask(p, queue=queue),
handler=self.primary_bootstrap_handler)
ask(conn, p, handler=handler)
if conn.getUUID() is None:
# Node identification was refused by master, it is considered
# as the primary as long as we are connected to it.
time.sleep(1)
if self.uuid is not None:
msg_id = conn.ask(Packets.AskNodeInformation(), queue=queue)
self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler)
msg_id = conn.ask(Packets.AskPartitionTable(), queue=queue)
self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler)
ask(conn, Packets.AskNodeInformation(), handler=handler)
ask(conn, Packets.AskPartitionTable(), handler=handler)
return self.pt.operational()
def registerDB(self, db, limit):
......@@ -490,32 +452,27 @@ class Application(object):
@profiler_decorator
def _loadFromStorage(self, oid, at_tid, before_tid):
self.local_var.asked_object = 0
data = None
packet = Packets.AskObject(oid, at_tid, before_tid)
for node, conn in self.cp.iterateForObject(oid, readable=True):
try:
self._askStorage(conn, packet)
noid, tid, next_tid, compression, checksum, data \
= self._askStorage(conn, packet)
except ConnectionClosed:
continue
# Check data
noid, tid, next_tid, compression, checksum, data \
= self.local_var.asked_object
if noid != oid:
# Oops, try with next node
neo.lib.logging.error('got wrong oid %s instead of %s from %s',
noid, dump(oid), conn)
self.local_var.asked_object = -1
continue
elif checksum != makeChecksum(data):
if checksum != makeChecksum(data):
# Warning: see TODO file.
# Check checksum.
neo.lib.logging.error('wrong checksum from %s for oid %s',
conn, dump(oid))
self.local_var.asked_object = -1
data = None
continue
break
if self.local_var.asked_object == -1:
raise NEOStorageError('inconsistent data')
if data is None:
# We didn't got any object from all storage node because of
# connection error
raise NEOStorageError('connection failure')
# Uncompress data
if compression:
......@@ -547,30 +504,34 @@ class Application(object):
@profiler_decorator
def tpc_begin(self, transaction, tid=None, status=' '):
"""Begin a new transaction."""
txn_container = self._txn_container
# First get a transaction, only one is allowed at a time
if self.local_var.txn is transaction:
if txn_container.get(transaction) is not None:
# We already begin the same transaction
raise StorageTransactionError('Duplicate tpc_begin calls')
if self.local_var.txn is not None:
raise NeoException, 'local_var is not clean in tpc_begin'
txn_context = txn_container.new(transaction)
# use the given TID or request a new one to the master
self._askPrimary(Packets.AskBeginTransaction(tid))
if self.local_var.tid is None:
answer_ttid = self._askPrimary(Packets.AskBeginTransaction(tid))
if answer_ttid is None:
raise NEOStorageError('tpc_begin failed')
assert tid in (None, self.local_var.tid), (tid, self.local_var.tid)
self.local_var.txn = transaction
assert tid in (None, answer_ttid), (tid, answer_ttid)
txn_context['txn'] = transaction
txn_context['ttid'] = answer_ttid
@profiler_decorator
def store(self, oid, serial, data, version, transaction):
"""Store object."""
if transaction is not self.local_var.txn:
txn_context = self._txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError(self, transaction)
neo.lib.logging.debug(
'storing oid %s serial %s', dump(oid), dump(serial))
self._store(oid, serial, data)
self._store(txn_context, oid, serial, data)
return None
def _store(self, oid, serial, data, data_serial=None, unlock=False):
def _store(self, txn_context, oid, serial, data, data_serial=None,
unlock=False):
ttid = txn_context['ttid']
if data is None:
# This is some undo: either a no-data object (undoing object
# creation) or a back-pointer to an earlier revision (going back to
......@@ -589,33 +550,33 @@ class Application(object):
else:
compression = 1
checksum = makeChecksum(compressed_data)
on_timeout = OnTimeout(self.onStoreTimeout, self.local_var.tid, oid)
on_timeout = OnTimeout(self.onStoreTimeout, ttid, oid)
# Store object in tmp cache
local_var = self.local_var
data_dict = local_var.data_dict
data_dict = txn_context['data_dict']
if oid not in data_dict:
local_var.data_list.append(oid)
txn_context['data_list'].append(oid)
data_dict[oid] = data
# Store data on each node
self.local_var.object_stored_counter_dict[oid] = {}
object_base_serial_dict = local_var.object_base_serial_dict
txn_context['object_stored_counter_dict'][oid] = {}
object_base_serial_dict = txn_context['object_base_serial_dict']
if oid not in object_base_serial_dict:
object_base_serial_dict[oid] = serial
self.local_var.object_serial_dict[oid] = serial
queue = self.local_var.queue
add_involved_nodes = self.local_var.involved_nodes.add
txn_context['object_serial_dict'][oid] = serial
queue = txn_context['queue']
involved_nodes = txn_context['involved_nodes']
add_involved_nodes = involved_nodes.add
packet = Packets.AskStoreObject(oid, serial, compression,
checksum, compressed_data, data_serial, self.local_var.tid, unlock)
checksum, compressed_data, data_serial, ttid, unlock)
for node, conn in self.cp.iterateForObject(oid, writable=True):
try:
conn.ask(packet, on_timeout=on_timeout, queue=queue)
add_involved_nodes(node)
except ConnectionClosed:
continue
if not self.local_var.involved_nodes:
if not involved_nodes:
raise NEOStorageError("Store failed")
self._waitAnyMessage(False)
self._waitAnyTransactionMessage(txn_context, False)
def onStoreTimeout(self, conn, msg_id, ttid, oid):
# NOTE: this method is called from poll thread, don't use
......@@ -628,17 +589,17 @@ class Application(object):
return True
@profiler_decorator
def _handleConflicts(self, tryToResolveConflict):
def _handleConflicts(self, txn_context, tryToResolveConflict):
result = []
append = result.append
local_var = self.local_var
# Check for conflicts
data_dict = local_var.data_dict
object_base_serial_dict = local_var.object_base_serial_dict
object_serial_dict = local_var.object_serial_dict
conflict_serial_dict = local_var.conflict_serial_dict.copy()
local_var.conflict_serial_dict.clear()
resolved_conflict_serial_dict = local_var.resolved_conflict_serial_dict
data_dict = txn_context['data_dict']
object_base_serial_dict = txn_context['object_base_serial_dict']
object_serial_dict = txn_context['object_serial_dict']
conflict_serial_dict = txn_context['conflict_serial_dict'].copy()
txn_context['conflict_serial_dict'].clear()
resolved_conflict_serial_dict = txn_context[
'resolved_conflict_serial_dict']
for oid, conflict_serial_set in conflict_serial_dict.iteritems():
resolved_serial_set = resolved_conflict_serial_dict.setdefault(
oid, set())
......@@ -650,7 +611,6 @@ class Application(object):
continue
serial = object_serial_dict[oid]
data = data_dict[oid]
tid = local_var.tid
resolved = False
if conflict_serial == ZERO_TID:
# Storage refused us from taking object lock, to avoid a
......@@ -665,12 +625,11 @@ class Application(object):
# object data again.
neo.lib.logging.info('Deadlock avoidance triggered on %r:%r',
dump(oid), dump(serial))
for store_oid, store_data in \
local_var.data_dict.iteritems():
for store_oid, store_data in data_dict.iteritems():
store_serial = object_serial_dict[store_oid]
if store_data is None:
self.checkCurrentSerialInTransaction(store_oid,
store_serial)
self._checkCurrentSerialInTransaction(txn_context,
store_oid, store_serial)
else:
if store_data is '':
# Some undo
......@@ -678,8 +637,8 @@ class Application(object):
' reliably work with undo, this must be '
'implemented.')
break
self._store(store_oid, store_serial, store_data,
unlock=True)
self._store(txn_context, store_oid, store_serial,
store_data, unlock=True)
else:
resolved = True
elif data is not None:
......@@ -694,7 +653,7 @@ class Application(object):
# Base serial changes too, as we resolved a conflict
object_base_serial_dict[oid] = conflict_serial
# Try to store again
self._store(oid, conflict_serial, new_data)
self._store(txn_context, oid, conflict_serial, new_data)
append(oid)
resolved = True
else:
......@@ -704,49 +663,51 @@ class Application(object):
if not resolved:
# XXX: Is it really required to remove from data_dict ?
del data_dict[oid]
local_var.data_list.remove(oid)
txn_context['data_list'].remove(oid)
if data is None:
exc = ReadConflictError(oid=oid, serials=(conflict_serial,
serial))
else:
exc = ConflictError(oid=oid, serials=(tid, serial),
data=data)
exc = ConflictError(oid=oid, serials=(txn_context['ttid'],
serial), data=data)
raise exc
return result
@profiler_decorator
def waitResponses(self):
def waitResponses(self, queue, handler_data):
"""Wait for all requests to be answered (or their connection to be
detected as closed)"""
queue = self.local_var.queue
pending = self.dispatcher.pending
_waitAnyMessage = self._waitAnyMessage
self.setHandlerData(handler_data)
while pending(queue):
_waitAnyMessage()
_waitAnyMessage(queue)
@profiler_decorator
def waitStoreResponses(self, tryToResolveConflict):
def waitStoreResponses(self, txn_context, tryToResolveConflict):
result = []
append = result.append
resolved_oid_set = set()
update = resolved_oid_set.update
local_var = self.local_var
tid = local_var.tid
ttid = txn_context['ttid']
_handleConflicts = self._handleConflicts
conflict_serial_dict = local_var.conflict_serial_dict
queue = local_var.queue
queue = txn_context['queue']
conflict_serial_dict = txn_context['conflict_serial_dict']
pending = self.dispatcher.pending
_waitAnyMessage = self._waitAnyMessage
_waitAnyTransactionMessage = self._waitAnyTransactionMessage
while pending(queue) or conflict_serial_dict:
_waitAnyMessage()
# Note: handler data can be overwritten by _handleConflicts
# so we must set it for each iteration.
_waitAnyTransactionMessage(txn_context)
if conflict_serial_dict:
conflicts = _handleConflicts(tryToResolveConflict)
conflicts = _handleConflicts(txn_context,
tryToResolveConflict)
if conflicts:
update(conflicts)
# Check for never-stored objects, and update result for all others
for oid, store_dict in \
local_var.object_stored_counter_dict.iteritems():
txn_context['object_stored_counter_dict'].iteritems():
if not store_dict:
neo.lib.logging.error('tpc_store failed')
raise NEOStorageError('tpc_store failed')
......@@ -757,27 +718,27 @@ class Application(object):
@profiler_decorator
def tpc_vote(self, transaction, tryToResolveConflict):
"""Store current transaction."""
local_var = self.local_var
if transaction is not local_var.txn:
txn_context = self._txn_container.get(transaction)
if txn_context is None or transaction is not txn_context['txn']:
raise StorageTransactionError(self, transaction)
result = self.waitStoreResponses(tryToResolveConflict)
result = self.waitStoreResponses(txn_context, tryToResolveConflict)
tid = local_var.tid
ttid = txn_context['ttid']
# Store data on each node
txn_stored_counter = 0
packet = Packets.AskStoreTransaction(tid, str(transaction.user),
packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension),
local_var.data_list)
add_involved_nodes = self.local_var.involved_nodes.add
for node, conn in self.cp.iterateForObject(tid, writable=True):
neo.lib.logging.debug("voting object %s on %s", dump(tid),
txn_context['data_list'])
add_involved_nodes = txn_context['involved_nodes'].add
for node, conn in self.cp.iterateForObject(ttid, writable=True):
neo.lib.logging.debug("voting object %s on %s", dump(ttid),
dump(conn.getUUID()))
try:
self._askStorage(conn, packet)
add_involved_nodes(node)
except ConnectionClosed:
continue
add_involved_nodes(node)
txn_stored_counter += 1
# check at least one storage node accepted
......@@ -790,20 +751,22 @@ class Application(object):
# tpc_finish.
self._getMasterConnection()
local_var.txn_voted = True
txn_context['txn_voted'] = True
return result
@profiler_decorator
def tpc_abort(self, transaction):
"""Abort current transaction."""
if transaction is not self.local_var.txn:
txn_container = self._txn_container
txn_context = txn_container.get(transaction)
if txn_context is None:
return
tid = self.local_var.tid
p = Packets.AbortTransaction(tid)
ttid = txn_context['ttid']
p = Packets.AbortTransaction(ttid)
getConnForNode = self.cp.getConnForNode
# cancel transaction one all those nodes
for node in self.local_var.involved_nodes:
for node in txn_context['involved_nodes']:
conn = getConnForNode(node)
if conn is None:
continue
......@@ -815,28 +778,30 @@ class Application(object):
'storage node %r of abortion, ignoring.',
conn, exc_info=1)
self._getMasterConnection().notify(p)
queue = self.local_var.queue
self.dispatcher.forget_queue(queue)
self.local_var.clear()
queue = txn_context['queue']
# We don't need to flush queue, as it won't be reused by future
# transactions (deleted on next line & indexed by transaction object
# instance).
self.dispatcher.forget_queue(queue, flush_queue=False)
txn_container.delete(transaction)
@profiler_decorator
def tpc_finish(self, transaction, tryToResolveConflict, f=None):
"""Finish current transaction."""
local_var = self.local_var
if local_var.txn is not transaction:
txn_container = self._txn_container
txn_context = txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError('tpc_finish called for wrong '
'transaction')
if not local_var.txn_voted:
if not txn_context['txn_voted']:
self.tpc_vote(transaction, tryToResolveConflict)
self._load_lock_acquire()
try:
# Call finish on master
oid_list = local_var.data_list
p = Packets.AskFinishTransaction(local_var.tid, oid_list)
self._askPrimary(p)
oid_list = txn_context['data_list']
p = Packets.AskFinishTransaction(txn_context['ttid'], oid_list)
tid = self._askPrimary(p)
# From now on, self.local_var.tid holds the "real" TID.
tid = local_var.tid
# Call function given by ZODB
if f is not None:
f(tid)
......@@ -851,8 +816,8 @@ class Application(object):
assert next_tid is None, (dump(oid), dump(base_tid),
dump(next_tid))
return (data, tid)
get_baseTID = local_var.object_base_serial_dict.get
for oid, data in local_var.data_dict.iteritems():
get_baseTID = txn_context['object_base_serial_dict'].get
for oid, data in txn_context['data_dict'].iteritems():
if data is None:
# this is just a remain of
# checkCurrentSerialInTransaction call, ignore (no data
......@@ -871,13 +836,14 @@ class Application(object):
mq_cache[(oid, tid)] = (data, None)
finally:
self._cache_lock_release()
local_var.clear()
txn_container.delete(transaction)
return tid
finally:
self._load_lock_release()
def undo(self, snapshot_tid, undone_tid, txn, tryToResolveConflict):
if txn is not self.local_var.txn:
txn_context = self._txn_container.get(txn)
if txn_context is None:
raise StorageTransactionError(self, undone_tid)
txn_info, txn_ext = self._getTransactionInformation(undone_tid)
......@@ -899,22 +865,23 @@ class Application(object):
getCellList = pt.getCellList
getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell
queue = self.local_var.queue
undo_object_tid_dict = self.local_var.undo_object_tid_dict = {}
queue = self._getThreadQueue()
ttid = txn_context['ttid']
for partition, oid_list in partition_oid_dict.iteritems():
cell_list = getCellList(partition, readable=True)
shuffle(cell_list)
cell_list.sort(key=getCellSortKey)
storage_conn = getConnForCell(cell_list[0])
storage_conn.ask(Packets.AskObjectUndoSerial(self.local_var.tid,
storage_conn.ask(Packets.AskObjectUndoSerial(ttid,
snapshot_tid, undone_tid, oid_list), queue=queue)
# Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError,
# meaning that objects in transaction's oid_list do not exist any
# longer. This is the symptom of a pack, so forbid undoing transaction
# when it happens.
undo_object_tid_dict = {}
try:
self.waitResponses()
self.waitResponses(queue, undo_object_tid_dict)
except NEOStorageNotFoundError:
self.dispatcher.forget_queue(queue)
raise UndoError('non-undoable transaction')
......@@ -929,9 +896,11 @@ class Application(object):
# object. This is an undo conflict, try to resolve it.
try:
# Load the latest version we are supposed to see
data = self.load(snapshot_tid, oid, serial=current_serial)[0]
data = self.load(snapshot_tid, oid,
serial=current_serial)[0]
# Load the version we were undoing to
undo_data = self.load(snapshot_tid, oid, serial=undo_serial)[0]
undo_data = self.load(snapshot_tid, oid,
serial=undo_serial)[0]
except NEOStorageNotFoundError:
raise UndoError('Object not found while resolving undo '
'conflict')
......@@ -945,7 +914,7 @@ class Application(object):
raise UndoError('Some data were modified by a later ' \
'transaction', oid)
undo_serial = None
self._store(oid, current_serial, data, undo_serial)
self._store(txn_context, oid, current_serial, data, undo_serial)
def _insertMetadata(self, txn_info, extension):
for k, v in loads(extension).items():
......@@ -955,7 +924,7 @@ class Application(object):
packet = Packets.AskTransactionInformation(tid)
for node, conn in self.cp.iterateForObject(tid, readable=True):
try:
self._askStorage(conn, packet)
txn_info, txn_ext = self._askStorage(conn, packet)
except ConnectionClosed:
continue
except NEOStorageNotFoundError:
......@@ -964,7 +933,7 @@ class Application(object):
break
else:
raise NEOStorageError('Transaction %r not found' % (tid, ))
return (self.local_var.txn_info, self.local_var.txn_ext)
return (txn_info, txn_ext)
def undoLog(self, first, last, filter=None, block=0):
# XXX: undoLog is broken
......@@ -978,8 +947,7 @@ class Application(object):
pt = self.getPartitionTable()
storage_node_list = pt.getNodeList()
self.local_var.node_tids = {}
queue = self.local_var.queue
queue = self._getThreadQueue()
packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
for storage_node in storage_node_list:
conn = self.cp.getConnForNode(storage_node)
......@@ -988,15 +956,11 @@ class Application(object):
conn.ask(packet, queue=queue)
# Wait for answers from all storages.
self.waitResponses()
tid_set = set()
self.waitResponses(queue, tid_set)
# Reorder tids
ordered_tids = set()
update = ordered_tids.update
for tid_list in self.local_var.node_tids.itervalues():
update(tid_list)
ordered_tids = list(ordered_tids)
ordered_tids.sort(reverse=True)
ordered_tids = sorted(tid_set, reverse=True)
neo.lib.logging.debug(
"UndoLog tids %s", [dump(x) for x in ordered_tids])
# For each transaction, get info
......@@ -1004,11 +968,10 @@ class Application(object):
append = undo_info.append
for tid in ordered_tids:
(txn_info, txn_ext) = self._getTransactionInformation(tid)
if filter is None or filter(self.local_var.txn_info):
txn_info = self.local_var.txn_info
if filter is None or filter(txn_info):
txn_info.pop('packed')
txn_info.pop("oids")
self._insertMetadata(txn_info, self.local_var.txn_ext)
self._insertMetadata(txn_info, txn_ext)
append(txn_info)
if len(undo_info) >= last - first:
break
......@@ -1024,9 +987,8 @@ class Application(object):
node_list = node_map.keys()
node_list.sort(key=self.cp.getCellSortKey)
partition_set = set(range(self.pt.getPartitions()))
queue = self.local_var.queue
queue = self._getThreadQueue()
# request a tid list for each partition
self.local_var.tids_from = set()
for node in node_list:
conn = self.cp.getConnForNode(node)
request_set = set(node_map[node]) & partition_set
......@@ -1038,40 +1000,34 @@ class Application(object):
if not partition_set:
break
assert not partition_set
self.waitResponses()
tid_set = set()
self.waitResponses(queue, tid_set)
# request transactions informations
txn_list = []
append = txn_list.append
tid = None
for tid in sorted(self.local_var.tids_from):
for tid in sorted(tid_set):
(txn_info, txn_ext) = self._getTransactionInformation(tid)
txn_info['ext'] = loads(self.local_var.txn_ext)
txn_info['ext'] = loads(txn_ext)
append(txn_info)
return (tid, txn_list)
def history(self, oid, version=None, size=1, filter=None):
queue = self._getThreadQueue()
# Get history informations for object first
packet = Packets.AskObjectHistory(oid, 0, size)
for node, conn in self.cp.iterateForObject(oid, readable=True):
# FIXME: we keep overwriting self.local_var.history here, we
# should aggregate it instead.
self.local_var.history = None
try:
self._askStorage(conn, packet)
conn.ask(packet, queue=queue)
except ConnectionClosed:
continue
if self.local_var.history[0] != oid:
# Got history for wrong oid
raise NEOStorageError('inconsistency in storage: asked oid ' \
'%r, got %r' % (oid, self.local_var.history[0]))
if not isinstance(self.local_var.history, tuple):
raise NEOStorageError('history failed')
history_dict = {}
self.waitResponses(queue, history_dict)
# Now that we have object informations, get txn informations
history_list = []
for serial, size in self.local_var.history[1]:
append = history_list.append
for serial in sorted(history_dict.keys(), reverse=True):
size = history_dict[serial]
txn_info, txn_ext = self._getTransactionInformation(serial)
# create history dict
txn_info.pop('id')
......@@ -1081,9 +1037,8 @@ class Application(object):
txn_info['version'] = ''
txn_info['size'] = size
if filter is None or filter(txn_info):
history_list.append(txn_info)
append(txn_info)
self._insertMetadata(txn_info, txn_ext)
return history_list
@profiler_decorator
......@@ -1111,16 +1066,15 @@ class Application(object):
return Iterator(self, start, stop)
def lastTransaction(self):
self._askPrimary(Packets.AskLastTransaction())
return self.local_var.last_transaction
return self._askPrimary(Packets.AskLastTransaction())
def abortVersion(self, src, transaction):
if transaction is not self.local_var.txn:
if self._txn_container.get(transaction) is None:
raise StorageTransactionError(self, transaction)
return '', []
def commitVersion(self, src, dest, transaction):
if transaction is not self.local_var.txn:
if self._txn_container.get(transaction) is None:
raise StorageTransactionError(self, transaction)
return '', []
......@@ -1141,12 +1095,6 @@ class Application(object):
def invalidationBarrier(self):
self._askPrimary(Packets.AskBarrier())
def setTID(self, value):
self.local_var.tid = value
def getTID(self):
return self.local_var.tid
def pack(self, t):
tid = repr(TimeStamp(*time.gmtime(t)[:5] + (t % 60, )))
if tid == ZERO_TID:
......@@ -1166,24 +1114,27 @@ class Application(object):
return self.load(None, oid)[1]
def checkCurrentSerialInTransaction(self, oid, serial, transaction):
local_var = self.local_var
if transaction is not local_var.txn:
txn_context = self._txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError(self, transaction)
local_var.object_serial_dict[oid] = serial
self._checkCurrentSerialInTransaction(txn_context, oid, serial)
def _checkCurrentSerialInTransaction(self, txn_context, oid, serial):
ttid = txn_context['ttid']
txn_context['object_serial_dict'][oid] = serial
# Placeholders
queue = local_var.queue
local_var.object_stored_counter_dict[oid] = {}
data_dict = local_var.data_dict
queue = txn_context['queue']
txn_context['object_stored_counter_dict'][oid] = {}
data_dict = txn_context['data_dict']
if oid not in data_dict:
# Marker value so we don't try to resolve conflicts.
data_dict[oid] = None
local_var.data_list.append(oid)
packet = Packets.AskCheckCurrentSerial(local_var.tid, serial, oid)
txn_context['data_list'].append(oid)
packet = Packets.AskCheckCurrentSerial(ttid, serial, oid)
for node, conn in self.cp.iterateForObject(oid, writable=True):
try:
conn.ask(packet, queue=queue)
except ConnectionClosed:
continue
self._waitAnyMessage(False)
self._waitAnyTransactionMessage(txn_context, False)
#
# Copyright (C) 2011 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from thread import get_ident
from neo.lib.locking import Queue
class ContainerBase(object):
def __init__(self):
self._context_dict = {}
def _getID(self, *args, **kw):
raise NotImplementedError
def _new(self, *args, **kw):
raise NotImplementedError
def delete(self, *args, **kw):
del self._context_dict[self._getID(*args, **kw)]
def get(self, *args, **kw):
return self._context_dict.get(self._getID(*args, **kw))
def new(self, *args, **kw):
result = self._context_dict[self._getID(*args, **kw)] = self._new(
*args, **kw)
return result
class ThreadContainer(ContainerBase):
def _getID(self):
return get_ident()
def _new(self):
return {
'queue': Queue(0),
'answer': None,
}
def get(self):
"""
Implicitely create a thread context if it doesn't exist.
"""
my_id = self._getID()
try:
result = self._context_dict[my_id]
except KeyError:
result = self._context_dict[my_id] = self._new()
return result
class TransactionContainer(ContainerBase):
def _getID(self, txn):
return id(txn)
def _new(self, txn):
return {
'queue': Queue(0),
'txn': txn,
'ttid': None,
'data_dict': {},
'data_list': [],
'object_base_serial_dict': {},
'object_serial_dict': {},
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {},
'txn_voted': False,
'involved_nodes': set(),
}
......@@ -156,21 +156,19 @@ class PrimaryNotificationsHandler(BaseHandler):
class PrimaryAnswersHandler(AnswerBaseHandler):
""" Handle that process expected packets from the primary master """
def answerBeginTransaction(self, conn, tid):
self.app.setTID(tid)
def answerBeginTransaction(self, conn, ttid):
self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list):
self.app.new_oid_list = oid_list
def answerTransactionFinished(self, conn, ttid, tid):
if ttid != self.app.getTID():
raise NEOStorageError('Wrong TID, transaction not started')
self.app.setTID(tid)
def answerTransactionFinished(self, conn, _, tid):
self.app.setHandlerData(tid)
def answerPack(self, conn, status):
if not status:
raise NEOStorageError('Already packing')
def answerLastTransaction(self, conn, ltid):
self.app.local_var.last_transaction = ltid
self.app.setHandlerData(ltid)
......@@ -68,23 +68,25 @@ class StorageAnswersHandler(AnswerBaseHandler):
if data_serial is not None:
raise NEOStorageError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), )
self.app.local_var.asked_object = (oid, start_serial, end_serial,
compression, checksum, data)
self.app.setHandlerData((oid, start_serial, end_serial,
compression, checksum, data))
def answerStoreObject(self, conn, conflicting, oid, serial):
local_var = self.app.local_var
object_stored_counter_dict = local_var.object_stored_counter_dict[oid]
txn_context = self.app.getHandlerData()
object_stored_counter_dict = txn_context[
'object_stored_counter_dict'][oid]
if conflicting:
neo.lib.logging.info('%r report a conflict for %r with %r', conn,
dump(oid), dump(serial))
conflict_serial_dict = local_var.conflict_serial_dict
conflict_serial_dict = txn_context['conflict_serial_dict']
if serial in object_stored_counter_dict:
raise NEOStorageError, 'A storage accepted object for ' \
'serial %s but another reports a conflict for it.' % (
dump(serial), )
# If this conflict is not already resolved, mark it for
# resolution.
if serial not in local_var.resolved_conflict_serial_dict.get(oid, ()):
if serial not in txn_context[
'resolved_conflict_serial_dict'].get(oid, ()):
conflict_serial_dict.setdefault(oid, set()).add(serial)
else:
object_stored_counter_dict[serial] = \
......@@ -92,31 +94,29 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, tid):
if tid != self.app.getTID():
raise NEOStorageError('Wrong TID, transaction not started')
def answerStoreTransaction(self, conn, _):
pass
def answerTIDsFrom(self, conn, tid_list):
neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
assert not self.app.local_var.tids_from.intersection(set(tid_list))
self.app.local_var.tids_from.update(tid_list)
tids_from = self.app.getHandlerData()
assert not tids_from.intersection(set(tid_list))
tids_from.update(tid_list)
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
# transaction information are returned as a dict
info = {}
info['time'] = TimeStamp(tid).timeTime()
info['user_name'] = user
info['description'] = desc
info['id'] = tid
info['oids'] = oid_list
info['packed'] = packed
self.app.local_var.txn_ext = ext
self.app.local_var.txn_info = info
def answerObjectHistory(self, conn, oid, history_list):
self.app.setHandlerData(({
'time': TimeStamp(tid).timeTime(),
'user_name': user,
'description': desc,
'id': tid,
'oids': oid_list,
'packed': packed,
}, ext))
def answerObjectHistory(self, conn, _, history_list):
# history_list is a list of tuple (serial, size)
self.app.local_var.history = oid, history_list
self.app.getHandlerData().update(history_list)
def oidNotFound(self, conn, message):
# This can happen either when :
......@@ -132,10 +132,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
raise NEOStorageNotFoundError(message)
def answerTIDs(self, conn, tid_list):
self.app.local_var.node_tids[conn.getUUID()] = tid_list
self.app.getHandlerData().update(tid_list)
def answerObjectUndoSerial(self, conn, object_tid_dict):
self.app.local_var.undo_object_tid_dict.update(object_tid_dict)
self.app.getHandlerData().update(object_tid_dict)
def answerHasLock(self, conn, oid, status):
if status == LockState.GRANTED_TO_OTHER:
......
......@@ -66,9 +66,7 @@ class ConnectionPool(object):
p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name)
try:
msg_id = conn.ask(p, queue=app.local_var.queue)
app._waitMessage(conn, msg_id,
handler=app.storage_bootstrap_handler)
app._ask(conn, p, handler=app.storage_bootstrap_handler)
except ConnectionClosed:
neo.lib.logging.error('Connection to %r failed', node)
self.notifyFailure(node)
......
......@@ -15,7 +15,6 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import new
import unittest
from cPickle import dumps
from mock import Mock, ReturnValues
......@@ -24,7 +23,8 @@ from neo.tests import NeoUnitTestBase
from neo.client.app import Application, RevisionIndex
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError
from neo.lib.protocol import Packet, Packets, Errors, INVALID_TID
from neo.lib.protocol import Packet, Packets, Errors, INVALID_TID, \
INVALID_PARTITION
from neo.lib.util import makeChecksum
import time
......@@ -45,11 +45,14 @@ def getPartitionTable(self):
self.master_conn = _getMasterConnection(self)
return self.pt
def _waitMessage(self, conn, msg_id, handler=None):
def _ask(self, conn, packet, handler=None):
conn.ask(packet)
self.setHandlerData(None)
if handler is None:
raise NotImplementedError
else:
handler.dispatch(conn, conn.fakeReceived())
return self.getHandlerData()
def resolving_tryToResolveConflict(oid, conflict_serial, serial, data):
return data
......@@ -63,10 +66,10 @@ class ClientApplicationTests(NeoUnitTestBase):
NeoUnitTestBase.setUp(self)
# apply monkey patches
self._getMasterConnection = Application._getMasterConnection
self._waitMessage = Application._waitMessage
self._ask = Application._ask
self.getPartitionTable = Application.getPartitionTable
Application._getMasterConnection = _getMasterConnection
Application._waitMessage = _waitMessage
Application._ask = _ask
Application.getPartitionTable = getPartitionTable
self._to_stop_list = []
......@@ -76,12 +79,19 @@ class ClientApplicationTests(NeoUnitTestBase):
app.close()
# restore environnement
Application._getMasterConnection = self._getMasterConnection
Application._waitMessage = self._waitMessage
Application._ask = self._ask
Application.getPartitionTable = self.getPartitionTable
NeoUnitTestBase.tearDown(self)
# some helpers
def _begin(self, app, txn, tid=None):
txn_context = app._txn_container.new(txn)
if tid is None:
tid = self.makeTID()
txn_context['ttid'] = tid
return txn_context
def checkAskPacket(self, conn, packet_type, decode=False):
calls = conn.mockGetNamedCalls('ask')
self.assertEquals(len(calls), 1)
......@@ -154,13 +164,6 @@ class ClientApplicationTests(NeoUnitTestBase):
#self.assertEquals(calls[0].getParam(0), conn)
#self.assertTrue(isinstance(calls[0].getParam(2), Queue))
def test_getQueue(self):
app = self.getApp()
# Test sanity check
self.assertTrue(getattr(app, 'local_var', None) is not None)
# Test that queue is created
self.assertTrue(getattr(app.local_var, 'queue', None) is not None)
def test_registerDB(self):
app = self.getApp()
dummy_db = []
......@@ -186,7 +189,6 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_load(self):
app = self.getApp()
app.local_var.barrier_done = True
mq = app.mq_cache
oid = self.makeOID()
tid1 = self.makeTID(1)
......@@ -235,7 +237,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet,
})
app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = an_object[:-1]
answer_barrier = Packets.AnswerBarrier()
answer_barrier.setId(1)
app.master_conn = Mock({
......@@ -257,7 +258,6 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_loadSerial(self):
app = self.getApp()
app.local_var.barrier_done = True
mq = app.mq_cache
oid = self.makeOID()
tid1 = self.makeTID(1)
......@@ -292,7 +292,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet,
})
app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = another_object[:-1]
result = loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT')
self.checkAskObject(conn)
......@@ -300,7 +299,6 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_loadBefore(self):
app = self.getApp()
app.local_var.barrier_done = True
mq = app.mq_cache
oid = self.makeOID()
tid1 = self.makeTID(1)
......@@ -332,7 +330,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet,
})
app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = an_object[:-1]
self.assertRaises(NEOStorageError, loadBefore, oid, tid1)
# object should not have been cached
self.assertFalse((oid, tid1) in mq)
......@@ -349,7 +346,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet,
})
app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = another_object
result = loadBefore(oid, tid3)
self.assertEquals(result, ('RIGHT', tid2, tid3))
self.checkAskObject(conn)
......@@ -369,16 +365,17 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet,
})
app.tpc_begin(transaction=txn, tid=tid)
self.assertTrue(app.local_var.txn is txn)
self.assertEquals(app.local_var.tid, tid)
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEquals(txn_context['ttid'], tid)
# next, the transaction already begin -> raise
self.assertRaises(StorageTransactionError, app.tpc_begin,
transaction=txn, tid=None)
self.assertTrue(app.local_var.txn is txn)
self.assertEquals(app.local_var.tid, tid)
# cancel and start a transaction without tid
app.local_var.txn = None
app.local_var.tid = None
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEquals(txn_context['ttid'], tid)
# start a transaction without tid
txn = Mock()
# no connection -> NEOStorageError (wait until connected to primary)
#self.assertRaises(NEOStorageError, app.tpc_begin, transaction=txn, tid=None)
# ask a tid to pmn
......@@ -392,8 +389,9 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkAskNewTid(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn)
# check attributes
self.assertTrue(app.local_var.txn is txn)
self.assertEquals(app.local_var.tid, tid)
txn_context = app._txn_container.get(txn)
self.assertTrue(txn_context['txn'] is txn)
self.assertEquals(txn_context['ttid'], tid)
def test_store1(self):
app = self.getApp()
......@@ -401,14 +399,10 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID()
txn = self.makeTransactionObject()
# invalid transaction > StorageTransactionError
app.local_var.txn = old_txn = object()
self.assertTrue(app.local_var.txn is not txn)
self.assertRaises(StorageTransactionError, app.store, oid, tid, '',
None, txn)
self.assertEquals(app.local_var.txn, old_txn)
# check partition_id and an empty cell list -> NEOStorageError
app.local_var.txn = txn
app.local_var.tid = tid
self._begin(app, txn, self.makeTID())
app.pt = Mock({ 'getCellListForOID': (), })
app.num_partitions = 2
self.assertRaises(NEOStorageError, app.store, oid, tid, '', None,
......@@ -423,8 +417,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID()
txn = self.makeTransactionObject()
# build conflicting state
app.local_var.txn = txn
app.local_var.tid = tid
txn_context = self._begin(app, txn, tid)
packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid)
packet.setId(0)
storage_address = ('127.0.0.1', 10020)
......@@ -436,14 +429,15 @@ class ClientApplicationTests(NeoUnitTestBase):
return not queue.empty()
app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address)
app.local_var.data_dict[oid] = 'BEFORE'
app.local_var.data_list.append(oid)
data_dict = txn_context['data_dict']
data_dict[oid] = 'BEFORE'
txn_context['data_list'].append(oid)
app.store(oid, tid, '', None, txn)
app.local_var.queue.put((conn, packet))
self.assertRaises(ConflictError, app.waitStoreResponses,
txn_context['queue'].put((conn, packet))
self.assertRaises(ConflictError, app.waitStoreResponses, txn_context,
failing_tryToResolveConflict)
self.assertTrue(oid not in app.local_var.data_dict)
self.assertEquals(app.local_var.object_stored_counter_dict[oid], {})
self.assertTrue(oid not in data_dict)
self.assertEquals(txn_context['object_stored_counter_dict'][oid], {})
self.checkAskStoreObject(conn)
def test_store3(self):
......@@ -452,8 +446,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID()
txn = self.makeTransactionObject()
# case with no conflict
app.local_var.txn = txn
app.local_var.tid = tid
txn_context = self._begin(app, txn, tid)
packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid)
packet.setId(0)
storage_address = ('127.0.0.1', 10020)
......@@ -467,47 +460,25 @@ class ClientApplicationTests(NeoUnitTestBase):
app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn)
self.checkAskStoreObject(conn)
app.local_var.queue.put((conn, packet))
app.waitStoreResponses(resolving_tryToResolveConflict)
self.assertEquals(app.local_var.object_stored_counter_dict[oid], {tid: 1})
self.assertEquals(app.local_var.data_dict.get(oid, None), 'DATA')
self.assertFalse(oid in app.local_var.conflict_serial_dict)
txn_context['queue'].put((conn, packet))
app.waitStoreResponses(txn_context, resolving_tryToResolveConflict)
self.assertEquals(txn_context['object_stored_counter_dict'][oid],
{tid: 1})
self.assertEquals(txn_context['data_dict'].get(oid, None), 'DATA')
self.assertFalse(oid in txn_context['conflict_serial_dict'])
def test_tpc_vote1(self):
app = self.getApp()
oid = self.makeOID(11)
txn = self.makeTransactionObject()
# invalid transaction > StorageTransactionError
app.local_var.txn = old_txn = object()
self.assertTrue(app.local_var.txn is not txn)
self.assertRaises(StorageTransactionError, app.tpc_vote, txn,
resolving_tryToResolveConflict)
self.assertEquals(app.local_var.txn, old_txn)
def test_tpc_vote2(self):
# fake transaction object
app = self.getApp()
app.local_var.txn = self.makeTransactionObject()
app.local_var.tid = self.makeTID()
# wrong answer -> failure
packet = Packets.AnswerStoreTransaction(INVALID_TID)
packet.setId(0)
conn = Mock({
'getNextId': 1,
'fakeReceived': packet,
'getAddress': ('127.0.0.1', 0),
})
app.cp = self.getConnectionPool([(Mock(), conn)])
self.assertRaises(NEOStorageError, app.tpc_vote, app.local_var.txn,
resolving_tryToResolveConflict)
self.checkAskPacket(conn, Packets.AskStoreTransaction)
def test_tpc_vote3(self):
app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject()
app.local_var.txn = txn
app.local_var.tid = tid
self._begin(app, txn, tid)
# response -> OK
packet = Packets.AnswerStoreTransaction(tid=tid)
packet.setId(0)
......@@ -529,10 +500,9 @@ class ClientApplicationTests(NeoUnitTestBase):
app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject()
app.local_var.txn = old_txn = object()
old_txn = object()
self._begin(app, old_txn, tid)
app.master_conn = Mock()
app.local_var.tid = tid
self.assertFalse(app.local_var.txn is txn)
conn = Mock()
cell = Mock()
app.pt = Mock({'getCellListForTID': (cell, cell)})
......@@ -541,8 +511,9 @@ class ClientApplicationTests(NeoUnitTestBase):
# no packet sent
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn)
self.assertEquals(app.local_var.txn, old_txn)
self.assertEquals(app.local_var.tid, tid)
txn_context = app._txn_container.get(old_txn)
self.assertTrue(txn_context['txn'] is old_txn)
self.assertEquals(txn_context['ttid'], tid)
def test_tpc_abort2(self):
# 2 nodes : 1 transaction in the first, 2 objects in the second
......@@ -552,7 +523,7 @@ class ClientApplicationTests(NeoUnitTestBase):
oid1, oid2 = self.makeOID(2), self.makeOID(4) # on partition 0
app, tid = self.getApp(), self.makeTID(1) # on partition 1
txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, tid
txn_context = self._begin(app, txn, tid)
app.master_conn = Mock({'__hash__': 0})
app.num_partitions = 2
cell1 = Mock({ 'getNode': 'NODE1', '__hash__': 1 })
......@@ -560,16 +531,13 @@ class ClientApplicationTests(NeoUnitTestBase):
conn1, conn2 = Mock({ 'getNextId': 1, }), Mock({ 'getNextId': 2, })
app.cp = Mock({ 'getConnForNode': ReturnValues(conn1, conn2), })
# fake data
app.local_var.data_dict = {oid1: '', oid2: ''}
app.local_var.involved_nodes = set([cell1, cell2])
txn_context['involved_nodes'].update([cell1, cell2])
app.tpc_abort(txn)
# will check if there was just one call/packet :
self.checkNotifyPacket(conn1, Packets.AbortTransaction)
self.checkNotifyPacket(conn2, Packets.AbortTransaction)
self.assertEquals(app.local_var.tid, None)
self.assertEquals(app.local_var.txn, None)
self.assertEquals(app.local_var.data_dict, {})
self.assertEquals(app.local_var.txn_voted, False)
self.checkNotifyPacket(app.master_conn, Packets.AbortTransaction)
self.assertEqual(app._txn_container.get(txn), None)
def test_tpc_abort3(self):
""" check that abort is sent to all nodes involved in the transaction """
......@@ -617,19 +585,20 @@ class ClientApplicationTests(NeoUnitTestBase):
})
app.master_conn = Mock({'__hash__': 0})
txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, tid
txn_context = self._begin(app, txn, tid)
class Dispatcher(object):
def pending(self, queue):
return not queue.empty()
def forget_queue(self, queue):
def forget_queue(self, queue, flush_queue=True):
pass
app.dispatcher = Dispatcher()
# conflict occurs on storage 2
app.store(oid1, tid, 'DATA', None, txn)
app.store(oid2, tid, 'DATA', None, txn)
app.local_var.queue.put((conn2, packet2))
app.local_var.queue.put((conn3, packet3))
queue = txn_context['queue']
queue.put((conn2, packet2))
queue.put((conn3, packet3))
# vote fails as the conflict is not resolved, nothing is sent to storage 3
self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict)
# abort must be sent to storage 1 and 2
......@@ -640,59 +609,11 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_tpc_finish1(self):
# transaction mismatch: raise
app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject()
app.local_var.txn = old_txn = object()
app.master_conn = Mock()
self.assertFalse(app.local_var.txn is txn)
conn = Mock()
self.assertRaises(StorageTransactionError, app.tpc_finish, txn, None)
# no packet sent
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn)
self.assertEquals(app.local_var.txn, old_txn)
def test_tpc_finish2(self):
# bad answer -> NEOStorageError
app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, tid
# test callable passed to tpc_finish
self.f_called = False
self.f_called_with_tid = None
def hook(tid):
self.f_called = True
self.f_called_with_tid = tid
packet = Packets.AnswerTransactionFinished(INVALID_TID, INVALID_TID)
packet.setId(0)
app.master_conn = Mock({
'getNextId': 1,
'getAddress': ('127.0.0.1', 10000),
'fakeReceived': packet,
})
self.vote_params = None
tpc_vote = app.tpc_vote
def voteDetector(transaction, tryToResolveConflict):
self.vote_params = (transaction, tryToResolveConflict)
dummy_tryToResolveConflict = []
app.tpc_vote = voteDetector
app.local_var.txn_voted = True
self.assertRaises(NEOStorageError, app.tpc_finish, txn,
dummy_tryToResolveConflict, hook)
self.assertFalse(self.f_called)
self.assertEqual(self.vote_params, None)
self.checkAskFinishTransaction(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn)
# Call again, but this time transaction is not voted yet
app.local_var.txn_voted = False
self.f_called = False
self.assertRaises(NEOStorageError, app.tpc_finish, txn,
dummy_tryToResolveConflict, hook)
self.assertFalse(self.f_called)
self.assertTrue(self.vote_params[0] is txn)
self.assertTrue(self.vote_params[1] is dummy_tryToResolveConflict)
app.tpc_vote = tpc_vote
def test_tpc_finish3(self):
# transaction is finished
......@@ -700,7 +621,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID()
ttid = self.makeTID()
txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, ttid
txn_context = self._begin(app, txn, tid)
self.f_called = False
self.f_called_with_tid = None
def hook(tid):
......@@ -713,17 +634,13 @@ class ClientApplicationTests(NeoUnitTestBase):
'getAddress': ('127.0.0.1', 10010),
'fakeReceived': packet,
})
app.local_var.txn_voted = True
app.local_var.txn_finished = True
txn_context['txn_voted'] = True
app.tpc_finish(txn, None, hook)
self.assertTrue(self.f_called)
self.assertEquals(self.f_called_with_tid, tid)
self.checkAskFinishTransaction(app.master_conn)
#self.checkDispatcherRegisterCalled(app, app.master_conn)
self.assertEquals(app.local_var.tid, None)
self.assertEquals(app.local_var.txn, None)
self.assertEquals(app.local_var.data_dict, {})
self.assertEquals(app.local_var.txn_voted, False)
self.assertEqual(app._txn_container.get(txn), None)
def test_undo1(self):
# invalid transaction
......@@ -731,21 +648,15 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID()
snapshot_tid = self.getNextTID()
txn = self.makeTransactionObject()
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data):
marker.append(1)
app.local_var.txn = old_txn = object()
pass
app.master_conn = Mock()
self.assertFalse(app.local_var.txn is txn)
conn = Mock()
self.assertRaises(StorageTransactionError, app.undo, snapshot_tid, tid,
txn, tryToResolveConflict)
# no packet sent
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn)
# nothing done
self.assertEquals(marker, [])
self.assertEquals(app.local_var.txn, old_txn)
def _getAppForUndoTests(self, oid0, tid0, tid1, tid2):
app = self.getApp()
......@@ -780,10 +691,10 @@ class ClientApplicationTests(NeoUnitTestBase):
return ({tid0: 'dummy', tid2: 'cdummy'}[serial], None, None)
app.load = load
store_marker = []
def _store(oid, serial, data, data_serial=None):
def _store(txn_context, oid, serial, data, data_serial=None,
unlock=False):
store_marker.append((oid, serial, data, data_serial))
app._store = _store
app.local_var.clear()
return app, conn, store_marker
def test_undoWithResolutionSuccess(self):
......@@ -805,7 +716,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)})
undo_serial.setId(2)
app.local_var.queue.put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
......@@ -846,7 +757,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
app.local_var.queue.put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial))
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
......@@ -872,7 +783,7 @@ class ClientApplicationTests(NeoUnitTestBase):
marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError
# The undo
app.local_var.queue.put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial))
self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict)
# Checking what happened
......@@ -905,7 +816,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2)
app.local_var.queue.put((conn, undo_serial))
app._getThreadQueue().put((conn, undo_serial))
def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''):
raise Exception, 'Test called conflict resolution, but there ' \
......@@ -929,15 +840,19 @@ class ClientApplicationTests(NeoUnitTestBase):
cell1, cell2 = Mock({}), Mock({})
tid1, tid2 = self.makeTID(1), self.makeTID(2)
oid1, oid2 = self.makeOID(1), self.makeOID(2)
# TIDs packets supplied by _waitMessage hook
# TIDs packets supplied by _ask hook
# TXN info packets
extension = dumps({})
p1 = Packets.AnswerTIDs([tid1])
p2 = Packets.AnswerTIDs([tid2])
p3 = Packets.AnswerTransactionInformation(tid1, '', '',
extension, False, (oid1, ))
p4 = Packets.AnswerTransactionInformation(tid2, '', '',
extension, False, (oid2, ))
p3.setId(0)
p4.setId(1)
p1.setId(0)
p2.setId(1)
p3.setId(2)
p4.setId(3)
conn = Mock({
'getNextId': 1,
'getUUID': ReturnValues(uuid1, uuid2),
......@@ -945,17 +860,36 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': ReturnValues(p3, p4),
'getAddress': ('127.0.0.1', 10010),
})
storage_1_conn = Mock()
storage_2_conn = Mock()
app.pt = Mock({
'getNodeList': (node1, node2, ),
'getCellListForTID': ReturnValues([cell1], [cell2]),
})
app.cp = self.getConnectionPool([(Mock(), conn)])
def waitResponses(self):
self.local_var.node_tids = {uuid1: (tid1, ), uuid2: (tid2, )}
app.waitResponses = new.instancemethod(waitResponses, app, Application)
app.cp = Mock({
'getConnForNode': ReturnValues(storage_1_conn, storage_2_conn),
'iterateForObject': [(Mock(), conn)]
})
def waitResponses(queue, handler_data):
app.setHandlerData(handler_data)
for p in (p1, p2):
app._handlePacket(Mock(), p, handler=app.storage_handler)
app.waitResponses = waitResponses
def txn_filter(info):
return info['id'] > '\x00' * 8
result = app.undoLog(0, 4, filter=txn_filter)
first = 0
last = 4
result = app.undoLog(first, last, filter=txn_filter)
pfirst, plast, ppartition = self.checkAskPacket(storage_1_conn,
Packets.AskTIDs, decode=True)
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
pfirst, plast, ppartition = self.checkAskPacket(storage_2_conn,
Packets.AskTIDs, decode=True)
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
self.assertEquals(result[0]['id'], tid1)
self.assertEquals(result[1]['id'], tid2)
......@@ -968,9 +902,9 @@ class ClientApplicationTests(NeoUnitTestBase):
p2 = Packets.AnswerObjectHistory(oid, object_history)
extension = dumps({'k': 'v'})
# transaction history
p3 = Packets.AnswerTransactionInformation(tid1, 'u', 'd',
p3 = Packets.AnswerTransactionInformation(tid2, 'u', 'd',
extension, False, (oid, ))
p4 = Packets.AnswerTransactionInformation(tid2, 'u', 'd',
p4 = Packets.AnswerTransactionInformation(tid1, 'u', 'd',
extension, False, (oid, ))
p2.setId(0)
p3.setId(1)
......@@ -979,7 +913,7 @@ class ClientApplicationTests(NeoUnitTestBase):
conn = Mock({
'getNextId': 1,
'fakeGetApp': app,
'fakeReceived': ReturnValues(p2, p3, p4),
'fakeReceived': ReturnValues(p3, p4),
'getAddress': ('127.0.0.1', 10010),
})
object_cells = [ Mock({}), ]
......@@ -988,12 +922,18 @@ class ClientApplicationTests(NeoUnitTestBase):
'getCellListForOID': object_cells,
'getCellListForTID': ReturnValues(history_cells, history_cells),
})
app.cp = self.getConnectionPool([(Mock(), conn)])
app.cp = Mock({
'iterateForObject': [(Mock(), conn)],
})
def waitResponses(queue, handler_data):
app.setHandlerData(handler_data)
app._handlePacket(Mock(), p2, handler=app.storage_handler)
app.waitResponses = waitResponses
# start test here
result = app.history(oid)
self.assertEquals(len(result), 2)
self.assertEquals(result[0]['tid'], tid1)
self.assertEquals(result[1]['tid'], tid2)
self.assertEquals(result[0]['tid'], tid2)
self.assertEquals(result[1]['tid'], tid1)
self.assertEquals(result[0]['size'], 42)
self.assertEquals(result[1]['size'], 42)
......@@ -1010,43 +950,41 @@ class ClientApplicationTests(NeoUnitTestBase):
# TODO: test more connection failure cases
# Seventh packet : askNodeInformation succeeded
all_passed = []
def _waitMessage8(conn, msg_id, handler=None):
def _ask8(_):
all_passed.append(1)
# Sixth packet : askPartitionTable succeeded
def _waitMessage7(conn, msg_id, handler=None):
def _ask7(_):
app.pt = Mock({'operational': True})
app._waitMessage = _waitMessage8
# fifth packet : request node identification succeeded
def _waitMessage6(conn, msg_id, handler=None):
def _ask6(conn):
conn.setUUID('D' * 16)
app.uuid = 'C' * 16
app._waitMessage = _waitMessage7
# fourth iteration : connection to primary master succeeded
def _waitMessage5(conn, msg_id, handler=None):
def _ask5(_):
app.trying_master_node = app.primary_master_node = Mock({
'getAddress': ('192.168.1.1', 10000),
'__str__': 'Fake master node',
})
app._waitMessage = _waitMessage6
# third iteration : node not ready
def _waitMessage4(conn, msg_id, handler=None):
def _ask4(_):
app.trying_master_node = None
app._waitMessage = _waitMessage5
# second iteration : master node changed
def _waitMessage3(conn, msg_id, handler=None):
def _ask3(_):
app.primary_master_node = Mock({
'getAddress': ('192.168.1.1', 10000),
'__str__': 'Fake master node',
})
app._waitMessage = _waitMessage4
# first iteration : connection failed
def _waitMessage2(conn, msg_id, handler=None):
def _ask2(_):
app.trying_master_node = None
app._waitMessage = _waitMessage3
# do nothing for the first call
def _waitMessage1(conn, msg_id, handler=None):
app._waitMessage = _waitMessage2
app._waitMessage = _waitMessage1
def _ask1(_):
pass
ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask5, _ask6, _ask7,
_ask8]
def _ask_base(conn, _, handler=None):
ask_func_list.pop(0)(conn)
app._ask = _ask_base
# faked environnement
app.connector_handler = DoNothingConnector
app.em = Mock({'getConnectionList': []})
......@@ -1056,23 +994,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational())
def test_askStorage(self):
""" _askStorage is private but test it anyway """
app = self.getApp('')
conn = Mock()
self.test_ok = False
def _waitMessage_hook(app, conn, msg_id, handler=None):
self.test_ok = True
packet = Packets.AskBeginTransaction()
packet.setId(0)
app._waitMessage = _waitMessage_hook
app._askStorage(conn, packet)
# check packet sent, connection unlocked and dispatcher updated
self.checkAskNewTid(conn)
self.checkDispatcherRegisterCalled(app, conn)
# and _waitMessage called
self.assertTrue(self.test_ok)
def test_askPrimary(self):
""" _askPrimary is private but test it anyway """
app = self.getApp('')
......@@ -1080,21 +1001,22 @@ class ClientApplicationTests(NeoUnitTestBase):
app.master_conn = conn
app.primary_handler = Mock()
self.test_ok = False
def _waitMessage_hook(app, conn, msg_id, handler=None):
def _ask_hook(app, conn, packet, handler=None):
conn.ask(packet)
self.assertTrue(handler is app.primary_handler)
self.test_ok = True
_waitMessage_old = Application._waitMessage
Application._waitMessage = _waitMessage_hook
_ask_old = Application._ask
Application._ask = _ask_hook
packet = Packets.AskBeginTransaction()
packet.setId(0)
try:
app._askPrimary(packet)
finally:
Application._waitMessage = _waitMessage_old
Application._ask = _ask_old
# check packet sent, connection locked during process and dispatcher updated
self.checkAskNewTid(conn)
self.checkDispatcherRegisterCalled(app, conn)
# and _waitMessage called
# and _ask called
self.assertTrue(self.test_ok)
# check NEOStorageError is raised when the primary connection is lost
app.master_conn = None
......@@ -1105,15 +1027,16 @@ class ClientApplicationTests(NeoUnitTestBase):
""" Thread context properties must not be visible accross instances
while remaining in the same thread """
app1 = self.getApp()
app1_local = app1.local_var
app1_local = app1._thread_container.get()
app2 = self.getApp()
app2_local = app2.local_var
app2_local = app2._thread_container.get()
property_id = 'thread_context_test'
self.assertFalse(hasattr(app1_local, property_id))
self.assertFalse(hasattr(app2_local, property_id))
setattr(app1_local, property_id, 'value')
self.assertTrue(hasattr(app1_local, property_id))
self.assertFalse(hasattr(app2_local, property_id))
value = 'value'
self.assertRaises(KeyError, app1_local.__getitem__, property_id)
self.assertRaises(KeyError, app2_local.__getitem__, property_id)
app1_local[property_id] = value
self.assertEqual(app1_local[property_id], value)
self.assertRaises(KeyError, app2_local.__getitem__, property_id)
def test_pack(self):
app = self.getApp()
......
......@@ -235,7 +235,7 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
tid = self.getNextTID()
conn = self.getConnection()
self.handler.answerBeginTransaction(conn, tid)
calls = self.app.mockGetNamedCalls('setTID')
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
......@@ -247,18 +247,12 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
def test_answerTransactionFinished(self):
conn = self.getConnection()
ttid1 = self.getNextTID()
ttid2 = self.getNextTID(ttid1)
tid2 = self.getNextTID(ttid2)
# wrong TID
self.app = Mock({'getTID': ttid1})
self.assertRaises(NEOStorageError,
self.handler.answerTransactionFinished,
conn, ttid2, tid2)
# matching TID
app = Mock({'getTID': ttid2})
handler = PrimaryAnswersHandler(app=app)
handler.answerTransactionFinished(conn, ttid2, tid2)
ttid2 = self.getNextTID()
tid2 = self.getNextTID()
self.handler.answerTransactionFinished(conn, ttid2, tid2)
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid2)
def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
......
......@@ -25,6 +25,7 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError
from ZODB.POSException import ConflictError
from neo.lib.exception import NodeNotReady
from ZODB.TimeStamp import TimeStamp
MARKER = []
......@@ -69,45 +70,57 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.app = Mock()
self.app.local_var = Mock()
self.handler = StorageAnswersHandler(self.app)
def getConnection(self):
return self.getFakeConnection()
def _checkHandlerData(self, ref):
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ref)
def test_answerObject(self):
conn = self.getConnection()
oid = self.getOID(0)
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.app.local_var.asked_object = None
self.handler.answerObject(conn, *the_object)
self.assertEqual(self.app.local_var.asked_object, the_object[:-1])
self._checkHandlerData(the_object[:-1])
# Check handler raises on non-None data_serial.
the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID())
self.app.local_var.asked_object = None
self.assertRaises(NEOStorageError, self.handler.answerObject, conn,
*the_object)
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict):
app = Mock({
'getHandlerData': {
'object_stored_counter_dict': object_stored_counter_dict,
'conflict_serial_dict': conflict_serial_dict,
'resolved_conflict_serial_dict': resolved_conflict_serial_dict,
}
})
return StorageAnswersHandler(app)
def test_answerStoreObject_1(self):
conn = self.getConnection()
oid = self.getOID(0)
tid = self.getNextTID()
# conflict
local_var = self.app.local_var
local_var.object_stored_counter_dict = {oid: {}}
local_var.conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {}
self.handler.answerStoreObject(conn, 1, oid, tid)
self.assertEqual(local_var.conflict_serial_dict[oid], set([tid, ]))
self.assertEqual(local_var.object_stored_counter_dict[oid], {})
self.assertFalse(oid in local_var.resolved_conflict_serial_dict)
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(conflict_serial_dict[oid], set([tid, ]))
self.assertEqual(object_stored_counter_dict[oid], {})
self.assertFalse(oid in resolved_conflict_serial_dict)
# object was already accepted by another storage, raise
local_var.object_stored_counter_dict = {oid: {tid: 1}}
local_var.conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {}
self.assertRaises(NEOStorageError, self.handler.answerStoreObject,
handler = self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {}, {})
self.assertRaises(NEOStorageError, handler.answerStoreObject,
conn, 1, oid, tid)
def test_answerStoreObject_2(self):
......@@ -116,25 +129,23 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid = self.getNextTID()
tid_2 = self.getNextTID()
# resolution-pending conflict
local_var = self.app.local_var
local_var.object_stored_counter_dict = {oid: {}}
local_var.conflict_serial_dict = {oid: set([tid, ])}
local_var.resolved_conflict_serial_dict = {}
self.handler.answerStoreObject(conn, 1, oid, tid)
self.assertEqual(local_var.conflict_serial_dict[oid], set([tid, ]))
self.assertFalse(oid in local_var.resolved_conflict_serial_dict)
self.assertEqual(local_var.object_stored_counter_dict[oid], {})
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {oid: set([tid, ])}
resolved_conflict_serial_dict = {}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(conflict_serial_dict[oid], set([tid, ]))
self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {})
# object was already accepted by another storage, raise
local_var.object_stored_counter_dict = {oid: {tid: 1}}
local_var.conflict_serial_dict = {oid: set([tid, ])}
local_var.resolved_conflict_serial_dict = {}
self.assertRaises(NEOStorageError, self.handler.answerStoreObject,
handler = self._getAnswerStoreObjectHandler({oid: {tid: 1}},
{oid: set([tid, ])}, {})
self.assertRaises(NEOStorageError, handler.answerStoreObject,
conn, 1, oid, tid)
# detected conflict is different, don't raise
local_var.object_stored_counter_dict = {oid: {}}
local_var.conflict_serial_dict = {oid: set([tid, ])}
local_var.resolved_conflict_serial_dict = {}
self.handler.answerStoreObject(conn, 1, oid, tid_2)
self._getAnswerStoreObjectHandler({oid: {}}, {oid: set([tid, ])}, {},
).answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_3(self):
conn = self.getConnection()
......@@ -145,49 +156,34 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
# This case happens if a storage is answering a store action for which
# any other storage already answered (with same conflict) and any other
# storage accepted the resolved object.
local_var = self.app.local_var
local_var.object_stored_counter_dict = {oid: {tid_2: 1}}
local_var.conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {oid: set([tid, ])}
self.handler.answerStoreObject(conn, 1, oid, tid)
self.assertFalse(oid in local_var.conflict_serial_dict)
self.assertEqual(local_var.resolved_conflict_serial_dict[oid],
object_stored_counter_dict = {oid: {tid_2: 1}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {oid: set([tid, ])}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 1, oid, tid)
self.assertFalse(oid in conflict_serial_dict)
self.assertEqual(resolved_conflict_serial_dict[oid],
set([tid, ]))
self.assertEqual(local_var.object_stored_counter_dict[oid], {tid_2: 1})
self.assertEqual(object_stored_counter_dict[oid], {tid_2: 1})
# detected conflict is different, don't raise
local_var.object_stored_counter_dict = {oid: {tid: 1}}
local_var.conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {oid: set([tid, ])}
self.handler.answerStoreObject(conn, 1, oid, tid_2)
self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {},
{oid: set([tid, ])}).answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_4(self):
conn = self.getConnection()
oid = self.getOID(0)
tid = self.getNextTID()
# no conflict
local_var = self.app.local_var
local_var.object_stored_counter_dict = {oid: {}}
local_var.conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {}
self.handler.answerStoreObject(conn, 0, oid, tid)
self.assertFalse(oid in local_var.conflict_serial_dict)
self.assertFalse(oid in local_var.resolved_conflict_serial_dict)
self.assertEqual(local_var.object_stored_counter_dict[oid], {tid: 1})
def test_answerStoreTransaction(self):
conn = self.getConnection()
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
# wrong tid
app = Mock({'getTID': tid1})
handler = StorageAnswersHandler(app=app)
self.assertRaises(NEOStorageError,
handler.answerStoreTransaction, conn,
tid2)
# good tid
app = Mock({'getTID': tid2})
handler = StorageAnswersHandler(app=app)
handler.answerStoreTransaction(conn, tid2)
object_stored_counter_dict = {oid: {}}
conflict_serial_dict = {}
resolved_conflict_serial_dict = {}
self._getAnswerStoreObjectHandler(object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict,
).answerStoreObject(conn, 0, oid, tid)
self.assertFalse(oid in conflict_serial_dict)
self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {tid: 1})
def test_answerTransactionInformation(self):
conn = self.getConnection()
......@@ -195,24 +191,30 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
user = 'USER'
desc = 'DESC'
ext = 'EXT'
packed = False
oid_list = [self.getOID(0), self.getOID(1)]
self.app.local_var.txn_info = None
self.handler.answerTransactionInformation(conn, tid, user, desc, ext,
False, oid_list)
txn_info = self.app.local_var.txn_info
self.assertTrue(isinstance(txn_info, dict))
self.assertEqual(txn_info['user_name'], user)
self.assertEqual(txn_info['description'], desc)
self.assertEqual(txn_info['id'], tid)
self.assertEqual(txn_info['oids'], oid_list)
packed, oid_list)
self._checkHandlerData(({
'time': TimeStamp(tid).timeTime(),
'user_name': user,
'description': desc,
'id': tid,
'oids': oid_list,
'packed': packed,
}, ext))
def test_answerObjectHistory(self):
conn = self.getConnection()
oid = self.getOID(0)
history_list = []
self.app.local_var.history = None
self.handler.answerObjectHistory(conn, oid, history_list)
self.assertEqual(self.app.local_var.history, (oid, history_list))
history_list = [self.getNextTID(), self.getNextTID()]
history_set = set()
app = Mock({
'getHandlerData': history_set,
})
handler = StorageAnswersHandler(app)
handler.answerObjectHistory(conn, oid, history_list)
self.assertEqual(history_set, set(history_list))
def test_oidNotFound(self):
conn = self.getConnection()
......@@ -235,10 +237,14 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid2 = self.getNextTID(tid1)
tid_list = [tid1, tid2]
conn = self.getFakeConnection(uuid=uuid)
self.app.local_var.node_tids = {}
self.handler.answerTIDs(conn, tid_list)
self.assertTrue(uuid in self.app.local_var.node_tids)
self.assertEqual(self.app.local_var.node_tids[uuid], tid_list)
tid_set = set()
app = Mock({
'getHandlerData': tid_set,
})
handler = StorageAnswersHandler(app)
handler.answerTIDs(conn, tid_list)
self.assertEqual(tid_set, set(tid_list))
def test_answerObjectUndoSerial(self):
uuid = self.getNewUUID()
......@@ -249,12 +255,14 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid1 = self.getNextTID()
tid2 = self.getNextTID()
tid3 = self.getNextTID()
self.app.local_var.undo_object_tid_dict = undo_dict = {
oid1: [tid0, tid1],
}
self.handler.answerObjectUndoSerial(conn, {
oid2: [tid2, tid3],
undo_dict = {}
app = Mock({
'getHandlerData': undo_dict,
})
handler = StorageAnswersHandler(app)
handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]})
self.assertEqual(undo_dict, {oid1: [tid0, tid1]})
handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]})
self.assertEqual(undo_dict, {
oid1: [tid0, tid1],
oid2: [tid2, tid3],
......
......@@ -22,9 +22,7 @@ from ZODB.tests.StorageTestBase import StorageTestBase
from neo.tests.zodb import ZODBTestCase
class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage):
def check_tid_ordering_w_commit(self):
self.fail("Test disabled")
pass
if __name__ == "__main__":
suite = unittest.makeSuite(BasicTests, 'check')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment