Commit 771d8341 authored by Collin Winter's avatar Collin Winter

Port r71408 to py3k: issue 5665, add more pickling tests.

parent f4ac1494
import io
import unittest import unittest
import pickle import pickle
import pickletools import pickletools
...@@ -842,7 +843,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -842,7 +843,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(x.bar, y.bar) self.assertEqual(x.bar, y.bar)
def test_reduce_overrides_default_reduce_ex(self): def test_reduce_overrides_default_reduce_ex(self):
for proto in 0, 1, 2: for proto in protocols:
x = REX_one() x = REX_one()
self.assertEqual(x._reduce_called, 0) self.assertEqual(x._reduce_called, 0)
s = self.dumps(x, proto) s = self.dumps(x, proto)
...@@ -851,7 +852,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -851,7 +852,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(y._reduce_called, 0) self.assertEqual(y._reduce_called, 0)
def test_reduce_ex_called(self): def test_reduce_ex_called(self):
for proto in 0, 1, 2: for proto in protocols:
x = REX_two() x = REX_two()
self.assertEqual(x._proto, None) self.assertEqual(x._proto, None)
s = self.dumps(x, proto) s = self.dumps(x, proto)
...@@ -860,7 +861,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -860,7 +861,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(y._proto, None) self.assertEqual(y._proto, None)
def test_reduce_ex_overrides_reduce(self): def test_reduce_ex_overrides_reduce(self):
for proto in 0, 1, 2: for proto in protocols:
x = REX_three() x = REX_three()
self.assertEqual(x._proto, None) self.assertEqual(x._proto, None)
s = self.dumps(x, proto) s = self.dumps(x, proto)
...@@ -869,7 +870,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -869,7 +870,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(y._proto, None) self.assertEqual(y._proto, None)
def test_reduce_ex_calls_base(self): def test_reduce_ex_calls_base(self):
for proto in 0, 1, 2: for proto in protocols:
x = REX_four() x = REX_four()
self.assertEqual(x._proto, None) self.assertEqual(x._proto, None)
s = self.dumps(x, proto) s = self.dumps(x, proto)
...@@ -878,7 +879,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -878,7 +879,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(y._proto, proto) self.assertEqual(y._proto, proto)
def test_reduce_calls_base(self): def test_reduce_calls_base(self):
for proto in 0, 1, 2: for proto in protocols:
x = REX_five() x = REX_five()
self.assertEqual(x._reduce_called, 0) self.assertEqual(x._reduce_called, 0)
s = self.dumps(x, proto) s = self.dumps(x, proto)
...@@ -917,6 +918,20 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -917,6 +918,20 @@ class AbstractPickleTests(unittest.TestCase):
except (pickle.PickleError): except (pickle.PickleError):
pass pass
def test_many_puts_and_gets(self):
# Test that internal data structures correctly deal with lots of
# puts/gets.
keys = ("aaa" + str(i) for i in range(100))
large_dict = dict((k, [4, 5, 6]) for k in keys)
obj = [dict(large_dict), dict(large_dict), dict(large_dict)]
for proto in protocols:
dumped = self.dumps(obj, proto)
loaded = self.loads(dumped)
self.assertEqual(loaded, obj,
"Failed protocol %d: %r != %r"
% (proto, obj, loaded))
# Test classes for reduce_ex # Test classes for reduce_ex
class REX_one(object): class REX_one(object):
...@@ -1002,6 +1017,7 @@ class BadGetattr: ...@@ -1002,6 +1017,7 @@ class BadGetattr:
def __getattr__(self, key): def __getattr__(self, key):
self.foo self.foo
class AbstractPickleModuleTests(unittest.TestCase): class AbstractPickleModuleTests(unittest.TestCase):
def test_dump_closed_file(self): def test_dump_closed_file(self):
...@@ -1022,13 +1038,20 @@ class AbstractPickleModuleTests(unittest.TestCase): ...@@ -1022,13 +1038,20 @@ class AbstractPickleModuleTests(unittest.TestCase):
finally: finally:
os.remove(TESTFN) os.remove(TESTFN)
def test_load_from_and_dump_to_file(self):
stream = io.BytesIO()
data = [123, {}, 124]
pickle.dump(data, stream)
stream.seek(0)
unpickled = pickle.load(stream)
self.assertEqual(unpickled, data)
def test_highest_protocol(self): def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes. # Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(pickle.HIGHEST_PROTOCOL, 3) self.assertEqual(pickle.HIGHEST_PROTOCOL, 3)
def test_callapi(self): def test_callapi(self):
from io import BytesIO f = io.BytesIO()
f = BytesIO()
# With and without keyword arguments # With and without keyword arguments
pickle.dump(123, f, -1) pickle.dump(123, f, -1)
pickle.dump(123, file=f, protocol=-1) pickle.dump(123, file=f, protocol=-1)
...@@ -1039,7 +1062,6 @@ class AbstractPickleModuleTests(unittest.TestCase): ...@@ -1039,7 +1062,6 @@ class AbstractPickleModuleTests(unittest.TestCase):
def test_bad_init(self): def test_bad_init(self):
# Test issue3664 (pickle can segfault from a badly initialized Pickler). # Test issue3664 (pickle can segfault from a badly initialized Pickler).
from io import BytesIO
# Override initialization without calling __init__() of the superclass. # Override initialization without calling __init__() of the superclass.
class BadPickler(pickle.Pickler): class BadPickler(pickle.Pickler):
def __init__(self): pass def __init__(self): pass
...@@ -1091,6 +1113,121 @@ class AbstractPersistentPicklerTests(unittest.TestCase): ...@@ -1091,6 +1113,121 @@ class AbstractPersistentPicklerTests(unittest.TestCase):
self.assertEqual(self.id_count, 5) self.assertEqual(self.id_count, 5)
self.assertEqual(self.load_count, 5) self.assertEqual(self.load_count, 5)
class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
pickler_class = None
unpickler_class = None
def setUp(self):
assert self.pickler_class
assert self.unpickler_class
def test_clear_pickler_memo(self):
# To test whether clear_memo() has any effect, we pickle an object,
# then pickle it again without clearing the memo; the two serialized
# forms should be different. If we clear_memo() and then pickle the
# object again, the third serialized form should be identical to the
# first one we obtained.
data = ["abcdefg", "abcdefg", 44]
f = io.BytesIO()
pickler = self.pickler_class(f)
pickler.dump(data)
first_pickled = f.getvalue()
# Reset StringIO object.
f.seek(0)
f.truncate()
pickler.dump(data)
second_pickled = f.getvalue()
# Reset the Pickler and StringIO objects.
pickler.clear_memo()
f.seek(0)
f.truncate()
pickler.dump(data)
third_pickled = f.getvalue()
self.assertNotEqual(first_pickled, second_pickled)
self.assertEqual(first_pickled, third_pickled)
def test_priming_pickler_memo(self):
# Verify that we can set the Pickler's memo attribute.
data = ["abcdefg", "abcdefg", 44]
f = io.BytesIO()
pickler = self.pickler_class(f)
pickler.dump(data)
first_pickled = f.getvalue()
f = io.BytesIO()
primed = self.pickler_class(f)
primed.memo = pickler.memo
primed.dump(data)
primed_pickled = f.getvalue()
self.assertNotEqual(first_pickled, primed_pickled)
def test_priming_unpickler_memo(self):
# Verify that we can set the Unpickler's memo attribute.
data = ["abcdefg", "abcdefg", 44]
f = io.BytesIO()
pickler = self.pickler_class(f)
pickler.dump(data)
first_pickled = f.getvalue()
f = io.BytesIO()
primed = self.pickler_class(f)
primed.memo = pickler.memo
primed.dump(data)
primed_pickled = f.getvalue()
unpickler = self.unpickler_class(io.BytesIO(first_pickled))
unpickled_data1 = unpickler.load()
self.assertEqual(unpickled_data1, data)
primed = self.unpickler_class(io.BytesIO(primed_pickled))
primed.memo = unpickler.memo
unpickled_data2 = primed.load()
primed.memo.clear()
self.assertEqual(unpickled_data2, data)
self.assertTrue(unpickled_data2 is unpickled_data1)
def test_reusing_unpickler_objects(self):
data1 = ["abcdefg", "abcdefg", 44]
f = io.BytesIO()
pickler = self.pickler_class(f)
pickler.dump(data1)
pickled1 = f.getvalue()
data2 = ["abcdefg", 44, 44]
f = io.BytesIO()
pickler = self.pickler_class(f)
pickler.dump(data2)
pickled2 = f.getvalue()
f = io.BytesIO()
f.write(pickled1)
f.seek(0)
unpickler = self.unpickler_class(f)
self.assertEqual(unpickler.load(), data1)
f.seek(0)
f.truncate()
f.write(pickled2)
f.seek(0)
self.assertEqual(unpickler.load(), data2)
if __name__ == "__main__": if __name__ == "__main__":
# Print some stuff that can be used to rewrite DATA{0,1,2} # Print some stuff that can be used to rewrite DATA{0,1,2}
from pickletools import dis from pickletools import dis
......
...@@ -6,6 +6,7 @@ from test import support ...@@ -6,6 +6,7 @@ from test import support
from test.pickletester import AbstractPickleTests from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests from test.pickletester import AbstractPersistentPicklerTests
from test.pickletester import AbstractPicklerUnpicklerObjectTests
try: try:
import _pickle import _pickle
...@@ -60,6 +61,12 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests): ...@@ -60,6 +61,12 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
return u.load() return u.load()
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
pickler_class = pickle._Pickler
unpickler_class = pickle._Unpickler
if has_c_implementation: if has_c_implementation:
class CPicklerTests(PyPicklerTests): class CPicklerTests(PyPicklerTests):
pickler = _pickle.Pickler pickler = _pickle.Pickler
...@@ -69,11 +76,26 @@ if has_c_implementation: ...@@ -69,11 +76,26 @@ if has_c_implementation:
pickler = _pickle.Pickler pickler = _pickle.Pickler
unpickler = _pickle.Unpickler unpickler = _pickle.Unpickler
class CDumpPickle_LoadPickle(PyPicklerTests):
pickler = _pickle.Pickler
unpickler = pickle._Unpickler
class DumpPickle_CLoadPickle(PyPicklerTests):
pickler = pickle._Pickler
unpickler = _pickle.Unpickler
class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
pickler_class = _pickle.Pickler
unpickler_class = _pickle.Unpickler
def test_main(): def test_main():
tests = [PickleTests, PyPicklerTests, PyPersPicklerTests] tests = [PickleTests, PyPicklerTests, PyPersPicklerTests]
if has_c_implementation: if has_c_implementation:
tests.extend([CPicklerTests, CPersPicklerTests]) tests.extend([CPicklerTests, CPersPicklerTests,
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
PyPicklerUnpicklerObjectTests,
CPicklerUnpicklerObjectTests])
support.run_unittest(*tests) support.run_unittest(*tests)
support.run_doctest(pickle) support.run_doctest(pickle)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment