Commit e97427e9 authored by Vincent Pelletier's avatar Vincent Pelletier

Transactions must not be bound to the thread which created them.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2643 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent c5a1efea
This diff is collapsed.
#
# Copyright (C) 2011 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from thread import get_ident
from neo.lib.locking import Queue
class ContainerBase(object):
def __init__(self):
self._context_dict = {}
def _getID(self, *args, **kw):
raise NotImplementedError
def _new(self, *args, **kw):
raise NotImplementedError
def delete(self, *args, **kw):
del self._context_dict[self._getID(*args, **kw)]
def get(self, *args, **kw):
return self._context_dict.get(self._getID(*args, **kw))
def new(self, *args, **kw):
result = self._context_dict[self._getID(*args, **kw)] = self._new(
*args, **kw)
return result
class ThreadContainer(ContainerBase):
def _getID(self):
return get_ident()
def _new(self):
return {
'queue': Queue(0),
'answer': None,
}
def get(self):
"""
Implicitely create a thread context if it doesn't exist.
"""
my_id = self._getID()
try:
result = self._context_dict[my_id]
except KeyError:
result = self._context_dict[my_id] = self._new()
return result
class TransactionContainer(ContainerBase):
def _getID(self, txn):
return id(txn)
def _new(self, txn):
return {
'queue': Queue(0),
'txn': txn,
'ttid': None,
'data_dict': {},
'data_list': [],
'object_base_serial_dict': {},
'object_serial_dict': {},
'object_stored_counter_dict': {},
'conflict_serial_dict': {},
'resolved_conflict_serial_dict': {},
'txn_voted': False,
'involved_nodes': set(),
}
......@@ -156,21 +156,19 @@ class PrimaryNotificationsHandler(BaseHandler):
class PrimaryAnswersHandler(AnswerBaseHandler):
""" Handle that process expected packets from the primary master """
def answerBeginTransaction(self, conn, tid):
self.app.setTID(tid)
def answerBeginTransaction(self, conn, ttid):
self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list):
self.app.new_oid_list = oid_list
def answerTransactionFinished(self, conn, ttid, tid):
if ttid != self.app.getTID():
raise NEOStorageError('Wrong TID, transaction not started')
self.app.setTID(tid)
def answerTransactionFinished(self, conn, _, tid):
self.app.setHandlerData(tid)
def answerPack(self, conn, status):
if not status:
raise NEOStorageError('Already packing')
def answerLastTransaction(self, conn, ltid):
self.app.local_var.last_transaction = ltid
self.app.setHandlerData(ltid)
......@@ -68,23 +68,25 @@ class StorageAnswersHandler(AnswerBaseHandler):
if data_serial is not None:
raise NEOStorageError, 'Storage should never send non-None ' \
'data_serial to clients, got %s' % (dump(data_serial), )
self.app.local_var.asked_object = (oid, start_serial, end_serial,
compression, checksum, data)
self.app.setHandlerData((oid, start_serial, end_serial,
compression, checksum, data))
def answerStoreObject(self, conn, conflicting, oid, serial):
local_var = self.app.local_var
object_stored_counter_dict = local_var.object_stored_counter_dict[oid]
txn_context = self.app.getHandlerData()
object_stored_counter_dict = txn_context[
'object_stored_counter_dict'][oid]
if conflicting:
neo.lib.logging.info('%r report a conflict for %r with %r', conn,
dump(oid), dump(serial))
conflict_serial_dict = local_var.conflict_serial_dict
conflict_serial_dict = txn_context['conflict_serial_dict']
if serial in object_stored_counter_dict:
raise NEOStorageError, 'A storage accepted object for ' \
'serial %s but another reports a conflict for it.' % (
dump(serial), )
# If this conflict is not already resolved, mark it for
# resolution.
if serial not in local_var.resolved_conflict_serial_dict.get(oid, ()):
if serial not in txn_context[
'resolved_conflict_serial_dict'].get(oid, ()):
conflict_serial_dict.setdefault(oid, set()).add(serial)
else:
object_stored_counter_dict[serial] = \
......@@ -92,31 +94,29 @@ class StorageAnswersHandler(AnswerBaseHandler):
answerCheckCurrentSerial = answerStoreObject
def answerStoreTransaction(self, conn, tid):
if tid != self.app.getTID():
raise NEOStorageError('Wrong TID, transaction not started')
def answerStoreTransaction(self, conn, _):
pass
def answerTIDsFrom(self, conn, tid_list):
neo.lib.logging.debug('Get %d TIDs from %r', len(tid_list), conn)
assert not self.app.local_var.tids_from.intersection(set(tid_list))
self.app.local_var.tids_from.update(tid_list)
tids_from = self.app.getHandlerData()
assert not tids_from.intersection(set(tid_list))
tids_from.update(tid_list)
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
# transaction information are returned as a dict
info = {}
info['time'] = TimeStamp(tid).timeTime()
info['user_name'] = user
info['description'] = desc
info['id'] = tid
info['oids'] = oid_list
info['packed'] = packed
self.app.local_var.txn_ext = ext
self.app.local_var.txn_info = info
def answerObjectHistory(self, conn, oid, history_list):
self.app.setHandlerData(({
'time': TimeStamp(tid).timeTime(),
'user_name': user,
'description': desc,
'id': tid,
'oids': oid_list,
'packed': packed,
}, ext))
def answerObjectHistory(self, conn, _, history_list):
# history_list is a list of tuple (serial, size)
self.app.local_var.history = oid, history_list
self.app.getHandlerData().update(history_list)
def oidNotFound(self, conn, message):
# This can happen either when :
......@@ -132,10 +132,10 @@ class StorageAnswersHandler(AnswerBaseHandler):
raise NEOStorageNotFoundError(message)
def answerTIDs(self, conn, tid_list):
self.app.local_var.node_tids[conn.getUUID()] = tid_list
self.app.getHandlerData().update(tid_list)
def answerObjectUndoSerial(self, conn, object_tid_dict):
self.app.local_var.undo_object_tid_dict.update(object_tid_dict)
self.app.getHandlerData().update(object_tid_dict)
def answerHasLock(self, conn, oid, status):
if status == LockState.GRANTED_TO_OTHER:
......
......@@ -66,9 +66,7 @@ class ConnectionPool(object):
p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name)
try:
msg_id = conn.ask(p, queue=app.local_var.queue)
app._waitMessage(conn, msg_id,
handler=app.storage_bootstrap_handler)
app._ask(conn, p, handler=app.storage_bootstrap_handler)
except ConnectionClosed:
neo.lib.logging.error('Connection to %r failed', node)
self.notifyFailure(node)
......
This diff is collapsed.
......@@ -235,7 +235,7 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
tid = self.getNextTID()
conn = self.getConnection()
self.handler.answerBeginTransaction(conn, tid)
calls = self.app.mockGetNamedCalls('setTID')
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid)
......@@ -247,18 +247,12 @@ class MasterAnswersHandlerTests(MasterHandlerTests):
def test_answerTransactionFinished(self):
conn = self.getConnection()
ttid1 = self.getNextTID()
ttid2 = self.getNextTID(ttid1)
tid2 = self.getNextTID(ttid2)
# wrong TID
self.app = Mock({'getTID': ttid1})
self.assertRaises(NEOStorageError,
self.handler.answerTransactionFinished,
conn, ttid2, tid2)
# matching TID
app = Mock({'getTID': ttid2})
handler = PrimaryAnswersHandler(app=app)
handler.answerTransactionFinished(conn, ttid2, tid2)
ttid2 = self.getNextTID()
tid2 = self.getNextTID()
self.handler.answerTransactionFinished(conn, ttid2, tid2)
calls = self.app.mockGetNamedCalls('setHandlerData')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid2)
def test_answerPack(self):
self.assertRaises(NEOStorageError, self.handler.answerPack, None, False)
......
This diff is collapsed.
......@@ -22,9 +22,7 @@ from ZODB.tests.StorageTestBase import StorageTestBase
from neo.tests.zodb import ZODBTestCase
class BasicTests(ZODBTestCase, StorageTestBase, BasicStorage):
def check_tid_ordering_w_commit(self):
self.fail("Test disabled")
pass
if __name__ == "__main__":
suite = unittest.makeSuite(BasicTests, 'check')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment