Commit 1e0c5efc authored by Julien Muchembled's avatar Julien Muchembled

Master must not die if client sends an invalid ttid

parent 8aef2569
......@@ -733,6 +733,7 @@ class Application(object):
tid = self._askPrimary(Packets.AskFinishTransaction(
txn_context['ttid'], cache_dict),
cache_dict=cache_dict, callback=f)
assert tid
return tid
finally:
self._load_lock_release()
......
......@@ -16,6 +16,7 @@
from neo.lib.handler import EventHandler
from neo.lib.protocol import ProtocolError, Packets
from ZODB.POSException import StorageError
class BaseHandler(EventHandler):
"""Base class for client-side EventHandler implementations."""
......@@ -59,3 +60,5 @@ class AnswerBaseHandler(EventHandler):
packetReceived = unexpectedInAnswerHandler
peerBroken = unexpectedInAnswerHandler
def protocolError(self, conn, message):
raise StorageError("protocol error: %s" % message)
......@@ -162,7 +162,7 @@ class EventHandler(object):
# Error packet handlers.
def error(self, conn, code, message):
def error(self, conn, code, message, **kw):
try:
getattr(self, Errors[code])(conn, message)
except (AttributeError, ValueError):
......
......@@ -17,7 +17,7 @@
from time import time
from struct import pack, unpack
from neo.lib import logging
from neo.lib.protocol import uuid_str, ZERO_TID
from neo.lib.protocol import ProtocolError, uuid_str, ZERO_TID
from neo.lib.util import dump, u64, addTID, tidFromTime
class DelayedError(Exception):
......@@ -295,7 +295,10 @@ class TransactionManager(object):
Prepare a transaction to be finished
"""
# XXX: not efficient but the list should be often small
txn = self._ttid_dict[ttid]
try:
txn = self._ttid_dict[ttid]
except KeyError:
raise ProtocolError("unknown ttid %s" % dump(ttid))
node = txn.getNode()
for _, tid in self._queue:
if ttid == tid:
......
......@@ -27,7 +27,7 @@ from neo.lib.connection import MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID
from . import ClientApplication, NEOCluster, NEOThreadedTest, Patch
from neo.lib.util import makeChecksum
from neo.lib.util import add64, makeChecksum
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
class PCounter(Persistent):
......@@ -649,6 +649,21 @@ class Test(NEOThreadedTest):
finally:
cluster.stop()
def testInvalidTTID(self):
cluster = NEOCluster()
try:
cluster.start()
client = cluster.client
client.setPoll(1)
txn = transaction.Transaction()
client.tpc_begin(txn)
txn_context = client._txn_container.get(txn)
txn_context['ttid'] = add64(txn_context['ttid'], 1)
self.assertRaises(POSException.StorageError,
client.tpc_finish, txn, None)
finally:
cluster.stop()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment