Commit 275a8ca2 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Implement ReadBuffer to minimize incoming data copies.

ReadBuffer join received strings only when all requested data is available.
This avoid many useless data copies in case of big packets. The gain factor
is about 50x for a 25MB packet.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1962 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent ee3ec432
......@@ -28,6 +28,7 @@ from neo.util import dump
from neo.logger import PACKET_LOGGER
from neo import attributeTracker
from neo.util import ReadBuffer
from neo.profiling import profiler_decorator
PING_DELAY = 5
......@@ -284,7 +285,7 @@ class Connection(BaseConnection):
def __init__(self, event_manager, handler, connector, addr=None):
BaseConnection.__init__(self, event_manager, handler,
connector=connector, addr=addr)
self.read_buf = []
self.read_buf = ReadBuffer()
self.write_buf = []
self.cur_id = 0
self.peer_id = 0
......@@ -327,7 +328,7 @@ class Connection(BaseConnection):
self._on_close()
self._on_close = None
del self.write_buf[:]
del self.read_buf[:]
self.read_buf.clear()
self._handlers.clear()
def abort(self):
......@@ -355,22 +356,16 @@ class Connection(BaseConnection):
def analyse(self):
"""Analyse received data."""
read_buf = self.read_buf
if len(read_buf) == 1:
msg = read_buf[0]
else:
msg = ''.join(self.read_buf)
while True:
# parse a packet
try:
packet = Packets.parse(msg)
packet = Packets.parse(self.read_buf)
if packet is None:
break
except PacketMalformedError, msg:
self.getHandler()._packetMalformed(self, msg)
return
self._timeout.refresh(time())
msg = msg[len(packet):]
packet_type = packet.getType()
if packet_type == Packets.Ping:
# Send a pong notification
......@@ -379,7 +374,6 @@ class Connection(BaseConnection):
# Skip PONG packets, its only purpose is refresh the timeout
# generated upong ping.
self._queue.append(packet)
self.read_buf = [msg]
def hasPendingMessages(self):
"""
......
......@@ -1525,11 +1525,12 @@ class PacketRegistry(dict):
# load packet classes
self.update(StaticRegistry)
def parse(self, msg):
if len(msg) < MIN_PACKET_SIZE:
def parse(self, buf):
if len(buf) < PACKET_HEADER_SIZE:
return None
msg_id, msg_type, msg_len = unpack(PACKET_HEADER_FORMAT,
msg[:PACKET_HEADER_SIZE])
header = buf.peek(PACKET_HEADER_SIZE)
assert header is not None
msg_id, msg_type, msg_len = unpack(PACKET_HEADER_FORMAT, header)
try:
packet_klass = self[msg_type]
except KeyError:
......@@ -1538,11 +1539,15 @@ class PacketRegistry(dict):
raise PacketMalformedError('message too big (%d)' % msg_len)
if msg_len < MIN_PACKET_SIZE:
raise PacketMalformedError('message too small (%d)' % msg_len)
if len(msg) < msg_len:
if len(buf) < msg_len:
# Not enough.
return None
buf.skip(PACKET_HEADER_SIZE)
msg_len -= PACKET_HEADER_SIZE
packet = packet_klass()
packet.setContent(msg_id, msg[PACKET_HEADER_SIZE:msg_len])
data = buf.read(msg_len)
assert data is not None
packet.setContent(msg_id, data)
return packet
# packets registration
......
......@@ -128,7 +128,15 @@ class ConnectionTests(NeoTestBase):
self.assertEquals(len(calls), n)
def _checkReadBuf(self, bc, data):
self.assertEqual(''.join(bc.read_buf), data)
content = bc.read_buf.peek(len(bc.read_buf))
self.assertEqual(''.join(content), data)
def _appendToReadBuf(self, bc, data):
bc.read_buf.append(data)
def _appendPacketToReadBuf(self, bc, packet):
data = ''.join(packet.encode())
bc.read_buf.append(data)
def _checkWriteBuf(self, bc, data):
self.assertEqual(''.join(bc.write_buf), data)
......@@ -392,7 +400,7 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID()))
p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1)
bc.read_buf += p.encode()
self._appendPacketToReadBuf(bc, p)
bc.analyse()
# check packet decoded
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
......@@ -419,7 +427,7 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID()))
p1 = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p1.setId(1)
bc.read_buf += p1.encode()
self._appendPacketToReadBuf(bc, p1)
# packet 2
master_list = (
(("127.0.0.1", 2135), self.getNewUUID()),
......@@ -432,8 +440,8 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID()))
p2 = Packets.AnswerPrimary( self.getNewUUID(), master_list)
p2.setId(2)
bc.read_buf += p2.encode()
self.assertEqual(len(''.join(bc.read_buf)), len(p1) + len(p2))
self._appendPacketToReadBuf(bc, p2)
self.assertEqual(len(bc.read_buf), len(p1) + len(p2))
bc.analyse()
# check two packets decoded
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 2)
......@@ -455,7 +463,7 @@ class ConnectionTests(NeoTestBase):
# give a bad packet, won't be decoded
bc = self._makeConnection()
bc._queue = Mock()
bc.read_buf += "datadatadatadata"
self._appendToReadBuf(bc, 'datadatadatadata')
self.assertEqual(len(bc.read_buf), 16)
bc.analyse()
self.assertEqual(len(bc.read_buf), 16)
......@@ -476,7 +484,7 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID()))
p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1)
bc.read_buf += p.encode()
self._appendPacketToReadBuf(bc, p)
bc.analyse()
# check packet decoded
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
......@@ -485,7 +493,7 @@ class ConnectionTests(NeoTestBase):
self.assertEqual(data.getType(), p.getType())
self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode())
self.assertEqual(''.join(bc.read_buf), '')
self._checkReadBuf(bc, '')
def test_Connection_writable1(self):
# with pending operation after send
......
......@@ -18,11 +18,51 @@
import unittest
from neo.tests import NeoTestBase
from neo import util
from neo.util import ReadBuffer
class UtilTests(NeoTestBase):
pass
def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
def testReadBufferPeek(self):
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# peek some data
self.assertEqual(buf.peek(3), 'abc')
self.assertEqual(buf.peek(5), None) # not enough
buf.append('def')
self.assertEqual(len(buf), 6)
self.assertEqual(buf.peek(3), 'abc') # no change
self.assertEqual(buf.peek(6), 'abcdef')
self.assertEqual(buf.peek(7), None)
def testReadBufferSkip(self):
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
buf.skip(1)
self.assertEqual(len(buf), 2)
buf.skip(3) # eat all
self.assertEqual(len(buf), 0)
if __name__ == "__main__":
unittest.main()
......
......@@ -19,6 +19,7 @@
import re
import socket
from zlib import adler32
from Queue import deque
from struct import pack, unpack
def u64(s):
......@@ -130,3 +131,83 @@ class Enum(dict):
def getByName(self, name):
return getattr(self, name)
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
size = len(data)
self.size += size
self.content.append((size, data))
def __len__(self):
""" Return the current buffer size """
return self.size
def _read(self, size):
""" Join all required chunks to build a string of requested size """
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
# select required chunks
while size > 0:
chunk_size, chunk_data = pop_chunk()
size -= chunk_size
append_data(chunk_data)
if size < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:size], last_chunk[size:]
self.content.appendleft((-size, let))
chunk_list[-1] = keep
# join all chunks (one copy)
return ''.join(chunk_list)
def skip(self, size):
""" Skip at most size bytes """
if self.size <= size:
self.size = 0
self.content.clear()
return
pop_chunk = self.content.popleft
self.size -= size
# skip chunks
while size > 0:
chunk_size, last_chunk = pop_chunk()
size -= chunk_size
if size < 0:
# but keep a part of the last one if needed
self.content.append((-size, last_chunk[size:]))
def peek(self, size):
""" Read size bytes but don't consume """
if self.size < size:
return None
data = self._read(size)
self.content.appendleft((size, data))
assert len(data) == size
return data
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
data = self._read(size)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment