Commit e97427e9 authored by Vincent Pelletier's avatar Vincent Pelletier

Transactions must not be bound to the thread which created them.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2643 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent c5a1efea
This diff is collapsed.
#
# Copyright (C) 2011 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from thread import get_ident
from neo.lib.locking import Queue
class ContainerBase(object):
def __init__(self):
self._context_dict = {}
def _getID(self, *args, **kw):
raise NotImplementedError
def _new(self, *args, **kw):
raise NotImplementedError
def delete(self, *args, **kw):
del self._context_dict[self._getID(*args, **kw)]
def get(self, *args, **kw):
return self._context_dict.get(self._getID(*args, **kw))
def new(self, *args, **kw):
result = self._context_dict[self._getID(*args, **kw)] = self._new(
*args, **kw)
return result
class ThreadContainer(ContainerBase):
def _getID(self):
return get_ident()
def _new(self):
return {
'queue': Queue(0),
'answer': None,
}
def get(self):
"""
Implicitely create a thread context if it doesn't exist.
"""
my_id = self._getID()
try:
result = self._context_dict[my_id]
except KeyError:
result = self._context_dict[my_id] = self._new()
return result
class TransactionContainer(ContainerBase):
def _getID(self, txn):
return id(txn)
def _new(self, txn):
return {
'queue': Queue(0),
'txn': txn,
'ttid': None,
'data_dict': {},
'data_list': [],
'object_base_serial_dict': {},
'object_serial_dict': {},
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {},
'txn_voted': False,
'involved_nodes': set(),
}
...@@ -156,21 +156,19 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -156,21 +156,19 @@ class PrimaryNotificationsHandler(BaseHandler):
class PrimaryAnswersHandler(AnswerBaseHandler): class PrimaryAnswersHandler(AnswerBaseHandler):
""" Handle that process expected packets from the primary master """ """ Handle that process expected packets from the primary master """
def answerBeginTransaction(self, conn, tid): def answerBeginTransaction(self, conn, ttid):
self.app.setTID(tid) self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
self.app.new_oid_list = oid_list self.app.new_oid_list = oid_list
def answerTransactionFinished(self, conn, ttid, tid): def answerTransactionFinished(self, conn, _, tid):
if ttid != self.app.getTID(): self.app.setHandlerData(tid)
raise NEOStorageError('Wrong TID, transaction not started')
self.app.setTID(tid)
def answerPack(self, conn, status): def answerPack(self, conn, status):
if not status: if not status:
raise NEOStorageError('Already packing') raise NEOStorageError('Already packing')
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
self.app.local_var.last_transaction = ltid self.app.setHandlerData(ltid)
...@@ -68,23 +68,25 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -68,23 +68,25 @@ class StorageAnswersHandler(AnswerBaseHandler):
if data_serial is not None: if data_serial is not None:
raise NEOStorageError, 'Storage should never send non-None ' \ raise NEOStorageError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), ) 'data_serial to clients, got %s' % (dump(data_serial), )
self.app.local_var.asked_object = (oid, start_serial, end_serial, self.app.setHandlerData((oid, start_serial, end_serial,
compression, checksum, data) compression, checksum, data))
def answerStoreObject(self, conn, conflicting, oid, serial): def answerStoreObject(self, conn, conflicting, oid, serial):
local_var = self.app.local_var txn_context = self.app.getHandlerData()
object_stored_counter_dict = local_var.object_stored_counter_dict[oid] object_stored_counter_dict = txn_context[
'object_stored_counter_dict'][oid]
if conflicting: if conflicting:
neo.lib.logging.info('%r report a conflict for %r with %r', conn, neo.lib.logging.info('%r report a conflict for %r with %r', conn,
dump(oid), dump(serial)) dump(oid), dump(serial))
conflict_serial_dict = local_var.conflict_serial_dict conflict_serial_dict = txn_context['conflict_serial_dict']
if serial in object_stored_counter_dict: if serial in object_stored_counter_dict:
raise NEOStorageError, 'A storage accepted object for ' \ raise NEOStorageError, 'A storage accepted object for ' \
'serial %s but another reports a conflict for it.' % ( 'serial %s but another reports a conflict for it.' % (
dump(serial), ) dump(serial), )
# If this conflict is not already resolved, mark it for # If this conflict is not already resolved, mark it for
# resolution. # resolution.
if serial not in local_var.resolved_conflict_serial_dict.get(oid, ()): if serial not in txn_context[
'resolved_conflict_serial_dict'].get(oid, ()):
conflict_serial_dict.setdefault(oid, set()).add(serial) conflict_serial_dict.setdefault(oid, set()).add(serial)
else: else:
object_stored_counter_dict[serial] = \ object_stored_counter_dict[serial] = \
...@@ -92,31 +94,29 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -92,31 +94,29 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, tid): def answerStoreTransaction(self, conn, _):
if tid != self.app.getTID(): pass
raise NEOStorageError('Wrong TID, transaction not started')
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn) neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
assert not self.app.local_var.tids_from.intersection(set(tid_list)) tids_from = self.app.getHandlerData()
self.app.local_var.tids_from.update(tid_list) assert not tids_from.intersection(set(tid_list))
tids_from.update(tid_list)
def answerTransactionInformation(self, conn, tid, def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list): user, desc, ext, packed, oid_list):
# transaction information are returned as a dict self.app.setHandlerData(({
info = {} 'time': TimeStamp(tid).timeTime(),
info['time'] = TimeStamp(tid).timeTime() 'user_name': user,
info['user_name'] = user 'description': desc,
info['description'] = desc 'id': tid,
info['id'] = tid 'oids': oid_list,
info['oids'] = oid_list 'packed': packed,
info['packed'] = packed }, ext))
self.app.local_var.txn_ext = ext
self.app.local_var.txn_info = info def answerObjectHistory(self, conn, _, history_list):
def answerObjectHistory(self, conn, oid, history_list):
# history_list is a list of tuple (serial, size) # history_list is a list of tuple (serial, size)
self.app.local_var.history = oid, history_list self.app.getHandlerData().update(history_list)
def oidNotFound(self, conn, message): def oidNotFound(self, conn, message):
# This can happen either when : # This can happen either when :
...@@ -132,10 +132,10 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -132,10 +132,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
raise NEOStorageNotFoundError(message) raise NEOStorageNotFoundError(message)
def answerTIDs(self, conn, tid_list): def answerTIDs(self, conn, tid_list):
self.app.local_var.node_tids[conn.getUUID()] = tid_list self.app.getHandlerData().update(tid_list)
def answerObjectUndoSerial(self, conn, object_tid_dict): def answerObjectUndoSerial(self, conn, object_tid_dict):
self.app.local_var.undo_object_tid_dict.update(object_tid_dict) self.app.getHandlerData().update(object_tid_dict)
def answerHasLock(self, conn, oid, status): def answerHasLock(self, conn, oid, status):
if status == LockState.GRANTED_TO_OTHER: if status == LockState.GRANTED_TO_OTHER:
......
...@@ -66,9 +66,7 @@ class ConnectionPool(object): ...@@ -66,9 +66,7 @@ class ConnectionPool(object):
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name)
try: try:
msg_id = conn.ask(p, queue=app.local_var.queue) app._ask(conn, p, handler=app.storage_bootstrap_handler)
app._waitMessage(conn, msg_id,
handler=app.storage_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
neo.lib.logging.error('Connection to %r failed', node) neo.lib.logging.error('Connection to %r failed', node)
self.notifyFailure(node) self.notifyFailure(node)
......
This diff is collapsed.
...@@ -235,7 +235,7 @@ class MasterAnswersHandlerTests(MasterHandlerTests): ...@@ -235,7 +235,7 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
tid = self.getNextTID() tid = self.getNextTID()
conn = self.getConnection() conn = self.getConnection()
self.handler.answerBeginTransaction(conn, tid) self.handler.answerBeginTransaction(conn, tid)
calls = self.app.mockGetNamedCalls('setTID') calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid) calls[0].checkArgs(tid)
...@@ -247,18 +247,12 @@ class MasterAnswersHandlerTests(MasterHandlerTests): ...@@ -247,18 +247,12 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
def test_answerTransactionFinished(self): def test_answerTransactionFinished(self):
conn = self.getConnection() conn = self.getConnection()
ttid1 = self.getNextTID() ttid2 = self.getNextTID()
ttid2 = self.getNextTID(ttid1) tid2 = self.getNextTID()
tid2 = self.getNextTID(ttid2) self.handler.answerTransactionFinished(conn, ttid2, tid2)
# wrong TID calls = self.app.mockGetNamedCalls('setHandlerData')
self.app = Mock({'getTID': ttid1}) self.assertEqual(len(calls), 1)
self.assertRaises(NEOStorageError, calls[0].checkArgs(tid2)
self.handler.answerTransactionFinished,
conn, ttid2, tid2)
# matching TID
app = Mock({'getTID': ttid2})
handler = PrimaryAnswersHandler(app=app)
handler.answerTransactionFinished(conn, ttid2, tid2)
def test_answerPack(self): def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False) self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
......
This diff is collapsed.
...@@ -22,9 +22,7 @@ from ZODB.tests.StorageTestBase import StorageTestBase ...@@ -22,9 +22,7 @@ from ZODB.tests.StorageTestBase import StorageTestBase
from neo.tests.zodb import ZODBTestCase from neo.tests.zodb import ZODBTestCase
class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage): class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage):
pass
def check_tid_ordering_w_commit(self):
self.fail("Test disabled")
if __name__ == "__main__": if __name__ == "__main__":
suite = unittest.makeSuite(BasicTests, 'check') suite = unittest.makeSuite(BasicTests, 'check')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment