Commit 60422372 authored by Grégory Wisniewski's avatar Grégory Wisniewski

* Move parse() into packet registry

* Support named arguments in default encoder
* Add some docstring and comments
* Describe assertions


git-svn-id: https://svn.erp5.org/repos/neo/trunk@1360 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 3db58c9a
...@@ -237,7 +237,7 @@ class Connection(BaseConnection): ...@@ -237,7 +237,7 @@ class Connection(BaseConnection):
while 1: while 1:
packet = None packet = None
try: try:
packet = protocol.parse(self.read_buf) packet = Packets.parse(self.read_buf)
except PacketMalformedError, msg: except PacketMalformedError, msg:
self.handler._packetMalformed(self, packet, msg) self.handler._packetMalformed(self, packet, msg)
return return
......
...@@ -97,16 +97,11 @@ INVALID_PTID = '\0' * 8 ...@@ -97,16 +97,11 @@ INVALID_PTID = '\0' * 8
INVALID_SERIAL = INVALID_TID INVALID_SERIAL = INVALID_TID
INVALID_PARTITION = 0xffffffff INVALID_PARTITION = 0xffffffff
STORAGE_NS = 'S'
MASTER_NS = 'M'
CLIENT_NS = 'C'
ADMIN_NS = 'A'
UUID_NAMESPACES = { UUID_NAMESPACES = {
NodeTypes.STORAGE: STORAGE_NS, NodeTypes.STORAGE: 'S',
NodeTypes.MASTER: MASTER_NS, NodeTypes.MASTER: 'M',
NodeTypes.CLIENT: CLIENT_NS, NodeTypes.CLIENT: 'C',
NodeTypes.ADMIN: ADMIN_NS, NodeTypes.ADMIN: 'A',
} }
class ProtocolError(Exception): class ProtocolError(Exception):
...@@ -164,6 +159,7 @@ def _decodeAddress(address): ...@@ -164,6 +159,7 @@ def _decodeAddress(address):
def _encodeAddress(address): def _encodeAddress(address):
if address is None: if address is None:
return '\0' * 6 return '\0' * 6
# address is a tuple (ip, port)
return pack('!4sH', inet_aton(address[0]), address[1]) return pack('!4sH', inet_aton(address[0]), address[1])
def _decodeUUID(uuid): def _decodeUUID(uuid):
...@@ -204,29 +200,14 @@ def _readString(buf, name, offset=0): ...@@ -204,29 +200,14 @@ def _readString(buf, name, offset=0):
raise PacketMalformedError("can't read string <%s>" % name) raise PacketMalformedError("can't read string <%s>" % name)
return (string, buf[offset+size:]) return (string, buf[offset+size:])
def parse(msg):
if len(msg) < MIN_PACKET_SIZE:
return None
msg_id, msg_type, msg_len = unpack('!LHL', msg[:PACKET_HEADER_SIZE])
try:
packet_klass = Packets[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
if msg_len < MIN_PACKET_SIZE:
raise PacketMalformedError('message too small (%d)' % msg_len)
if len(msg) < msg_len:
# Not enough.
return None
packet = packet_klass()
packet.setContent(msg_type, msg_id, msg[PACKET_HEADER_SIZE:msg_len])
return packet
class Packet(object): class Packet(object):
"""
Base class for any packet definition.
Each subclass should override _encode() and _decode() and return a string or
a tuple respectively.
"""
# XXX: use a global a class-attribute
_body = None _body = None
_code = None _code = None
_args = None _args = None
...@@ -252,8 +233,8 @@ class Packet(object): ...@@ -252,8 +233,8 @@ class Packet(object):
name = self.__class__.__name__ name = self.__class__.__name__
raise PacketMalformedError("%s fail (%s)" % (name, msg)) raise PacketMalformedError("%s fail (%s)" % (name, msg))
def setContent(self, code, msg_id, body): def setContent(self, msg_id, body):
self._code = code """ Register the packet content for future decoding """
self._id = msg_id self._id = msg_id
self._body = body self._body = body
...@@ -261,11 +242,10 @@ class Packet(object): ...@@ -261,11 +242,10 @@ class Packet(object):
self._id = value self._id = value
def getId(self): def getId(self):
assert self._id is not None assert self._id is not None, "No identifier applied on the packet"
return self._id return self._id
def getCode(self): def getCode(self):
assert self._code is not None
return self._code return self._code
def getType(self): def getType(self):
...@@ -273,10 +253,9 @@ class Packet(object): ...@@ -273,10 +253,9 @@ class Packet(object):
# TODO: replace this with __call__ that take the id as parameter # TODO: replace this with __call__ that take the id as parameter
def __str__(self): def __str__(self):
assert self._id is not None
content = self._body content = self._body
length = PACKET_HEADER_SIZE + len(content) length = PACKET_HEADER_SIZE + len(content)
return pack('!LHL', self._id, self._code, length) + content return pack('!LHL', self.getId(), self._code, length) + content
def __len__(self): def __len__(self):
return PACKET_HEADER_SIZE + len(self._body) return PACKET_HEADER_SIZE + len(self._body)
...@@ -290,12 +269,9 @@ class Packet(object): ...@@ -290,12 +269,9 @@ class Packet(object):
def _encode(self, *args, **kw): def _encode(self, *args, **kw):
""" Default encoder, join all arguments """ """ Default encoder, join all arguments """
# XXX: this is a convenient method but there is a lack of argument args = list(args)
# checking (tid length...) args.extend(kw.values())
assert kw == {} return ''.join([str(i) for i in args] or '')
if args:
return ''.join([str(i) for i in args])
return ''
def _decode(self, body): def _decode(self, body):
""" Default decoder, message must be empty """ """ Default decoder, message must be empty """
...@@ -303,7 +279,6 @@ class Packet(object): ...@@ -303,7 +279,6 @@ class Packet(object):
return () return ()
def isResponse(self): def isResponse(self):
# FIXME: usefull ?
return self._code & 0x8000 == 0x8000 return self._code & 0x8000 == 0x8000
...@@ -1237,6 +1212,7 @@ class Error(Packet): ...@@ -1237,6 +1212,7 @@ class Error(Packet):
StaticRegistry = {} StaticRegistry = {}
def register(code, cls): def register(code, cls):
""" Register a packet in the packet registry """
assert code not in StaticRegistry, "Duplicate packet code" assert code not in StaticRegistry, "Duplicate packet code"
cls._code = code cls._code = code
StaticRegistry[code] = cls StaticRegistry[code] = cls
...@@ -1244,17 +1220,35 @@ def register(code, cls): ...@@ -1244,17 +1220,35 @@ def register(code, cls):
class PacketRegistry(dict): class PacketRegistry(dict):
"""
Packet registry that check packet code unicity and provide an index
"""
def __init__(self): def __init__(self):
dict.__init__(self) dict.__init__(self)
# TODO: self load and lookup cls in module from attr name: # load packet classes
# for attr in self:
# if issubclass(module.get(attr), Packet):
# cls = module.get(attr)
# code = ???
# self[code] = cls
self.update(StaticRegistry) self.update(StaticRegistry)
def parse(self, msg):
if len(msg) < MIN_PACKET_SIZE:
return None
msg_id, msg_type, msg_len = unpack('!LHL', msg[:PACKET_HEADER_SIZE])
try:
packet_klass = self[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
if msg_len < MIN_PACKET_SIZE:
raise PacketMalformedError('message too small (%d)' % msg_len)
if len(msg) < msg_len:
# Not enough.
return None
packet = packet_klass()
packet.setContent(msg_id, msg[PACKET_HEADER_SIZE:msg_len])
return packet
# packets registration
Error = register(0x8000, Error) Error = register(0x8000, Error)
Ping = register(0x0001, Ping) Ping = register(0x0001, Ping)
Pong = register(0x8001, Pong) Pong = register(0x8001, Pong)
...@@ -1320,6 +1314,7 @@ class PacketRegistry(dict): ...@@ -1320,6 +1314,7 @@ class PacketRegistry(dict):
AnswerClusterState = register(0x8028, AnswerClusterState) AnswerClusterState = register(0x8028, AnswerClusterState)
NotifyLastOID = register(0x0030, NotifyLastOID) NotifyLastOID = register(0x0030, NotifyLastOID)
# build a "singleton"
Packets = PacketRegistry() Packets = PacketRegistry()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment