Commit e97427e9 authored by Vincent Pelletier's avatar Vincent Pelletier

Transactions must not be bound to the thread which created them.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2643 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent c5a1efea
...@@ -15,10 +15,9 @@ ...@@ -15,10 +15,9 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from thread import get_ident
from cPickle import dumps, loads from cPickle import dumps, loads
from zlib import compress as real_compress, decompress from zlib import compress as real_compress, decompress
from neo.lib.locking import Queue, Empty from neo.lib.locking import Empty
from random import shuffle from random import shuffle
import time import time
import os import os
...@@ -49,6 +48,7 @@ from neo.lib.util import u64, parseMasterList ...@@ -49,6 +48,7 @@ from neo.lib.util import u64, parseMasterList
from neo.lib.profiling import profiler_decorator, PROFILING_ENABLED from neo.lib.profiling import profiler_decorator, PROFILING_ENABLED
from neo.lib.live_debug import register as registerLiveDebugger from neo.lib.live_debug import register as registerLiveDebugger
from neo.client.mq_index import RevisionIndex from neo.client.mq_index import RevisionIndex
from neo.client.container import ThreadContainer, TransactionContainer
if PROFILING_ENABLED: if PROFILING_ENABLED:
# Those functions require a "real" python function wrapper before they can # Those functions require a "real" python function wrapper before they can
...@@ -65,60 +65,6 @@ else: ...@@ -65,60 +65,6 @@ else:
compress = real_compress compress = real_compress
makeChecksum = real_makeChecksum makeChecksum = real_makeChecksum
class ThreadContext(object):
def __init__(self):
super(ThreadContext, self).__setattr__('_threads_dict', {})
def __getThreadData(self):
thread_id = get_ident()
try:
result = self._threads_dict[thread_id]
except KeyError:
self.clear(thread_id)
result = self._threads_dict[thread_id]
return result
def __getattr__(self, name):
thread_data = self.__getThreadData()
try:
return thread_data[name]
except KeyError:
raise AttributeError, name
def __setattr__(self, name, value):
thread_data = self.__getThreadData()
thread_data[name] = value
def clear(self, thread_id=None):
if thread_id is None:
thread_id = get_ident()
thread_dict = self._threads_dict.get(thread_id)
if thread_dict is None:
queue = Queue(0)
else:
queue = thread_dict['queue']
self._threads_dict[thread_id] = {
'tid': None,
'txn': None,
'data_dict': {},
'data_list': [],
'object_base_serial_dict': {},
'object_serial_dict': {},
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {},
'txn_voted': False,
'queue': queue,
'txn_info': 0,
'history': None,
'node_tids': {},
'asked_object': 0,
'undo_object_tid_dict': {},
'involved_nodes': set(),
'last_transaction': None,
}
class Application(object): class Application(object):
"""The client node application.""" """The client node application."""
...@@ -158,7 +104,8 @@ class Application(object): ...@@ -158,7 +104,8 @@ class Application(object):
self.primary_bootstrap_handler = master.PrimaryBootstrapHandler(self) self.primary_bootstrap_handler = master.PrimaryBootstrapHandler(self)
self.notifications_handler = master.PrimaryNotificationsHandler( self) self.notifications_handler = master.PrimaryNotificationsHandler( self)
# Internal attribute distinct between thread # Internal attribute distinct between thread
self.local_var = ThreadContext() self._thread_container = ThreadContainer()
self._txn_container = TransactionContainer()
# Lock definition : # Lock definition :
# _load_lock is used to make loading and storing atomic # _load_lock is used to make loading and storing atomic
lock = Lock() lock = Lock()
...@@ -185,6 +132,15 @@ class Application(object): ...@@ -185,6 +132,15 @@ class Application(object):
self.compress = compress self.compress = compress
registerLiveDebugger(on_log=self.log) registerLiveDebugger(on_log=self.log)
def getHandlerData(self):
return self._thread_container.get()['answer']
def setHandlerData(self, data):
self._thread_container.get()['answer'] = data
def _getThreadQueue(self):
return self._thread_container.get()['queue']
def log(self): def log(self):
self.em.log() self.em.log()
self.nm.log() self.nm.log()
...@@ -222,7 +178,7 @@ class Application(object): ...@@ -222,7 +178,7 @@ class Application(object):
conn.unlock() conn.unlock()
@profiler_decorator @profiler_decorator
def _waitAnyMessage(self, block=True): def _waitAnyMessage(self, queue, block=True):
""" """
Handle all pending packets. Handle all pending packets.
block block
...@@ -230,7 +186,6 @@ class Application(object): ...@@ -230,7 +186,6 @@ class Application(object):
received. received.
""" """
pending = self.dispatcher.pending pending = self.dispatcher.pending
queue = self.local_var.queue
get = queue.get get = queue.get
_handlePacket = self._handlePacket _handlePacket = self._handlePacket
while pending(queue): while pending(queue):
...@@ -247,39 +202,53 @@ class Application(object): ...@@ -247,39 +202,53 @@ class Application(object):
except ConnectionClosed: except ConnectionClosed:
pass pass
def _waitAnyTransactionMessage(self, txn_context, block=True):
"""
Just like _waitAnyMessage, but for per-transaction exchanges, rather
than per-thread.
"""
queue = txn_context['queue']
self.setHandlerData(txn_context)
try:
self._waitAnyMessage(queue, block=block)
finally:
# Don't leave access to thread context, even if a raise happens.
self.setHandlerData(None)
@profiler_decorator @profiler_decorator
def _waitMessage(self, target_conn, msg_id, handler=None): def _ask(self, conn, packet, handler=None):
"""Wait for a message returned by the dispatcher in queues.""" self.setHandlerData(None)
get = self.local_var.queue.get queue = self._getThreadQueue()
msg_id = conn.ask(packet, queue=queue)
get = queue.get
_handlePacket = self._handlePacket _handlePacket = self._handlePacket
while True: while True:
conn, packet = get(True) qconn, qpacket = get(True)
is_forgotten = isinstance(packet, ForgottenPacket) is_forgotten = isinstance(qpacket, ForgottenPacket)
if target_conn is conn: if conn is qconn:
# check fake packet # check fake packet
if packet is None: if qpacket is None:
raise ConnectionClosed raise ConnectionClosed
if msg_id == packet.getId(): if msg_id == qpacket.getId():
if is_forgotten: if is_forgotten:
raise ValueError, 'ForgottenPacket for an ' \ raise ValueError, 'ForgottenPacket for an ' \
'explicitely expected packet.' 'explicitely expected packet.'
_handlePacket(conn, packet, handler=handler) _handlePacket(qconn, qpacket, handler=handler)
break break
if not is_forgotten and packet is not None: if not is_forgotten and qpacket is not None:
_handlePacket(conn, packet) _handlePacket(qconn, qpacket)
return self.getHandlerData()
@profiler_decorator @profiler_decorator
def _askStorage(self, conn, packet): def _askStorage(self, conn, packet):
""" Send a request to a storage node and process its answer """ """ Send a request to a storage node and process its answer """
msg_id = conn.ask(packet, queue=self.local_var.queue) return self._ask(conn, packet, handler=self.storage_handler)
self._waitMessage(conn, msg_id, self.storage_handler)
@profiler_decorator @profiler_decorator
def _askPrimary(self, packet): def _askPrimary(self, packet):
""" Send a request to the primary master and process its answer """ """ Send a request to the primary master and process its answer """
conn = self._getMasterConnection() return self._ask(self._getMasterConnection(), packet,
msg_id = conn.ask(packet, queue=self.local_var.queue) handler=self.primary_handler)
self._waitMessage(conn, msg_id, self.primary_handler)
@profiler_decorator @profiler_decorator
def _getMasterConnection(self): def _getMasterConnection(self):
...@@ -311,7 +280,6 @@ class Application(object): ...@@ -311,7 +280,6 @@ class Application(object):
neo.lib.logging.debug('connecting to primary master...') neo.lib.logging.debug('connecting to primary master...')
ready = False ready = False
nm = self.nm nm = self.nm
queue = self.local_var.queue
packet = Packets.AskPrimary() packet = Packets.AskPrimary()
while not ready: while not ready:
# Get network connection to primary master # Get network connection to primary master
...@@ -346,9 +314,8 @@ class Application(object): ...@@ -346,9 +314,8 @@ class Application(object):
self.trying_master_node) self.trying_master_node)
continue continue
try: try:
msg_id = conn.ask(packet, queue=queue) self._ask(conn, packet,
self._waitMessage(conn, msg_id, handler=self.primary_bootstrap_handler)
handler=self.primary_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
continue continue
# If we reached the primary master node, mark as connected # If we reached the primary master node, mark as connected
...@@ -373,24 +340,19 @@ class Application(object): ...@@ -373,24 +340,19 @@ class Application(object):
looked-up again. looked-up again.
""" """
neo.lib.logging.info('Initializing from master') neo.lib.logging.info('Initializing from master')
queue = self.local_var.queue ask = self._ask
handler = self.primary_bootstrap_handler
# Identify to primary master and request initial data # Identify to primary master and request initial data
p = Packets.RequestIdentification(NodeTypes.CLIENT, self.uuid, None, p = Packets.RequestIdentification(NodeTypes.CLIENT, self.uuid, None,
self.name) self.name)
while conn.getUUID() is None: while conn.getUUID() is None:
self._waitMessage(conn, conn.ask(p, queue=queue), ask(conn, p, handler=handler)
handler=self.primary_bootstrap_handler)
if conn.getUUID() is None: if conn.getUUID() is None:
# Node identification was refused by master, it is considered # Node identification was refused by master, it is considered
# as the primary as long as we are connected to it. # as the primary as long as we are connected to it.
time.sleep(1) time.sleep(1)
if self.uuid is not None: ask(conn, Packets.AskNodeInformation(), handler=handler)
msg_id = conn.ask(Packets.AskNodeInformation(), queue=queue) ask(conn, Packets.AskPartitionTable(), handler=handler)
self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler)
msg_id = conn.ask(Packets.AskPartitionTable(), queue=queue)
self._waitMessage(conn, msg_id,
handler=self.primary_bootstrap_handler)
return self.pt.operational() return self.pt.operational()
def registerDB(self, db, limit): def registerDB(self, db, limit):
...@@ -490,32 +452,27 @@ class Application(object): ...@@ -490,32 +452,27 @@ class Application(object):
@profiler_decorator @profiler_decorator
def _loadFromStorage(self, oid, at_tid, before_tid): def _loadFromStorage(self, oid, at_tid, before_tid):
self.local_var.asked_object = 0 data = None
packet = Packets.AskObject(oid, at_tid, before_tid) packet = Packets.AskObject(oid, at_tid, before_tid)
for node, conn in self.cp.iterateForObject(oid, readable=True): for node, conn in self.cp.iterateForObject(oid, readable=True):
try: try:
self._askStorage(conn, packet) noid, tid, next_tid, compression, checksum, data \
= self._askStorage(conn, packet)
except ConnectionClosed: except ConnectionClosed:
continue continue
# Check data if checksum != makeChecksum(data):
noid, tid, next_tid, compression, checksum, data \ # Warning: see TODO file.
= self.local_var.asked_object
if noid != oid:
# Oops, try with next node
neo.lib.logging.error('got wrong oid %s instead of %s from %s',
noid, dump(oid), conn)
self.local_var.asked_object = -1
continue
elif checksum != makeChecksum(data):
# Check checksum. # Check checksum.
neo.lib.logging.error('wrong checksum from %s for oid %s', neo.lib.logging.error('wrong checksum from %s for oid %s',
conn, dump(oid)) conn, dump(oid))
self.local_var.asked_object = -1 data = None
continue continue
break break
if self.local_var.asked_object == -1: if data is None:
raise NEOStorageError('inconsistent data') # We didn't got any object from all storage node because of
# connection error
raise NEOStorageError('connection failure')
# Uncompress data # Uncompress data
if compression: if compression:
...@@ -547,30 +504,34 @@ class Application(object): ...@@ -547,30 +504,34 @@ class Application(object):
@profiler_decorator @profiler_decorator
def tpc_begin(self, transaction, tid=None, status=' '): def tpc_begin(self, transaction, tid=None, status=' '):
"""Begin a new transaction.""" """Begin a new transaction."""
txn_container = self._txn_container
# First get a transaction, only one is allowed at a time # First get a transaction, only one is allowed at a time
if self.local_var.txn is transaction: if txn_container.get(transaction) is not None:
# We already begin the same transaction # We already begin the same transaction
raise StorageTransactionError('Duplicate tpc_begin calls') raise StorageTransactionError('Duplicate tpc_begin calls')
if self.local_var.txn is not None: txn_context = txn_container.new(transaction)
raise NeoException, 'local_var is not clean in tpc_begin'
# use the given TID or request a new one to the master # use the given TID or request a new one to the master
self._askPrimary(Packets.AskBeginTransaction(tid)) answer_ttid = self._askPrimary(Packets.AskBeginTransaction(tid))
if self.local_var.tid is None: if answer_ttid is None:
raise NEOStorageError('tpc_begin failed') raise NEOStorageError('tpc_begin failed')
assert tid in (None, self.local_var.tid), (tid, self.local_var.tid) assert tid in (None, answer_ttid), (tid, answer_ttid)
self.local_var.txn = transaction txn_context['txn'] = transaction
txn_context['ttid'] = answer_ttid
@profiler_decorator @profiler_decorator
def store(self, oid, serial, data, version, transaction): def store(self, oid, serial, data, version, transaction):
"""Store object.""" """Store object."""
if transaction is not self.local_var.txn: txn_context = self._txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
neo.lib.logging.debug( neo.lib.logging.debug(
'storing oid %s serial %s', dump(oid), dump(serial)) 'storing oid %s serial %s', dump(oid), dump(serial))
self._store(oid, serial, data) self._store(txn_context, oid, serial, data)
return None return None
def _store(self, oid, serial, data, data_serial=None, unlock=False): def _store(self, txn_context, oid, serial, data, data_serial=None,
unlock=False):
ttid = txn_context['ttid']
if data is None: if data is None:
# This is some undo: either a no-data object (undoing object # This is some undo: either a no-data object (undoing object
# creation) or a back-pointer to an earlier revision (going back to # creation) or a back-pointer to an earlier revision (going back to
...@@ -589,33 +550,33 @@ class Application(object): ...@@ -589,33 +550,33 @@ class Application(object):
else: else:
compression = 1 compression = 1
checksum = makeChecksum(compressed_data) checksum = makeChecksum(compressed_data)
on_timeout = OnTimeout(self.onStoreTimeout, self.local_var.tid, oid) on_timeout = OnTimeout(self.onStoreTimeout, ttid, oid)
# Store object in tmp cache # Store object in tmp cache
local_var = self.local_var data_dict = txn_context['data_dict']
data_dict = local_var.data_dict
if oid not in data_dict: if oid not in data_dict:
local_var.data_list.append(oid) txn_context['data_list'].append(oid)
data_dict[oid] = data data_dict[oid] = data
# Store data on each node # Store data on each node
self.local_var.object_stored_counter_dict[oid] = {} txn_context['object_stored_counter_dict'][oid] = {}
object_base_serial_dict = local_var.object_base_serial_dict object_base_serial_dict = txn_context['object_base_serial_dict']
if oid not in object_base_serial_dict: if oid not in object_base_serial_dict:
object_base_serial_dict[oid] = serial object_base_serial_dict[oid] = serial
self.local_var.object_serial_dict[oid] = serial txn_context['object_serial_dict'][oid] = serial
queue = self.local_var.queue queue = txn_context['queue']
add_involved_nodes = self.local_var.involved_nodes.add involved_nodes = txn_context['involved_nodes']
add_involved_nodes = involved_nodes.add
packet = Packets.AskStoreObject(oid, serial, compression, packet = Packets.AskStoreObject(oid, serial, compression,
checksum, compressed_data, data_serial, self.local_var.tid, unlock) checksum, compressed_data, data_serial, ttid, unlock)
for node, conn in self.cp.iterateForObject(oid, writable=True): for node, conn in self.cp.iterateForObject(oid, writable=True):
try: try:
conn.ask(packet, on_timeout=on_timeout, queue=queue) conn.ask(packet, on_timeout=on_timeout, queue=queue)
add_involved_nodes(node) add_involved_nodes(node)
except ConnectionClosed: except ConnectionClosed:
continue continue
if not self.local_var.involved_nodes: if not involved_nodes:
raise NEOStorageError("Store failed") raise NEOStorageError("Store failed")
self._waitAnyMessage(False) self._waitAnyTransactionMessage(txn_context, False)
def onStoreTimeout(self, conn, msg_id, ttid, oid): def onStoreTimeout(self, conn, msg_id, ttid, oid):
# NOTE: this method is called from poll thread, don't use # NOTE: this method is called from poll thread, don't use
...@@ -628,17 +589,17 @@ class Application(object): ...@@ -628,17 +589,17 @@ class Application(object):
return True return True
@profiler_decorator @profiler_decorator
def _handleConflicts(self, tryToResolveConflict): def _handleConflicts(self, txn_context, tryToResolveConflict):
result = [] result = []
append = result.append append = result.append
local_var = self.local_var
# Check for conflicts # Check for conflicts
data_dict = local_var.data_dict data_dict = txn_context['data_dict']
object_base_serial_dict = local_var.object_base_serial_dict object_base_serial_dict = txn_context['object_base_serial_dict']
object_serial_dict = local_var.object_serial_dict object_serial_dict = txn_context['object_serial_dict']
conflict_serial_dict = local_var.conflict_serial_dict.copy() conflict_serial_dict = txn_context['conflict_serial_dict'].copy()
local_var.conflict_serial_dict.clear() txn_context['conflict_serial_dict'].clear()
resolved_conflict_serial_dict = local_var.resolved_conflict_serial_dict resolved_conflict_serial_dict = txn_context[
'resolved_conflict_serial_dict']
for oid, conflict_serial_set in conflict_serial_dict.iteritems(): for oid, conflict_serial_set in conflict_serial_dict.iteritems():
resolved_serial_set = resolved_conflict_serial_dict.setdefault( resolved_serial_set = resolved_conflict_serial_dict.setdefault(
oid, set()) oid, set())
...@@ -650,7 +611,6 @@ class Application(object): ...@@ -650,7 +611,6 @@ class Application(object):
continue continue
serial = object_serial_dict[oid] serial = object_serial_dict[oid]
data = data_dict[oid] data = data_dict[oid]
tid = local_var.tid
resolved = False resolved = False
if conflict_serial == ZERO_TID: if conflict_serial == ZERO_TID:
# Storage refused us from taking object lock, to avoid a # Storage refused us from taking object lock, to avoid a
...@@ -665,12 +625,11 @@ class Application(object): ...@@ -665,12 +625,11 @@ class Application(object):
# object data again. # object data again.
neo.lib.logging.info('Deadlock avoidance triggered on %r:%r', neo.lib.logging.info('Deadlock avoidance triggered on %r:%r',
dump(oid), dump(serial)) dump(oid), dump(serial))
for store_oid, store_data in \ for store_oid, store_data in data_dict.iteritems():
local_var.data_dict.iteritems():
store_serial = object_serial_dict[store_oid] store_serial = object_serial_dict[store_oid]
if store_data is None: if store_data is None:
self.checkCurrentSerialInTransaction(store_oid, self._checkCurrentSerialInTransaction(txn_context,
store_serial) store_oid, store_serial)
else: else:
if store_data is '': if store_data is '':
# Some undo # Some undo
...@@ -678,8 +637,8 @@ class Application(object): ...@@ -678,8 +637,8 @@ class Application(object):
' reliably work with undo, this must be ' ' reliably work with undo, this must be '
'implemented.') 'implemented.')
break break
self._store(store_oid, store_serial, store_data, self._store(txn_context, store_oid, store_serial,
unlock=True) store_data, unlock=True)
else: else:
resolved = True resolved = True
elif data is not None: elif data is not None:
...@@ -694,7 +653,7 @@ class Application(object): ...@@ -694,7 +653,7 @@ class Application(object):
# Base serial changes too, as we resolved a conflict # Base serial changes too, as we resolved a conflict
object_base_serial_dict[oid] = conflict_serial object_base_serial_dict[oid] = conflict_serial
# Try to store again # Try to store again
self._store(oid, conflict_serial, new_data) self._store(txn_context, oid, conflict_serial, new_data)
append(oid) append(oid)
resolved = True resolved = True
else: else:
...@@ -704,49 +663,51 @@ class Application(object): ...@@ -704,49 +663,51 @@ class Application(object):
if not resolved: if not resolved:
# XXX: Is it really required to remove from data_dict ? # XXX: Is it really required to remove from data_dict ?
del data_dict[oid] del data_dict[oid]
local_var.data_list.remove(oid) txn_context['data_list'].remove(oid)
if data is None: if data is None:
exc = ReadConflictError(oid=oid, serials=(conflict_serial, exc = ReadConflictError(oid=oid, serials=(conflict_serial,
serial)) serial))
else: else:
exc = ConflictError(oid=oid, serials=(tid, serial), exc = ConflictError(oid=oid, serials=(txn_context['ttid'],
data=data) serial), data=data)
raise exc raise exc
return result return result
@profiler_decorator @profiler_decorator
def waitResponses(self): def waitResponses(self, queue, handler_data):
"""Wait for all requests to be answered (or their connection to be """Wait for all requests to be answered (or their connection to be
detected as closed)""" detected as closed)"""
queue = self.local_var.queue
pending = self.dispatcher.pending pending = self.dispatcher.pending
_waitAnyMessage = self._waitAnyMessage _waitAnyMessage = self._waitAnyMessage
self.setHandlerData(handler_data)
while pending(queue): while pending(queue):
_waitAnyMessage() _waitAnyMessage(queue)
@profiler_decorator @profiler_decorator
def waitStoreResponses(self, tryToResolveConflict): def waitStoreResponses(self, txn_context, tryToResolveConflict):
result = [] result = []
append = result.append append = result.append
resolved_oid_set = set() resolved_oid_set = set()
update = resolved_oid_set.update update = resolved_oid_set.update
local_var = self.local_var ttid = txn_context['ttid']
tid = local_var.tid
_handleConflicts = self._handleConflicts _handleConflicts = self._handleConflicts
conflict_serial_dict = local_var.conflict_serial_dict queue = txn_context['queue']
queue = local_var.queue conflict_serial_dict = txn_context['conflict_serial_dict']
pending = self.dispatcher.pending pending = self.dispatcher.pending
_waitAnyMessage = self._waitAnyMessage _waitAnyTransactionMessage = self._waitAnyTransactionMessage
while pending(queue) or conflict_serial_dict: while pending(queue) or conflict_serial_dict:
_waitAnyMessage() # Note: handler data can be overwritten by _handleConflicts
# so we must set it for each iteration.
_waitAnyTransactionMessage(txn_context)
if conflict_serial_dict: if conflict_serial_dict:
conflicts = _handleConflicts(tryToResolveConflict) conflicts = _handleConflicts(txn_context,
tryToResolveConflict)
if conflicts: if conflicts:
update(conflicts) update(conflicts)
# Check for never-stored objects, and update result for all others # Check for never-stored objects, and update result for all others
for oid, store_dict in \ for oid, store_dict in \
local_var.object_stored_counter_dict.iteritems(): txn_context['object_stored_counter_dict'].iteritems():
if not store_dict: if not store_dict:
neo.lib.logging.error('tpc_store failed') neo.lib.logging.error('tpc_store failed')
raise NEOStorageError('tpc_store failed') raise NEOStorageError('tpc_store failed')
...@@ -757,27 +718,27 @@ class Application(object): ...@@ -757,27 +718,27 @@ class Application(object):
@profiler_decorator @profiler_decorator
def tpc_vote(self, transaction, tryToResolveConflict): def tpc_vote(self, transaction, tryToResolveConflict):
"""Store current transaction.""" """Store current transaction."""
local_var = self.local_var txn_context = self._txn_container.get(transaction)
if transaction is not local_var.txn: if txn_context is None or transaction is not txn_context['txn']:
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
result = self.waitStoreResponses(tryToResolveConflict) result = self.waitStoreResponses(txn_context, tryToResolveConflict)
tid = local_var.tid ttid = txn_context['ttid']
# Store data on each node # Store data on each node
txn_stored_counter = 0 txn_stored_counter = 0
packet = Packets.AskStoreTransaction(tid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), dumps(transaction._extension), str(transaction.description), dumps(transaction._extension),
local_var.data_list) txn_context['data_list'])
add_involved_nodes = self.local_var.involved_nodes.add add_involved_nodes = txn_context['involved_nodes'].add
for node, conn in self.cp.iterateForObject(tid, writable=True): for node, conn in self.cp.iterateForObject(ttid, writable=True):
neo.lib.logging.debug("voting object %s on %s", dump(tid), neo.lib.logging.debug("voting object %s on %s", dump(ttid),
dump(conn.getUUID())) dump(conn.getUUID()))
try: try:
self._askStorage(conn, packet) self._askStorage(conn, packet)
add_involved_nodes(node)
except ConnectionClosed: except ConnectionClosed:
continue continue
add_involved_nodes(node)
txn_stored_counter += 1 txn_stored_counter += 1
# check at least one storage node accepted # check at least one storage node accepted
...@@ -790,20 +751,22 @@ class Application(object): ...@@ -790,20 +751,22 @@ class Application(object):
# tpc_finish. # tpc_finish.
self._getMasterConnection() self._getMasterConnection()
local_var.txn_voted = True txn_context['txn_voted'] = True
return result return result
@profiler_decorator @profiler_decorator
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
"""Abort current transaction.""" """Abort current transaction."""
if transaction is not self.local_var.txn: txn_container = self._txn_container
txn_context = txn_container.get(transaction)
if txn_context is None:
return return
tid = self.local_var.tid ttid = txn_context['ttid']
p = Packets.AbortTransaction(tid) p = Packets.AbortTransaction(ttid)
getConnForNode = self.cp.getConnForNode getConnForNode = self.cp.getConnForNode
# cancel transaction one all those nodes # cancel transaction one all those nodes
for node in self.local_var.involved_nodes: for node in txn_context['involved_nodes']:
conn = getConnForNode(node) conn = getConnForNode(node)
if conn is None: if conn is None:
continue continue
...@@ -815,28 +778,30 @@ class Application(object): ...@@ -815,28 +778,30 @@ class Application(object):
'storage node %r of abortion, ignoring.', 'storage node %r of abortion, ignoring.',
conn, exc_info=1) conn, exc_info=1)
self._getMasterConnection().notify(p) self._getMasterConnection().notify(p)
queue = self.local_var.queue queue = txn_context['queue']
self.dispatcher.forget_queue(queue) # We don't need to flush queue, as it won't be reused by future
self.local_var.clear() # transactions (deleted on next line & indexed by transaction object
# instance).
self.dispatcher.forget_queue(queue, flush_queue=False)
txn_container.delete(transaction)
@profiler_decorator @profiler_decorator
def tpc_finish(self, transaction, tryToResolveConflict, f=None): def tpc_finish(self, transaction, tryToResolveConflict, f=None):
"""Finish current transaction.""" """Finish current transaction."""
local_var = self.local_var txn_container = self._txn_container
if local_var.txn is not transaction: txn_context = txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError('tpc_finish called for wrong ' raise StorageTransactionError('tpc_finish called for wrong '
'transaction') 'transaction')
if not local_var.txn_voted: if not txn_context['txn_voted']:
self.tpc_vote(transaction, tryToResolveConflict) self.tpc_vote(transaction, tryToResolveConflict)
self._load_lock_acquire() self._load_lock_acquire()
try: try:
# Call finish on master # Call finish on master
oid_list = local_var.data_list oid_list = txn_context['data_list']
p = Packets.AskFinishTransaction(local_var.tid, oid_list) p = Packets.AskFinishTransaction(txn_context['ttid'], oid_list)
self._askPrimary(p) tid = self._askPrimary(p)
# From now on, self.local_var.tid holds the "real" TID.
tid = local_var.tid
# Call function given by ZODB # Call function given by ZODB
if f is not None: if f is not None:
f(tid) f(tid)
...@@ -851,8 +816,8 @@ class Application(object): ...@@ -851,8 +816,8 @@ class Application(object):
assert next_tid is None, (dump(oid), dump(base_tid), assert next_tid is None, (dump(oid), dump(base_tid),
dump(next_tid)) dump(next_tid))
return (data, tid) return (data, tid)
get_baseTID = local_var.object_base_serial_dict.get get_baseTID = txn_context['object_base_serial_dict'].get
for oid, data in local_var.data_dict.iteritems(): for oid, data in txn_context['data_dict'].iteritems():
if data is None: if data is None:
# this is just a remain of # this is just a remain of
# checkCurrentSerialInTransaction call, ignore (no data # checkCurrentSerialInTransaction call, ignore (no data
...@@ -871,13 +836,14 @@ class Application(object): ...@@ -871,13 +836,14 @@ class Application(object):
mq_cache[(oid, tid)] = (data, None) mq_cache[(oid, tid)] = (data, None)
finally: finally:
self._cache_lock_release() self._cache_lock_release()
local_var.clear() txn_container.delete(transaction)
return tid return tid
finally: finally:
self._load_lock_release() self._load_lock_release()
def undo(self, snapshot_tid, undone_tid, txn, tryToResolveConflict): def undo(self, snapshot_tid, undone_tid, txn, tryToResolveConflict):
if txn is not self.local_var.txn: txn_context = self._txn_container.get(txn)
if txn_context is None:
raise StorageTransactionError(self, undone_tid) raise StorageTransactionError(self, undone_tid)
txn_info, txn_ext = self._getTransactionInformation(undone_tid) txn_info, txn_ext = self._getTransactionInformation(undone_tid)
...@@ -899,22 +865,23 @@ class Application(object): ...@@ -899,22 +865,23 @@ class Application(object):
getCellList = pt.getCellList getCellList = pt.getCellList
getCellSortKey = self.cp.getCellSortKey getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell getConnForCell = self.cp.getConnForCell
queue = self.local_var.queue queue = self._getThreadQueue()
undo_object_tid_dict = self.local_var.undo_object_tid_dict = {} ttid = txn_context['ttid']
for partition, oid_list in partition_oid_dict.iteritems(): for partition, oid_list in partition_oid_dict.iteritems():
cell_list = getCellList(partition, readable=True) cell_list = getCellList(partition, readable=True)
shuffle(cell_list) shuffle(cell_list)
cell_list.sort(key=getCellSortKey) cell_list.sort(key=getCellSortKey)
storage_conn = getConnForCell(cell_list[0]) storage_conn = getConnForCell(cell_list[0])
storage_conn.ask(Packets.AskObjectUndoSerial(self.local_var.tid, storage_conn.ask(Packets.AskObjectUndoSerial(ttid,
snapshot_tid, undone_tid, oid_list), queue=queue) snapshot_tid, undone_tid, oid_list), queue=queue)
# Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError, # Wait for all AnswerObjectUndoSerial. We might get OidNotFoundError,
# meaning that objects in transaction's oid_list do not exist any # meaning that objects in transaction's oid_list do not exist any
# longer. This is the symptom of a pack, so forbid undoing transaction # longer. This is the symptom of a pack, so forbid undoing transaction
# when it happens. # when it happens.
undo_object_tid_dict = {}
try: try:
self.waitResponses() self.waitResponses(queue, undo_object_tid_dict)
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
self.dispatcher.forget_queue(queue) self.dispatcher.forget_queue(queue)
raise UndoError('non-undoable transaction') raise UndoError('non-undoable transaction')
...@@ -929,9 +896,11 @@ class Application(object): ...@@ -929,9 +896,11 @@ class Application(object):
# object. This is an undo conflict, try to resolve it. # object. This is an undo conflict, try to resolve it.
try: try:
# Load the latest version we are supposed to see # Load the latest version we are supposed to see
data = self.load(snapshot_tid, oid, serial=current_serial)[0] data = self.load(snapshot_tid, oid,
serial=current_serial)[0]
# Load the version we were undoing to # Load the version we were undoing to
undo_data = self.load(snapshot_tid, oid, serial=undo_serial)[0] undo_data = self.load(snapshot_tid, oid,
serial=undo_serial)[0]
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
raise UndoError('Object not found while resolving undo ' raise UndoError('Object not found while resolving undo '
'conflict') 'conflict')
...@@ -945,7 +914,7 @@ class Application(object): ...@@ -945,7 +914,7 @@ class Application(object):
raise UndoError('Some data were modified by a later ' \ raise UndoError('Some data were modified by a later ' \
'transaction', oid) 'transaction', oid)
undo_serial = None undo_serial = None
self._store(oid, current_serial, data, undo_serial) self._store(txn_context, oid, current_serial, data, undo_serial)
def _insertMetadata(self, txn_info, extension): def _insertMetadata(self, txn_info, extension):
for k, v in loads(extension).items(): for k, v in loads(extension).items():
...@@ -955,7 +924,7 @@ class Application(object): ...@@ -955,7 +924,7 @@ class Application(object):
packet = Packets.AskTransactionInformation(tid) packet = Packets.AskTransactionInformation(tid)
for node, conn in self.cp.iterateForObject(tid, readable=True): for node, conn in self.cp.iterateForObject(tid, readable=True):
try: try:
self._askStorage(conn, packet) txn_info, txn_ext = self._askStorage(conn, packet)
except ConnectionClosed: except ConnectionClosed:
continue continue
except NEOStorageNotFoundError: except NEOStorageNotFoundError:
...@@ -964,7 +933,7 @@ class Application(object): ...@@ -964,7 +933,7 @@ class Application(object):
break break
else: else:
raise NEOStorageError('Transaction %r not found' % (tid, )) raise NEOStorageError('Transaction %r not found' % (tid, ))
return (self.local_var.txn_info, self.local_var.txn_ext) return (txn_info, txn_ext)
def undoLog(self, first, last, filter=None, block=0): def undoLog(self, first, last, filter=None, block=0):
# XXX: undoLog is broken # XXX: undoLog is broken
...@@ -978,8 +947,7 @@ class Application(object): ...@@ -978,8 +947,7 @@ class Application(object):
pt = self.getPartitionTable() pt = self.getPartitionTable()
storage_node_list = pt.getNodeList() storage_node_list = pt.getNodeList()
self.local_var.node_tids = {} queue = self._getThreadQueue()
queue = self.local_var.queue
packet = Packets.AskTIDs(first, last, INVALID_PARTITION) packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
for storage_node in storage_node_list: for storage_node in storage_node_list:
conn = self.cp.getConnForNode(storage_node) conn = self.cp.getConnForNode(storage_node)
...@@ -988,15 +956,11 @@ class Application(object): ...@@ -988,15 +956,11 @@ class Application(object):
conn.ask(packet, queue=queue) conn.ask(packet, queue=queue)
# Wait for answers from all storages. # Wait for answers from all storages.
self.waitResponses() tid_set = set()
self.waitResponses(queue, tid_set)
# Reorder tids # Reorder tids
ordered_tids = set() ordered_tids = sorted(tid_set, reverse=True)
update = ordered_tids.update
for tid_list in self.local_var.node_tids.itervalues():
update(tid_list)
ordered_tids = list(ordered_tids)
ordered_tids.sort(reverse=True)
neo.lib.logging.debug( neo.lib.logging.debug(
"UndoLog tids %s", [dump(x) for x in ordered_tids]) "UndoLog tids %s", [dump(x) for x in ordered_tids])
# For each transaction, get info # For each transaction, get info
...@@ -1004,11 +968,10 @@ class Application(object): ...@@ -1004,11 +968,10 @@ class Application(object):
append = undo_info.append append = undo_info.append
for tid in ordered_tids: for tid in ordered_tids:
(txn_info, txn_ext) = self._getTransactionInformation(tid) (txn_info, txn_ext) = self._getTransactionInformation(tid)
if filter is None or filter(self.local_var.txn_info): if filter is None or filter(txn_info):
txn_info = self.local_var.txn_info
txn_info.pop('packed') txn_info.pop('packed')
txn_info.pop("oids") txn_info.pop("oids")
self._insertMetadata(txn_info, self.local_var.txn_ext) self._insertMetadata(txn_info, txn_ext)
append(txn_info) append(txn_info)
if len(undo_info) >= last - first: if len(undo_info) >= last - first:
break break
...@@ -1024,9 +987,8 @@ class Application(object): ...@@ -1024,9 +987,8 @@ class Application(object):
node_list = node_map.keys() node_list = node_map.keys()
node_list.sort(key=self.cp.getCellSortKey) node_list.sort(key=self.cp.getCellSortKey)
partition_set = set(range(self.pt.getPartitions())) partition_set = set(range(self.pt.getPartitions()))
queue = self.local_var.queue queue = self._getThreadQueue()
# request a tid list for each partition # request a tid list for each partition
self.local_var.tids_from = set()
for node in node_list: for node in node_list:
conn = self.cp.getConnForNode(node) conn = self.cp.getConnForNode(node)
request_set = set(node_map[node]) & partition_set request_set = set(node_map[node]) & partition_set
...@@ -1038,40 +1000,34 @@ class Application(object): ...@@ -1038,40 +1000,34 @@ class Application(object):
if not partition_set: if not partition_set:
break break
assert not partition_set assert not partition_set
self.waitResponses() tid_set = set()
self.waitResponses(queue, tid_set)
# request transactions informations # request transactions informations
txn_list = [] txn_list = []
append = txn_list.append append = txn_list.append
tid = None tid = None
for tid in sorted(self.local_var.tids_from): for tid in sorted(tid_set):
(txn_info, txn_ext) = self._getTransactionInformation(tid) (txn_info, txn_ext) = self._getTransactionInformation(tid)
txn_info['ext'] = loads(self.local_var.txn_ext) txn_info['ext'] = loads(txn_ext)
append(txn_info) append(txn_info)
return (tid, txn_list) return (tid, txn_list)
def history(self, oid, version=None, size=1, filter=None): def history(self, oid, version=None, size=1, filter=None):
queue = self._getThreadQueue()
# Get history informations for object first # Get history informations for object first
packet = Packets.AskObjectHistory(oid, 0, size) packet = Packets.AskObjectHistory(oid, 0, size)
for node, conn in self.cp.iterateForObject(oid, readable=True): for node, conn in self.cp.iterateForObject(oid, readable=True):
# FIXME: we keep overwriting self.local_var.history here, we
# should aggregate it instead.
self.local_var.history = None
try: try:
self._askStorage(conn, packet) conn.ask(packet, queue=queue)
except ConnectionClosed: except ConnectionClosed:
continue continue
history_dict = {}
if self.local_var.history[0] != oid: self.waitResponses(queue, history_dict)
# Got history for wrong oid
raise NEOStorageError('inconsistency in storage: asked oid ' \
'%r, got %r' % (oid, self.local_var.history[0]))
if not isinstance(self.local_var.history, tuple):
raise NEOStorageError('history failed')
# Now that we have object informations, get txn informations # Now that we have object informations, get txn informations
history_list = [] history_list = []
for serial, size in self.local_var.history[1]: append = history_list.append
for serial in sorted(history_dict.keys(), reverse=True):
size = history_dict[serial]
txn_info, txn_ext = self._getTransactionInformation(serial) txn_info, txn_ext = self._getTransactionInformation(serial)
# create history dict # create history dict
txn_info.pop('id') txn_info.pop('id')
...@@ -1081,9 +1037,8 @@ class Application(object): ...@@ -1081,9 +1037,8 @@ class Application(object):
txn_info['version'] = '' txn_info['version'] = ''
txn_info['size'] = size txn_info['size'] = size
if filter is None or filter(txn_info): if filter is None or filter(txn_info):
history_list.append(txn_info) append(txn_info)
self._insertMetadata(txn_info, txn_ext) self._insertMetadata(txn_info, txn_ext)
return history_list return history_list
@profiler_decorator @profiler_decorator
...@@ -1111,16 +1066,15 @@ class Application(object): ...@@ -1111,16 +1066,15 @@ class Application(object):
return Iterator(self, start, stop) return Iterator(self, start, stop)
def lastTransaction(self): def lastTransaction(self):
self._askPrimary(Packets.AskLastTransaction()) return self._askPrimary(Packets.AskLastTransaction())
return self.local_var.last_transaction
def abortVersion(self, src, transaction): def abortVersion(self, src, transaction):
if transaction is not self.local_var.txn: if self._txn_container.get(transaction) is None:
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
return '', [] return '', []
def commitVersion(self, src, dest, transaction): def commitVersion(self, src, dest, transaction):
if transaction is not self.local_var.txn: if self._txn_container.get(transaction) is None:
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
return '', [] return '', []
...@@ -1141,12 +1095,6 @@ class Application(object): ...@@ -1141,12 +1095,6 @@ class Application(object):
def invalidationBarrier(self): def invalidationBarrier(self):
self._askPrimary(Packets.AskBarrier()) self._askPrimary(Packets.AskBarrier())
def setTID(self, value):
self.local_var.tid = value
def getTID(self):
return self.local_var.tid
def pack(self, t): def pack(self, t):
tid = repr(TimeStamp(*time.gmtime(t)[:5] + (t % 60, ))) tid = repr(TimeStamp(*time.gmtime(t)[:5] + (t % 60, )))
if tid == ZERO_TID: if tid == ZERO_TID:
...@@ -1166,24 +1114,27 @@ class Application(object): ...@@ -1166,24 +1114,27 @@ class Application(object):
return self.load(None, oid)[1] return self.load(None, oid)[1]
def checkCurrentSerialInTransaction(self, oid, serial, transaction): def checkCurrentSerialInTransaction(self, oid, serial, transaction):
local_var = self.local_var txn_context = self._txn_container.get(transaction)
if transaction is not local_var.txn: if txn_context is None:
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
local_var.object_serial_dict[oid] = serial self._checkCurrentSerialInTransaction(txn_context, oid, serial)
def _checkCurrentSerialInTransaction(self, txn_context, oid, serial):
ttid = txn_context['ttid']
txn_context['object_serial_dict'][oid] = serial
# Placeholders # Placeholders
queue = local_var.queue queue = txn_context['queue']
local_var.object_stored_counter_dict[oid] = {} txn_context['object_stored_counter_dict'][oid] = {}
data_dict = local_var.data_dict data_dict = txn_context['data_dict']
if oid not in data_dict: if oid not in data_dict:
# Marker value so we don't try to resolve conflicts. # Marker value so we don't try to resolve conflicts.
data_dict[oid] = None data_dict[oid] = None
local_var.data_list.append(oid) txn_context['data_list'].append(oid)
packet = Packets.AskCheckCurrentSerial(local_var.tid, serial, oid) packet = Packets.AskCheckCurrentSerial(ttid, serial, oid)
for node, conn in self.cp.iterateForObject(oid, writable=True): for node, conn in self.cp.iterateForObject(oid, writable=True):
try: try:
conn.ask(packet, queue=queue) conn.ask(packet, queue=queue)
except ConnectionClosed: except ConnectionClosed:
continue continue
self._waitAnyTransactionMessage(txn_context, False)
self._waitAnyMessage(False)
#
# Copyright (C) 2011 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from thread import get_ident
from neo.lib.locking import Queue
class ContainerBase(object):
def __init__(self):
self._context_dict = {}
def _getID(self, *args, **kw):
raise NotImplementedError
def _new(self, *args, **kw):
raise NotImplementedError
def delete(self, *args, **kw):
del self._context_dict[self._getID(*args, **kw)]
def get(self, *args, **kw):
return self._context_dict.get(self._getID(*args, **kw))
def new(self, *args, **kw):
result = self._context_dict[self._getID(*args, **kw)] = self._new(
*args, **kw)
return result
class ThreadContainer(ContainerBase):
def _getID(self):
return get_ident()
def _new(self):
return {
'queue': Queue(0),
'answer': None,
}
def get(self):
"""
Implicitely create a thread context if it doesn't exist.
"""
my_id = self._getID()
try:
result = self._context_dict[my_id]
except KeyError:
result = self._context_dict[my_id] = self._new()
return result
class TransactionContainer(ContainerBase):
def _getID(self, txn):
return id(txn)
def _new(self, txn):
return {
'queue': Queue(0),
'txn': txn,
'ttid': None,
'data_dict': {},
'data_list': [],
'object_base_serial_dict': {},
'object_serial_dict': {},
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {},
'txn_voted': False,
'involved_nodes': set(),
}
...@@ -156,21 +156,19 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -156,21 +156,19 @@ class PrimaryNotificationsHandler(BaseHandler):
class PrimaryAnswersHandler(AnswerBaseHandler): class PrimaryAnswersHandler(AnswerBaseHandler):
""" Handle that process expected packets from the primary master """ """ Handle that process expected packets from the primary master """
def answerBeginTransaction(self, conn, tid): def answerBeginTransaction(self, conn, ttid):
self.app.setTID(tid) self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
self.app.new_oid_list = oid_list self.app.new_oid_list = oid_list
def answerTransactionFinished(self, conn, ttid, tid): def answerTransactionFinished(self, conn, _, tid):
if ttid != self.app.getTID(): self.app.setHandlerData(tid)
raise NEOStorageError('Wrong TID, transaction not started')
self.app.setTID(tid)
def answerPack(self, conn, status): def answerPack(self, conn, status):
if not status: if not status:
raise NEOStorageError('Already packing') raise NEOStorageError('Already packing')
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
self.app.local_var.last_transaction = ltid self.app.setHandlerData(ltid)
...@@ -68,23 +68,25 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -68,23 +68,25 @@ class StorageAnswersHandler(AnswerBaseHandler):
if data_serial is not None: if data_serial is not None:
raise NEOStorageError, 'Storage should never send non-None ' \ raise NEOStorageError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), ) 'data_serial to clients, got %s' % (dump(data_serial), )
self.app.local_var.asked_object = (oid, start_serial, end_serial, self.app.setHandlerData((oid, start_serial, end_serial,
compression, checksum, data) compression, checksum, data))
def answerStoreObject(self, conn, conflicting, oid, serial): def answerStoreObject(self, conn, conflicting, oid, serial):
local_var = self.app.local_var txn_context = self.app.getHandlerData()
object_stored_counter_dict = local_var.object_stored_counter_dict[oid] object_stored_counter_dict = txn_context[
'object_stored_counter_dict'][oid]
if conflicting: if conflicting:
neo.lib.logging.info('%r report a conflict for %r with %r', conn, neo.lib.logging.info('%r report a conflict for %r with %r', conn,
dump(oid), dump(serial)) dump(oid), dump(serial))
conflict_serial_dict = local_var.conflict_serial_dict conflict_serial_dict = txn_context['conflict_serial_dict']
if serial in object_stored_counter_dict: if serial in object_stored_counter_dict:
raise NEOStorageError, 'A storage accepted object for ' \ raise NEOStorageError, 'A storage accepted object for ' \
'serial %s but another reports a conflict for it.' % ( 'serial %s but another reports a conflict for it.' % (
dump(serial), ) dump(serial), )
# If this conflict is not already resolved, mark it for # If this conflict is not already resolved, mark it for
# resolution. # resolution.
if serial not in local_var.resolved_conflict_serial_dict.get(oid, ()): if serial not in txn_context[
'resolved_conflict_serial_dict'].get(oid, ()):
conflict_serial_dict.setdefault(oid, set()).add(serial) conflict_serial_dict.setdefault(oid, set()).add(serial)
else: else:
object_stored_counter_dict[serial] = \ object_stored_counter_dict[serial] = \
...@@ -92,31 +94,29 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -92,31 +94,29 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, tid): def answerStoreTransaction(self, conn, _):
if tid != self.app.getTID(): pass
raise NEOStorageError('Wrong TID, transaction not started')
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn) neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
assert not self.app.local_var.tids_from.intersection(set(tid_list)) tids_from = self.app.getHandlerData()
self.app.local_var.tids_from.update(tid_list) assert not tids_from.intersection(set(tid_list))
tids_from.update(tid_list)
def answerTransactionInformation(self, conn, tid, def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list): user, desc, ext, packed, oid_list):
# transaction information are returned as a dict self.app.setHandlerData(({
info = {} 'time': TimeStamp(tid).timeTime(),
info['time'] = TimeStamp(tid).timeTime() 'user_name': user,
info['user_name'] = user 'description': desc,
info['description'] = desc 'id': tid,
info['id'] = tid 'oids': oid_list,
info['oids'] = oid_list 'packed': packed,
info['packed'] = packed }, ext))
self.app.local_var.txn_ext = ext
self.app.local_var.txn_info = info def answerObjectHistory(self, conn, _, history_list):
def answerObjectHistory(self, conn, oid, history_list):
# history_list is a list of tuple (serial, size) # history_list is a list of tuple (serial, size)
self.app.local_var.history = oid, history_list self.app.getHandlerData().update(history_list)
def oidNotFound(self, conn, message): def oidNotFound(self, conn, message):
# This can happen either when : # This can happen either when :
...@@ -132,10 +132,10 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -132,10 +132,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
raise NEOStorageNotFoundError(message) raise NEOStorageNotFoundError(message)
def answerTIDs(self, conn, tid_list): def answerTIDs(self, conn, tid_list):
self.app.local_var.node_tids[conn.getUUID()] = tid_list self.app.getHandlerData().update(tid_list)
def answerObjectUndoSerial(self, conn, object_tid_dict): def answerObjectUndoSerial(self, conn, object_tid_dict):
self.app.local_var.undo_object_tid_dict.update(object_tid_dict) self.app.getHandlerData().update(object_tid_dict)
def answerHasLock(self, conn, oid, status): def answerHasLock(self, conn, oid, status):
if status == LockState.GRANTED_TO_OTHER: if status == LockState.GRANTED_TO_OTHER:
......
...@@ -66,9 +66,7 @@ class ConnectionPool(object): ...@@ -66,9 +66,7 @@ class ConnectionPool(object):
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name)
try: try:
msg_id = conn.ask(p, queue=app.local_var.queue) app._ask(conn, p, handler=app.storage_bootstrap_handler)
app._waitMessage(conn, msg_id,
handler=app.storage_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
neo.lib.logging.error('Connection to %r failed', node) neo.lib.logging.error('Connection to %r failed', node)
self.notifyFailure(node) self.notifyFailure(node)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import new
import unittest import unittest
from cPickle import dumps from cPickle import dumps
from mock import Mock, ReturnValues from mock import Mock, ReturnValues
...@@ -24,7 +23,8 @@ from neo.tests import NeoUnitTestBase ...@@ -24,7 +23,8 @@ from neo.tests import NeoUnitTestBase
from neo.client.app import Application, RevisionIndex from neo.client.app import Application, RevisionIndex
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError from neo.client.exception import NEOStorageDoesNotExistError
from neo.lib.protocol import Packet, Packets, Errors, INVALID_TID from neo.lib.protocol import Packet, Packets, Errors, INVALID_TID, \
INVALID_PARTITION
from neo.lib.util import makeChecksum from neo.lib.util import makeChecksum
import time import time
...@@ -45,11 +45,14 @@ def getPartitionTable(self): ...@@ -45,11 +45,14 @@ def getPartitionTable(self):
self.master_conn = _getMasterConnection(self) self.master_conn = _getMasterConnection(self)
return self.pt return self.pt
def _waitMessage(self, conn, msg_id, handler=None): def _ask(self, conn, packet, handler=None):
conn.ask(packet)
self.setHandlerData(None)
if handler is None: if handler is None:
raise NotImplementedError raise NotImplementedError
else: else:
handler.dispatch(conn, conn.fakeReceived()) handler.dispatch(conn, conn.fakeReceived())
return self.getHandlerData()
def resolving_tryToResolveConflict(oid, conflict_serial, serial, data): def resolving_tryToResolveConflict(oid, conflict_serial, serial, data):
return data return data
...@@ -63,10 +66,10 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -63,10 +66,10 @@ class ClientApplicationTests(NeoUnitTestBase):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
# apply monkey patches # apply monkey patches
self._getMasterConnection = Application._getMasterConnection self._getMasterConnection = Application._getMasterConnection
self._waitMessage = Application._waitMessage self._ask = Application._ask
self.getPartitionTable = Application.getPartitionTable self.getPartitionTable = Application.getPartitionTable
Application._getMasterConnection = _getMasterConnection Application._getMasterConnection = _getMasterConnection
Application._waitMessage = _waitMessage Application._ask = _ask
Application.getPartitionTable = getPartitionTable Application.getPartitionTable = getPartitionTable
self._to_stop_list = [] self._to_stop_list = []
...@@ -76,12 +79,19 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -76,12 +79,19 @@ class ClientApplicationTests(NeoUnitTestBase):
app.close() app.close()
# restore environnement # restore environnement
Application._getMasterConnection = self._getMasterConnection Application._getMasterConnection = self._getMasterConnection
Application._waitMessage = self._waitMessage Application._ask = self._ask
Application.getPartitionTable = self.getPartitionTable Application.getPartitionTable = self.getPartitionTable
NeoUnitTestBase.tearDown(self) NeoUnitTestBase.tearDown(self)
# some helpers # some helpers
def _begin(self, app, txn, tid=None):
txn_context = app._txn_container.new(txn)
if tid is None:
tid = self.makeTID()
txn_context['ttid'] = tid
return txn_context
def checkAskPacket(self, conn, packet_type, decode=False): def checkAskPacket(self, conn, packet_type, decode=False):
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
...@@ -154,13 +164,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -154,13 +164,6 @@ class ClientApplicationTests(NeoUnitTestBase):
#self.assertEquals(calls[0].getParam(0), conn) #self.assertEquals(calls[0].getParam(0), conn)
#self.assertTrue(isinstance(calls[0].getParam(2), Queue)) #self.assertTrue(isinstance(calls[0].getParam(2), Queue))
def test_getQueue(self):
app = self.getApp()
# Test sanity check
self.assertTrue(getattr(app, 'local_var', None) is not None)
# Test that queue is created
self.assertTrue(getattr(app.local_var, 'queue', None) is not None)
def test_registerDB(self): def test_registerDB(self):
app = self.getApp() app = self.getApp()
dummy_db = [] dummy_db = []
...@@ -186,7 +189,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -186,7 +189,6 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_load(self): def test_load(self):
app = self.getApp() app = self.getApp()
app.local_var.barrier_done = True
mq = app.mq_cache mq = app.mq_cache
oid = self.makeOID() oid = self.makeOID()
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
...@@ -235,7 +237,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -235,7 +237,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.cp = self.getConnectionPool([(Mock(), conn)]) app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = an_object[:-1]
answer_barrier = Packets.AnswerBarrier() answer_barrier = Packets.AnswerBarrier()
answer_barrier.setId(1) answer_barrier.setId(1)
app.master_conn = Mock({ app.master_conn = Mock({
...@@ -257,7 +258,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -257,7 +258,6 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_loadSerial(self): def test_loadSerial(self):
app = self.getApp() app = self.getApp()
app.local_var.barrier_done = True
mq = app.mq_cache mq = app.mq_cache
oid = self.makeOID() oid = self.makeOID()
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
...@@ -292,7 +292,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -292,7 +292,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.cp = self.getConnectionPool([(Mock(), conn)]) app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = another_object[:-1]
result = loadSerial(oid, tid1) result = loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT') self.assertEquals(result, 'RIGHT')
self.checkAskObject(conn) self.checkAskObject(conn)
...@@ -300,7 +299,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -300,7 +299,6 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_loadBefore(self): def test_loadBefore(self):
app = self.getApp() app = self.getApp()
app.local_var.barrier_done = True
mq = app.mq_cache mq = app.mq_cache
oid = self.makeOID() oid = self.makeOID()
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
...@@ -332,7 +330,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -332,7 +330,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.cp = self.getConnectionPool([(Mock(), conn)]) app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = an_object[:-1]
self.assertRaises(NEOStorageError, loadBefore, oid, tid1) self.assertRaises(NEOStorageError, loadBefore, oid, tid1)
# object should not have been cached # object should not have been cached
self.assertFalse((oid, tid1) in mq) self.assertFalse((oid, tid1) in mq)
...@@ -349,7 +346,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -349,7 +346,6 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.cp = self.getConnectionPool([(Mock(), conn)]) app.cp = self.getConnectionPool([(Mock(), conn)])
app.local_var.asked_object = another_object
result = loadBefore(oid, tid3) result = loadBefore(oid, tid3)
self.assertEquals(result, ('RIGHT', tid2, tid3)) self.assertEquals(result, ('RIGHT', tid2, tid3))
self.checkAskObject(conn) self.checkAskObject(conn)
...@@ -369,16 +365,17 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -369,16 +365,17 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.tpc_begin(transaction=txn, tid=tid) app.tpc_begin(transaction=txn, tid=tid)
self.assertTrue(app.local_var.txn is txn) txn_context = app._txn_container.get(txn)
self.assertEquals(app.local_var.tid, tid) self.assertTrue(txn_context['txn'] is txn)
self.assertEquals(txn_context['ttid'], tid)
# next, the transaction already begin -> raise # next, the transaction already begin -> raise
self.assertRaises(StorageTransactionError, app.tpc_begin, self.assertRaises(StorageTransactionError, app.tpc_begin,
transaction=txn, tid=None) transaction=txn, tid=None)
self.assertTrue(app.local_var.txn is txn) txn_context = app._txn_container.get(txn)
self.assertEquals(app.local_var.tid, tid) self.assertTrue(txn_context['txn'] is txn)
# cancel and start a transaction without tid self.assertEquals(txn_context['ttid'], tid)
app.local_var.txn = None # start a transaction without tid
app.local_var.tid = None txn = Mock()
# no connection -> NEOStorageError (wait until connected to primary) # no connection -> NEOStorageError (wait until connected to primary)
#self.assertRaises(NEOStorageError, app.tpc_begin, transaction=txn, tid=None) #self.assertRaises(NEOStorageError, app.tpc_begin, transaction=txn, tid=None)
# ask a tid to pmn # ask a tid to pmn
...@@ -392,8 +389,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -392,8 +389,9 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkAskNewTid(app.master_conn) self.checkAskNewTid(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn) self.checkDispatcherRegisterCalled(app, app.master_conn)
# check attributes # check attributes
self.assertTrue(app.local_var.txn is txn) txn_context = app._txn_container.get(txn)
self.assertEquals(app.local_var.tid, tid) self.assertTrue(txn_context['txn'] is txn)
self.assertEquals(txn_context['ttid'], tid)
def test_store1(self): def test_store1(self):
app = self.getApp() app = self.getApp()
...@@ -401,14 +399,10 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -401,14 +399,10 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
# invalid transaction > StorageTransactionError # invalid transaction > StorageTransactionError
app.local_var.txn = old_txn = object()
self.assertTrue(app.local_var.txn is not txn)
self.assertRaises(StorageTransactionError, app.store, oid, tid, '', self.assertRaises(StorageTransactionError, app.store, oid, tid, '',
None, txn) None, txn)
self.assertEquals(app.local_var.txn, old_txn)
# check partition_id and an empty cell list -> NEOStorageError # check partition_id and an empty cell list -> NEOStorageError
app.local_var.txn = txn self._begin(app, txn, self.makeTID())
app.local_var.tid = tid
app.pt = Mock({ 'getCellListForOID': (), }) app.pt = Mock({ 'getCellListForOID': (), })
app.num_partitions = 2 app.num_partitions = 2
self.assertRaises(NEOStorageError, app.store, oid, tid, '', None, self.assertRaises(NEOStorageError, app.store, oid, tid, '', None,
...@@ -423,8 +417,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -423,8 +417,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
# build conflicting state # build conflicting state
app.local_var.txn = txn txn_context = self._begin(app, txn, tid)
app.local_var.tid = tid
packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid) packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid)
packet.setId(0) packet.setId(0)
storage_address = ('127.0.0.1', 10020) storage_address = ('127.0.0.1', 10020)
...@@ -436,14 +429,15 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -436,14 +429,15 @@ class ClientApplicationTests(NeoUnitTestBase):
return not queue.empty() return not queue.empty()
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
app.nm.createStorage(address=storage_address) app.nm.createStorage(address=storage_address)
app.local_var.data_dict[oid] = 'BEFORE' data_dict = txn_context['data_dict']
app.local_var.data_list.append(oid) data_dict[oid] = 'BEFORE'
txn_context['data_list'].append(oid)
app.store(oid, tid, '', None, txn) app.store(oid, tid, '', None, txn)
app.local_var.queue.put((conn, packet)) txn_context['queue'].put((conn, packet))
self.assertRaises(ConflictError, app.waitStoreResponses, self.assertRaises(ConflictError, app.waitStoreResponses, txn_context,
failing_tryToResolveConflict) failing_tryToResolveConflict)
self.assertTrue(oid not in app.local_var.data_dict) self.assertTrue(oid not in data_dict)
self.assertEquals(app.local_var.object_stored_counter_dict[oid], {}) self.assertEquals(txn_context['object_stored_counter_dict'][oid], {})
self.checkAskStoreObject(conn) self.checkAskStoreObject(conn)
def test_store3(self): def test_store3(self):
...@@ -452,8 +446,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -452,8 +446,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
# case with no conflict # case with no conflict
app.local_var.txn = txn txn_context = self._begin(app, txn, tid)
app.local_var.tid = tid
packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid) packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid)
packet.setId(0) packet.setId(0)
storage_address = ('127.0.0.1', 10020) storage_address = ('127.0.0.1', 10020)
...@@ -467,47 +460,25 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -467,47 +460,25 @@ class ClientApplicationTests(NeoUnitTestBase):
app.nm.createStorage(address=storage_address) app.nm.createStorage(address=storage_address)
app.store(oid, tid, 'DATA', None, txn) app.store(oid, tid, 'DATA', None, txn)
self.checkAskStoreObject(conn) self.checkAskStoreObject(conn)
app.local_var.queue.put((conn, packet)) txn_context['queue'].put((conn, packet))
app.waitStoreResponses(resolving_tryToResolveConflict) app.waitStoreResponses(txn_context, resolving_tryToResolveConflict)
self.assertEquals(app.local_var.object_stored_counter_dict[oid], {tid: 1}) self.assertEquals(txn_context['object_stored_counter_dict'][oid],
self.assertEquals(app.local_var.data_dict.get(oid, None), 'DATA') {tid: 1})
self.assertFalse(oid in app.local_var.conflict_serial_dict) self.assertEquals(txn_context['data_dict'].get(oid, None), 'DATA')
self.assertFalse(oid in txn_context['conflict_serial_dict'])
def test_tpc_vote1(self): def test_tpc_vote1(self):
app = self.getApp() app = self.getApp()
oid = self.makeOID(11)
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
# invalid transaction > StorageTransactionError # invalid transaction > StorageTransactionError
app.local_var.txn = old_txn = object()
self.assertTrue(app.local_var.txn is not txn)
self.assertRaises(StorageTransactionError, app.tpc_vote, txn, self.assertRaises(StorageTransactionError, app.tpc_vote, txn,
resolving_tryToResolveConflict) resolving_tryToResolveConflict)
self.assertEquals(app.local_var.txn, old_txn)
def test_tpc_vote2(self):
# fake transaction object
app = self.getApp()
app.local_var.txn = self.makeTransactionObject()
app.local_var.tid = self.makeTID()
# wrong answer -> failure
packet = Packets.AnswerStoreTransaction(INVALID_TID)
packet.setId(0)
conn = Mock({
'getNextId': 1,
'fakeReceived': packet,
'getAddress': ('127.0.0.1', 0),
})
app.cp = self.getConnectionPool([(Mock(), conn)])
self.assertRaises(NEOStorageError, app.tpc_vote, app.local_var.txn,
resolving_tryToResolveConflict)
self.checkAskPacket(conn, Packets.AskStoreTransaction)
def test_tpc_vote3(self): def test_tpc_vote3(self):
app = self.getApp() app = self.getApp()
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.local_var.txn = txn self._begin(app, txn, tid)
app.local_var.tid = tid
# response -> OK # response -> OK
packet = Packets.AnswerStoreTransaction(tid=tid) packet = Packets.AnswerStoreTransaction(tid=tid)
packet.setId(0) packet.setId(0)
...@@ -529,10 +500,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -529,10 +500,9 @@ class ClientApplicationTests(NeoUnitTestBase):
app = self.getApp() app = self.getApp()
tid = self.makeTID() tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.local_var.txn = old_txn = object() old_txn = object()
self._begin(app, old_txn, tid)
app.master_conn = Mock() app.master_conn = Mock()
app.local_var.tid = tid
self.assertFalse(app.local_var.txn is txn)
conn = Mock() conn = Mock()
cell = Mock() cell = Mock()
app.pt = Mock({'getCellListForTID': (cell, cell)}) app.pt = Mock({'getCellListForTID': (cell, cell)})
...@@ -541,8 +511,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -541,8 +511,9 @@ class ClientApplicationTests(NeoUnitTestBase):
# no packet sent # no packet sent
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
self.assertEquals(app.local_var.txn, old_txn) txn_context = app._txn_container.get(old_txn)
self.assertEquals(app.local_var.tid, tid) self.assertTrue(txn_context['txn'] is old_txn)
self.assertEquals(txn_context['ttid'], tid)
def test_tpc_abort2(self): def test_tpc_abort2(self):
# 2 nodes : 1 transaction in the first, 2 objects in the second # 2 nodes : 1 transaction in the first, 2 objects in the second
...@@ -552,7 +523,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -552,7 +523,7 @@ class ClientApplicationTests(NeoUnitTestBase):
oid1, oid2 = self.makeOID(2), self.makeOID(4) # on partition 0 oid1, oid2 = self.makeOID(2), self.makeOID(4) # on partition 0
app, tid = self.getApp(), self.makeTID(1) # on partition 1 app, tid = self.getApp(), self.makeTID(1) # on partition 1
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, tid txn_context = self._begin(app, txn, tid)
app.master_conn = Mock({'__hash__': 0}) app.master_conn = Mock({'__hash__': 0})
app.num_partitions = 2 app.num_partitions = 2
cell1 = Mock({ 'getNode': 'NODE1', '__hash__': 1 }) cell1 = Mock({ 'getNode': 'NODE1', '__hash__': 1 })
...@@ -560,16 +531,13 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -560,16 +531,13 @@ class ClientApplicationTests(NeoUnitTestBase):
conn1, conn2 = Mock({ 'getNextId': 1, }), Mock({ 'getNextId': 2, }) conn1, conn2 = Mock({ 'getNextId': 1, }), Mock({ 'getNextId': 2, })
app.cp = Mock({ 'getConnForNode': ReturnValues(conn1, conn2), }) app.cp = Mock({ 'getConnForNode': ReturnValues(conn1, conn2), })
# fake data # fake data
app.local_var.data_dict = {oid1: '', oid2: ''} txn_context['involved_nodes'].update([cell1, cell2])
app.local_var.involved_nodes = set([cell1, cell2])
app.tpc_abort(txn) app.tpc_abort(txn)
# will check if there was just one call/packet : # will check if there was just one call/packet :
self.checkNotifyPacket(conn1, Packets.AbortTransaction) self.checkNotifyPacket(conn1, Packets.AbortTransaction)
self.checkNotifyPacket(conn2, Packets.AbortTransaction) self.checkNotifyPacket(conn2, Packets.AbortTransaction)
self.assertEquals(app.local_var.tid, None) self.checkNotifyPacket(app.master_conn, Packets.AbortTransaction)
self.assertEquals(app.local_var.txn, None) self.assertEqual(app._txn_container.get(txn), None)
self.assertEquals(app.local_var.data_dict, {})
self.assertEquals(app.local_var.txn_voted, False)
def test_tpc_abort3(self): def test_tpc_abort3(self):
""" check that abort is sent to all nodes involved in the transaction """ """ check that abort is sent to all nodes involved in the transaction """
...@@ -617,19 +585,20 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -617,19 +585,20 @@ class ClientApplicationTests(NeoUnitTestBase):
}) })
app.master_conn = Mock({'__hash__': 0}) app.master_conn = Mock({'__hash__': 0})
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, tid txn_context = self._begin(app, txn, tid)
class Dispatcher(object): class Dispatcher(object):
def pending(self, queue): def pending(self, queue):
return not queue.empty() return not queue.empty()
def forget_queue(self, queue): def forget_queue(self, queue, flush_queue=True):
pass pass
app.dispatcher = Dispatcher() app.dispatcher = Dispatcher()
# conflict occurs on storage 2 # conflict occurs on storage 2
app.store(oid1, tid, 'DATA', None, txn) app.store(oid1, tid, 'DATA', None, txn)
app.store(oid2, tid, 'DATA', None, txn) app.store(oid2, tid, 'DATA', None, txn)
app.local_var.queue.put((conn2, packet2)) queue = txn_context['queue']
app.local_var.queue.put((conn3, packet3)) queue.put((conn2, packet2))
queue.put((conn3, packet3))
# vote fails as the conflict is not resolved, nothing is sent to storage 3 # vote fails as the conflict is not resolved, nothing is sent to storage 3
self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict) self.assertRaises(ConflictError, app.tpc_vote, txn, failing_tryToResolveConflict)
# abort must be sent to storage 1 and 2 # abort must be sent to storage 1 and 2
...@@ -640,59 +609,11 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -640,59 +609,11 @@ class ClientApplicationTests(NeoUnitTestBase):
def test_tpc_finish1(self): def test_tpc_finish1(self):
# transaction mismatch: raise # transaction mismatch: raise
app = self.getApp() app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.local_var.txn = old_txn = object()
app.master_conn = Mock() app.master_conn = Mock()
self.assertFalse(app.local_var.txn is txn)
conn = Mock()
self.assertRaises(StorageTransactionError, app.tpc_finish, txn, None) self.assertRaises(StorageTransactionError, app.tpc_finish, txn, None)
# no packet sent # no packet sent
self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
self.assertEquals(app.local_var.txn, old_txn)
def test_tpc_finish2(self):
# bad answer -> NEOStorageError
app = self.getApp()
tid = self.makeTID()
txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, tid
# test callable passed to tpc_finish
self.f_called = False
self.f_called_with_tid = None
def hook(tid):
self.f_called = True
self.f_called_with_tid = tid
packet = Packets.AnswerTransactionFinished(INVALID_TID, INVALID_TID)
packet.setId(0)
app.master_conn = Mock({
'getNextId': 1,
'getAddress': ('127.0.0.1', 10000),
'fakeReceived': packet,
})
self.vote_params = None
tpc_vote = app.tpc_vote
def voteDetector(transaction, tryToResolveConflict):
self.vote_params = (transaction, tryToResolveConflict)
dummy_tryToResolveConflict = []
app.tpc_vote = voteDetector
app.local_var.txn_voted = True
self.assertRaises(NEOStorageError, app.tpc_finish, txn,
dummy_tryToResolveConflict, hook)
self.assertFalse(self.f_called)
self.assertEqual(self.vote_params, None)
self.checkAskFinishTransaction(app.master_conn)
self.checkDispatcherRegisterCalled(app, app.master_conn)
# Call again, but this time transaction is not voted yet
app.local_var.txn_voted = False
self.f_called = False
self.assertRaises(NEOStorageError, app.tpc_finish, txn,
dummy_tryToResolveConflict, hook)
self.assertFalse(self.f_called)
self.assertTrue(self.vote_params[0] is txn)
self.assertTrue(self.vote_params[1] is dummy_tryToResolveConflict)
app.tpc_vote = tpc_vote
def test_tpc_finish3(self): def test_tpc_finish3(self):
# transaction is finished # transaction is finished
...@@ -700,7 +621,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -700,7 +621,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
ttid = self.makeTID() ttid = self.makeTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
app.local_var.txn, app.local_var.tid = txn, ttid txn_context = self._begin(app, txn, tid)
self.f_called = False self.f_called = False
self.f_called_with_tid = None self.f_called_with_tid = None
def hook(tid): def hook(tid):
...@@ -713,17 +634,13 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -713,17 +634,13 @@ class ClientApplicationTests(NeoUnitTestBase):
'getAddress': ('127.0.0.1', 10010), 'getAddress': ('127.0.0.1', 10010),
'fakeReceived': packet, 'fakeReceived': packet,
}) })
app.local_var.txn_voted = True txn_context['txn_voted'] = True
app.local_var.txn_finished = True
app.tpc_finish(txn, None, hook) app.tpc_finish(txn, None, hook)
self.assertTrue(self.f_called) self.assertTrue(self.f_called)
self.assertEquals(self.f_called_with_tid, tid) self.assertEquals(self.f_called_with_tid, tid)
self.checkAskFinishTransaction(app.master_conn) self.checkAskFinishTransaction(app.master_conn)
#self.checkDispatcherRegisterCalled(app, app.master_conn) #self.checkDispatcherRegisterCalled(app, app.master_conn)
self.assertEquals(app.local_var.tid, None) self.assertEqual(app._txn_container.get(txn), None)
self.assertEquals(app.local_var.txn, None)
self.assertEquals(app.local_var.data_dict, {})
self.assertEquals(app.local_var.txn_voted, False)
def test_undo1(self): def test_undo1(self):
# invalid transaction # invalid transaction
...@@ -731,21 +648,15 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -731,21 +648,15 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID() tid = self.makeTID()
snapshot_tid = self.getNextTID() snapshot_tid = self.getNextTID()
txn = self.makeTransactionObject() txn = self.makeTransactionObject()
marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data): def tryToResolveConflict(oid, conflict_serial, serial, data):
marker.append(1) pass
app.local_var.txn = old_txn = object()
app.master_conn = Mock() app.master_conn = Mock()
self.assertFalse(app.local_var.txn is txn)
conn = Mock() conn = Mock()
self.assertRaises(StorageTransactionError, app.undo, snapshot_tid, tid, self.assertRaises(StorageTransactionError, app.undo, snapshot_tid, tid,
txn, tryToResolveConflict) txn, tryToResolveConflict)
# no packet sent # no packet sent
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.checkNoPacketSent(app.master_conn) self.checkNoPacketSent(app.master_conn)
# nothing done
self.assertEquals(marker, [])
self.assertEquals(app.local_var.txn, old_txn)
def _getAppForUndoTests(self, oid0, tid0, tid1, tid2): def _getAppForUndoTests(self, oid0, tid0, tid1, tid2):
app = self.getApp() app = self.getApp()
...@@ -780,10 +691,10 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -780,10 +691,10 @@ class ClientApplicationTests(NeoUnitTestBase):
return ({tid0: 'dummy', tid2: 'cdummy'}[serial], None, None) return ({tid0: 'dummy', tid2: 'cdummy'}[serial], None, None)
app.load = load app.load = load
store_marker = [] store_marker = []
def _store(oid, serial, data, data_serial=None): def _store(txn_context, oid, serial, data, data_serial=None,
unlock=False):
store_marker.append((oid, serial, data, data_serial)) store_marker.append((oid, serial, data, data_serial))
app._store = _store app._store = _store
app.local_var.clear()
return app, conn, store_marker return app, conn, store_marker
def test_undoWithResolutionSuccess(self): def test_undoWithResolutionSuccess(self):
...@@ -805,7 +716,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -805,7 +716,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial = Packets.AnswerObjectUndoSerial({ undo_serial = Packets.AnswerObjectUndoSerial({
oid0: (tid2, tid0, False)}) oid0: (tid2, tid0, False)})
undo_serial.setId(2) undo_serial.setId(2)
app.local_var.queue.put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial))
marker = [] marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
...@@ -846,7 +757,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -846,7 +757,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2) undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1, app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2) tid2)
app.local_var.queue.put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial))
marker = [] marker = []
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
...@@ -872,7 +783,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -872,7 +783,7 @@ class ClientApplicationTests(NeoUnitTestBase):
marker.append((oid, conflict_serial, serial, data, committedData)) marker.append((oid, conflict_serial, serial, data, committedData))
raise ConflictError raise ConflictError
# The undo # The undo
app.local_var.queue.put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial))
self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn, self.assertRaises(UndoError, app.undo, snapshot_tid, tid1, txn,
tryToResolveConflict) tryToResolveConflict)
# Checking what happened # Checking what happened
...@@ -905,7 +816,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -905,7 +816,7 @@ class ClientApplicationTests(NeoUnitTestBase):
undo_serial.setId(2) undo_serial.setId(2)
app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1, app, conn, store_marker = self._getAppForUndoTests(oid0, tid0, tid1,
tid2) tid2)
app.local_var.queue.put((conn, undo_serial)) app._getThreadQueue().put((conn, undo_serial))
def tryToResolveConflict(oid, conflict_serial, serial, data, def tryToResolveConflict(oid, conflict_serial, serial, data,
committedData=''): committedData=''):
raise Exception, 'Test called conflict resolution, but there ' \ raise Exception, 'Test called conflict resolution, but there ' \
...@@ -929,15 +840,19 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -929,15 +840,19 @@ class ClientApplicationTests(NeoUnitTestBase):
cell1, cell2 = Mock({}), Mock({}) cell1, cell2 = Mock({}), Mock({})
tid1, tid2 = self.makeTID(1), self.makeTID(2) tid1, tid2 = self.makeTID(1), self.makeTID(2)
oid1, oid2 = self.makeOID(1), self.makeOID(2) oid1, oid2 = self.makeOID(1), self.makeOID(2)
# TIDs packets supplied by _waitMessage hook # TIDs packets supplied by _ask hook
# TXN info packets # TXN info packets
extension = dumps({}) extension = dumps({})
p1 = Packets.AnswerTIDs([tid1])
p2 = Packets.AnswerTIDs([tid2])
p3 = Packets.AnswerTransactionInformation(tid1, '', '', p3 = Packets.AnswerTransactionInformation(tid1, '', '',
extension, False, (oid1, )) extension, False, (oid1, ))
p4 = Packets.AnswerTransactionInformation(tid2, '', '', p4 = Packets.AnswerTransactionInformation(tid2, '', '',
extension, False, (oid2, )) extension, False, (oid2, ))
p3.setId(0) p1.setId(0)
p4.setId(1) p2.setId(1)
p3.setId(2)
p4.setId(3)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'getUUID': ReturnValues(uuid1, uuid2), 'getUUID': ReturnValues(uuid1, uuid2),
...@@ -945,17 +860,36 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -945,17 +860,36 @@ class ClientApplicationTests(NeoUnitTestBase):
'fakeReceived': ReturnValues(p3, p4), 'fakeReceived': ReturnValues(p3, p4),
'getAddress': ('127.0.0.1', 10010), 'getAddress': ('127.0.0.1', 10010),
}) })
storage_1_conn = Mock()
storage_2_conn = Mock()
app.pt = Mock({ app.pt = Mock({
'getNodeList': (node1, node2, ), 'getNodeList': (node1, node2, ),
'getCellListForTID': ReturnValues([cell1], [cell2]), 'getCellListForTID': ReturnValues([cell1], [cell2]),
}) })
app.cp = self.getConnectionPool([(Mock(), conn)]) app.cp = Mock({
def waitResponses(self): 'getConnForNode': ReturnValues(storage_1_conn, storage_2_conn),
self.local_var.node_tids = {uuid1: (tid1, ), uuid2: (tid2, )} 'iterateForObject': [(Mock(), conn)]
app.waitResponses = new.instancemethod(waitResponses, app, Application) })
def waitResponses(queue, handler_data):
app.setHandlerData(handler_data)
for p in (p1, p2):
app._handlePacket(Mock(), p, handler=app.storage_handler)
app.waitResponses = waitResponses
def txn_filter(info): def txn_filter(info):
return info['id'] > '\x00' * 8 return info['id'] > '\x00' * 8
result = app.undoLog(0, 4, filter=txn_filter) first = 0
last = 4
result = app.undoLog(first, last, filter=txn_filter)
pfirst, plast, ppartition = self.checkAskPacket(storage_1_conn,
Packets.AskTIDs, decode=True)
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
pfirst, plast, ppartition = self.checkAskPacket(storage_2_conn,
Packets.AskTIDs, decode=True)
self.assertEqual(pfirst, first)
self.assertEqual(plast, last)
self.assertEqual(ppartition, INVALID_PARTITION)
self.assertEquals(result[0]['id'], tid1) self.assertEquals(result[0]['id'], tid1)
self.assertEquals(result[1]['id'], tid2) self.assertEquals(result[1]['id'], tid2)
...@@ -968,9 +902,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -968,9 +902,9 @@ class ClientApplicationTests(NeoUnitTestBase):
p2 = Packets.AnswerObjectHistory(oid, object_history) p2 = Packets.AnswerObjectHistory(oid, object_history)
extension = dumps({'k': 'v'}) extension = dumps({'k': 'v'})
# transaction history # transaction history
p3 = Packets.AnswerTransactionInformation(tid1, 'u', 'd', p3 = Packets.AnswerTransactionInformation(tid2, 'u', 'd',
extension, False, (oid, )) extension, False, (oid, ))
p4 = Packets.AnswerTransactionInformation(tid2, 'u', 'd', p4 = Packets.AnswerTransactionInformation(tid1, 'u', 'd',
extension, False, (oid, )) extension, False, (oid, ))
p2.setId(0) p2.setId(0)
p3.setId(1) p3.setId(1)
...@@ -979,7 +913,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -979,7 +913,7 @@ class ClientApplicationTests(NeoUnitTestBase):
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeGetApp': app, 'fakeGetApp': app,
'fakeReceived': ReturnValues(p2, p3, p4), 'fakeReceived': ReturnValues(p3, p4),
'getAddress': ('127.0.0.1', 10010), 'getAddress': ('127.0.0.1', 10010),
}) })
object_cells = [ Mock({}), ] object_cells = [ Mock({}), ]
...@@ -988,12 +922,18 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -988,12 +922,18 @@ class ClientApplicationTests(NeoUnitTestBase):
'getCellListForOID': object_cells, 'getCellListForOID': object_cells,
'getCellListForTID': ReturnValues(history_cells, history_cells), 'getCellListForTID': ReturnValues(history_cells, history_cells),
}) })
app.cp = self.getConnectionPool([(Mock(), conn)]) app.cp = Mock({
'iterateForObject': [(Mock(), conn)],
})
def waitResponses(queue, handler_data):
app.setHandlerData(handler_data)
app._handlePacket(Mock(), p2, handler=app.storage_handler)
app.waitResponses = waitResponses
# start test here # start test here
result = app.history(oid) result = app.history(oid)
self.assertEquals(len(result), 2) self.assertEquals(len(result), 2)
self.assertEquals(result[0]['tid'], tid1) self.assertEquals(result[0]['tid'], tid2)
self.assertEquals(result[1]['tid'], tid2) self.assertEquals(result[1]['tid'], tid1)
self.assertEquals(result[0]['size'], 42) self.assertEquals(result[0]['size'], 42)
self.assertEquals(result[1]['size'], 42) self.assertEquals(result[1]['size'], 42)
...@@ -1010,43 +950,41 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -1010,43 +950,41 @@ class ClientApplicationTests(NeoUnitTestBase):
# TODO: test more connection failure cases # TODO: test more connection failure cases
# Seventh packet : askNodeInformation succeeded # Seventh packet : askNodeInformation succeeded
all_passed = [] all_passed = []
def _waitMessage8(conn, msg_id, handler=None): def _ask8(_):
all_passed.append(1) all_passed.append(1)
# Sixth packet : askPartitionTable succeeded # Sixth packet : askPartitionTable succeeded
def _waitMessage7(conn, msg_id, handler=None): def _ask7(_):
app.pt = Mock({'operational': True}) app.pt = Mock({'operational': True})
app._waitMessage = _waitMessage8
# fifth packet : request node identification succeeded # fifth packet : request node identification succeeded
def _waitMessage6(conn, msg_id, handler=None): def _ask6(conn):
conn.setUUID('D' * 16) conn.setUUID('D' * 16)
app.uuid = 'C' * 16 app.uuid = 'C' * 16
app._waitMessage = _waitMessage7
# fourth iteration : connection to primary master succeeded # fourth iteration : connection to primary master succeeded
def _waitMessage5(conn, msg_id, handler=None): def _ask5(_):
app.trying_master_node = app.primary_master_node = Mock({ app.trying_master_node = app.primary_master_node = Mock({
'getAddress': ('192.168.1.1', 10000), 'getAddress': ('192.168.1.1', 10000),
'__str__': 'Fake master node', '__str__': 'Fake master node',
}) })
app._waitMessage = _waitMessage6
# third iteration : node not ready # third iteration : node not ready
def _waitMessage4(conn, msg_id, handler=None): def _ask4(_):
app.trying_master_node = None app.trying_master_node = None
app._waitMessage = _waitMessage5
# second iteration : master node changed # second iteration : master node changed
def _waitMessage3(conn, msg_id, handler=None): def _ask3(_):
app.primary_master_node = Mock({ app.primary_master_node = Mock({
'getAddress': ('192.168.1.1', 10000), 'getAddress': ('192.168.1.1', 10000),
'__str__': 'Fake master node', '__str__': 'Fake master node',
}) })
app._waitMessage = _waitMessage4
# first iteration : connection failed # first iteration : connection failed
def _waitMessage2(conn, msg_id, handler=None): def _ask2(_):
app.trying_master_node = None app.trying_master_node = None
app._waitMessage = _waitMessage3
# do nothing for the first call # do nothing for the first call
def _waitMessage1(conn, msg_id, handler=None): def _ask1(_):
app._waitMessage = _waitMessage2 pass
app._waitMessage = _waitMessage1 ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask5, _ask6, _ask7,
_ask8]
def _ask_base(conn, _, handler=None):
ask_func_list.pop(0)(conn)
app._ask = _ask_base
# faked environnement # faked environnement
app.connector_handler = DoNothingConnector app.connector_handler = DoNothingConnector
app.em = Mock({'getConnectionList': []}) app.em = Mock({'getConnectionList': []})
...@@ -1056,23 +994,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -1056,23 +994,6 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertTrue(app.master_conn is not None) self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational()) self.assertTrue(app.pt.operational())
def test_askStorage(self):
""" _askStorage is private but test it anyway """
app = self.getApp('')
conn = Mock()
self.test_ok = False
def _waitMessage_hook(app, conn, msg_id, handler=None):
self.test_ok = True
packet = Packets.AskBeginTransaction()
packet.setId(0)
app._waitMessage = _waitMessage_hook
app._askStorage(conn, packet)
# check packet sent, connection unlocked and dispatcher updated
self.checkAskNewTid(conn)
self.checkDispatcherRegisterCalled(app, conn)
# and _waitMessage called
self.assertTrue(self.test_ok)
def test_askPrimary(self): def test_askPrimary(self):
""" _askPrimary is private but test it anyway """ """ _askPrimary is private but test it anyway """
app = self.getApp('') app = self.getApp('')
...@@ -1080,21 +1001,22 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -1080,21 +1001,22 @@ class ClientApplicationTests(NeoUnitTestBase):
app.master_conn = conn app.master_conn = conn
app.primary_handler = Mock() app.primary_handler = Mock()
self.test_ok = False self.test_ok = False
def _waitMessage_hook(app, conn, msg_id, handler=None): def _ask_hook(app, conn, packet, handler=None):
conn.ask(packet)
self.assertTrue(handler is app.primary_handler) self.assertTrue(handler is app.primary_handler)
self.test_ok = True self.test_ok = True
_waitMessage_old = Application._waitMessage _ask_old = Application._ask
Application._waitMessage = _waitMessage_hook Application._ask = _ask_hook
packet = Packets.AskBeginTransaction() packet = Packets.AskBeginTransaction()
packet.setId(0) packet.setId(0)
try: try:
app._askPrimary(packet) app._askPrimary(packet)
finally: finally:
Application._waitMessage = _waitMessage_old Application._ask = _ask_old
# check packet sent, connection locked during process and dispatcher updated # check packet sent, connection locked during process and dispatcher updated
self.checkAskNewTid(conn) self.checkAskNewTid(conn)
self.checkDispatcherRegisterCalled(app, conn) self.checkDispatcherRegisterCalled(app, conn)
# and _waitMessage called # and _ask called
self.assertTrue(self.test_ok) self.assertTrue(self.test_ok)
# check NEOStorageError is raised when the primary connection is lost # check NEOStorageError is raised when the primary connection is lost
app.master_conn = None app.master_conn = None
...@@ -1105,15 +1027,16 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -1105,15 +1027,16 @@ class ClientApplicationTests(NeoUnitTestBase):
""" Thread context properties must not be visible accross instances """ Thread context properties must not be visible accross instances
while remaining in the same thread """ while remaining in the same thread """
app1 = self.getApp() app1 = self.getApp()
app1_local = app1.local_var app1_local = app1._thread_container.get()
app2 = self.getApp() app2 = self.getApp()
app2_local = app2.local_var app2_local = app2._thread_container.get()
property_id = 'thread_context_test' property_id = 'thread_context_test'
self.assertFalse(hasattr(app1_local, property_id)) value = 'value'
self.assertFalse(hasattr(app2_local, property_id)) self.assertRaises(KeyError, app1_local.__getitem__, property_id)
setattr(app1_local, property_id, 'value') self.assertRaises(KeyError, app2_local.__getitem__, property_id)
self.assertTrue(hasattr(app1_local, property_id)) app1_local[property_id] = value
self.assertFalse(hasattr(app2_local, property_id)) self.assertEqual(app1_local[property_id], value)
self.assertRaises(KeyError, app2_local.__getitem__, property_id)
def test_pack(self): def test_pack(self):
app = self.getApp() app = self.getApp()
......
...@@ -235,7 +235,7 @@ class MasterAnswersHandlerTests(MasterHandlerTests): ...@@ -235,7 +235,7 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
tid = self.getNextTID() tid = self.getNextTID()
conn = self.getConnection() conn = self.getConnection()
self.handler.answerBeginTransaction(conn, tid) self.handler.answerBeginTransaction(conn, tid)
calls = self.app.mockGetNamedCalls('setTID') calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid) calls[0].checkArgs(tid)
...@@ -247,18 +247,12 @@ class MasterAnswersHandlerTests(MasterHandlerTests): ...@@ -247,18 +247,12 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
def test_answerTransactionFinished(self): def test_answerTransactionFinished(self):
conn = self.getConnection() conn = self.getConnection()
ttid1 = self.getNextTID() ttid2 = self.getNextTID()
ttid2 = self.getNextTID(ttid1) tid2 = self.getNextTID()
tid2 = self.getNextTID(ttid2) self.handler.answerTransactionFinished(conn, ttid2, tid2)
# wrong TID calls = self.app.mockGetNamedCalls('setHandlerData')
self.app = Mock({'getTID': ttid1}) self.assertEqual(len(calls), 1)
self.assertRaises(NEOStorageError, calls[0].checkArgs(tid2)
self.handler.answerTransactionFinished,
conn, ttid2, tid2)
# matching TID
app = Mock({'getTID': ttid2})
handler = PrimaryAnswersHandler(app=app)
handler.answerTransactionFinished(conn, ttid2, tid2)
def test_answerPack(self): def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False) self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
......
...@@ -25,6 +25,7 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError ...@@ -25,6 +25,7 @@ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError from neo.client.exception import NEOStorageDoesNotExistError
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from neo.lib.exception import NodeNotReady from neo.lib.exception import NodeNotReady
from ZODB.TimeStamp import TimeStamp
MARKER = [] MARKER = []
...@@ -69,45 +70,57 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -69,45 +70,57 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
self.app = Mock() self.app = Mock()
self.app.local_var = Mock()
self.handler = StorageAnswersHandler(self.app) self.handler = StorageAnswersHandler(self.app)
def getConnection(self): def getConnection(self):
return self.getFakeConnection() return self.getFakeConnection()
def _checkHandlerData(self, ref):
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(ref)
def test_answerObject(self): def test_answerObject(self):
conn = self.getConnection() conn = self.getConnection()
oid = self.getOID(0) oid = self.getOID(0)
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1) tid2 = self.getNextTID(tid1)
the_object = (oid, tid1, tid2, 0, '', 'DATA', None) the_object = (oid, tid1, tid2, 0, '', 'DATA', None)
self.app.local_var.asked_object = None
self.handler.answerObject(conn, *the_object) self.handler.answerObject(conn, *the_object)
self.assertEqual(self.app.local_var.asked_object, the_object[:-1]) self._checkHandlerData(the_object[:-1])
# Check handler raises on non-None data_serial. # Check handler raises on non-None data_serial.
the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID()) the_object = (oid, tid1, tid2, 0, '', 'DATA', self.getNextTID())
self.app.local_var.asked_object = None
self.assertRaises(NEOStorageError, self.handler.answerObject, conn, self.assertRaises(NEOStorageError, self.handler.answerObject, conn,
*the_object) *the_object)
def _getAnswerStoreObjectHandler(self, object_stored_counter_dict,
conflict_serial_dict, resolved_conflict_serial_dict):
app = Mock({
'getHandlerData': {
'object_stored_counter_dict': object_stored_counter_dict,
'conflict_serial_dict': conflict_serial_dict,
'resolved_conflict_serial_dict': resolved_conflict_serial_dict,
}
})
return StorageAnswersHandler(app)
def test_answerStoreObject_1(self): def test_answerStoreObject_1(self):
conn = self.getConnection() conn = self.getConnection()
oid = self.getOID(0) oid = self.getOID(0)
tid = self.getNextTID() tid = self.getNextTID()
# conflict # conflict
local_var = self.app.local_var object_stored_counter_dict = {oid: {}}
local_var.object_stored_counter_dict = {oid: {}} conflict_serial_dict = {}
local_var.conflict_serial_dict = {} resolved_conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {} self._getAnswerStoreObjectHandler(object_stored_counter_dict,
self.handler.answerStoreObject(conn, 1, oid, tid) conflict_serial_dict, resolved_conflict_serial_dict,
self.assertEqual(local_var.conflict_serial_dict[oid], set([tid, ])) ).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(local_var.object_stored_counter_dict[oid], {}) self.assertEqual(conflict_serial_dict[oid], set([tid, ]))
self.assertFalse(oid in local_var.resolved_conflict_serial_dict) self.assertEqual(object_stored_counter_dict[oid], {})
self.assertFalse(oid in resolved_conflict_serial_dict)
# object was already accepted by another storage, raise # object was already accepted by another storage, raise
local_var.object_stored_counter_dict = {oid: {tid: 1}} handler = self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {}, {})
local_var.conflict_serial_dict = {} self.assertRaises(NEOStorageError, handler.answerStoreObject,
local_var.resolved_conflict_serial_dict = {}
self.assertRaises(NEOStorageError, self.handler.answerStoreObject,
conn, 1, oid, tid) conn, 1, oid, tid)
def test_answerStoreObject_2(self): def test_answerStoreObject_2(self):
...@@ -116,25 +129,23 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -116,25 +129,23 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
tid_2 = self.getNextTID() tid_2 = self.getNextTID()
# resolution-pending conflict # resolution-pending conflict
local_var = self.app.local_var object_stored_counter_dict = {oid: {}}
local_var.object_stored_counter_dict = {oid: {}} conflict_serial_dict = {oid: set([tid, ])}
local_var.conflict_serial_dict = {oid: set([tid, ])} resolved_conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {} self._getAnswerStoreObjectHandler(object_stored_counter_dict,
self.handler.answerStoreObject(conn, 1, oid, tid) conflict_serial_dict, resolved_conflict_serial_dict,
self.assertEqual(local_var.conflict_serial_dict[oid], set([tid, ])) ).answerStoreObject(conn, 1, oid, tid)
self.assertFalse(oid in local_var.resolved_conflict_serial_dict) self.assertEqual(conflict_serial_dict[oid], set([tid, ]))
self.assertEqual(local_var.object_stored_counter_dict[oid], {}) self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {})
# object was already accepted by another storage, raise # object was already accepted by another storage, raise
local_var.object_stored_counter_dict = {oid: {tid: 1}} handler = self._getAnswerStoreObjectHandler({oid: {tid: 1}},
local_var.conflict_serial_dict = {oid: set([tid, ])} {oid: set([tid, ])}, {})
local_var.resolved_conflict_serial_dict = {} self.assertRaises(NEOStorageError, handler.answerStoreObject,
self.assertRaises(NEOStorageError, self.handler.answerStoreObject,
conn, 1, oid, tid) conn, 1, oid, tid)
# detected conflict is different, don't raise # detected conflict is different, don't raise
local_var.object_stored_counter_dict = {oid: {}} self._getAnswerStoreObjectHandler({oid: {}}, {oid: set([tid, ])}, {},
local_var.conflict_serial_dict = {oid: set([tid, ])} ).answerStoreObject(conn, 1, oid, tid_2)
local_var.resolved_conflict_serial_dict = {}
self.handler.answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_3(self): def test_answerStoreObject_3(self):
conn = self.getConnection() conn = self.getConnection()
...@@ -145,49 +156,34 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -145,49 +156,34 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
# This case happens if a storage is answering a store action for which # This case happens if a storage is answering a store action for which
# any other storage already answered (with same conflict) and any other # any other storage already answered (with same conflict) and any other
# storage accepted the resolved object. # storage accepted the resolved object.
local_var = self.app.local_var object_stored_counter_dict = {oid: {tid_2: 1}}
local_var.object_stored_counter_dict = {oid: {tid_2: 1}} conflict_serial_dict = {}
local_var.conflict_serial_dict = {} resolved_conflict_serial_dict = {oid: set([tid, ])}
local_var.resolved_conflict_serial_dict = {oid: set([tid, ])} self._getAnswerStoreObjectHandler(object_stored_counter_dict,
self.handler.answerStoreObject(conn, 1, oid, tid) conflict_serial_dict, resolved_conflict_serial_dict,
self.assertFalse(oid in local_var.conflict_serial_dict) ).answerStoreObject(conn, 1, oid, tid)
self.assertEqual(local_var.resolved_conflict_serial_dict[oid], self.assertFalse(oid in conflict_serial_dict)
self.assertEqual(resolved_conflict_serial_dict[oid],
set([tid, ])) set([tid, ]))
self.assertEqual(local_var.object_stored_counter_dict[oid], {tid_2: 1}) self.assertEqual(object_stored_counter_dict[oid], {tid_2: 1})
# detected conflict is different, don't raise # detected conflict is different, don't raise
local_var.object_stored_counter_dict = {oid: {tid: 1}} self._getAnswerStoreObjectHandler({oid: {tid: 1}}, {},
local_var.conflict_serial_dict = {} {oid: set([tid, ])}).answerStoreObject(conn, 1, oid, tid_2)
local_var.resolved_conflict_serial_dict = {oid: set([tid, ])}
self.handler.answerStoreObject(conn, 1, oid, tid_2)
def test_answerStoreObject_4(self): def test_answerStoreObject_4(self):
conn = self.getConnection() conn = self.getConnection()
oid = self.getOID(0) oid = self.getOID(0)
tid = self.getNextTID() tid = self.getNextTID()
# no conflict # no conflict
local_var = self.app.local_var object_stored_counter_dict = {oid: {}}
local_var.object_stored_counter_dict = {oid: {}} conflict_serial_dict = {}
local_var.conflict_serial_dict = {} resolved_conflict_serial_dict = {}
local_var.resolved_conflict_serial_dict = {} self._getAnswerStoreObjectHandler(object_stored_counter_dict,
self.handler.answerStoreObject(conn, 0, oid, tid) conflict_serial_dict, resolved_conflict_serial_dict,
self.assertFalse(oid in local_var.conflict_serial_dict) ).answerStoreObject(conn, 0, oid, tid)
self.assertFalse(oid in local_var.resolved_conflict_serial_dict) self.assertFalse(oid in conflict_serial_dict)
self.assertEqual(local_var.object_stored_counter_dict[oid], {tid: 1}) self.assertFalse(oid in resolved_conflict_serial_dict)
self.assertEqual(object_stored_counter_dict[oid], {tid: 1})
def test_answerStoreTransaction(self):
conn = self.getConnection()
tid1 = self.getNextTID()
tid2 = self.getNextTID(tid1)
# wrong tid
app = Mock({'getTID': tid1})
handler = StorageAnswersHandler(app=app)
self.assertRaises(NEOStorageError,
handler.answerStoreTransaction, conn,
tid2)
# good tid
app = Mock({'getTID': tid2})
handler = StorageAnswersHandler(app=app)
handler.answerStoreTransaction(conn, tid2)
def test_answerTransactionInformation(self): def test_answerTransactionInformation(self):
conn = self.getConnection() conn = self.getConnection()
...@@ -195,24 +191,30 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -195,24 +191,30 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
user = 'USER' user = 'USER'
desc = 'DESC' desc = 'DESC'
ext = 'EXT' ext = 'EXT'
packed = False
oid_list = [self.getOID(0), self.getOID(1)] oid_list = [self.getOID(0), self.getOID(1)]
self.app.local_var.txn_info = None
self.handler.answerTransactionInformation(conn, tid, user, desc, ext, self.handler.answerTransactionInformation(conn, tid, user, desc, ext,
False, oid_list) packed, oid_list)
txn_info = self.app.local_var.txn_info self._checkHandlerData(({
self.assertTrue(isinstance(txn_info, dict)) 'time': TimeStamp(tid).timeTime(),
self.assertEqual(txn_info['user_name'], user) 'user_name': user,
self.assertEqual(txn_info['description'], desc) 'description': desc,
self.assertEqual(txn_info['id'], tid) 'id': tid,
self.assertEqual(txn_info['oids'], oid_list) 'oids': oid_list,
'packed': packed,
}, ext))
def test_answerObjectHistory(self): def test_answerObjectHistory(self):
conn = self.getConnection() conn = self.getConnection()
oid = self.getOID(0) oid = self.getOID(0)
history_list = [] history_list = [self.getNextTID(), self.getNextTID()]
self.app.local_var.history = None history_set = set()
self.handler.answerObjectHistory(conn, oid, history_list) app = Mock({
self.assertEqual(self.app.local_var.history, (oid, history_list)) 'getHandlerData': history_set,
})
handler = StorageAnswersHandler(app)
handler.answerObjectHistory(conn, oid, history_list)
self.assertEqual(history_set, set(history_list))
def test_oidNotFound(self): def test_oidNotFound(self):
conn = self.getConnection() conn = self.getConnection()
...@@ -235,10 +237,14 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -235,10 +237,14 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid2 = self.getNextTID(tid1) tid2 = self.getNextTID(tid1)
tid_list = [tid1, tid2] tid_list = [tid1, tid2]
conn = self.getFakeConnection(uuid=uuid) conn = self.getFakeConnection(uuid=uuid)
self.app.local_var.node_tids = {} tid_set = set()
self.handler.answerTIDs(conn, tid_list) app = Mock({
self.assertTrue(uuid in self.app.local_var.node_tids) 'getHandlerData': tid_set,
self.assertEqual(self.app.local_var.node_tids[uuid], tid_list) })
handler = StorageAnswersHandler(app)
handler.answerTIDs(conn, tid_list)
self.assertEqual(tid_set, set(tid_list))
def test_answerObjectUndoSerial(self): def test_answerObjectUndoSerial(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
...@@ -249,12 +255,14 @@ class StorageAnswerHandlerTests(NeoUnitTestBase): ...@@ -249,12 +255,14 @@ class StorageAnswerHandlerTests(NeoUnitTestBase):
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
self.app.local_var.undo_object_tid_dict = undo_dict = { undo_dict = {}
oid1: [tid0, tid1], app = Mock({
} 'getHandlerData': undo_dict,
self.handler.answerObjectUndoSerial(conn, {
oid2: [tid2, tid3],
}) })
handler = StorageAnswersHandler(app)
handler.answerObjectUndoSerial(conn, {oid1: [tid0, tid1]})
self.assertEqual(undo_dict, {oid1: [tid0, tid1]})
handler.answerObjectUndoSerial(conn, {oid2: [tid2, tid3]})
self.assertEqual(undo_dict, { self.assertEqual(undo_dict, {
oid1: [tid0, tid1], oid1: [tid0, tid1],
oid2: [tid2, tid3], oid2: [tid2, tid3],
......
...@@ -22,9 +22,7 @@ from ZODB.tests.StorageTestBase import StorageTestBase ...@@ -22,9 +22,7 @@ from ZODB.tests.StorageTestBase import StorageTestBase
from neo.tests.zodb import ZODBTestCase from neo.tests.zodb import ZODBTestCase
class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage): class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage):
pass
def check_tid_ordering_w_commit(self):
self.fail("Test disabled")
if __name__ == "__main__": if __name__ == "__main__":
suite = unittest.makeSuite(BasicTests, 'check') suite = unittest.makeSuite(BasicTests, 'check')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment