Commit 8aef2569 authored by Julien Muchembled's avatar Julien Muchembled

client: optimize/refactor thread/transaction containers

parent 11c428e0
......@@ -120,13 +120,10 @@ class Application(object):
registerLiveDebugger(on_log=self.log)
def getHandlerData(self):
return self._thread_container.get()['answer']
return self._thread_container.answer
def setHandlerData(self, data):
self._thread_container.get()['answer'] = data
def _getThreadQueue(self):
return self._thread_container.get()['queue']
self._thread_container.answer = data
def log(self):
self.em.log()
......@@ -202,7 +199,7 @@ class Application(object):
def _ask(self, conn, packet, handler=None, **kw):
self.setHandlerData(None)
queue = self._getThreadQueue()
queue = self._thread_container.queue
msg_id = conn.ask(packet, queue=queue, **kw)
get = queue.get
_handlePacket = self._handlePacket
......@@ -454,12 +451,8 @@ class Application(object):
def tpc_begin(self, transaction, tid=None, status=' '):
"""Begin a new transaction."""
txn_container = self._txn_container
# First get a transaction, only one is allowed at a time
if txn_container.get(transaction) is not None:
# We already begin the same transaction
raise StorageTransactionError('Duplicate tpc_begin calls')
txn_context = txn_container.new(transaction)
txn_context = self._txn_container.new(transaction)
# use the given TID or request a new one to the master
answer_ttid = self._askPrimary(Packets.AskBeginTransaction(tid))
if answer_ttid is None:
......@@ -469,11 +462,8 @@ class Application(object):
def store(self, oid, serial, data, version, transaction):
"""Store object."""
txn_context = self._txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError(self, transaction)
logging.debug('storing oid %s serial %s', dump(oid), dump(serial))
self._store(txn_context, oid, serial, data)
self._store(self._txn_container.get(transaction), oid, serial, data)
def _store(self, txn_context, oid, serial, data, data_serial=None,
unlock=False):
......@@ -673,9 +663,6 @@ class Application(object):
def tpc_vote(self, transaction, tryToResolveConflict):
"""Store current transaction."""
txn_context = self._txn_container.get(transaction)
if txn_context is None or transaction is not txn_context['txn']:
raise StorageTransactionError(self, transaction)
result = self.waitStoreResponses(txn_context, tryToResolveConflict)
ttid = txn_context['ttid']
......@@ -711,11 +698,9 @@ class Application(object):
def tpc_abort(self, transaction):
"""Abort current transaction."""
txn_container = self._txn_container
txn_context = txn_container.get(transaction)
txn_context = self._txn_container.pop(transaction)
if txn_context is None:
return
ttid = txn_context['ttid']
p = Packets.AbortTransaction(ttid)
getConnForNode = self.cp.getConnForNode
......@@ -730,38 +715,30 @@ class Application(object):
logging.exception('Exception in tpc_abort while notifying'
'storage node %r of abortion, ignoring.', conn)
self._getMasterConnection().notify(p)
queue = txn_context['queue']
# We don't need to flush queue, as it won't be reused by future
# transactions (deleted on next line & indexed by transaction object
# instance).
self.dispatcher.forget_queue(queue, flush_queue=False)
txn_container.delete(transaction)
self.dispatcher.forget_queue(txn_context['queue'], flush_queue=False)
def tpc_finish(self, transaction, tryToResolveConflict, f=None):
"""Finish current transaction."""
txn_container = self._txn_container
txn_context = txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError('tpc_finish called for wrong '
'transaction')
if not txn_context['txn_voted']:
if not txn_container.get(transaction)['txn_voted']:
self.tpc_vote(transaction, tryToResolveConflict)
self._load_lock_acquire()
try:
# Call finish on master
txn_context = txn_container.pop(transaction)
cache_dict = txn_context['cache_dict']
tid = self._askPrimary(Packets.AskFinishTransaction(
txn_context['ttid'], cache_dict),
cache_dict=cache_dict, callback=f)
txn_container.delete(transaction)
return tid
finally:
self._load_lock_release()
def undo(self, undone_tid, txn, tryToResolveConflict):
txn_context = self._txn_container.get(txn)
if txn_context is None:
raise StorageTransactionError(self, undone_tid)
txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids']
......@@ -782,7 +759,7 @@ class Application(object):
getCellList = pt.getCellList
getCellSortKey = self.cp.getCellSortKey
getConnForCell = self.cp.getConnForCell
queue = self._getThreadQueue()
queue = self._thread_container.queue
ttid = txn_context['ttid']
undo_object_tid_dict = {}
snapshot_tid = p64(u64(self.last_tid) + 1)
......@@ -866,7 +843,7 @@ class Application(object):
# Each storage node will return TIDs only for UP_TO_DATE state and
# FEEDING state cells
pt = self.getPartitionTable()
queue = self._getThreadQueue()
queue = self._thread_container.queue
packet = Packets.AskTIDs(first, last, INVALID_PARTITION)
tid_set = set()
for storage_node in pt.getNodeSet(True):
......@@ -1015,10 +992,8 @@ class Application(object):
return self.load(oid)[1]
def checkCurrentSerialInTransaction(self, oid, serial, transaction):
txn_context = self._txn_container.get(transaction)
if txn_context is None:
raise StorageTransactionError(self, transaction)
self._checkCurrentSerialInTransaction(txn_context, oid, serial)
self._checkCurrentSerialInTransaction(
self._txn_container.get(transaction), oid, serial)
def _checkCurrentSerialInTransaction(self, txn_context, oid, serial):
ttid = txn_context['ttid']
......
......@@ -14,9 +14,10 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from thread import get_ident
import threading
from neo.lib.locking import Lock, Empty
from collections import deque
from ZODB.POSException import StorageTransactionError
class SimpleQueue(object):
"""
......@@ -63,54 +64,29 @@ class SimpleQueue(object):
def empty(self):
return not self._queue
class ContainerBase(object):
__slots__ = ('_context_dict', )
class ThreadContainer(threading.local):
def __init__(self):
self._context_dict = {}
self.queue = SimpleQueue()
self.answer = None
def _getID(self, *args, **kw):
raise NotImplementedError
class TransactionContainer(dict):
def _new(self, *args, **kw):
raise NotImplementedError
def pop(self, txn):
return dict.pop(self, id(txn), None)
def delete(self, *args, **kw):
del self._context_dict[self._getID(*args, **kw)]
def get(self, *args, **kw):
return self._context_dict.get(self._getID(*args, **kw))
def new(self, *args, **kw):
result = self._context_dict[self._getID(*args, **kw)] = self._new(
*args, **kw)
return result
class ThreadContainer(ContainerBase):
def _getID(self):
return get_ident()
def _new(self):
return {
'queue': SimpleQueue(),
'answer': None,
}
def get(self):
"""
Implicitely create a thread context if it doesn't exist.
"""
def get(self, txn):
try:
return self._context_dict[self._getID()]
return self[id(txn)]
except KeyError:
return self.new()
class TransactionContainer(ContainerBase):
def _getID(self, txn):
return id(txn)
def _new(self, txn):
return {
raise StorageTransactionError("unknown transaction %r" % txn)
def new(self, txn):
key = id(txn)
if key in self:
raise StorageTransactionError("commit of transaction %r"
" already started" % txn)
context = self[key] = {
'queue': SimpleQueue(),
'txn': txn,
'ttid': None,
......@@ -126,4 +102,4 @@ class TransactionContainer(ContainerBase):
'txn_voted': False,
'involved_nodes': set(),
}
return context
......@@ -245,7 +245,7 @@ class ClientApplicationTests(NeoUnitTestBase):
tid = self.makeTID()
txn = Mock()
# first, tid is supplied
self.assertTrue(app._txn_container.get(txn) is None)
self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
packet = Packets.AnswerBeginTransaction(tid=tid)
packet.setId(0)
app.master_conn = Mock({
......@@ -419,7 +419,7 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkNotifyPacket(conn1, Packets.AbortTransaction)
self.checkNotifyPacket(conn2, Packets.AbortTransaction)
self.checkNotifyPacket(app.master_conn, Packets.AbortTransaction)
self.assertEqual(app._txn_container.get(txn), None)
self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
def test_tpc_abort3(self):
""" check that abort is sent to all nodes involved in the transaction """
......@@ -503,7 +503,7 @@ class ClientApplicationTests(NeoUnitTestBase):
app.tpc_finish(txn, None)
self.checkAskFinishTransaction(app.master_conn)
#self.checkDispatcherRegisterCalled(app, app.master_conn)
self.assertEqual(app._txn_container.get(txn), None)
self.assertRaises(StorageTransactionError, app._txn_container.get, txn)
def test_undo1(self):
# invalid transaction
......@@ -843,16 +843,16 @@ class ClientApplicationTests(NeoUnitTestBase):
""" Thread context properties must not be visible accross instances
while remaining in the same thread """
app1 = self.getApp()
app1_local = app1._thread_container.get()
app1_local = app1._thread_container
app2 = self.getApp()
app2_local = app2._thread_container.get()
app2_local = app2._thread_container
property_id = 'thread_context_test'
value = 'value'
self.assertRaises(KeyError, app1_local.__getitem__, property_id)
self.assertRaises(KeyError, app2_local.__getitem__, property_id)
app1_local[property_id] = value
self.assertEqual(app1_local[property_id], value)
self.assertRaises(KeyError, app2_local.__getitem__, property_id)
self.assertFalse(hasattr(app1_local, property_id))
self.assertFalse(hasattr(app2_local, property_id))
setattr(app1_local, property_id, value)
self.assertEqual(getattr(app1_local, property_id), value)
self.assertFalse(hasattr(app2_local, property_id))
def test_pack(self):
app = self.getApp()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment