Commit 6dc1b33a authored by Antoine Pitrou's avatar Antoine Pitrou

Issue #4574: reading an UTF16-encoded text file crashes if \r on 64-char boundary.

parent fb579177
...@@ -1282,25 +1282,23 @@ class IncrementalNewlineDecoder(codecs.IncrementalDecoder): ...@@ -1282,25 +1282,23 @@ class IncrementalNewlineDecoder(codecs.IncrementalDecoder):
""" """
def __init__(self, decoder, translate, errors='strict'): def __init__(self, decoder, translate, errors='strict'):
codecs.IncrementalDecoder.__init__(self, errors=errors) codecs.IncrementalDecoder.__init__(self, errors=errors)
self.buffer = b''
self.translate = translate self.translate = translate
self.decoder = decoder self.decoder = decoder
self.seennl = 0 self.seennl = 0
self.pendingcr = False
def decode(self, input, final=False): def decode(self, input, final=False):
# decode input (with the eventual \r from a previous pass) # decode input (with the eventual \r from a previous pass)
if self.buffer:
input = self.buffer + input
output = self.decoder.decode(input, final=final) output = self.decoder.decode(input, final=final)
if self.pendingcr and (output or final):
output = "\r" + output
self.pendingcr = False
# retain last \r even when not translating data: # retain last \r even when not translating data:
# then readline() is sure to get \r\n in one pass # then readline() is sure to get \r\n in one pass
if output.endswith("\r") and not final: if output.endswith("\r") and not final:
output = output[:-1] output = output[:-1]
self.buffer = b'\r' self.pendingcr = True
else:
self.buffer = b''
# Record which newlines are read # Record which newlines are read
crlf = output.count('\r\n') crlf = output.count('\r\n')
...@@ -1319,20 +1317,19 @@ class IncrementalNewlineDecoder(codecs.IncrementalDecoder): ...@@ -1319,20 +1317,19 @@ class IncrementalNewlineDecoder(codecs.IncrementalDecoder):
def getstate(self): def getstate(self):
buf, flag = self.decoder.getstate() buf, flag = self.decoder.getstate()
return buf + self.buffer, flag flag <<= 1
if self.pendingcr:
flag |= 1
return buf, flag
def setstate(self, state): def setstate(self, state):
buf, flag = state buf, flag = state
if buf.endswith(b'\r'): self.pendingcr = bool(flag & 1)
self.buffer = b'\r' self.decoder.setstate((buf, flag >> 1))
buf = buf[:-1]
else:
self.buffer = b''
self.decoder.setstate((buf, flag))
def reset(self): def reset(self):
self.seennl = 0 self.seennl = 0
self.buffer = b'' self.pendingcr = False
self.decoder.reset() self.decoder.reset()
_LF = 1 _LF = 1
......
...@@ -679,8 +679,9 @@ class StatefulIncrementalDecoder(codecs.IncrementalDecoder): ...@@ -679,8 +679,9 @@ class StatefulIncrementalDecoder(codecs.IncrementalDecoder):
@classmethod @classmethod
def lookupTestDecoder(cls, name): def lookupTestDecoder(cls, name):
if cls.codecEnabled and name == 'test_decoder': if cls.codecEnabled and name == 'test_decoder':
latin1 = codecs.lookup('latin-1')
return codecs.CodecInfo( return codecs.CodecInfo(
name='test_decoder', encode=None, decode=None, name='test_decoder', encode=latin1.encode, decode=None,
incrementalencoder=None, incrementalencoder=None,
streamreader=None, streamwriter=None, streamreader=None, streamwriter=None,
incrementaldecoder=cls) incrementaldecoder=cls)
...@@ -840,8 +841,11 @@ class TextIOWrapperTest(unittest.TestCase): ...@@ -840,8 +841,11 @@ class TextIOWrapperTest(unittest.TestCase):
[ '\r\n', [ "unix\nwindows\r\n", "os9\rlast\nnonl" ] ], [ '\r\n', [ "unix\nwindows\r\n", "os9\rlast\nnonl" ] ],
[ '\r', [ "unix\nwindows\r", "\nos9\r", "last\nnonl" ] ], [ '\r', [ "unix\nwindows\r", "\nos9\r", "last\nnonl" ] ],
] ]
encodings = (
encodings = ('utf-8', 'latin-1') 'utf-8', 'latin-1',
'utf-16', 'utf-16-le', 'utf-16-be',
'utf-32', 'utf-32-le', 'utf-32-be',
)
# Try a range of buffer sizes to test the case where \r is the last # Try a range of buffer sizes to test the case where \r is the last
# character in TextIOWrapper._pending_line. # character in TextIOWrapper._pending_line.
...@@ -1195,56 +1199,84 @@ class TextIOWrapperTest(unittest.TestCase): ...@@ -1195,56 +1199,84 @@ class TextIOWrapperTest(unittest.TestCase):
self.assertEqual(buffer.seekable(), txt.seekable()) self.assertEqual(buffer.seekable(), txt.seekable())
def test_newline_decoder(self): def check_newline_decoder_utf8(self, decoder):
import codecs # UTF-8 specific tests for a newline decoder
decoder = codecs.getincrementaldecoder("utf-8")() def _check_decode(b, s, **kwargs):
decoder = io.IncrementalNewlineDecoder(decoder, translate=True) # We exercise getstate() / setstate() as well as decode()
state = decoder.getstate()
self.assertEquals(decoder.decode(b, **kwargs), s)
decoder.setstate(state)
self.assertEquals(decoder.decode(b, **kwargs), s)
self.assertEquals(decoder.decode(b'\xe8\xa2\x88'), "\u8888") _check_decode(b'\xe8\xa2\x88', "\u8888")
self.assertEquals(decoder.decode(b'\xe8'), "") _check_decode(b'\xe8', "")
self.assertEquals(decoder.decode(b'\xa2'), "") _check_decode(b'\xa2', "")
self.assertEquals(decoder.decode(b'\x88'), "\u8888") _check_decode(b'\x88', "\u8888")
self.assertEquals(decoder.decode(b'\xe8'), "") _check_decode(b'\xe8', "")
self.assertRaises(UnicodeDecodeError, decoder.decode, b'', final=True) _check_decode(b'\xa2', "")
_check_decode(b'\x88', "\u8888")
decoder.setstate((b'', 0)) _check_decode(b'\xe8', "")
self.assertEquals(decoder.decode(b'\n'), "\n") self.assertRaises(UnicodeDecodeError, decoder.decode, b'', final=True)
self.assertEquals(decoder.decode(b'\r'), "")
self.assertEquals(decoder.decode(b'', final=True), "\n")
self.assertEquals(decoder.decode(b'\r', final=True), "\n")
self.assertEquals(decoder.decode(b'\r'), "")
self.assertEquals(decoder.decode(b'a'), "\na")
self.assertEquals(decoder.decode(b'\r\r\n'), "\n\n")
self.assertEquals(decoder.decode(b'\r'), "")
self.assertEquals(decoder.decode(b'\r'), "\n")
self.assertEquals(decoder.decode(b'\na'), "\na")
self.assertEquals(decoder.decode(b'\xe8\xa2\x88\r\n'), "\u8888\n")
self.assertEquals(decoder.decode(b'\xe8\xa2\x88'), "\u8888")
self.assertEquals(decoder.decode(b'\n'), "\n")
self.assertEquals(decoder.decode(b'\xe8\xa2\x88\r'), "\u8888")
self.assertEquals(decoder.decode(b'\n'), "\n")
decoder = codecs.getincrementaldecoder("utf-8")() decoder.reset()
decoder = io.IncrementalNewlineDecoder(decoder, translate=True) _check_decode(b'\n', "\n")
_check_decode(b'\r', "")
_check_decode(b'', "\n", final=True)
_check_decode(b'\r', "\n", final=True)
_check_decode(b'\r', "")
_check_decode(b'a', "\na")
_check_decode(b'\r\r\n', "\n\n")
_check_decode(b'\r', "")
_check_decode(b'\r', "\n")
_check_decode(b'\na', "\na")
_check_decode(b'\xe8\xa2\x88\r\n', "\u8888\n")
_check_decode(b'\xe8\xa2\x88', "\u8888")
_check_decode(b'\n', "\n")
_check_decode(b'\xe8\xa2\x88\r', "\u8888")
_check_decode(b'\n', "\n")
def check_newline_decoder(self, decoder, encoding):
result = []
encoder = codecs.getincrementalencoder(encoding)()
def _decode_bytewise(s):
for b in encoder.encode(s):
result.append(decoder.decode(bytes([b])))
self.assertEquals(decoder.newlines, None) self.assertEquals(decoder.newlines, None)
decoder.decode(b"abc\n\r") _decode_bytewise("abc\n\r")
self.assertEquals(decoder.newlines, '\n') self.assertEquals(decoder.newlines, '\n')
decoder.decode(b"\nabc") _decode_bytewise("\nabc")
self.assertEquals(decoder.newlines, ('\n', '\r\n')) self.assertEquals(decoder.newlines, ('\n', '\r\n'))
decoder.decode(b"abc\r") _decode_bytewise("abc\r")
self.assertEquals(decoder.newlines, ('\n', '\r\n')) self.assertEquals(decoder.newlines, ('\n', '\r\n'))
decoder.decode(b"abc") _decode_bytewise("abc")
self.assertEquals(decoder.newlines, ('\r', '\n', '\r\n')) self.assertEquals(decoder.newlines, ('\r', '\n', '\r\n'))
decoder.decode(b"abc\r") _decode_bytewise("abc\r")
self.assertEquals("".join(result), "abc\n\nabcabc\nabcabc")
decoder.reset() decoder.reset()
self.assertEquals(decoder.decode(b"abc"), "abc") self.assertEquals(decoder.decode("abc".encode(encoding)), "abc")
self.assertEquals(decoder.newlines, None) self.assertEquals(decoder.newlines, None)
def test_newline_decoder(self):
encodings = (
'utf-8', 'latin-1',
'utf-16', 'utf-16-le', 'utf-16-be',
'utf-32', 'utf-32-le', 'utf-32-be',
)
for enc in encodings:
decoder = codecs.getincrementaldecoder(enc)()
decoder = io.IncrementalNewlineDecoder(decoder, translate=True)
self.check_newline_decoder(decoder, enc)
decoder = codecs.getincrementaldecoder("utf-8")()
decoder = io.IncrementalNewlineDecoder(decoder, translate=True)
self.check_newline_decoder_utf8(decoder)
# XXX Tests for open() # XXX Tests for open()
class MiscIOTest(unittest.TestCase): class MiscIOTest(unittest.TestCase):
......
...@@ -45,6 +45,9 @@ Core and Builtins ...@@ -45,6 +45,9 @@ Core and Builtins
Library Library
------- -------
- Issue #4574: reading an UTF16-encoded text file crashes if \r on 64-char
boundary.
- Issue #4223: inspect.getsource() will now correctly display source code - Issue #4223: inspect.getsource() will now correctly display source code
for packages loaded via zipimport (or any other conformant PEP 302 for packages loaded via zipimport (or any other conformant PEP 302
loader). Original patch by Alexander Belopolsky. loader). Original patch by Alexander Belopolsky.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment