Commit ad8b51f0 authored by Yury Selivanov's avatar Yury Selivanov

asyncio.streams: Use bytebuffer in StreamReader; Add assertion in feed_data

parent 7a49d33d
...@@ -4,8 +4,6 @@ __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', ...@@ -4,8 +4,6 @@ __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server', 'IncompleteReadError', 'open_connection', 'start_server', 'IncompleteReadError',
] ]
import collections
from . import events from . import events
from . import futures from . import futures
from . import protocols from . import protocols
...@@ -259,9 +257,7 @@ class StreamReader: ...@@ -259,9 +257,7 @@ class StreamReader:
if loop is None: if loop is None:
loop = events.get_event_loop() loop = events.get_event_loop()
self._loop = loop self._loop = loop
# TODO: Use a bytearray for a buffer, like the transport. self._buffer = bytearray()
self._buffer = collections.deque() # Deque of bytes objects.
self._byte_count = 0 # Bytes in buffer.
self._eof = False # Whether we're done. self._eof = False # Whether we're done.
self._waiter = None # A future. self._waiter = None # A future.
self._exception = None self._exception = None
...@@ -285,7 +281,7 @@ class StreamReader: ...@@ -285,7 +281,7 @@ class StreamReader:
self._transport = transport self._transport = transport
def _maybe_resume_transport(self): def _maybe_resume_transport(self):
if self._paused and self._byte_count <= self._limit: if self._paused and len(self._buffer) <= self._limit:
self._paused = False self._paused = False
self._transport.resume_reading() self._transport.resume_reading()
...@@ -298,11 +294,12 @@ class StreamReader: ...@@ -298,11 +294,12 @@ class StreamReader:
waiter.set_result(True) waiter.set_result(True)
def feed_data(self, data): def feed_data(self, data):
assert not self._eof, 'feed_data after feed_eof'
if not data: if not data:
return return
self._buffer.append(data) self._buffer.extend(data)
self._byte_count += len(data)
waiter = self._waiter waiter = self._waiter
if waiter is not None: if waiter is not None:
...@@ -312,7 +309,7 @@ class StreamReader: ...@@ -312,7 +309,7 @@ class StreamReader:
if (self._transport is not None and if (self._transport is not None and
not self._paused and not self._paused and
self._byte_count > 2*self._limit): len(self._buffer) > 2*self._limit):
try: try:
self._transport.pause_reading() self._transport.pause_reading()
except NotImplementedError: except NotImplementedError:
...@@ -338,28 +335,22 @@ class StreamReader: ...@@ -338,28 +335,22 @@ class StreamReader:
if self._exception is not None: if self._exception is not None:
raise self._exception raise self._exception
parts = [] line = bytearray()
parts_size = 0
not_enough = True not_enough = True
while not_enough: while not_enough:
while self._buffer and not_enough: while self._buffer and not_enough:
data = self._buffer.popleft() ichar = self._buffer.find(b'\n')
ichar = data.find(b'\n')
if ichar < 0: if ichar < 0:
parts.append(data) line.extend(self._buffer)
parts_size += len(data) self._buffer.clear()
else: else:
ichar += 1 ichar += 1
head, tail = data[:ichar], data[ichar:] line.extend(self._buffer[:ichar])
if tail: del self._buffer[:ichar]
self._buffer.appendleft(tail)
not_enough = False not_enough = False
parts.append(head)
parts_size += len(head)
if parts_size > self._limit: if len(line) > self._limit:
self._byte_count -= parts_size
self._maybe_resume_transport() self._maybe_resume_transport()
raise ValueError('Line is too long') raise ValueError('Line is too long')
...@@ -373,11 +364,8 @@ class StreamReader: ...@@ -373,11 +364,8 @@ class StreamReader:
finally: finally:
self._waiter = None self._waiter = None
line = b''.join(parts)
self._byte_count -= parts_size
self._maybe_resume_transport() self._maybe_resume_transport()
return bytes(line)
return line
@tasks.coroutine @tasks.coroutine
def read(self, n=-1): def read(self, n=-1):
...@@ -395,36 +383,23 @@ class StreamReader: ...@@ -395,36 +383,23 @@ class StreamReader:
finally: finally:
self._waiter = None self._waiter = None
else: else:
if not self._byte_count and not self._eof: if not self._buffer and not self._eof:
self._waiter = self._create_waiter('read') self._waiter = self._create_waiter('read')
try: try:
yield from self._waiter yield from self._waiter
finally: finally:
self._waiter = None self._waiter = None
if n < 0 or self._byte_count <= n: if n < 0 or len(self._buffer) <= n:
data = b''.join(self._buffer) data = bytes(self._buffer)
self._buffer.clear() self._buffer.clear()
self._byte_count = 0 else:
self._maybe_resume_transport() # n > 0 and len(self._buffer) > n
return data data = bytes(self._buffer[:n])
del self._buffer[:n]
parts = []
parts_bytes = 0
while self._buffer and parts_bytes < n:
data = self._buffer.popleft()
data_bytes = len(data)
if n < parts_bytes + data_bytes:
data_bytes = n - parts_bytes
data, rest = data[:data_bytes], data[data_bytes:]
self._buffer.appendleft(rest)
parts.append(data)
parts_bytes += data_bytes
self._byte_count -= data_bytes
self._maybe_resume_transport() self._maybe_resume_transport()
return data
return b''.join(parts)
@tasks.coroutine @tasks.coroutine
def readexactly(self, n): def readexactly(self, n):
......
...@@ -79,13 +79,13 @@ class StreamReaderTests(unittest.TestCase): ...@@ -79,13 +79,13 @@ class StreamReaderTests(unittest.TestCase):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'') stream.feed_data(b'')
self.assertEqual(0, stream._byte_count) self.assertEqual(b'', stream._buffer)
def test_feed_data_byte_count(self): def test_feed_nonempty_data(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
self.assertEqual(len(self.DATA), stream._byte_count) self.assertEqual(self.DATA, stream._buffer)
def test_read_zero(self): def test_read_zero(self):
# Read zero bytes. # Read zero bytes.
...@@ -94,7 +94,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -94,7 +94,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(stream.read(0)) data = self.loop.run_until_complete(stream.read(0))
self.assertEqual(b'', data) self.assertEqual(b'', data)
self.assertEqual(len(self.DATA), stream._byte_count) self.assertEqual(self.DATA, stream._buffer)
def test_read(self): def test_read(self):
# Read bytes. # Read bytes.
...@@ -107,7 +107,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -107,7 +107,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task) data = self.loop.run_until_complete(read_task)
self.assertEqual(self.DATA, data) self.assertEqual(self.DATA, data)
self.assertFalse(stream._byte_count) self.assertEqual(b'', stream._buffer)
def test_read_line_breaks(self): def test_read_line_breaks(self):
# Read bytes without line breaks. # Read bytes without line breaks.
...@@ -118,7 +118,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -118,7 +118,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(stream.read(5)) data = self.loop.run_until_complete(stream.read(5))
self.assertEqual(b'line1', data) self.assertEqual(b'line1', data)
self.assertEqual(5, stream._byte_count) self.assertEqual(b'line2', stream._buffer)
def test_read_eof(self): def test_read_eof(self):
# Read bytes, stop at eof. # Read bytes, stop at eof.
...@@ -131,7 +131,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -131,7 +131,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task) data = self.loop.run_until_complete(read_task)
self.assertEqual(b'', data) self.assertEqual(b'', data)
self.assertFalse(stream._byte_count) self.assertEqual(b'', stream._buffer)
def test_read_until_eof(self): def test_read_until_eof(self):
# Read all bytes until eof. # Read all bytes until eof.
...@@ -147,7 +147,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -147,7 +147,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task) data = self.loop.run_until_complete(read_task)
self.assertEqual(b'chunk1\nchunk2', data) self.assertEqual(b'chunk1\nchunk2', data)
self.assertFalse(stream._byte_count) self.assertEqual(b'', stream._buffer)
def test_read_exception(self): def test_read_exception(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
...@@ -161,7 +161,8 @@ class StreamReaderTests(unittest.TestCase): ...@@ -161,7 +161,8 @@ class StreamReaderTests(unittest.TestCase):
ValueError, self.loop.run_until_complete, stream.read(2)) ValueError, self.loop.run_until_complete, stream.read(2))
def test_readline(self): def test_readline(self):
# Read one line. # Read one line. 'readline' will need to wait for the data
# to come from 'cb'
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'chunk1 ') stream.feed_data(b'chunk1 ')
read_task = asyncio.Task(stream.readline(), loop=self.loop) read_task = asyncio.Task(stream.readline(), loop=self.loop)
...@@ -174,30 +175,40 @@ class StreamReaderTests(unittest.TestCase): ...@@ -174,30 +175,40 @@ class StreamReaderTests(unittest.TestCase):
line = self.loop.run_until_complete(read_task) line = self.loop.run_until_complete(read_task)
self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
self.assertEqual(len(b'\n chunk4')-1, stream._byte_count) self.assertEqual(b' chunk4', stream._buffer)
def test_readline_limit_with_existing_data(self): def test_readline_limit_with_existing_data(self):
stream = asyncio.StreamReader(3, loop=self.loop) # Read one line. The data is in StreamReader's buffer
# before the event loop is run.
stream = asyncio.StreamReader(limit=3, loop=self.loop)
stream.feed_data(b'li') stream.feed_data(b'li')
stream.feed_data(b'ne1\nline2\n') stream.feed_data(b'ne1\nline2\n')
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline()) ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual([b'line2\n'], list(stream._buffer)) # The buffer should contain the remaining data after exception
self.assertEqual(b'line2\n', stream._buffer)
stream = asyncio.StreamReader(3, loop=self.loop) stream = asyncio.StreamReader(limit=3, loop=self.loop)
stream.feed_data(b'li') stream.feed_data(b'li')
stream.feed_data(b'ne1') stream.feed_data(b'ne1')
stream.feed_data(b'li') stream.feed_data(b'li')
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline()) ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual([b'li'], list(stream._buffer)) # No b'\n' at the end. The 'limit' is set to 3. So before
self.assertEqual(2, stream._byte_count) # waiting for the new data in buffer, 'readline' will consume
# the entire buffer, and since the length of the consumed data
# is more than 3, it will raise a ValudError. The buffer is
# expected to be empty now.
self.assertEqual(b'', stream._buffer)
def test_readline_limit(self): def test_readline_limit(self):
stream = asyncio.StreamReader(7, loop=self.loop) # Read one line. StreamReaders are fed with data after
# their 'readline' methods are called.
stream = asyncio.StreamReader(limit=7, loop=self.loop)
def cb(): def cb():
stream.feed_data(b'chunk1') stream.feed_data(b'chunk1')
stream.feed_data(b'chunk2') stream.feed_data(b'chunk2')
...@@ -207,10 +218,25 @@ class StreamReaderTests(unittest.TestCase): ...@@ -207,10 +218,25 @@ class StreamReaderTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline()) ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual([b'chunk3\n'], list(stream._buffer)) # The buffer had just one line of data, and after raising
self.assertEqual(7, stream._byte_count) # a ValueError it should be empty.
self.assertEqual(b'', stream._buffer)
stream = asyncio.StreamReader(limit=7, loop=self.loop)
def cb():
stream.feed_data(b'chunk1')
stream.feed_data(b'chunk2\n')
stream.feed_data(b'chunk3\n')
stream.feed_eof()
self.loop.call_soon(cb)
self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual(b'chunk3\n', stream._buffer)
def test_readline_line_byte_count(self): def test_readline_nolimit_nowait(self):
# All needed data for the first 'readline' call will be
# in the buffer.
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[:6])
stream.feed_data(self.DATA[6:]) stream.feed_data(self.DATA[6:])
...@@ -218,7 +244,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -218,7 +244,7 @@ class StreamReaderTests(unittest.TestCase):
line = self.loop.run_until_complete(stream.readline()) line = self.loop.run_until_complete(stream.readline())
self.assertEqual(b'line1\n', line) self.assertEqual(b'line1\n', line)
self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count) self.assertEqual(b'line2\nline3\n', stream._buffer)
def test_readline_eof(self): def test_readline_eof(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
...@@ -244,9 +270,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -244,9 +270,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(stream.read(7)) data = self.loop.run_until_complete(stream.read(7))
self.assertEqual(b'line2\nl', data) self.assertEqual(b'line2\nl', data)
self.assertEqual( self.assertEqual(b'ine3\n', stream._buffer)
len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
stream._byte_count)
def test_readline_exception(self): def test_readline_exception(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
...@@ -258,6 +282,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -258,6 +282,7 @@ class StreamReaderTests(unittest.TestCase):
stream.set_exception(ValueError()) stream.set_exception(ValueError())
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline()) ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual(b'', stream._buffer)
def test_readexactly_zero_or_less(self): def test_readexactly_zero_or_less(self):
# Read exact number of bytes (zero or less). # Read exact number of bytes (zero or less).
...@@ -266,11 +291,11 @@ class StreamReaderTests(unittest.TestCase): ...@@ -266,11 +291,11 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(stream.readexactly(0)) data = self.loop.run_until_complete(stream.readexactly(0))
self.assertEqual(b'', data) self.assertEqual(b'', data)
self.assertEqual(len(self.DATA), stream._byte_count) self.assertEqual(self.DATA, stream._buffer)
data = self.loop.run_until_complete(stream.readexactly(-1)) data = self.loop.run_until_complete(stream.readexactly(-1))
self.assertEqual(b'', data) self.assertEqual(b'', data)
self.assertEqual(len(self.DATA), stream._byte_count) self.assertEqual(self.DATA, stream._buffer)
def test_readexactly(self): def test_readexactly(self):
# Read exact number of bytes. # Read exact number of bytes.
...@@ -287,7 +312,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -287,7 +312,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task) data = self.loop.run_until_complete(read_task)
self.assertEqual(self.DATA + self.DATA, data) self.assertEqual(self.DATA + self.DATA, data)
self.assertEqual(len(self.DATA), stream._byte_count) self.assertEqual(self.DATA, stream._buffer)
def test_readexactly_eof(self): def test_readexactly_eof(self):
# Read exact number of bytes (eof). # Read exact number of bytes (eof).
...@@ -306,7 +331,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -306,7 +331,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertEqual(cm.exception.expected, n) self.assertEqual(cm.exception.expected, n)
self.assertEqual(str(cm.exception), self.assertEqual(str(cm.exception),
'18 bytes read on a total of 36 expected bytes') '18 bytes read on a total of 36 expected bytes')
self.assertFalse(stream._byte_count) self.assertEqual(b'', stream._buffer)
def test_readexactly_exception(self): def test_readexactly_exception(self):
stream = asyncio.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment