Commit cc313061 authored by Alexandre Vassalotti's avatar Alexandre Vassalotti

Issue 2917: Merge the pickle and cPickle module.

parent 1e637b73
......@@ -174,7 +174,7 @@ __all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)])
# Pickling machinery
class Pickler:
class _Pickler:
def __init__(self, file, protocol=None):
"""This takes a binary file for writing a pickle data stream.
......@@ -182,21 +182,19 @@ class Pickler:
All protocols now read and write bytes.
The optional protocol argument tells the pickler to use the
given protocol; supported protocols are 0, 1, 2. The default
protocol is 2; it's been supported for many years now.
Protocol 1 is more efficient than protocol 0; protocol 2 is
more efficient than protocol 1.
given protocol; supported protocols are 0, 1, 2, 3. The default
protocol is 3; a backward-incompatible protocol designed for
Python 3.0.
Specifying a negative protocol version selects the highest
protocol version supported. The higher the protocol used, the
more recent the version of Python needed to read the pickle
produced.
The file parameter must have a write() method that accepts a single
string argument. It can thus be an open file object, a StringIO
object, or any other custom object that meets this interface.
The file argument must have a write() method that accepts a single
bytes argument. It can thus be a file object opened for binary
writing, a io.BytesIO instance, or any other custom object that
meets this interface.
"""
if protocol is None:
protocol = DEFAULT_PROTOCOL
......@@ -204,7 +202,10 @@ class Pickler:
protocol = HIGHEST_PROTOCOL
elif not 0 <= protocol <= HIGHEST_PROTOCOL:
raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
self.write = file.write
try:
self.write = file.write
except AttributeError:
raise TypeError("file must have a 'write' attribute")
self.memo = {}
self.proto = int(protocol)
self.bin = protocol >= 1
......@@ -270,10 +271,10 @@ class Pickler:
return GET + repr(i).encode("ascii") + b'\n'
def save(self, obj):
def save(self, obj, save_persistent_id=True):
# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid:
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
......@@ -341,7 +342,7 @@ class Pickler:
def save_pers(self, pid):
# Save a persistent id reference
if self.bin:
self.save(pid)
self.save(pid, save_persistent_id=False)
self.write(BINPERSID)
else:
self.write(PERSID + str(pid).encode("ascii") + b'\n')
......@@ -350,13 +351,13 @@ class Pickler:
listitems=None, dictitems=None, obj=None):
# This API is called by some subclasses
# Assert that args is a tuple or None
# Assert that args is a tuple
if not isinstance(args, tuple):
raise PicklingError("args from reduce() should be a tuple")
raise PicklingError("args from save_reduce() should be a tuple")
# Assert that func is callable
if not hasattr(func, '__call__'):
raise PicklingError("func from reduce should be callable")
raise PicklingError("func from save_reduce() should be callable")
save = self.save
write = self.write
......@@ -438,31 +439,6 @@ class Pickler:
self.write(obj and TRUE or FALSE)
dispatch[bool] = save_bool
def save_int(self, obj, pack=struct.pack):
if self.bin:
# If the int is small enough to fit in a signed 4-byte 2's-comp
# format, we can store it more efficiently than the general
# case.
# First one- and two-byte unsigned ints:
if obj >= 0:
if obj <= 0xff:
self.write(BININT1 + bytes([obj]))
return
if obj <= 0xffff:
self.write(BININT2 + bytes([obj&0xff, obj>>8]))
return
# Next check for 4-byte signed ints:
high_bits = obj >> 31 # note that Python shift sign-extends
if high_bits == 0 or high_bits == -1:
# All high bits are copies of bit 2**31, so the value
# fits in a 4-byte signed int.
self.write(BININT + pack("<i", obj))
return
# Text pickle, or int too big to fit in signed 4-byte format.
self.write(INT + repr(obj).encode("ascii") + b'\n')
# XXX save_int is merged into save_long
# dispatch[int] = save_int
def save_long(self, obj, pack=struct.pack):
if self.bin:
# If the int is small enough to fit in a signed 4-byte 2's-comp
......@@ -503,7 +479,7 @@ class Pickler:
def save_bytes(self, obj, pack=struct.pack):
if self.proto < 3:
self.save_reduce(bytes, (list(obj),))
self.save_reduce(bytes, (list(obj),), obj=obj)
return
n = len(obj)
if n < 256:
......@@ -579,12 +555,6 @@ class Pickler:
dispatch[tuple] = save_tuple
# save_empty_tuple() isn't used by anything in Python 2.3. However, I
# found a Pickler subclass in Zope3 that calls it, so it's not harmless
# to remove it.
def save_empty_tuple(self, obj):
self.write(EMPTY_TUPLE)
def save_list(self, obj):
write = self.write
......@@ -696,7 +666,7 @@ class Pickler:
module = whichmodule(obj, name)
try:
__import__(module)
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
except (ImportError, KeyError, AttributeError):
......@@ -720,9 +690,19 @@ class Pickler:
else:
write(EXT4 + pack("<i", code))
return
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 3:
write(GLOBAL + bytes(module, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else:
try:
write(GLOBAL + bytes(module, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto))
write(GLOBAL + bytes(module, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
self.memoize(obj)
dispatch[FunctionType] = save_global
......@@ -781,7 +761,7 @@ def whichmodule(func, funcname):
# Unpickling machinery
class Unpickler:
class _Unpickler:
def __init__(self, file, *, encoding="ASCII", errors="strict"):
"""This takes a binary file for reading a pickle data stream.
......@@ -841,6 +821,9 @@ class Unpickler:
while stack[k] is not mark: k = k-1
return k
def persistent_load(self, pid):
raise UnpickingError("unsupported persistent id encountered")
dispatch = {}
def load_proto(self):
......@@ -850,7 +833,7 @@ class Unpickler:
dispatch[PROTO[0]] = load_proto
def load_persid(self):
pid = self.readline()[:-1]
pid = self.readline()[:-1].decode("ascii")
self.append(self.persistent_load(pid))
dispatch[PERSID[0]] = load_persid
......@@ -879,9 +862,9 @@ class Unpickler:
val = True
else:
try:
val = int(data)
val = int(data, 0)
except ValueError:
val = int(data)
val = int(data, 0)
self.append(val)
dispatch[INT[0]] = load_int
......@@ -933,7 +916,8 @@ class Unpickler:
break
else:
raise ValueError("insecure string pickle: %r" % orig)
self.append(codecs.escape_decode(rep)[0])
self.append(codecs.escape_decode(rep)[0]
.decode(self.encoding, self.errors))
dispatch[STRING[0]] = load_string
def load_binstring(self):
......@@ -975,7 +959,7 @@ class Unpickler:
dispatch[TUPLE[0]] = load_tuple
def load_empty_tuple(self):
self.stack.append(())
self.append(())
dispatch[EMPTY_TUPLE[0]] = load_empty_tuple
def load_tuple1(self):
......@@ -991,11 +975,11 @@ class Unpickler:
dispatch[TUPLE3[0]] = load_tuple3
def load_empty_list(self):
self.stack.append([])
self.append([])
dispatch[EMPTY_LIST[0]] = load_empty_list
def load_empty_dictionary(self):
self.stack.append({})
self.append({})
dispatch[EMPTY_DICT[0]] = load_empty_dictionary
def load_list(self):
......@@ -1022,13 +1006,13 @@ class Unpickler:
def _instantiate(self, klass, k):
args = tuple(self.stack[k+1:])
del self.stack[k:]
instantiated = 0
instantiated = False
if (not args and
isinstance(klass, type) and
not hasattr(klass, "__getinitargs__")):
value = _EmptyClass()
value.__class__ = klass
instantiated = 1
instantiated = True
if not instantiated:
try:
value = klass(*args)
......@@ -1038,8 +1022,8 @@ class Unpickler:
self.append(value)
def load_inst(self):
module = self.readline()[:-1]
name = self.readline()[:-1]
module = self.readline()[:-1].decode("ascii")
name = self.readline()[:-1].decode("ascii")
klass = self.find_class(module, name)
self._instantiate(klass, self.marker())
dispatch[INST[0]] = load_inst
......@@ -1059,8 +1043,8 @@ class Unpickler:
dispatch[NEWOBJ[0]] = load_newobj
def load_global(self):
module = self.readline()[:-1]
name = self.readline()[:-1]
module = self.readline()[:-1].decode("utf-8")
name = self.readline()[:-1].decode("utf-8")
klass = self.find_class(module, name)
self.append(klass)
dispatch[GLOBAL[0]] = load_global
......@@ -1095,11 +1079,7 @@ class Unpickler:
def find_class(self, module, name):
# Subclasses may override this
if isinstance(module, bytes_types):
module = module.decode("utf-8")
if isinstance(name, bytes_types):
name = name.decode("utf-8")
__import__(module)
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
return klass
......@@ -1131,31 +1111,33 @@ class Unpickler:
dispatch[DUP[0]] = load_dup
def load_get(self):
self.append(self.memo[self.readline()[:-1].decode("ascii")])
i = int(self.readline()[:-1])
self.append(self.memo[i])
dispatch[GET[0]] = load_get
def load_binget(self):
i = ord(self.read(1))
self.append(self.memo[repr(i)])
i = self.read(1)[0]
self.append(self.memo[i])
dispatch[BINGET[0]] = load_binget
def load_long_binget(self):
i = mloads(b'i' + self.read(4))
self.append(self.memo[repr(i)])
self.append(self.memo[i])
dispatch[LONG_BINGET[0]] = load_long_binget
def load_put(self):
self.memo[self.readline()[:-1].decode("ascii")] = self.stack[-1]
i = int(self.readline()[:-1])
self.memo[i] = self.stack[-1]
dispatch[PUT[0]] = load_put
def load_binput(self):
i = ord(self.read(1))
self.memo[repr(i)] = self.stack[-1]
i = self.read(1)[0]
self.memo[i] = self.stack[-1]
dispatch[BINPUT[0]] = load_binput
def load_long_binput(self):
i = mloads(b'i' + self.read(4))
self.memo[repr(i)] = self.stack[-1]
self.memo[i] = self.stack[-1]
dispatch[LONG_BINPUT[0]] = load_long_binput
def load_append(self):
......@@ -1321,6 +1303,15 @@ def decode_long(data):
n -= 1 << (nbytes * 8)
return n
# Use the faster _pickle if possible
try:
from _pickle import *
except ImportError:
Pickler, Unpickler = _Pickler, _Unpickler
PickleError = _PickleError
PicklingError = _PicklingError
UnpicklingError = _UnpicklingError
# Shorthands
def dump(obj, file, protocol=None):
......@@ -1333,14 +1324,14 @@ def dumps(obj, protocol=None):
assert isinstance(res, bytes_types)
return res
def load(file):
return Unpickler(file).load()
def load(file, *, encoding="ASCII", errors="strict"):
return Unpickler(file, encoding=encoding, errors=errors).load()
def loads(s):
def loads(s, *, encoding="ASCII", errors="strict"):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file).load()
return Unpickler(file, encoding=encoding, errors=errors).load()
# Doctest
......
......@@ -2079,11 +2079,12 @@ _dis_test = r"""
70: t TUPLE (MARK at 49)
71: p PUT 5
74: R REDUCE
75: V UNICODE 'def'
80: p PUT 6
83: s SETITEM
84: a APPEND
85: . STOP
75: p PUT 6
78: V UNICODE 'def'
83: p PUT 7
86: s SETITEM
87: a APPEND
88: . STOP
highest protocol among opcodes = 0
Try again with a "binary" pickle.
......@@ -2115,11 +2116,12 @@ Try again with a "binary" pickle.
49: t TUPLE (MARK at 37)
50: q BINPUT 5
52: R REDUCE
53: X BINUNICODE 'def'
61: q BINPUT 6
63: s SETITEM
64: e APPENDS (MARK at 3)
65: . STOP
53: q BINPUT 6
55: X BINUNICODE 'def'
63: q BINPUT 7
65: s SETITEM
66: e APPENDS (MARK at 3)
67: . STOP
highest protocol among opcodes = 1
Exercise the INST/OBJ/BUILD family.
......
......@@ -362,7 +362,7 @@ def create_data():
return x
class AbstractPickleTests(unittest.TestCase):
# Subclass must define self.dumps, self.loads, self.error.
# Subclass must define self.dumps, self.loads.
_testdata = create_data()
......@@ -463,8 +463,9 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(list(x[0].attr.keys()), [1])
self.assert_(x[0].attr[1] is x)
def test_garyp(self):
self.assertRaises(self.error, self.loads, b'garyp')
def test_get(self):
self.assertRaises(KeyError, self.loads, b'g0\np0')
self.assertEquals(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)])
def test_insecure_strings(self):
# XXX Some of these tests are temporarily disabled
......@@ -955,7 +956,7 @@ class AbstractPickleModuleTests(unittest.TestCase):
f = open(TESTFN, "wb")
try:
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
self.assertRaises(ValueError, pickle.dump, 123, f)
finally:
os.remove(TESTFN)
......@@ -964,24 +965,24 @@ class AbstractPickleModuleTests(unittest.TestCase):
f = open(TESTFN, "wb")
try:
f.close()
self.assertRaises(ValueError, self.module.dump, 123, f)
self.assertRaises(ValueError, pickle.dump, 123, f)
finally:
os.remove(TESTFN)
def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(self.module.HIGHEST_PROTOCOL, 3)
self.assertEqual(pickle.HIGHEST_PROTOCOL, 3)
def test_callapi(self):
from io import BytesIO
f = BytesIO()
# With and without keyword arguments
self.module.dump(123, f, -1)
self.module.dump(123, file=f, protocol=-1)
self.module.dumps(123, -1)
self.module.dumps(123, protocol=-1)
self.module.Pickler(f, -1)
self.module.Pickler(f, protocol=-1)
pickle.dump(123, f, -1)
pickle.dump(123, file=f, protocol=-1)
pickle.dumps(123, -1)
pickle.dumps(123, protocol=-1)
pickle.Pickler(f, -1)
pickle.Pickler(f, protocol=-1)
class AbstractPersistentPicklerTests(unittest.TestCase):
......
......@@ -7,37 +7,42 @@ from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests
class PickleTests(AbstractPickleTests, AbstractPickleModuleTests):
try:
import _pickle
has_c_implementation = True
except ImportError:
has_c_implementation = False
module = pickle
error = KeyError
def dumps(self, arg, proto=None):
return pickle.dumps(arg, proto)
class PickleTests(AbstractPickleModuleTests):
pass
def loads(self, buf):
return pickle.loads(buf)
class PicklerTests(AbstractPickleTests):
class PyPicklerTests(AbstractPickleTests):
error = KeyError
pickler = pickle._Pickler
unpickler = pickle._Unpickler
def dumps(self, arg, proto=None):
f = io.BytesIO()
p = pickle.Pickler(f, proto)
p = self.pickler(f, proto)
p.dump(arg)
f.seek(0)
return bytes(f.read())
def loads(self, buf):
f = io.BytesIO(buf)
u = pickle.Unpickler(f)
u = self.unpickler(f)
return u.load()
class PersPicklerTests(AbstractPersistentPicklerTests):
class PyPersPicklerTests(AbstractPersistentPicklerTests):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
def dumps(self, arg, proto=None):
class PersPickler(pickle.Pickler):
class PersPickler(self.pickler):
def persistent_id(subself, obj):
return self.persistent_id(obj)
f = io.BytesIO()
......@@ -47,19 +52,29 @@ class PersPicklerTests(AbstractPersistentPicklerTests):
return f.read()
def loads(self, buf):
class PersUnpickler(pickle.Unpickler):
class PersUnpickler(self.unpickler):
def persistent_load(subself, obj):
return self.persistent_load(obj)
f = io.BytesIO(buf)
u = PersUnpickler(f)
return u.load()
if has_c_implementation:
class CPicklerTests(PyPicklerTests):
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
class CPersPicklerTests(PyPersPicklerTests):
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
def test_main():
support.run_unittest(
PickleTests,
PicklerTests,
PersPicklerTests
)
tests = [PickleTests, PyPicklerTests, PyPersPicklerTests]
if has_c_implementation:
tests.extend([CPicklerTests, CPersPicklerTests])
support.run_unittest(*tests)
support.run_doctest(pickle)
if __name__ == "__main__":
......
......@@ -12,8 +12,6 @@ class OptimizedPickleTests(AbstractPickleTests, AbstractPickleModuleTests):
def loads(self, buf):
return pickle.loads(buf)
module = pickle
error = KeyError
def test_main():
support.run_unittest(OptimizedPickleTests)
......
......@@ -78,6 +78,10 @@ Extension Modules
Library
-------
- The ``pickle`` module is now automatically use an optimized C
implementation of Pickler and Unpickler when available. The
``cPickle`` module is no longer needed.
- Removed the ``htmllib`` and ``sgmllib`` modules.
- The deprecated ``SmartCookie`` and ``SimpleCookie`` classes have
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -422,6 +422,8 @@ class PyBuildExt(build_ext):
exts.append( Extension("_functools", ["_functoolsmodule.c"]) )
# Memory-based IO accelerator modules
exts.append( Extension("_bytesio", ["_bytesio.c"]) )
# C-optimized pickle replacement
exts.append( Extension("_pickle", ["_pickle.c"]) )
# atexit
exts.append( Extension("atexit", ["atexitmodule.c"]) )
# _json speedups
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment