Commit 275a8ca2 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Implement ReadBuffer to minimize incoming data copies.

ReadBuffer join received strings only when all requested data is available.
This avoid many useless data copies in case of big packets. The gain factor
is about 50x for a 25MB packet.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@1962 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent ee3ec432
...@@ -28,6 +28,7 @@ from neo.util import dump ...@@ -28,6 +28,7 @@ from neo.util import dump
from neo.logger import PACKET_LOGGER from neo.logger import PACKET_LOGGER
from neo import attributeTracker from neo import attributeTracker
from neo.util import ReadBuffer
from neo.profiling import profiler_decorator from neo.profiling import profiler_decorator
PING_DELAY = 5 PING_DELAY = 5
...@@ -284,7 +285,7 @@ class Connection(BaseConnection): ...@@ -284,7 +285,7 @@ class Connection(BaseConnection):
def __init__(self, event_manager, handler, connector, addr=None): def __init__(self, event_manager, handler, connector, addr=None):
BaseConnection.__init__(self, event_manager, handler, BaseConnection.__init__(self, event_manager, handler,
connector=connector, addr=addr) connector=connector, addr=addr)
self.read_buf = [] self.read_buf = ReadBuffer()
self.write_buf = [] self.write_buf = []
self.cur_id = 0 self.cur_id = 0
self.peer_id = 0 self.peer_id = 0
...@@ -327,7 +328,7 @@ class Connection(BaseConnection): ...@@ -327,7 +328,7 @@ class Connection(BaseConnection):
self._on_close() self._on_close()
self._on_close = None self._on_close = None
del self.write_buf[:] del self.write_buf[:]
del self.read_buf[:] self.read_buf.clear()
self._handlers.clear() self._handlers.clear()
def abort(self): def abort(self):
...@@ -355,22 +356,16 @@ class Connection(BaseConnection): ...@@ -355,22 +356,16 @@ class Connection(BaseConnection):
def analyse(self): def analyse(self):
"""Analyse received data.""" """Analyse received data."""
read_buf = self.read_buf
if len(read_buf) == 1:
msg = read_buf[0]
else:
msg = ''.join(self.read_buf)
while True: while True:
# parse a packet # parse a packet
try: try:
packet = Packets.parse(msg) packet = Packets.parse(self.read_buf)
if packet is None: if packet is None:
break break
except PacketMalformedError, msg: except PacketMalformedError, msg:
self.getHandler()._packetMalformed(self, msg) self.getHandler()._packetMalformed(self, msg)
return return
self._timeout.refresh(time()) self._timeout.refresh(time())
msg = msg[len(packet):]
packet_type = packet.getType() packet_type = packet.getType()
if packet_type == Packets.Ping: if packet_type == Packets.Ping:
# Send a pong notification # Send a pong notification
...@@ -379,7 +374,6 @@ class Connection(BaseConnection): ...@@ -379,7 +374,6 @@ class Connection(BaseConnection):
# Skip PONG packets, its only purpose is refresh the timeout # Skip PONG packets, its only purpose is refresh the timeout
# generated upong ping. # generated upong ping.
self._queue.append(packet) self._queue.append(packet)
self.read_buf = [msg]
def hasPendingMessages(self): def hasPendingMessages(self):
""" """
......
...@@ -1525,11 +1525,12 @@ class PacketRegistry(dict): ...@@ -1525,11 +1525,12 @@ class PacketRegistry(dict):
# load packet classes # load packet classes
self.update(StaticRegistry) self.update(StaticRegistry)
def parse(self, msg): def parse(self, buf):
if len(msg) < MIN_PACKET_SIZE: if len(buf) < PACKET_HEADER_SIZE:
return None return None
msg_id, msg_type, msg_len = unpack(PACKET_HEADER_FORMAT, header = buf.peek(PACKET_HEADER_SIZE)
msg[:PACKET_HEADER_SIZE]) assert header is not None
msg_id, msg_type, msg_len = unpack(PACKET_HEADER_FORMAT, header)
try: try:
packet_klass = self[msg_type] packet_klass = self[msg_type]
except KeyError: except KeyError:
...@@ -1538,11 +1539,15 @@ class PacketRegistry(dict): ...@@ -1538,11 +1539,15 @@ class PacketRegistry(dict):
raise PacketMalformedError('message too big (%d)' % msg_len) raise PacketMalformedError('message too big (%d)' % msg_len)
if msg_len < MIN_PACKET_SIZE: if msg_len < MIN_PACKET_SIZE:
raise PacketMalformedError('message too small (%d)' % msg_len) raise PacketMalformedError('message too small (%d)' % msg_len)
if len(msg) < msg_len: if len(buf) < msg_len:
# Not enough. # Not enough.
return None return None
buf.skip(PACKET_HEADER_SIZE)
msg_len -= PACKET_HEADER_SIZE
packet = packet_klass() packet = packet_klass()
packet.setContent(msg_id, msg[PACKET_HEADER_SIZE:msg_len]) data = buf.read(msg_len)
assert data is not None
packet.setContent(msg_id, data)
return packet return packet
# packets registration # packets registration
......
...@@ -128,7 +128,15 @@ class ConnectionTests(NeoTestBase): ...@@ -128,7 +128,15 @@ class ConnectionTests(NeoTestBase):
self.assertEquals(len(calls), n) self.assertEquals(len(calls), n)
def _checkReadBuf(self, bc, data): def _checkReadBuf(self, bc, data):
self.assertEqual(''.join(bc.read_buf), data) content = bc.read_buf.peek(len(bc.read_buf))
self.assertEqual(''.join(content), data)
def _appendToReadBuf(self, bc, data):
bc.read_buf.append(data)
def _appendPacketToReadBuf(self, bc, packet):
data = ''.join(packet.encode())
bc.read_buf.append(data)
def _checkWriteBuf(self, bc, data): def _checkWriteBuf(self, bc, data):
self.assertEqual(''.join(bc.write_buf), data) self.assertEqual(''.join(bc.write_buf), data)
...@@ -392,7 +400,7 @@ class ConnectionTests(NeoTestBase): ...@@ -392,7 +400,7 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p = Packets.AnswerPrimary(self.getNewUUID(), master_list) p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1) p.setId(1)
bc.read_buf += p.encode() self._appendPacketToReadBuf(bc, p)
bc.analyse() bc.analyse()
# check packet decoded # check packet decoded
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
...@@ -419,7 +427,7 @@ class ConnectionTests(NeoTestBase): ...@@ -419,7 +427,7 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p1 = Packets.AnswerPrimary(self.getNewUUID(), master_list) p1 = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p1.setId(1) p1.setId(1)
bc.read_buf += p1.encode() self._appendPacketToReadBuf(bc, p1)
# packet 2 # packet 2
master_list = ( master_list = (
(("127.0.0.1", 2135), self.getNewUUID()), (("127.0.0.1", 2135), self.getNewUUID()),
...@@ -432,8 +440,8 @@ class ConnectionTests(NeoTestBase): ...@@ -432,8 +440,8 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p2 = Packets.AnswerPrimary( self.getNewUUID(), master_list) p2 = Packets.AnswerPrimary( self.getNewUUID(), master_list)
p2.setId(2) p2.setId(2)
bc.read_buf += p2.encode() self._appendPacketToReadBuf(bc, p2)
self.assertEqual(len(''.join(bc.read_buf)), len(p1) + len(p2)) self.assertEqual(len(bc.read_buf), len(p1) + len(p2))
bc.analyse() bc.analyse()
# check two packets decoded # check two packets decoded
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 2) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 2)
...@@ -455,7 +463,7 @@ class ConnectionTests(NeoTestBase): ...@@ -455,7 +463,7 @@ class ConnectionTests(NeoTestBase):
# give a bad packet, won't be decoded # give a bad packet, won't be decoded
bc = self._makeConnection() bc = self._makeConnection()
bc._queue = Mock() bc._queue = Mock()
bc.read_buf += "datadatadatadata" self._appendToReadBuf(bc, 'datadatadatadata')
self.assertEqual(len(bc.read_buf), 16) self.assertEqual(len(bc.read_buf), 16)
bc.analyse() bc.analyse()
self.assertEqual(len(bc.read_buf), 16) self.assertEqual(len(bc.read_buf), 16)
...@@ -476,7 +484,7 @@ class ConnectionTests(NeoTestBase): ...@@ -476,7 +484,7 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p = Packets.AnswerPrimary(self.getNewUUID(), master_list) p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1) p.setId(1)
bc.read_buf += p.encode() self._appendPacketToReadBuf(bc, p)
bc.analyse() bc.analyse()
# check packet decoded # check packet decoded
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
...@@ -485,7 +493,7 @@ class ConnectionTests(NeoTestBase): ...@@ -485,7 +493,7 @@ class ConnectionTests(NeoTestBase):
self.assertEqual(data.getType(), p.getType()) self.assertEqual(data.getType(), p.getType())
self.assertEqual(data.getId(), p.getId()) self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode()) self.assertEqual(data.decode(), p.decode())
self.assertEqual(''.join(bc.read_buf), '') self._checkReadBuf(bc, '')
def test_Connection_writable1(self): def test_Connection_writable1(self):
# with pending operation after send # with pending operation after send
......
...@@ -18,11 +18,51 @@ ...@@ -18,11 +18,51 @@
import unittest import unittest
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo import util from neo.util import ReadBuffer
class UtilTests(NeoTestBase): class UtilTests(NeoTestBase):
pass def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
def testReadBufferPeek(self):
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# peek some data
self.assertEqual(buf.peek(3), 'abc')
self.assertEqual(buf.peek(5), None) # not enough
buf.append('def')
self.assertEqual(len(buf), 6)
self.assertEqual(buf.peek(3), 'abc') # no change
self.assertEqual(buf.peek(6), 'abcdef')
self.assertEqual(buf.peek(7), None)
def testReadBufferSkip(self):
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
buf.skip(1)
self.assertEqual(len(buf), 2)
buf.skip(3) # eat all
self.assertEqual(len(buf), 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import re import re
import socket import socket
from zlib import adler32 from zlib import adler32
from Queue import deque
from struct import pack, unpack from struct import pack, unpack
def u64(s): def u64(s):
...@@ -130,3 +131,83 @@ class Enum(dict): ...@@ -130,3 +131,83 @@ class Enum(dict):
def getByName(self, name): def getByName(self, name):
return getattr(self, name) return getattr(self, name)
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
size = len(data)
self.size += size
self.content.append((size, data))
def __len__(self):
""" Return the current buffer size """
return self.size
def _read(self, size):
""" Join all required chunks to build a string of requested size """
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
# select required chunks
while size > 0:
chunk_size, chunk_data = pop_chunk()
size -= chunk_size
append_data(chunk_data)
if size < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:size], last_chunk[size:]
self.content.appendleft((-size, let))
chunk_list[-1] = keep
# join all chunks (one copy)
return ''.join(chunk_list)
def skip(self, size):
""" Skip at most size bytes """
if self.size <= size:
self.size = 0
self.content.clear()
return
pop_chunk = self.content.popleft
self.size -= size
# skip chunks
while size > 0:
chunk_size, last_chunk = pop_chunk()
size -= chunk_size
if size < 0:
# but keep a part of the last one if needed
self.content.append((-size, last_chunk[size:]))
def peek(self, size):
""" Read size bytes but don't consume """
if self.size < size:
return None
data = self._read(size)
self.content.appendleft((size, data))
assert len(data) == size
return data
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
data = self._read(size)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment