Commit 60422372 authored by Grégory Wisniewski's avatar Grégory Wisniewski

* Move parse() into packet registry

* Support named arguments in default encoder
* Add some docstring and comments
* Describe assertions


git-svn-id: https://svn.erp5.org/repos/neo/trunk@1360 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 3db58c9a
......@@ -237,7 +237,7 @@ class Connection(BaseConnection):
while 1:
packet = None
try:
packet = protocol.parse(self.read_buf)
packet = Packets.parse(self.read_buf)
except PacketMalformedError, msg:
self.handler._packetMalformed(self, packet, msg)
return
......
......@@ -97,16 +97,11 @@ INVALID_PTID = '\0' * 8
INVALID_SERIAL = INVALID_TID
INVALID_PARTITION = 0xffffffff
STORAGE_NS = 'S'
MASTER_NS = 'M'
CLIENT_NS = 'C'
ADMIN_NS = 'A'
UUID_NAMESPACES = {
NodeTypes.STORAGE: STORAGE_NS,
NodeTypes.MASTER: MASTER_NS,
NodeTypes.CLIENT: CLIENT_NS,
NodeTypes.ADMIN: ADMIN_NS,
NodeTypes.STORAGE: 'S',
NodeTypes.MASTER: 'M',
NodeTypes.CLIENT: 'C',
NodeTypes.ADMIN: 'A',
}
class ProtocolError(Exception):
......@@ -164,6 +159,7 @@ def _decodeAddress(address):
def _encodeAddress(address):
if address is None:
return '\0' * 6
# address is a tuple (ip, port)
return pack('!4sH', inet_aton(address[0]), address[1])
def _decodeUUID(uuid):
......@@ -204,29 +200,14 @@ def _readString(buf, name, offset=0):
raise PacketMalformedError("can't read string <%s>" % name)
return (string, buf[offset+size:])
def parse(msg):
if len(msg) < MIN_PACKET_SIZE:
return None
msg_id, msg_type, msg_len = unpack('!LHL', msg[:PACKET_HEADER_SIZE])
try:
packet_klass = Packets[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
if msg_len < MIN_PACKET_SIZE:
raise PacketMalformedError('message too small (%d)' % msg_len)
if len(msg) < msg_len:
# Not enough.
return None
packet = packet_klass()
packet.setContent(msg_type, msg_id, msg[PACKET_HEADER_SIZE:msg_len])
return packet
class Packet(object):
"""
Base class for any packet definition.
Each subclass should override _encode() and _decode() and return a string or
a tuple respectively.
"""
# XXX: use a global a class-attribute
_body = None
_code = None
_args = None
......@@ -252,8 +233,8 @@ class Packet(object):
name = self.__class__.__name__
raise PacketMalformedError("%s fail (%s)" % (name, msg))
def setContent(self, code, msg_id, body):
self._code = code
def setContent(self, msg_id, body):
""" Register the packet content for future decoding """
self._id = msg_id
self._body = body
......@@ -261,11 +242,10 @@ class Packet(object):
self._id = value
def getId(self):
assert self._id is not None
assert self._id is not None, "No identifier applied on the packet"
return self._id
def getCode(self):
assert self._code is not None
return self._code
def getType(self):
......@@ -273,10 +253,9 @@ class Packet(object):
# TODO: replace this with __call__ that take the id as parameter
def __str__(self):
assert self._id is not None
content = self._body
length = PACKET_HEADER_SIZE + len(content)
return pack('!LHL', self._id, self._code, length) + content
return pack('!LHL', self.getId(), self._code, length) + content
def __len__(self):
return PACKET_HEADER_SIZE + len(self._body)
......@@ -290,12 +269,9 @@ class Packet(object):
def _encode(self, *args, **kw):
""" Default encoder, join all arguments """
# XXX: this is a convenient method but there is a lack of argument
# checking (tid length...)
assert kw == {}
if args:
return ''.join([str(i) for i in args])
return ''
args = list(args)
args.extend(kw.values())
return ''.join([str(i) for i in args] or '')
def _decode(self, body):
""" Default decoder, message must be empty """
......@@ -303,7 +279,6 @@ class Packet(object):
return ()
def isResponse(self):
# FIXME: usefull ?
return self._code & 0x8000 == 0x8000
......@@ -1237,6 +1212,7 @@ class Error(Packet):
StaticRegistry = {}
def register(code, cls):
""" Register a packet in the packet registry """
assert code not in StaticRegistry, "Duplicate packet code"
cls._code = code
StaticRegistry[code] = cls
......@@ -1244,17 +1220,35 @@ def register(code, cls):
class PacketRegistry(dict):
"""
Packet registry that check packet code unicity and provide an index
"""
def __init__(self):
dict.__init__(self)
# TODO: self load and lookup cls in module from attr name:
# for attr in self:
# if issubclass(module.get(attr), Packet):
# cls = module.get(attr)
# code = ???
# self[code] = cls
# load packet classes
self.update(StaticRegistry)
def parse(self, msg):
if len(msg) < MIN_PACKET_SIZE:
return None
msg_id, msg_type, msg_len = unpack('!LHL', msg[:PACKET_HEADER_SIZE])
try:
packet_klass = self[msg_type]
except KeyError:
raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE:
raise PacketMalformedError('message too big (%d)' % msg_len)
if msg_len < MIN_PACKET_SIZE:
raise PacketMalformedError('message too small (%d)' % msg_len)
if len(msg) < msg_len:
# Not enough.
return None
packet = packet_klass()
packet.setContent(msg_id, msg[PACKET_HEADER_SIZE:msg_len])
return packet
# packets registration
Error = register(0x8000, Error)
Ping = register(0x0001, Ping)
Pong = register(0x8001, Pong)
......@@ -1320,6 +1314,7 @@ class PacketRegistry(dict):
AnswerClusterState = register(0x8028, AnswerClusterState)
NotifyLastOID = register(0x0030, NotifyLastOID)
# build a "singleton"
Packets = PacketRegistry()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment