Commit c4863a33 authored by Julien Muchembled's avatar Julien Muchembled

Simplify EventHandler by removing 'packet_dispatch_table'

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2685 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 7b76e88b
......@@ -26,7 +26,6 @@ class EventHandler(object):
def __init__(self, app):
self.app = app
self.packet_dispatch_table = self.__initPacketDispatchTable()
self.error_dispatch_table = self.__initErrorDispatchTable()
def __repr__(self):
......@@ -49,8 +48,8 @@ class EventHandler(object):
"""This is a helper method to handle various packet types."""
try:
try:
method = self.packet_dispatch_table[packet.getType()]
except KeyError:
method = getattr(self, packet.handler_method_name)
except AttributeError:
raise UnexpectedPacketError('no handler found')
args = packet.decode() or ()
conn.setPeerId(packet.getId())
......@@ -133,259 +132,12 @@ class EventHandler(object):
def notify(self, conn, message):
neo.lib.logging.info('notification from %r: %s', conn, message)
def requestIdentification(self, conn, node_type, uuid, address, name):
raise UnexpectedPacketError
def _requestIdentification(self, conn, protocol, node_type,
uuid, address, name):
self.requestIdentification(conn, node_type, uuid, address, name)
def acceptIdentification(self, conn, node_type,
uuid, num_partitions, num_replicas, your_uuid):
raise UnexpectedPacketError
def askPrimary(self, conn):
raise UnexpectedPacketError
def answerPrimary(self, conn, primary_uuid,
known_master_list):
raise UnexpectedPacketError
def announcePrimary(self, con):
raise UnexpectedPacketError
def reelectPrimary(self, conn):
raise UnexpectedPacketError
def notifyNodeInformation(self, conn, node_list):
raise UnexpectedPacketError
def askLastIDs(self, conn):
raise UnexpectedPacketError
def answerLastIDs(self, conn, loid, ltid, lptid):
raise UnexpectedPacketError
def askPartitionTable(self, conn):
raise UnexpectedPacketError
def answerPartitionTable(self, conn, ptid, row_list):
raise UnexpectedPacketError
def sendPartitionTable(self, conn, ptid, row_list):
raise UnexpectedPacketError
def notifyPartitionChanges(self, conn, ptid, cell_list):
raise UnexpectedPacketError
def startOperation(self, conn):
raise UnexpectedPacketError
def stopOperation(self, conn):
raise UnexpectedPacketError
def askUnfinishedTransactions(self, conn):
raise UnexpectedPacketError
def answerUnfinishedTransactions(self, conn, max_tid, ttid_list):
raise UnexpectedPacketError
def askObjectPresent(self, conn, oid, tid):
raise UnexpectedPacketError
def answerObjectPresent(self, conn, oid, tid):
raise UnexpectedPacketError
def deleteTransaction(self, conn, tid, oid_list):
raise UnexpectedPacketError
def commitTransaction(self, conn, tid):
raise UnexpectedPacketError
def askBeginTransaction(self, conn, tid):
raise UnexpectedPacketError
def answerBeginTransaction(self, conn, ttid):
raise UnexpectedPacketError
def askNewOIDs(self, conn, num_oids):
raise UnexpectedPacketError
def answerNewOIDs(self, conn, num_oids):
raise UnexpectedPacketError
def askFinishTransaction(self, conn, ttid, oid_list):
raise UnexpectedPacketError
def answerTransactionFinished(self, conn, ttid, tid):
raise UnexpectedPacketError
def askLockInformation(self, conn, ttid, tid, oid_list):
raise UnexpectedPacketError
def answerInformationLocked(self, conn, ttid):
raise UnexpectedPacketError
def invalidateObjects(self, conn, tid, oid_list):
raise UnexpectedPacketError
def notifyUnlockInformation(self, conn, ttid):
raise UnexpectedPacketError
def notifyTransactionFinished(self, conn, ttid, max_tid):
raise UnexpectedPacketError
def askStoreObject(self, conn, oid, serial,
compression, checksum, data, data_serial, ttid, unlock):
raise UnexpectedPacketError
def answerStoreObject(self, conn, conflicting, oid, serial):
raise UnexpectedPacketError
def abortTransaction(self, conn, ttid):
raise UnexpectedPacketError
def askStoreTransaction(self, conn, ttid, user, desc,
ext, oid_list):
raise UnexpectedPacketError
def answerStoreTransaction(self, conn, ttid):
raise UnexpectedPacketError
def askObject(self, conn, oid, serial, ttid):
raise UnexpectedPacketError
def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial):
raise UnexpectedPacketError
def askTIDs(self, conn, first, last, partition):
raise UnexpectedPacketError
def answerTIDs(self, conn, tid_list):
raise UnexpectedPacketError
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition):
raise UnexpectedPacketError
def answerTIDsFrom(self, conn, tid_list):
raise UnexpectedPacketError
def askTransactionInformation(self, conn, tid):
raise UnexpectedPacketError
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
raise UnexpectedPacketError
def askObjectHistory(self, conn, oid, first, last):
raise UnexpectedPacketError
def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError
def askObjectHistoryFrom(self, conn, oid, min_serial, max_serial, length,
partition):
raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, object_dict):
raise UnexpectedPacketError
def askPartitionList(self, conn, min_offset, max_offset, uuid):
raise UnexpectedPacketError
def answerPartitionList(self, conn, ptid, row_list):
raise UnexpectedPacketError
def askNodeList(self, conn, offset_list):
raise UnexpectedPacketError
def answerNodeList(self, conn, node_list):
raise UnexpectedPacketError
def setNodeState(self, conn, uuid, state, modify_partition_table):
raise UnexpectedPacketError
def addPendingNodes(self, conn, uuid_list):
raise UnexpectedPacketError
def askNodeInformation(self, conn):
raise UnexpectedPacketError
def answerNodeInformation(self, conn):
raise UnexpectedPacketError
def askClusterState(self, conn):
raise UnexpectedPacketError
def answerClusterState(self, conn, state):
raise UnexpectedPacketError
def setClusterState(self, conn, state):
raise UnexpectedPacketError
def notifyClusterInformation(self, conn, state):
raise UnexpectedPacketError
def notifyLastOID(self, conn, oid):
raise UnexpectedPacketError
def notifyReplicationDone(self, conn, offset):
raise UnexpectedPacketError
def askObjectUndoSerial(self, conn, ttid, ltid, undone_tid, oid_list):
raise UnexpectedPacketError
def answerObjectUndoSerial(self, conn, object_tid_dict):
raise UnexpectedPacketError
def askHasLock(self, conn, ttid, oid):
raise UnexpectedPacketError
def answerHasLock(self, conn, oid, status):
raise UnexpectedPacketError
def askBarrier(self, conn):
conn.answer(Packets.AnswerBarrier())
def answerBarrier(self, conn):
pass
def askPack(self, conn, tid):
raise UnexpectedPacketError
def answerPack(self, conn, status):
raise UnexpectedPacketError
def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
raise UnexpectedPacketError
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
raise UnexpectedPacketError
def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition):
raise UnexpectedPacketError
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
raise UnexpectedPacketError
def notifyReady(self, conn):
raise UnexpectedPacketError
def askLastTransaction(self, conn):
raise UnexpectedPacketError
def answerLastTransaction(self, conn, tid):
raise UnexpectedPacketError
def askCheckCurrentSerial(self, conn, ttid, serial, oid):
raise UnexpectedPacketError
answerCheckCurrentSerial = answerStoreObject
# Error packet handlers.
def error(self, conn, code, message):
......@@ -426,96 +178,6 @@ class EventHandler(object):
# Fetch tables initialization
def __initPacketDispatchTable(self):
d = {}
d[Packets.Error] = self.error
d[Packets.Notify] = self.notify
d[Packets.RequestIdentification] = self._requestIdentification
d[Packets.AcceptIdentification] = self.acceptIdentification
d[Packets.AskPrimary] = self.askPrimary
d[Packets.AnswerPrimary] = self.answerPrimary
d[Packets.AnnouncePrimary] = self.announcePrimary
d[Packets.ReelectPrimary] = self.reelectPrimary
d[Packets.NotifyNodeInformation] = self.notifyNodeInformation
d[Packets.AskLastIDs] = self.askLastIDs
d[Packets.AnswerLastIDs] = self.answerLastIDs
d[Packets.AskPartitionTable] = self.askPartitionTable
d[Packets.AnswerPartitionTable] = self.answerPartitionTable
d[Packets.SendPartitionTable] = self.sendPartitionTable
d[Packets.NotifyPartitionChanges] = self.notifyPartitionChanges
d[Packets.StartOperation] = self.startOperation
d[Packets.StopOperation] = self.stopOperation
d[Packets.AskUnfinishedTransactions] = self.askUnfinishedTransactions
d[Packets.AnswerUnfinishedTransactions] = \
self.answerUnfinishedTransactions
d[Packets.AskObjectPresent] = self.askObjectPresent
d[Packets.AnswerObjectPresent] = self.answerObjectPresent
d[Packets.DeleteTransaction] = self.deleteTransaction
d[Packets.CommitTransaction] = self.commitTransaction
d[Packets.AskBeginTransaction] = self.askBeginTransaction
d[Packets.AnswerBeginTransaction] = self.answerBeginTransaction
d[Packets.AskFinishTransaction] = self.askFinishTransaction
d[Packets.AnswerTransactionFinished] = self.answerTransactionFinished
d[Packets.AskLockInformation] = self.askLockInformation
d[Packets.AnswerInformationLocked] = self.answerInformationLocked
d[Packets.InvalidateObjects] = self.invalidateObjects
d[Packets.NotifyUnlockInformation] = self.notifyUnlockInformation
d[Packets.AskNewOIDs] = self.askNewOIDs
d[Packets.AnswerNewOIDs] = self.answerNewOIDs
d[Packets.AskStoreObject] = self.askStoreObject
d[Packets.AnswerStoreObject] = self.answerStoreObject
d[Packets.AbortTransaction] = self.abortTransaction
d[Packets.AskStoreTransaction] = self.askStoreTransaction
d[Packets.AnswerStoreTransaction] = self.answerStoreTransaction
d[Packets.AskObject] = self.askObject
d[Packets.AnswerObject] = self.answerObject
d[Packets.AskTIDs] = self.askTIDs
d[Packets.AnswerTIDs] = self.answerTIDs
d[Packets.AskTIDsFrom] = self.askTIDsFrom
d[Packets.AnswerTIDsFrom] = self.answerTIDsFrom
d[Packets.AskTransactionInformation] = self.askTransactionInformation
d[Packets.AnswerTransactionInformation] = \
self.answerTransactionInformation
d[Packets.AskObjectHistory] = self.askObjectHistory
d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
d[Packets.AskPartitionList] = self.askPartitionList
d[Packets.AnswerPartitionList] = self.answerPartitionList
d[Packets.AskNodeList] = self.askNodeList
d[Packets.AnswerNodeList] = self.answerNodeList
d[Packets.SetNodeState] = self.setNodeState
d[Packets.SetClusterState] = self.setClusterState
d[Packets.AddPendingNodes] = self.addPendingNodes
d[Packets.AskNodeInformation] = self.askNodeInformation
d[Packets.AnswerNodeInformation] = self.answerNodeInformation
d[Packets.AskClusterState] = self.askClusterState
d[Packets.AnswerClusterState] = self.answerClusterState
d[Packets.NotifyClusterInformation] = self.notifyClusterInformation
d[Packets.NotifyLastOID] = self.notifyLastOID
d[Packets.NotifyReplicationDone] = self.notifyReplicationDone
d[Packets.AskObjectUndoSerial] = self.askObjectUndoSerial
d[Packets.AnswerObjectUndoSerial] = self.answerObjectUndoSerial
d[Packets.AskHasLock] = self.askHasLock
d[Packets.AnswerHasLock] = self.answerHasLock
d[Packets.AskBarrier] = self.askBarrier
d[Packets.AnswerBarrier] = self.answerBarrier
d[Packets.AskPack] = self.askPack
d[Packets.AnswerPack] = self.answerPack
d[Packets.AskCheckTIDRange] = self.askCheckTIDRange
d[Packets.AnswerCheckTIDRange] = self.answerCheckTIDRange
d[Packets.AskCheckSerialRange] = self.askCheckSerialRange
d[Packets.AnswerCheckSerialRange] = self.answerCheckSerialRange
d[Packets.NotifyReady] = self.notifyReady
d[Packets.AskLastTransaction] = self.askLastTransaction
d[Packets.AnswerLastTransaction] = self.answerLastTransaction
d[Packets.AskCheckCurrentSerial] = self.askCheckCurrentSerial
d[Packets.AnswerCheckCurrentSerial] = self.answerCheckCurrentSerial
d[Packets.NotifyTransactionFinished] = self.notifyTransactionFinished
return d
def __initErrorDispatchTable(self):
d = {}
......
......@@ -28,7 +28,6 @@ class PacketLogger(object):
def __init__(self):
_temp = EventHandler(None)
self.packet_dispatch_table = _temp.packet_dispatch_table
self.error_dispatch_table = _temp.error_dispatch_table
self.enable(LOGGER_ENABLED)
......@@ -38,7 +37,6 @@ class PacketLogger(object):
def _dispatch(self, conn, packet, direction):
"""This is a helper method to handle various packet types."""
# default log message
klass = packet.getType()
uuid = dump(conn.getUUID())
ip, port = conn.getAddress()
packet_name = packet.__class__.__name__
......@@ -47,8 +45,7 @@ class PacketLogger(object):
neo.lib.logging.debug('#0x%08x %-30s %s %s (%s:%d)', packet.getId(),
packet_name, direction, uuid, ip, port)
# look for custom packet logger
logger = self.packet_dispatch_table.get(klass, None)
logger = logger and getattr(self, logger.im_func.__name__, None)
logger = getattr(self, packet.handler_method_name, None)
if logger is None:
return
# enhanced log
......
......@@ -18,7 +18,6 @@
import socket
import sys
import traceback
from types import ClassType
from socket import inet_ntoa, inet_aton
from cStringIO import StringIO
from struct import Struct
......@@ -673,6 +672,9 @@ class RequestIdentification(Packet):
args.insert(0, PROTOCOL_VERSION)
super(RequestIdentification, self).__init__(*args, **kw)
def decode(self):
return super(RequestIdentification, self).decode()[1:]
class PrimaryMaster(Packet):
"""
Ask a current primary master node. This must be the second message when
......@@ -1355,7 +1357,7 @@ def register(code, request, ignore_when_closed=None):
if answer in (Error, None):
return request
# build a class for the answer
answer = ClassType('Answer%s' % (request.__name__, ), (Packet, ), {})
answer = type('Answer%s' % (request.__name__, ), (Packet, ), {})
answer._fmt = request._answer
# compute the answer code
code = code | RESPONSE_MASK
......@@ -1384,14 +1386,16 @@ class ParserState(object):
def clear(self):
self.payload = None
class PacketRegistry(dict):
class Packets(dict):
"""
Packet registry that check packet code unicity and provide an index
"""
def __init__(self):
dict.__init__(self)
# load packet classes
self.update(StaticRegistry)
def __metaclass__(name, base, d):
for k, v in d.iteritems():
if isinstance(v, type) and issubclass(v, Packet):
v.handler_method_name = k[0].lower() + k[1:]
# this builds a "singleton"
return type('PacketRegistry', base, d)(StaticRegistry)
def parse(self, buf, state_container):
state = state_container.get()
......@@ -1531,9 +1535,6 @@ class PacketRegistry(dict):
NotifyTransactionFinished = register(
0x003E, NotifyTransactionFinished)
# build a "singleton"
Packets = PacketRegistry()
def register_error(code):
def wrapper(registry, message=''):
return Error(code, message)
......
......@@ -373,6 +373,7 @@ class ConnectionTests(NeoUnitTestBase):
def test_07_Connection_addPacket(self):
# new packet
p = Mock({"encode" : "testdata"})
p.handler_method_name = 'testmethod'
bc = self._makeConnection()
self._checkWriteBuf(bc, '')
bc._addPacket(p)
......
......@@ -28,22 +28,17 @@ class HandlerTests(NeoUnitTestBase):
NeoUnitTestBase.setUp(self)
app = Mock()
self.handler = EventHandler(app)
self.fake_type = 'FAKE_PACKET_TYPE'
def setFakeMethod(self, method):
self.handler.packet_dispatch_table[self.fake_type] = method
self.handler.fake_method = method
def getFakePacket(self):
return Mock({
'getType': self.fake_type,
p = Mock({
'decode': (),
'__repr__': 'Fake Packet',
})
def checkFakeCalled(self):
method = self.handler.packet_dispatch_table[self.fake_type]
calls = method.getNamedCalls('__call__')
self.assertEquals(len(calls), 1)
p.handler_method_name = 'fake_method'
return p
def test_dispatch(self):
conn = self.getFakeConnection()
......
......@@ -73,7 +73,7 @@ class ProtocolTests(NeoUnitTestBase):
uuid = self.getNewUUID()
p = Packets.RequestIdentification(NodeTypes.CLIENT,
uuid, (self.local_ip, 9080), "unittest")
(plow, phigh), node, p_uuid, (ip, port), name = p.decode()
node, p_uuid, (ip, port), name = p.decode()
self.assertEqual(node, NodeTypes.CLIENT)
self.assertEqual(p_uuid, uuid)
self.assertEqual(ip, self.local_ip)
......@@ -85,7 +85,7 @@ class ProtocolTests(NeoUnitTestBase):
self.local_ip = IP_VERSION_FORMAT_DICT[socket.AF_INET6]
p = Packets.RequestIdentification(NodeTypes.CLIENT,
uuid, (self.local_ip, 9080), "unittest")
(plow, phigh), node, p_uuid, (ip, port), name = p.decode()
node, p_uuid, (ip, port), name = p.decode()
self.assertEqual(node, NodeTypes.CLIENT)
self.assertEqual(p_uuid, uuid)
self.assertEqual(ip, self.local_ip)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment