Commit c4863a33 authored by Julien Muchembled's avatar Julien Muchembled

Simplify EventHandler by removing 'packet_dispatch_table'

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2685 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 7b76e88b
This diff is collapsed.
...@@ -28,7 +28,6 @@ class PacketLogger(object): ...@@ -28,7 +28,6 @@ class PacketLogger(object):
def __init__(self): def __init__(self):
_temp = EventHandler(None) _temp = EventHandler(None)
self.packet_dispatch_table = _temp.packet_dispatch_table
self.error_dispatch_table = _temp.error_dispatch_table self.error_dispatch_table = _temp.error_dispatch_table
self.enable(LOGGER_ENABLED) self.enable(LOGGER_ENABLED)
...@@ -38,7 +37,6 @@ class PacketLogger(object): ...@@ -38,7 +37,6 @@ class PacketLogger(object):
def _dispatch(self, conn, packet, direction): def _dispatch(self, conn, packet, direction):
"""This is a helper method to handle various packet types.""" """This is a helper method to handle various packet types."""
# default log message # default log message
klass = packet.getType()
uuid = dump(conn.getUUID()) uuid = dump(conn.getUUID())
ip, port = conn.getAddress() ip, port = conn.getAddress()
packet_name = packet.__class__.__name__ packet_name = packet.__class__.__name__
...@@ -47,8 +45,7 @@ class PacketLogger(object): ...@@ -47,8 +45,7 @@ class PacketLogger(object):
neo.lib.logging.debug('#0x%08x %-30s %s %s (%s:%d)', packet.getId(), neo.lib.logging.debug('#0x%08x %-30s %s %s (%s:%d)', packet.getId(),
packet_name, direction, uuid, ip, port) packet_name, direction, uuid, ip, port)
# look for custom packet logger # look for custom packet logger
logger = self.packet_dispatch_table.get(klass, None) logger = getattr(self, packet.handler_method_name, None)
logger = logger and getattr(self, logger.im_func.__name__, None)
if logger is None: if logger is None:
return return
# enhanced log # enhanced log
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import socket import socket
import sys import sys
import traceback import traceback
from types import ClassType
from socket import inet_ntoa, inet_aton from socket import inet_ntoa, inet_aton
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
...@@ -673,6 +672,9 @@ class RequestIdentification(Packet): ...@@ -673,6 +672,9 @@ class RequestIdentification(Packet):
args.insert(0, PROTOCOL_VERSION) args.insert(0, PROTOCOL_VERSION)
super(RequestIdentification, self).__init__(*args, **kw) super(RequestIdentification, self).__init__(*args, **kw)
def decode(self):
return super(RequestIdentification, self).decode()[1:]
class PrimaryMaster(Packet): class PrimaryMaster(Packet):
""" """
Ask a current primary master node. This must be the second message when Ask a current primary master node. This must be the second message when
...@@ -1355,7 +1357,7 @@ def register(code, request, ignore_when_closed=None): ...@@ -1355,7 +1357,7 @@ def register(code, request, ignore_when_closed=None):
if answer in (Error, None): if answer in (Error, None):
return request return request
# build a class for the answer # build a class for the answer
answer = ClassType('Answer%s' % (request.__name__, ), (Packet, ), {}) answer = type('Answer%s' % (request.__name__, ), (Packet, ), {})
answer._fmt = request._answer answer._fmt = request._answer
# compute the answer code # compute the answer code
code = code | RESPONSE_MASK code = code | RESPONSE_MASK
...@@ -1384,14 +1386,16 @@ class ParserState(object): ...@@ -1384,14 +1386,16 @@ class ParserState(object):
def clear(self): def clear(self):
self.payload = None self.payload = None
class PacketRegistry(dict): class Packets(dict):
""" """
Packet registry that check packet code unicity and provide an index Packet registry that check packet code unicity and provide an index
""" """
def __init__(self): def __metaclass__(name, base, d):
dict.__init__(self) for k, v in d.iteritems():
# load packet classes if isinstance(v, type) and issubclass(v, Packet):
self.update(StaticRegistry) v.handler_method_name = k[0].lower() + k[1:]
# this builds a "singleton"
return type('PacketRegistry', base, d)(StaticRegistry)
def parse(self, buf, state_container): def parse(self, buf, state_container):
state = state_container.get() state = state_container.get()
...@@ -1531,9 +1535,6 @@ class PacketRegistry(dict): ...@@ -1531,9 +1535,6 @@ class PacketRegistry(dict):
NotifyTransactionFinished = register( NotifyTransactionFinished = register(
0x003E, NotifyTransactionFinished) 0x003E, NotifyTransactionFinished)
# build a "singleton"
Packets = PacketRegistry()
def register_error(code): def register_error(code):
def wrapper(registry, message=''): def wrapper(registry, message=''):
return Error(code, message) return Error(code, message)
......
...@@ -373,6 +373,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -373,6 +373,7 @@ class ConnectionTests(NeoUnitTestBase):
def test_07_Connection_addPacket(self): def test_07_Connection_addPacket(self):
# new packet # new packet
p = Mock({"encode" : "testdata"}) p = Mock({"encode" : "testdata"})
p.handler_method_name = 'testmethod'
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc._addPacket(p) bc._addPacket(p)
......
...@@ -28,22 +28,17 @@ class HandlerTests(NeoUnitTestBase): ...@@ -28,22 +28,17 @@ class HandlerTests(NeoUnitTestBase):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
app = Mock() app = Mock()
self.handler = EventHandler(app) self.handler = EventHandler(app)
self.fake_type = 'FAKE_PACKET_TYPE'
def setFakeMethod(self, method): def setFakeMethod(self, method):
self.handler.packet_dispatch_table[self.fake_type] = method self.handler.fake_method = method
def getFakePacket(self): def getFakePacket(self):
return Mock({ p = Mock({
'getType': self.fake_type,
'decode': (), 'decode': (),
'__repr__': 'Fake Packet', '__repr__': 'Fake Packet',
}) })
p.handler_method_name = 'fake_method'
def checkFakeCalled(self): return p
method = self.handler.packet_dispatch_table[self.fake_type]
calls = method.getNamedCalls('__call__')
self.assertEquals(len(calls), 1)
def test_dispatch(self): def test_dispatch(self):
conn = self.getFakeConnection() conn = self.getFakeConnection()
......
...@@ -73,7 +73,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -73,7 +73,7 @@ class ProtocolTests(NeoUnitTestBase):
uuid = self.getNewUUID() uuid = self.getNewUUID()
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
uuid, (self.local_ip, 9080), "unittest") uuid, (self.local_ip, 9080), "unittest")
(plow, phigh), node, p_uuid, (ip, port), name = p.decode() node, p_uuid, (ip, port), name = p.decode()
self.assertEqual(node, NodeTypes.CLIENT) self.assertEqual(node, NodeTypes.CLIENT)
self.assertEqual(p_uuid, uuid) self.assertEqual(p_uuid, uuid)
self.assertEqual(ip, self.local_ip) self.assertEqual(ip, self.local_ip)
...@@ -85,7 +85,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -85,7 +85,7 @@ class ProtocolTests(NeoUnitTestBase):
self.local_ip = IP_VERSION_FORMAT_DICT[socket.AF_INET6] self.local_ip = IP_VERSION_FORMAT_DICT[socket.AF_INET6]
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
uuid, (self.local_ip, 9080), "unittest") uuid, (self.local_ip, 9080), "unittest")
(plow, phigh), node, p_uuid, (ip, port), name = p.decode() node, p_uuid, (ip, port), name = p.decode()
self.assertEqual(node, NodeTypes.CLIENT) self.assertEqual(node, NodeTypes.CLIENT)
self.assertEqual(p_uuid, uuid) self.assertEqual(p_uuid, uuid)
self.assertEqual(ip, self.local_ip) self.assertEqual(ip, self.local_ip)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment