Commit a3e32c92 authored by Serhiy Storchaka's avatar Serhiy Storchaka

Closes #16551. Cleanup pickle.py.

parent c8fb047d
......@@ -26,9 +26,10 @@ Misc variables:
from types import FunctionType, BuiltinFunctionType
from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache
import marshal
from itertools import islice
import sys
import struct
from sys import maxsize
from struct import pack, unpack
import re
import io
import codecs
......@@ -58,11 +59,6 @@ HIGHEST_PROTOCOL = 3
# there are too many issues with that.
DEFAULT_PROTOCOL = 3
# Why use struct.pack() for pickling but marshal.loads() for
# unpickling? struct.pack() is 40% faster than marshal.dumps(), but
# marshal.loads() is twice as fast as struct.unpack()!
mloads = marshal.loads
class PickleError(Exception):
"""A common base class for the other pickling exceptions."""
pass
......@@ -231,7 +227,7 @@ class _Pickler:
raise PicklingError("Pickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,))
if self.proto >= 2:
self.write(PROTO + bytes([self.proto]))
self.write(PROTO + pack("<B", self.proto))
self.save(obj)
self.write(STOP)
......@@ -258,20 +254,20 @@ class _Pickler:
self.memo[id(obj)] = memo_len, obj
# Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
def put(self, i, pack=struct.pack):
def put(self, i):
if self.bin:
if i < 256:
return BINPUT + bytes([i])
return BINPUT + pack("<B", i)
else:
return LONG_BINPUT + pack("<I", i)
return PUT + repr(i).encode("ascii") + b'\n'
# Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
def get(self, i, pack=struct.pack):
def get(self, i):
if self.bin:
if i < 256:
return BINGET + bytes([i])
return BINGET + pack("<B", i)
else:
return LONG_BINGET + pack("<I", i)
......@@ -286,20 +282,20 @@ class _Pickler:
# Check the memo
x = self.memo.get(id(obj))
if x:
if x is not None:
self.write(self.get(x[0]))
return
# Check the type dispatch table
t = type(obj)
f = self.dispatch.get(t)
if f:
if f is not None:
f(self, obj) # Call unbound method with explicit self
return
# Check private dispatch table if any, or else copyreg.dispatch_table
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
if reduce:
if reduce is not None:
rv = reduce(obj)
else:
# Check for a class with a custom metaclass; treat as regular class
......@@ -313,11 +309,11 @@ class _Pickler:
# Check for a __reduce_ex__ method, fall back to __reduce__
reduce = getattr(obj, "__reduce_ex__", None)
if reduce:
if reduce is not None:
rv = reduce(self.proto)
else:
reduce = getattr(obj, "__reduce__", None)
if reduce:
if reduce is not None:
rv = reduce()
else:
raise PicklingError("Can't pickle %r object: %r" %
......@@ -448,12 +444,12 @@ class _Pickler:
def save_bool(self, obj):
if self.proto >= 2:
self.write(obj and NEWTRUE or NEWFALSE)
self.write(NEWTRUE if obj else NEWFALSE)
else:
self.write(obj and TRUE or FALSE)
self.write(TRUE if obj else FALSE)
dispatch[bool] = save_bool
def save_long(self, obj, pack=struct.pack):
def save_long(self, obj):
if self.bin:
# If the int is small enough to fit in a signed 4-byte 2's-comp
# format, we can store it more efficiently than the general
......@@ -461,39 +457,36 @@ class _Pickler:
# First one- and two-byte unsigned ints:
if obj >= 0:
if obj <= 0xff:
self.write(BININT1 + bytes([obj]))
self.write(BININT1 + pack("<B", obj))
return
if obj <= 0xffff:
self.write(BININT2 + bytes([obj&0xff, obj>>8]))
self.write(BININT2 + pack("<H", obj))
return
# Next check for 4-byte signed ints:
high_bits = obj >> 31 # note that Python shift sign-extends
if high_bits == 0 or high_bits == -1:
# All high bits are copies of bit 2**31, so the value
# fits in a 4-byte signed int.
if -0x80000000 <= obj <= 0x7fffffff:
self.write(BININT + pack("<i", obj))
return
if self.proto >= 2:
encoded = encode_long(obj)
n = len(encoded)
if n < 256:
self.write(LONG1 + bytes([n]) + encoded)
self.write(LONG1 + pack("<B", n) + encoded)
else:
self.write(LONG4 + pack("<i", n) + encoded)
return
self.write(LONG + repr(obj).encode("ascii") + b'L\n')
dispatch[int] = save_long
def save_float(self, obj, pack=struct.pack):
def save_float(self, obj):
if self.bin:
self.write(BINFLOAT + pack('>d', obj))
else:
self.write(FLOAT + repr(obj).encode("ascii") + b'\n')
dispatch[float] = save_float
def save_bytes(self, obj, pack=struct.pack):
def save_bytes(self, obj):
if self.proto < 3:
if len(obj) == 0:
if not obj: # bytes object is empty
self.save_reduce(bytes, (), obj=obj)
else:
self.save_reduce(codecs.encode,
......@@ -501,13 +494,13 @@ class _Pickler:
return
n = len(obj)
if n < 256:
self.write(SHORT_BINBYTES + bytes([n]) + bytes(obj))
self.write(SHORT_BINBYTES + pack("<B", n) + obj)
else:
self.write(BINBYTES + pack("<I", n) + bytes(obj))
self.write(BINBYTES + pack("<I", n) + obj)
self.memoize(obj)
dispatch[bytes] = save_bytes
def save_str(self, obj, pack=struct.pack):
def save_str(self, obj):
if self.bin:
encoded = obj.encode('utf-8', 'surrogatepass')
n = len(encoded)
......@@ -515,39 +508,36 @@ class _Pickler:
else:
obj = obj.replace("\\", "\\u005c")
obj = obj.replace("\n", "\\u000a")
self.write(UNICODE + bytes(obj.encode('raw-unicode-escape')) +
b'\n')
self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n')
self.memoize(obj)
dispatch[str] = save_str
def save_tuple(self, obj):
write = self.write
proto = self.proto
n = len(obj)
if n == 0:
if proto:
write(EMPTY_TUPLE)
if not obj: # tuple is empty
if self.bin:
self.write(EMPTY_TUPLE)
else:
write(MARK + TUPLE)
self.write(MARK + TUPLE)
return
n = len(obj)
save = self.save
memo = self.memo
if n <= 3 and proto >= 2:
if n <= 3 and self.proto >= 2:
for element in obj:
save(element)
# Subtle. Same as in the big comment below.
if id(obj) in memo:
get = self.get(memo[id(obj)][0])
write(POP * n + get)
self.write(POP * n + get)
else:
write(_tuplesize2code[n])
self.write(_tuplesize2code[n])
self.memoize(obj)
return
# proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple
# has more than 3 elements.
write = self.write
write(MARK)
for element in obj:
save(element)
......@@ -561,25 +551,23 @@ class _Pickler:
# could have been done in the "for element" loop instead, but
# recursive tuples are a rare thing.
get = self.get(memo[id(obj)][0])
if proto:
if self.bin:
write(POP_MARK + get)
else: # proto 0 -- POP_MARK not available
write(POP * (n+1) + get)
return
# No recursion.
self.write(TUPLE)
write(TUPLE)
self.memoize(obj)
dispatch[tuple] = save_tuple
def save_list(self, obj):
write = self.write
if self.bin:
write(EMPTY_LIST)
self.write(EMPTY_LIST)
else: # proto 0 -- can't use EMPTY_LIST
write(MARK + LIST)
self.write(MARK + LIST)
self.memoize(obj)
self._batch_appends(obj)
......@@ -599,17 +587,9 @@ class _Pickler:
write(APPEND)
return
items = iter(items)
r = range(self._BATCHSIZE)
while items is not None:
tmp = []
for i in r:
try:
x = next(items)
tmp.append(x)
except StopIteration:
items = None
break
it = iter(items)
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
write(MARK)
......@@ -620,14 +600,14 @@ class _Pickler:
save(tmp[0])
write(APPEND)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
def save_dict(self, obj):
write = self.write
if self.bin:
write(EMPTY_DICT)
self.write(EMPTY_DICT)
else: # proto 0 -- can't use EMPTY_DICT
write(MARK + DICT)
self.write(MARK + DICT)
self.memoize(obj)
self._batch_setitems(obj.items())
......@@ -648,16 +628,9 @@ class _Pickler:
write(SETITEM)
return
items = iter(items)
r = range(self._BATCHSIZE)
while items is not None:
tmp = []
for i in r:
try:
tmp.append(next(items))
except StopIteration:
items = None
break
it = iter(items)
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
write(MARK)
......@@ -671,8 +644,10 @@ class _Pickler:
save(v)
write(SETITEM)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
def save_global(self, obj, name=None, pack=struct.pack):
def save_global(self, obj, name=None):
write = self.write
memo = self.memo
......@@ -702,9 +677,9 @@ class _Pickler:
if code:
assert code > 0
if code <= 0xff:
write(EXT1 + bytes([code]))
write(EXT1 + pack("<B", code))
elif code <= 0xffff:
write(EXT2 + bytes([code&0xff, code>>8]))
write(EXT2 + pack("<H", code))
else:
write(EXT4 + pack("<i", code))
return
......@@ -732,25 +707,6 @@ class _Pickler:
dispatch[BuiltinFunctionType] = save_global
dispatch[type] = save_global
# Pickling helpers
def _keep_alive(x, memo):
"""Keeps a reference to the object x in the memo.
Because we remember objects by their id, we have
to assure that possibly temporary objects are kept
alive by referencing them.
We store a reference at the id of the memo, which should
normally not be used unless someone tries to deepcopy
the memo itself...
"""
try:
memo[id(memo)].append(x)
except KeyError:
# aha, this is the first one :-)
memo[id(memo)]=[x]
# A cache for whichmodule(), mapping a function object to the name of
# the module in which the function was found.
......@@ -832,7 +788,7 @@ class _Unpickler:
read = self.read
dispatch = self.dispatch
try:
while 1:
while True:
key = read(1)
if not key:
raise EOFError
......@@ -862,7 +818,7 @@ class _Unpickler:
dispatch = {}
def load_proto(self):
proto = ord(self.read(1))
proto = self.read(1)[0]
if not 0 <= proto <= HIGHEST_PROTOCOL:
raise ValueError("unsupported pickle protocol: %d" % proto)
self.proto = proto
......@@ -897,40 +853,37 @@ class _Unpickler:
elif data == TRUE[1:]:
val = True
else:
try:
val = int(data, 0)
except ValueError:
val = int(data, 0)
val = int(data, 0)
self.append(val)
dispatch[INT[0]] = load_int
def load_binint(self):
self.append(mloads(b'i' + self.read(4)))
self.append(unpack('<i', self.read(4))[0])
dispatch[BININT[0]] = load_binint
def load_binint1(self):
self.append(ord(self.read(1)))
self.append(self.read(1)[0])
dispatch[BININT1[0]] = load_binint1
def load_binint2(self):
self.append(mloads(b'i' + self.read(2) + b'\000\000'))
self.append(unpack('<H', self.read(2))[0])
dispatch[BININT2[0]] = load_binint2
def load_long(self):
val = self.readline()[:-1].decode("ascii")
if val and val[-1] == 'L':
val = self.readline()[:-1]
if val and val[-1] == b'L'[0]:
val = val[:-1]
self.append(int(val, 0))
dispatch[LONG[0]] = load_long
def load_long1(self):
n = ord(self.read(1))
n = self.read(1)[0]
data = self.read(n)
self.append(decode_long(data))
dispatch[LONG1[0]] = load_long1
def load_long4(self):
n = mloads(b'i' + self.read(4))
n, = unpack('<i', self.read(4))
if n < 0:
# Corrupt or hostile pickle -- we never write one like this
raise UnpicklingError("LONG pickle has negative byte count")
......@@ -942,28 +895,25 @@ class _Unpickler:
self.append(float(self.readline()[:-1]))
dispatch[FLOAT[0]] = load_float
def load_binfloat(self, unpack=struct.unpack):
def load_binfloat(self):
self.append(unpack('>d', self.read(8))[0])
dispatch[BINFLOAT[0]] = load_binfloat
def load_string(self):
orig = self.readline()
rep = orig[:-1]
for q in (b'"', b"'"): # double or single quote
if rep.startswith(q):
if not rep.endswith(q):
raise ValueError("insecure string pickle")
rep = rep[len(q):-len(q)]
break
# Strip outermost quotes
if rep[0] == rep[-1] and rep[0] in b'"\'':
rep = rep[1:-1]
else:
raise ValueError("insecure string pickle: %r" % orig)
raise ValueError("insecure string pickle")
self.append(codecs.escape_decode(rep)[0]
.decode(self.encoding, self.errors))
dispatch[STRING[0]] = load_string
def load_binstring(self):
# Deprecated BINSTRING uses signed 32-bit length
len = mloads(b'i' + self.read(4))
len, = unpack('<i', self.read(4))
if len < 0:
raise UnpicklingError("BINSTRING pickle has negative byte count")
data = self.read(len)
......@@ -971,7 +921,7 @@ class _Unpickler:
self.append(value)
dispatch[BINSTRING[0]] = load_binstring
def load_binbytes(self, unpack=struct.unpack, maxsize=sys.maxsize):
def load_binbytes(self):
len, = unpack('<I', self.read(4))
if len > maxsize:
raise UnpicklingError("BINBYTES exceeds system's maximum size "
......@@ -983,7 +933,7 @@ class _Unpickler:
self.append(str(self.readline()[:-1], 'raw-unicode-escape'))
dispatch[UNICODE[0]] = load_unicode
def load_binunicode(self, unpack=struct.unpack, maxsize=sys.maxsize):
def load_binunicode(self):
len, = unpack('<I', self.read(4))
if len > maxsize:
raise UnpicklingError("BINUNICODE exceeds system's maximum size "
......@@ -992,15 +942,15 @@ class _Unpickler:
dispatch[BINUNICODE[0]] = load_binunicode
def load_short_binstring(self):
len = ord(self.read(1))
data = bytes(self.read(len))
len = self.read(1)[0]
data = self.read(len)
value = str(data, self.encoding, self.errors)
self.append(value)
dispatch[SHORT_BINSTRING[0]] = load_short_binstring
def load_short_binbytes(self):
len = ord(self.read(1))
self.append(bytes(self.read(len)))
len = self.read(1)[0]
self.append(self.read(len))
dispatch[SHORT_BINBYTES[0]] = load_short_binbytes
def load_tuple(self):
......@@ -1039,12 +989,9 @@ class _Unpickler:
def load_dict(self):
k = self.marker()
d = {}
items = self.stack[k+1:]
for i in range(0, len(items), 2):
key = items[i]
value = items[i+1]
d[key] = value
d = {items[i]: items[i+1]
for i in range(0, len(items), 2)}
self.stack[k:] = [d]
dispatch[DICT[0]] = load_dict
......@@ -1096,17 +1043,17 @@ class _Unpickler:
dispatch[GLOBAL[0]] = load_global
def load_ext1(self):
code = ord(self.read(1))
code = self.read(1)[0]
self.get_extension(code)
dispatch[EXT1[0]] = load_ext1
def load_ext2(self):
code = mloads(b'i' + self.read(2) + b'\000\000')
code, = unpack('<H', self.read(2))
self.get_extension(code)
dispatch[EXT2[0]] = load_ext2
def load_ext4(self):
code = mloads(b'i' + self.read(4))
code, = unpack('<i', self.read(4))
self.get_extension(code)
dispatch[EXT4[0]] = load_ext4
......@@ -1174,7 +1121,7 @@ class _Unpickler:
self.append(self.memo[i])
dispatch[BINGET[0]] = load_binget
def load_long_binget(self, unpack=struct.unpack):
def load_long_binget(self):
i, = unpack('<I', self.read(4))
self.append(self.memo[i])
dispatch[LONG_BINGET[0]] = load_long_binget
......@@ -1193,7 +1140,7 @@ class _Unpickler:
self.memo[i] = self.stack[-1]
dispatch[BINPUT[0]] = load_binput
def load_long_binput(self, unpack=struct.unpack, maxsize=sys.maxsize):
def load_long_binput(self):
i, = unpack('<I', self.read(4))
if i > maxsize:
raise ValueError("negative LONG_BINPUT argument")
......@@ -1238,7 +1185,7 @@ class _Unpickler:
state = stack.pop()
inst = stack[-1]
setstate = getattr(inst, "__setstate__", None)
if setstate:
if setstate is not None:
setstate(state)
return
slotstate = None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment