Commit cfe5f20f authored by Guido van Rossum's avatar Guido van Rossum

Got test_pickletools and test_pickle working.

(Alas, test_cpickle is still broken.)
parent f9e91c9c
...@@ -465,7 +465,7 @@ class Pickler: ...@@ -465,7 +465,7 @@ class Pickler:
self.write(BININT1 + bytes([obj])) self.write(BININT1 + bytes([obj]))
return return
if obj <= 0xffff: if obj <= 0xffff:
self.write(BININT2, bytes([obj&0xff, obj>>8])) self.write(BININT2 + bytes([obj&0xff, obj>>8]))
return return
# Next check for 4-byte signed ints: # Next check for 4-byte signed ints:
high_bits = obj >> 31 # note that Python shift sign-extends high_bits = obj >> 31 # note that Python shift sign-extends
...@@ -820,6 +820,7 @@ class Unpickler: ...@@ -820,6 +820,7 @@ class Unpickler:
key = read(1) key = read(1)
if not key: if not key:
raise EOFError raise EOFError
assert isinstance(key, bytes)
dispatch[key[0]](self) dispatch[key[0]](self)
except _Stop as stopinst: except _Stop as stopinst:
return stopinst.value return stopinst.value
...@@ -892,7 +893,7 @@ class Unpickler: ...@@ -892,7 +893,7 @@ class Unpickler:
dispatch[BININT1[0]] = load_binint1 dispatch[BININT1[0]] = load_binint1
def load_binint2(self): def load_binint2(self):
self.append(mloads(b'i' + self.read(2) + '\000\000')) self.append(mloads(b'i' + self.read(2) + b'\000\000'))
dispatch[BININT2[0]] = load_binint2 dispatch[BININT2[0]] = load_binint2
def load_long(self): def load_long(self):
...@@ -1111,7 +1112,7 @@ class Unpickler: ...@@ -1111,7 +1112,7 @@ class Unpickler:
dispatch[DUP[0]] = load_dup dispatch[DUP[0]] = load_dup
def load_get(self): def load_get(self):
self.append(self.memo[self.readline()[:-1]]) self.append(self.memo[str8(self.readline())[:-1]])
dispatch[GET[0]] = load_get dispatch[GET[0]] = load_get
def load_binget(self): def load_binget(self):
...@@ -1226,24 +1227,24 @@ def encode_long(x): ...@@ -1226,24 +1227,24 @@ def encode_long(x):
byte in the LONG1 pickling context. byte in the LONG1 pickling context.
>>> encode_long(0) >>> encode_long(0)
'' b''
>>> encode_long(255) >>> encode_long(255)
'\xff\x00' b'\xff\x00'
>>> encode_long(32767) >>> encode_long(32767)
'\xff\x7f' b'\xff\x7f'
>>> encode_long(-256) >>> encode_long(-256)
'\x00\xff' b'\x00\xff'
>>> encode_long(-32768) >>> encode_long(-32768)
'\x00\x80' b'\x00\x80'
>>> encode_long(-128) >>> encode_long(-128)
'\x80' b'\x80'
>>> encode_long(127) >>> encode_long(127)
'\x7f' b'\x7f'
>>> >>>
""" """
if x == 0: if x == 0:
return '' return b''
if x > 0: if x > 0:
ashex = hex(x) ashex = hex(x)
assert ashex.startswith("0x") assert ashex.startswith("0x")
...@@ -1284,24 +1285,24 @@ def encode_long(x): ...@@ -1284,24 +1285,24 @@ def encode_long(x):
ashex = ashex[2:] ashex = ashex[2:]
assert len(ashex) & 1 == 0, (x, ashex) assert len(ashex) & 1 == 0, (x, ashex)
binary = _binascii.unhexlify(ashex) binary = _binascii.unhexlify(ashex)
return binary[::-1] return bytes(binary[::-1])
def decode_long(data): def decode_long(data):
r"""Decode a long from a two's complement little-endian binary string. r"""Decode a long from a two's complement little-endian binary string.
>>> decode_long('') >>> decode_long(b'')
0 0
>>> decode_long("\xff\x00") >>> decode_long(b"\xff\x00")
255 255
>>> decode_long("\xff\x7f") >>> decode_long(b"\xff\x7f")
32767 32767
>>> decode_long("\x00\xff") >>> decode_long(b"\x00\xff")
-256 -256
>>> decode_long("\x00\x80") >>> decode_long(b"\x00\x80")
-32768 -32768
>>> decode_long("\x80") >>> decode_long(b"\x80")
-128 -128
>>> decode_long("\x7f") >>> decode_long(b"\x7f")
127 127
""" """
...@@ -1310,7 +1311,7 @@ def decode_long(data): ...@@ -1310,7 +1311,7 @@ def decode_long(data):
return 0 return 0
ashex = _binascii.hexlify(data[::-1]) ashex = _binascii.hexlify(data[::-1])
n = int(ashex, 16) # quadratic time before Python 2.3; linear now n = int(ashex, 16) # quadratic time before Python 2.3; linear now
if data[-1] >= '\x80': if data[-1] >= 0x80:
n -= 1 << (nbytes * 8) n -= 1 << (nbytes * 8)
return n return n
...@@ -1320,15 +1321,19 @@ def dump(obj, file, protocol=None): ...@@ -1320,15 +1321,19 @@ def dump(obj, file, protocol=None):
Pickler(file, protocol).dump(obj) Pickler(file, protocol).dump(obj)
def dumps(obj, protocol=None): def dumps(obj, protocol=None):
file = io.BytesIO() f = io.BytesIO()
Pickler(file, protocol).dump(obj) Pickler(f, protocol).dump(obj)
return file.getvalue() res = f.getvalue()
assert isinstance(res, bytes)
return res
def load(file): def load(file):
return Unpickler(file).load() return Unpickler(file).load()
def loads(str): def loads(s):
file = io.BytesIO(str) if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file).load() return Unpickler(file).load()
# Doctest # Doctest
......
...@@ -202,14 +202,14 @@ from struct import unpack as _unpack ...@@ -202,14 +202,14 @@ from struct import unpack as _unpack
def read_uint1(f): def read_uint1(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_uint1(StringIO.StringIO('\xff')) >>> read_uint1(io.BytesIO(b'\xff'))
255 255
""" """
data = f.read(1) data = f.read(1)
if data: if data:
return ord(data) return data[0]
raise ValueError("not enough data in stream to read uint1") raise ValueError("not enough data in stream to read uint1")
uint1 = ArgumentDescriptor( uint1 = ArgumentDescriptor(
...@@ -221,10 +221,10 @@ uint1 = ArgumentDescriptor( ...@@ -221,10 +221,10 @@ uint1 = ArgumentDescriptor(
def read_uint2(f): def read_uint2(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_uint2(StringIO.StringIO('\xff\x00')) >>> read_uint2(io.BytesIO(b'\xff\x00'))
255 255
>>> read_uint2(StringIO.StringIO('\xff\xff')) >>> read_uint2(io.BytesIO(b'\xff\xff'))
65535 65535
""" """
...@@ -242,10 +242,10 @@ uint2 = ArgumentDescriptor( ...@@ -242,10 +242,10 @@ uint2 = ArgumentDescriptor(
def read_int4(f): def read_int4(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_int4(StringIO.StringIO('\xff\x00\x00\x00')) >>> read_int4(io.BytesIO(b'\xff\x00\x00\x00'))
255 255
>>> read_int4(StringIO.StringIO('\x00\x00\x00\x80')) == -(2**31) >>> read_int4(io.BytesIO(b'\x00\x00\x00\x80')) == -(2**31)
True True
""" """
...@@ -261,34 +261,48 @@ int4 = ArgumentDescriptor( ...@@ -261,34 +261,48 @@ int4 = ArgumentDescriptor(
doc="Four-byte signed integer, little-endian, 2's complement.") doc="Four-byte signed integer, little-endian, 2's complement.")
def readline(f):
"""Read a line from a binary file."""
# XXX Slow but at least correct
b = bytes()
while True:
c = f.read(1)
if not c:
break
b += c
if c == b'\n':
break
return b
def read_stringnl(f, decode=True, stripquotes=True): def read_stringnl(f, decode=True, stripquotes=True):
r""" r"""
>>> import StringIO >>> import io
>>> read_stringnl(StringIO.StringIO("'abcd'\nefg\n")) >>> read_stringnl(io.BytesIO(b"'abcd'\nefg\n"))
'abcd' 'abcd'
>>> read_stringnl(StringIO.StringIO("\n")) >>> read_stringnl(io.BytesIO(b"\n"))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: no string quotes around '' ValueError: no string quotes around b''
>>> read_stringnl(StringIO.StringIO("\n"), stripquotes=False) >>> read_stringnl(io.BytesIO(b"\n"), stripquotes=False)
'' ''
>>> read_stringnl(StringIO.StringIO("''\n")) >>> read_stringnl(io.BytesIO(b"''\n"))
'' ''
>>> read_stringnl(StringIO.StringIO('"abcd"')) >>> read_stringnl(io.BytesIO(b'"abcd"'))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: no newline found when trying to read stringnl ValueError: no newline found when trying to read stringnl
Embedded escapes are undone in the result. Embedded escapes are undone in the result.
>>> read_stringnl(StringIO.StringIO(r"'a\n\\b\x00c\td'" + "\n'e'")) >>> read_stringnl(io.BytesIO(br"'a\n\\b\x00c\td'" + b"\n'e'"))
'a\n\\b\x00c\td' 'a\n\\b\x00c\td'
""" """
data = f.readline() data = readline(f)
if not data.endswith('\n'): if not data.endswith('\n'):
raise ValueError("no newline found when trying to read stringnl") raise ValueError("no newline found when trying to read stringnl")
data = data[:-1] # lose the newline data = data[:-1] # lose the newline
...@@ -336,8 +350,8 @@ stringnl_noescape = ArgumentDescriptor( ...@@ -336,8 +350,8 @@ stringnl_noescape = ArgumentDescriptor(
def read_stringnl_noescape_pair(f): def read_stringnl_noescape_pair(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_stringnl_noescape_pair(StringIO.StringIO("Queue\nEmpty\njunk")) >>> read_stringnl_noescape_pair(io.BytesIO(b"Queue\nEmpty\njunk"))
'Queue Empty' 'Queue Empty'
""" """
...@@ -358,12 +372,12 @@ stringnl_noescape_pair = ArgumentDescriptor( ...@@ -358,12 +372,12 @@ stringnl_noescape_pair = ArgumentDescriptor(
def read_string4(f): def read_string4(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_string4(StringIO.StringIO("\x00\x00\x00\x00abc")) >>> read_string4(io.BytesIO(b"\x00\x00\x00\x00abc"))
'' ''
>>> read_string4(StringIO.StringIO("\x03\x00\x00\x00abcdef")) >>> read_string4(io.BytesIO(b"\x03\x00\x00\x00abcdef"))
'abc' 'abc'
>>> read_string4(StringIO.StringIO("\x00\x00\x00\x03abcdef")) >>> read_string4(io.BytesIO(b"\x00\x00\x00\x03abcdef"))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: expected 50331648 bytes in a string4, but only 6 remain ValueError: expected 50331648 bytes in a string4, but only 6 remain
...@@ -374,7 +388,7 @@ def read_string4(f): ...@@ -374,7 +388,7 @@ def read_string4(f):
raise ValueError("string4 byte count < 0: %d" % n) raise ValueError("string4 byte count < 0: %d" % n)
data = f.read(n) data = f.read(n)
if len(data) == n: if len(data) == n:
return data return data.decode("latin-1")
raise ValueError("expected %d bytes in a string4, but only %d remain" % raise ValueError("expected %d bytes in a string4, but only %d remain" %
(n, len(data))) (n, len(data)))
...@@ -392,10 +406,10 @@ string4 = ArgumentDescriptor( ...@@ -392,10 +406,10 @@ string4 = ArgumentDescriptor(
def read_string1(f): def read_string1(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_string1(StringIO.StringIO("\x00")) >>> read_string1(io.BytesIO(b"\x00"))
'' ''
>>> read_string1(StringIO.StringIO("\x03abcdef")) >>> read_string1(io.BytesIO(b"\x03abcdef"))
'abc' 'abc'
""" """
...@@ -403,7 +417,7 @@ def read_string1(f): ...@@ -403,7 +417,7 @@ def read_string1(f):
assert n >= 0 assert n >= 0
data = f.read(n) data = f.read(n)
if len(data) == n: if len(data) == n:
return data return data.decode("latin-1")
raise ValueError("expected %d bytes in a string1, but only %d remain" % raise ValueError("expected %d bytes in a string1, but only %d remain" %
(n, len(data))) (n, len(data)))
...@@ -421,12 +435,12 @@ string1 = ArgumentDescriptor( ...@@ -421,12 +435,12 @@ string1 = ArgumentDescriptor(
def read_unicodestringnl(f): def read_unicodestringnl(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_unicodestringnl(StringIO.StringIO("abc\uabcd\njunk")) >>> read_unicodestringnl(io.BytesIO(b"abc\\uabcd\njunk")) == 'abc\uabcd'
u'abc\uabcd' True
""" """
data = f.readline() data = readline(f)
if not data.endswith('\n'): if not data.endswith('\n'):
raise ValueError("no newline found when trying to read " raise ValueError("no newline found when trying to read "
"unicodestringnl") "unicodestringnl")
...@@ -446,17 +460,17 @@ unicodestringnl = ArgumentDescriptor( ...@@ -446,17 +460,17 @@ unicodestringnl = ArgumentDescriptor(
def read_unicodestring4(f): def read_unicodestring4(f):
r""" r"""
>>> import StringIO >>> import io
>>> s = u'abcd\uabcd' >>> s = 'abcd\uabcd'
>>> enc = s.encode('utf-8') >>> enc = s.encode('utf-8')
>>> enc >>> enc
'abcd\xea\xaf\x8d' b'abcd\xea\xaf\x8d'
>>> n = chr(len(enc)) + chr(0) * 3 # little-endian 4-byte length >>> n = bytes([len(enc), 0, 0, 0]) # little-endian 4-byte length
>>> t = read_unicodestring4(StringIO.StringIO(n + enc + 'junk')) >>> t = read_unicodestring4(io.BytesIO(n + enc + b'junk'))
>>> s == t >>> s == t
True True
>>> read_unicodestring4(StringIO.StringIO(n + enc[:-1])) >>> read_unicodestring4(io.BytesIO(n + enc[:-1]))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: expected 7 bytes in a unicodestring4, but only 6 remain ValueError: expected 7 bytes in a unicodestring4, but only 6 remain
...@@ -486,14 +500,14 @@ unicodestring4 = ArgumentDescriptor( ...@@ -486,14 +500,14 @@ unicodestring4 = ArgumentDescriptor(
def read_decimalnl_short(f): def read_decimalnl_short(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_decimalnl_short(StringIO.StringIO("1234\n56")) >>> read_decimalnl_short(io.BytesIO(b"1234\n56"))
1234 1234
>>> read_decimalnl_short(StringIO.StringIO("1234L\n56")) >>> read_decimalnl_short(io.BytesIO(b"1234L\n56"))
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: trailing 'L' not allowed in '1234L' ValueError: trailing 'L' not allowed in b'1234L'
""" """
s = read_stringnl(f, decode=False, stripquotes=False) s = read_stringnl(f, decode=False, stripquotes=False)
...@@ -515,12 +529,12 @@ def read_decimalnl_short(f): ...@@ -515,12 +529,12 @@ def read_decimalnl_short(f):
def read_decimalnl_long(f): def read_decimalnl_long(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_decimalnl_long(StringIO.StringIO("1234L\n56")) >>> read_decimalnl_long(io.BytesIO(b"1234L\n56"))
1234 1234
>>> read_decimalnl_long(StringIO.StringIO("123456789012345678901234L\n6")) >>> read_decimalnl_long(io.BytesIO(b"123456789012345678901234L\n6"))
123456789012345678901234 123456789012345678901234
""" """
...@@ -554,8 +568,8 @@ decimalnl_long = ArgumentDescriptor( ...@@ -554,8 +568,8 @@ decimalnl_long = ArgumentDescriptor(
def read_floatnl(f): def read_floatnl(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_floatnl(StringIO.StringIO("-1.25\n6")) >>> read_floatnl(io.BytesIO(b"-1.25\n6"))
-1.25 -1.25
""" """
s = read_stringnl(f, decode=False, stripquotes=False) s = read_stringnl(f, decode=False, stripquotes=False)
...@@ -576,11 +590,11 @@ floatnl = ArgumentDescriptor( ...@@ -576,11 +590,11 @@ floatnl = ArgumentDescriptor(
def read_float8(f): def read_float8(f):
r""" r"""
>>> import StringIO, struct >>> import io, struct
>>> raw = struct.pack(">d", -1.25) >>> raw = struct.pack(">d", -1.25)
>>> raw >>> raw
'\xbf\xf4\x00\x00\x00\x00\x00\x00' b'\xbf\xf4\x00\x00\x00\x00\x00\x00'
>>> read_float8(StringIO.StringIO(raw + "\n")) >>> read_float8(io.BytesIO(raw + b"\n"))
-1.25 -1.25
""" """
...@@ -614,16 +628,16 @@ from pickle import decode_long ...@@ -614,16 +628,16 @@ from pickle import decode_long
def read_long1(f): def read_long1(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_long1(StringIO.StringIO("\x00")) >>> read_long1(io.BytesIO(b"\x00"))
0 0
>>> read_long1(StringIO.StringIO("\x02\xff\x00")) >>> read_long1(io.BytesIO(b"\x02\xff\x00"))
255 255
>>> read_long1(StringIO.StringIO("\x02\xff\x7f")) >>> read_long1(io.BytesIO(b"\x02\xff\x7f"))
32767 32767
>>> read_long1(StringIO.StringIO("\x02\x00\xff")) >>> read_long1(io.BytesIO(b"\x02\x00\xff"))
-256 -256
>>> read_long1(StringIO.StringIO("\x02\x00\x80")) >>> read_long1(io.BytesIO(b"\x02\x00\x80"))
-32768 -32768
""" """
...@@ -646,16 +660,16 @@ long1 = ArgumentDescriptor( ...@@ -646,16 +660,16 @@ long1 = ArgumentDescriptor(
def read_long4(f): def read_long4(f):
r""" r"""
>>> import StringIO >>> import io
>>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\xff\x00")) >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\xff\x00"))
255 255
>>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\xff\x7f")) >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\xff\x7f"))
32767 32767
>>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\x00\xff")) >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\x00\xff"))
-256 -256
>>> read_long4(StringIO.StringIO("\x02\x00\x00\x00\x00\x80")) >>> read_long4(io.BytesIO(b"\x02\x00\x00\x00\x00\x80"))
-32768 -32768
>>> read_long1(StringIO.StringIO("\x00\x00\x00\x00")) >>> read_long1(io.BytesIO(b"\x00\x00\x00\x00"))
0 0
""" """
...@@ -701,7 +715,7 @@ class StackObject(object): ...@@ -701,7 +715,7 @@ class StackObject(object):
) )
def __init__(self, name, obtype, doc): def __init__(self, name, obtype, doc):
assert isinstance(name, str) assert isinstance(name, basestring)
self.name = name self.name = name
assert isinstance(obtype, type) or isinstance(obtype, tuple) assert isinstance(obtype, type) or isinstance(obtype, tuple)
...@@ -710,7 +724,7 @@ class StackObject(object): ...@@ -710,7 +724,7 @@ class StackObject(object):
assert isinstance(contained, type) assert isinstance(contained, type)
self.obtype = obtype self.obtype = obtype
assert isinstance(doc, str) assert isinstance(doc, basestring)
self.doc = doc self.doc = doc
def __repr__(self): def __repr__(self):
...@@ -846,10 +860,10 @@ class OpcodeInfo(object): ...@@ -846,10 +860,10 @@ class OpcodeInfo(object):
def __init__(self, name, code, arg, def __init__(self, name, code, arg,
stack_before, stack_after, proto, doc): stack_before, stack_after, proto, doc):
assert isinstance(name, str) assert isinstance(name, basestring)
self.name = name self.name = name
assert isinstance(code, str) assert isinstance(code, basestring)
assert len(code) == 1 assert len(code) == 1
self.code = code self.code = code
...@@ -869,7 +883,7 @@ class OpcodeInfo(object): ...@@ -869,7 +883,7 @@ class OpcodeInfo(object):
assert isinstance(proto, int) and 0 <= proto <= 2 assert isinstance(proto, int) and 0 <= proto <= 2
self.proto = proto self.proto = proto
assert isinstance(doc, str) assert isinstance(doc, basestring)
self.doc = doc self.doc = doc
I = OpcodeInfo I = OpcodeInfo
...@@ -1819,10 +1833,9 @@ def genops(pickle): ...@@ -1819,10 +1833,9 @@ def genops(pickle):
to query its current position) pos is None. to query its current position) pos is None.
""" """
import cStringIO as StringIO if isinstance(pickle, bytes):
import io
if isinstance(pickle, str): pickle = io.BytesIO(pickle)
pickle = StringIO.StringIO(pickle)
if hasattr(pickle, "tell"): if hasattr(pickle, "tell"):
getpos = pickle.tell getpos = pickle.tell
...@@ -1832,9 +1845,9 @@ def genops(pickle): ...@@ -1832,9 +1845,9 @@ def genops(pickle):
while True: while True:
pos = getpos() pos = getpos()
code = pickle.read(1) code = pickle.read(1)
opcode = code2op.get(code) opcode = code2op.get(code.decode("latin-1"))
if opcode is None: if opcode is None:
if code == "": if code == b"":
raise ValueError("pickle exhausted before seeing STOP") raise ValueError("pickle exhausted before seeing STOP")
else: else:
raise ValueError("at position %s, opcode %r unknown" % ( raise ValueError("at position %s, opcode %r unknown" % (
...@@ -1845,7 +1858,7 @@ def genops(pickle): ...@@ -1845,7 +1858,7 @@ def genops(pickle):
else: else:
arg = opcode.arg.reader(pickle) arg = opcode.arg.reader(pickle)
yield opcode, arg, pos yield opcode, arg, pos
if code == '.': if code == b'.':
assert opcode.name == 'STOP' assert opcode.name == 'STOP'
break break
...@@ -1995,7 +2008,7 @@ class _Example: ...@@ -1995,7 +2008,7 @@ class _Example:
_dis_test = r""" _dis_test = r"""
>>> import pickle >>> import pickle
>>> x = [1, 2, (3, 4), {'abc': u"def"}] >>> x = [1, 2, (3, 4), {str8('abc'): "def"}]
>>> pkl = pickle.dumps(x, 0) >>> pkl = pickle.dumps(x, 0)
>>> dis(pkl) >>> dis(pkl)
0: ( MARK 0: ( MARK
...@@ -2016,7 +2029,7 @@ _dis_test = r""" ...@@ -2016,7 +2029,7 @@ _dis_test = r"""
27: p PUT 2 27: p PUT 2
30: S STRING 'abc' 30: S STRING 'abc'
37: p PUT 3 37: p PUT 3
40: V UNICODE u'def' 40: V UNICODE 'def'
45: p PUT 4 45: p PUT 4
48: s SETITEM 48: s SETITEM
49: a APPEND 49: a APPEND
...@@ -2041,7 +2054,7 @@ Try again with a "binary" pickle. ...@@ -2041,7 +2054,7 @@ Try again with a "binary" pickle.
17: q BINPUT 2 17: q BINPUT 2
19: U SHORT_BINSTRING 'abc' 19: U SHORT_BINSTRING 'abc'
24: q BINPUT 3 24: q BINPUT 3
26: X BINUNICODE u'def' 26: X BINUNICODE 'def'
34: q BINPUT 4 34: q BINPUT 4
36: s SETITEM 36: s SETITEM
37: e APPENDS (MARK at 3) 37: e APPENDS (MARK at 3)
...@@ -2216,13 +2229,14 @@ highest protocol among opcodes = 2 ...@@ -2216,13 +2229,14 @@ highest protocol among opcodes = 2
_memo_test = r""" _memo_test = r"""
>>> import pickle >>> import pickle
>>> from StringIO import StringIO >>> import io
>>> f = StringIO() >>> f = io.BytesIO()
>>> p = pickle.Pickler(f, 2) >>> p = pickle.Pickler(f, 2)
>>> x = [1, 2, 3] >>> x = [1, 2, 3]
>>> p.dump(x) >>> p.dump(x)
>>> p.dump(x) >>> p.dump(x)
>>> f.seek(0) >>> f.seek(0)
0
>>> memo = {} >>> memo = {}
>>> dis(f, memo=memo) >>> dis(f, memo=memo)
0: \x80 PROTO 2 0: \x80 PROTO 2
......
...@@ -21,7 +21,7 @@ protocols = range(pickle.HIGHEST_PROTOCOL + 1) ...@@ -21,7 +21,7 @@ protocols = range(pickle.HIGHEST_PROTOCOL + 1)
# Return True if opcode code appears in the pickle, else False. # Return True if opcode code appears in the pickle, else False.
def opcode_in_pickle(code, pickle): def opcode_in_pickle(code, pickle):
for op, dummy, dummy in pickletools.genops(pickle): for op, dummy, dummy in pickletools.genops(pickle):
if op.code == code: if op.code == code.decode("latin-1"):
return True return True
return False return False
...@@ -29,7 +29,7 @@ def opcode_in_pickle(code, pickle): ...@@ -29,7 +29,7 @@ def opcode_in_pickle(code, pickle):
def count_opcode(code, pickle): def count_opcode(code, pickle):
n = 0 n = 0
for op, dummy, dummy in pickletools.genops(pickle): for op, dummy, dummy in pickletools.genops(pickle):
if op.code == code: if op.code == code.decode("latin-1"):
n += 1 n += 1
return n return n
...@@ -95,7 +95,7 @@ class use_metaclass(object, metaclass=metaclass): ...@@ -95,7 +95,7 @@ class use_metaclass(object, metaclass=metaclass):
# the object returned by create_data(). # the object returned by create_data().
# break into multiple strings to avoid confusing font-lock-mode # break into multiple strings to avoid confusing font-lock-mode
DATA0 = """(lp1 DATA0 = b"""(lp1
I0 I0
aL1L aL1L
aF2 aF2
...@@ -103,7 +103,7 @@ ac__builtin__ ...@@ -103,7 +103,7 @@ ac__builtin__
complex complex
p2 p2
""" + \ """ + \
"""(F3 b"""(F3
F0 F0
tRp3 tRp3
aI1 aI1
...@@ -118,15 +118,15 @@ aI2147483647 ...@@ -118,15 +118,15 @@ aI2147483647
aI-2147483647 aI-2147483647
aI-2147483648 aI-2147483648
a""" + \ a""" + \
"""(S'abc' b"""(S'abc'
p4 p4
g4 g4
""" + \ """ + \
"""(i__main__ b"""(i__main__
C C
p5 p5
""" + \ """ + \
"""(dp6 b"""(dp6
S'foo' S'foo'
p7 p7
I1 I1
...@@ -213,14 +213,14 @@ DATA0_DIS = """\ ...@@ -213,14 +213,14 @@ DATA0_DIS = """\
highest protocol among opcodes = 0 highest protocol among opcodes = 0
""" """
DATA1 = (']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00' DATA1 = (b']q\x01(K\x00L1L\nG@\x00\x00\x00\x00\x00\x00\x00'
'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00' b'c__builtin__\ncomplex\nq\x02(G@\x08\x00\x00\x00\x00\x00'
'\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff' b'\x00G\x00\x00\x00\x00\x00\x00\x00\x00tRq\x03K\x01J\xff\xff'
'\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff' b'\xff\xffK\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xff'
'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00' b'J\x01\x00\xff\xffJ\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00'
'\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n' b'\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n'
'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh' b'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh'
'\x06tq\nh\nK\x05e.' b'\x06tq\nh\nK\x05e.'
) )
# Disassembly of DATA1. # Disassembly of DATA1.
...@@ -280,13 +280,13 @@ DATA1_DIS = """\ ...@@ -280,13 +280,13 @@ DATA1_DIS = """\
highest protocol among opcodes = 1 highest protocol among opcodes = 1
""" """
DATA2 = ('\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00' DATA2 = (b'\x80\x02]q\x01(K\x00\x8a\x01\x01G@\x00\x00\x00\x00\x00\x00\x00'
'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00' b'c__builtin__\ncomplex\nq\x02G@\x08\x00\x00\x00\x00\x00\x00G\x00'
'\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK' b'\x00\x00\x00\x00\x00\x00\x00\x86Rq\x03K\x01J\xff\xff\xff\xffK'
'\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff' b'\xffJ\x01\xff\xff\xffJ\x00\xff\xff\xffM\xff\xffJ\x01\x00\xff\xff'
'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00' b'J\x00\x00\xff\xffJ\xff\xff\xff\x7fJ\x01\x00\x00\x80J\x00\x00\x00'
'\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo' b'\x80(U\x03abcq\x04h\x04(c__main__\nC\nq\x05oq\x06}q\x07(U\x03foo'
'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.') b'q\x08K\x01U\x03barq\tK\x02ubh\x06tq\nh\nK\x05e.')
# Disassembly of DATA2. # Disassembly of DATA2.
DATA2_DIS = """\ DATA2_DIS = """\
...@@ -465,7 +465,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -465,7 +465,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assert_(x[0].attr[1] is x) self.assert_(x[0].attr[1] is x)
def test_garyp(self): def test_garyp(self):
self.assertRaises(self.error, self.loads, 'garyp') self.assertRaises(self.error, self.loads, b'garyp')
def test_insecure_strings(self): def test_insecure_strings(self):
insecure = ["abc", "2 + 2", # not quoted insecure = ["abc", "2 + 2", # not quoted
...@@ -479,7 +479,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -479,7 +479,7 @@ class AbstractPickleTests(unittest.TestCase):
#"'\\\\a\'\'\'\\\'\\\\\''", #"'\\\\a\'\'\'\\\'\\\\\''",
] ]
for s in insecure: for s in insecure:
buf = "S" + s + "\012p0\012." buf = b"S" + bytes(s) + b"\012p0\012."
self.assertRaises(ValueError, self.loads, buf) self.assertRaises(ValueError, self.loads, buf)
if have_unicode: if have_unicode:
...@@ -505,12 +505,12 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -505,12 +505,12 @@ class AbstractPickleTests(unittest.TestCase):
def test_maxint64(self): def test_maxint64(self):
maxint64 = (1 << 63) - 1 maxint64 = (1 << 63) - 1
data = 'I' + str(maxint64) + '\n.' data = b'I' + bytes(str(maxint64)) + b'\n.'
got = self.loads(data) got = self.loads(data)
self.assertEqual(got, maxint64) self.assertEqual(got, maxint64)
# Try too with a bogus literal. # Try too with a bogus literal.
data = 'I' + str(maxint64) + 'JUNK\n.' data = b'I' + bytes(str(maxint64)) + b'JUNK\n.'
self.assertRaises(ValueError, self.loads, data) self.assertRaises(ValueError, self.loads, data)
def test_long(self): def test_long(self):
...@@ -535,7 +535,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -535,7 +535,7 @@ class AbstractPickleTests(unittest.TestCase):
@run_with_locale('LC_ALL', 'de_DE', 'fr_FR') @run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
def test_float_format(self): def test_float_format(self):
# make sure that floats are formatted locale independent # make sure that floats are formatted locale independent
self.assertEqual(self.dumps(1.2)[0:3], 'F1.') self.assertEqual(self.dumps(1.2)[0:3], b'F1.')
def test_reduce(self): def test_reduce(self):
pass pass
...@@ -577,12 +577,12 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -577,12 +577,12 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
expected = build_none expected = build_none
if proto >= 2: if proto >= 2:
expected = pickle.PROTO + chr(proto) + expected expected = pickle.PROTO + bytes([proto]) + expected
p = self.dumps(None, proto) p = self.dumps(None, proto)
self.assertEqual(p, expected) self.assertEqual(p, expected)
oob = protocols[-1] + 1 # a future protocol oob = protocols[-1] + 1 # a future protocol
badpickle = pickle.PROTO + chr(oob) + build_none badpickle = pickle.PROTO + bytes([oob]) + build_none
try: try:
self.loads(badpickle) self.loads(badpickle)
except ValueError as detail: except ValueError as detail:
...@@ -708,8 +708,8 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -708,8 +708,8 @@ class AbstractPickleTests(unittest.TestCase):
# Dump using protocol 1 for comparison. # Dump using protocol 1 for comparison.
s1 = self.dumps(x, 1) s1 = self.dumps(x, 1)
self.assert_(__name__ in s1) self.assert_(bytes(__name__) in s1)
self.assert_("MyList" in s1) self.assert_(b"MyList" in s1)
self.assertEqual(opcode_in_pickle(opcode, s1), False) self.assertEqual(opcode_in_pickle(opcode, s1), False)
y = self.loads(s1) y = self.loads(s1)
...@@ -718,9 +718,9 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -718,9 +718,9 @@ class AbstractPickleTests(unittest.TestCase):
# Dump using protocol 2 for test. # Dump using protocol 2 for test.
s2 = self.dumps(x, 2) s2 = self.dumps(x, 2)
self.assert_(__name__ not in s2) self.assert_(bytes(__name__) not in s2)
self.assert_("MyList" not in s2) self.assert_(b"MyList" not in s2)
self.assertEqual(opcode_in_pickle(opcode, s2), True) self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2))
y = self.loads(s2) y = self.loads(s2)
self.assertEqual(list(x), list(y)) self.assertEqual(list(x), list(y))
...@@ -770,6 +770,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -770,6 +770,7 @@ class AbstractPickleTests(unittest.TestCase):
x = dict.fromkeys(range(n)) x = dict.fromkeys(range(n))
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
assert isinstance(s, bytes)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assertEqual(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s) num_setitems = count_opcode(pickle.SETITEMS, s)
......
import pickle import pickle
import unittest import unittest
from cStringIO import StringIO import io
from test import test_support from test import test_support
...@@ -26,16 +26,16 @@ class PicklerTests(AbstractPickleTests): ...@@ -26,16 +26,16 @@ class PicklerTests(AbstractPickleTests):
error = KeyError error = KeyError
def dumps(self, arg, proto=0, fast=0): def dumps(self, arg, proto=0, fast=0):
f = StringIO() f = io.BytesIO()
p = pickle.Pickler(f, proto) p = pickle.Pickler(f, proto)
if fast: if fast:
p.fast = fast p.fast = fast
p.dump(arg) p.dump(arg)
f.seek(0) f.seek(0)
return f.read() return bytes(f.read())
def loads(self, buf): def loads(self, buf):
f = StringIO(buf) f = io.BytesIO(buf)
u = pickle.Unpickler(f) u = pickle.Unpickler(f)
return u.load() return u.load()
...@@ -45,7 +45,7 @@ class PersPicklerTests(AbstractPersistentPicklerTests): ...@@ -45,7 +45,7 @@ class PersPicklerTests(AbstractPersistentPicklerTests):
class PersPickler(pickle.Pickler): class PersPickler(pickle.Pickler):
def persistent_id(subself, obj): def persistent_id(subself, obj):
return self.persistent_id(obj) return self.persistent_id(obj)
f = StringIO() f = io.BytesIO()
p = PersPickler(f, proto) p = PersPickler(f, proto)
if fast: if fast:
p.fast = fast p.fast = fast
...@@ -57,7 +57,7 @@ class PersPicklerTests(AbstractPersistentPicklerTests): ...@@ -57,7 +57,7 @@ class PersPicklerTests(AbstractPersistentPicklerTests):
class PersUnpickler(pickle.Unpickler): class PersUnpickler(pickle.Unpickler):
def persistent_load(subself, obj): def persistent_load(subself, obj):
return self.persistent_load(obj) return self.persistent_load(obj)
f = StringIO(buf) f = io.BytesIO(buf)
u = PersUnpickler(f) u = PersUnpickler(f)
return u.load() return u.load()
......
...@@ -5241,6 +5241,13 @@ cpm_dumps(PyObject *self, PyObject *args, PyObject *kwds) ...@@ -5241,6 +5241,13 @@ cpm_dumps(PyObject *self, PyObject *args, PyObject *kwds)
goto finally; goto finally;
res = PycStringIO->cgetvalue(file); res = PycStringIO->cgetvalue(file);
if (res == NULL)
goto finally;
if (!PyBytes_Check(res)) {
PyObject *tmp = res;
res = PyBytes_FromObject(res);
Py_DECREF(tmp);
}
finally: finally:
Py_XDECREF(pickler); Py_XDECREF(pickler);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment