Commit c9dc4a2a authored by Antoine Pitrou's avatar Antoine Pitrou

Issue #17810: Implement PEP 3154, pickle protocol 4.

Most of the work is by Alexandre.
parent 95401c5f
...@@ -459,12 +459,29 @@ implementation of this behaviour:: ...@@ -459,12 +459,29 @@ implementation of this behaviour::
Classes can alter the default behaviour by providing one or several special Classes can alter the default behaviour by providing one or several special
methods: methods:
.. method:: object.__getnewargs_ex__()
In protocols 4 and newer, classes that implements the
:meth:`__getnewargs_ex__` method can dictate the values passed to the
:meth:`__new__` method upon unpickling. The method must return a pair
``(args, kwargs)`` where *args* is a tuple of positional arguments
and *kwargs* a dictionary of named arguments for constructing the
object. Those will be passed to the :meth:`__new__` method upon
unpickling.
You should implement this method if the :meth:`__new__` method of your
class requires keyword-only arguments. Otherwise, it is recommended for
compatibility to implement :meth:`__getnewargs__`.
.. method:: object.__getnewargs__() .. method:: object.__getnewargs__()
In protocol 2 and newer, classes that implements the :meth:`__getnewargs__` This method serve a similar purpose as :meth:`__getnewargs_ex__` but
method can dictate the values passed to the :meth:`__new__` method upon for protocols 2 and newer. It must return a tuple of arguments `args`
unpickling. This is often needed for classes whose :meth:`__new__` method which will be passed to the :meth:`__new__` method upon unpickling.
requires arguments.
In protocols 4 and newer, :meth:`__getnewargs__` will not be called if
:meth:`__getnewargs_ex__` is defined.
.. method:: object.__getstate__() .. method:: object.__getstate__()
...@@ -496,10 +513,10 @@ the methods :meth:`__getstate__` and :meth:`__setstate__`. ...@@ -496,10 +513,10 @@ the methods :meth:`__getstate__` and :meth:`__setstate__`.
At unpickling time, some methods like :meth:`__getattr__`, At unpickling time, some methods like :meth:`__getattr__`,
:meth:`__getattribute__`, or :meth:`__setattr__` may be called upon the :meth:`__getattribute__`, or :meth:`__setattr__` may be called upon the
instance. In case those methods rely on some internal invariant being true, instance. In case those methods rely on some internal invariant being
the type should implement :meth:`__getnewargs__` to establish such an true, the type should implement :meth:`__getnewargs__` or
invariant; otherwise, neither :meth:`__new__` nor :meth:`__init__` will be :meth:`__getnewargs_ex__` to establish such an invariant; otherwise,
called. neither :meth:`__new__` nor :meth:`__init__` will be called.
.. index:: pair: copy; protocol .. index:: pair: copy; protocol
...@@ -511,7 +528,7 @@ objects. [#]_ ...@@ -511,7 +528,7 @@ objects. [#]_
Although powerful, implementing :meth:`__reduce__` directly in your classes is Although powerful, implementing :meth:`__reduce__` directly in your classes is
error prone. For this reason, class designers should use the high-level error prone. For this reason, class designers should use the high-level
interface (i.e., :meth:`__getnewargs__`, :meth:`__getstate__` and interface (i.e., :meth:`__getnewargs_ex__`, :meth:`__getstate__` and
:meth:`__setstate__`) whenever possible. We will show, however, cases where :meth:`__setstate__`) whenever possible. We will show, however, cases where
using :meth:`__reduce__` is the only option or leads to more efficient pickling using :meth:`__reduce__` is the only option or leads to more efficient pickling
or both. or both.
......
...@@ -109,6 +109,7 @@ New expected features for Python implementations: ...@@ -109,6 +109,7 @@ New expected features for Python implementations:
Significantly Improved Library Modules: Significantly Improved Library Modules:
* Single-dispatch generic functions in :mod:`functoools` (:pep:`443`) * Single-dispatch generic functions in :mod:`functoools` (:pep:`443`)
* New :mod:`pickle` protocol 4 (:pep:`3154`)
* SHA-3 (Keccak) support for :mod:`hashlib`. * SHA-3 (Keccak) support for :mod:`hashlib`.
* TLSv1.1 and TLSv1.2 support for :mod:`ssl`. * TLSv1.1 and TLSv1.2 support for :mod:`ssl`.
* :mod:`multiprocessing` now has option to avoid using :func:`os.fork` * :mod:`multiprocessing` now has option to avoid using :func:`os.fork`
...@@ -285,6 +286,20 @@ described in the PEP. Existing importers should be updated to implement ...@@ -285,6 +286,20 @@ described in the PEP. Existing importers should be updated to implement
the new methods. the new methods.
Pickle protocol 4
=================
The new :mod:`pickle` protocol addresses a number of issues that were present
in previous protocols, such as the serialization of nested classes, very
large strings and containers, or classes whose :meth:`__new__` method takes
keyword-only arguments. It also brings a couple efficiency improvements.
.. seealso::
:pep:`3154` - Pickle protocol 4
PEP written by Antoine Pitrou and implemented by Alexandre Vassalotti.
Other Language Changes Other Language Changes
====================== ======================
......
...@@ -87,6 +87,12 @@ def _reduce_ex(self, proto): ...@@ -87,6 +87,12 @@ def _reduce_ex(self, proto):
def __newobj__(cls, *args): def __newobj__(cls, *args):
return cls.__new__(cls, *args) return cls.__new__(cls, *args)
def __newobj_ex__(cls, args, kwargs):
"""Used by pickle protocol 4, instead of __newobj__ to allow classes with
keyword-only arguments to be pickled correctly.
"""
return cls.__new__(cls, *args, **kwargs)
def _slotnames(cls): def _slotnames(cls):
"""Return a list of slot names for a given class. """Return a list of slot names for a given class.
......
...@@ -23,7 +23,7 @@ Misc variables: ...@@ -23,7 +23,7 @@ Misc variables:
""" """
from types import FunctionType, BuiltinFunctionType from types import FunctionType, BuiltinFunctionType, ModuleType
from copyreg import dispatch_table from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache from copyreg import _extension_registry, _inverted_registry, _extension_cache
from itertools import islice from itertools import islice
...@@ -42,17 +42,18 @@ __all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", ...@@ -42,17 +42,18 @@ __all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler",
bytes_types = (bytes, bytearray) bytes_types = (bytes, bytearray)
# These are purely informational; no code uses these. # These are purely informational; no code uses these.
format_version = "3.0" # File format version we write format_version = "4.0" # File format version we write
compatible_formats = ["1.0", # Original protocol 0 compatible_formats = ["1.0", # Original protocol 0
"1.1", # Protocol 0 with INST added "1.1", # Protocol 0 with INST added
"1.2", # Original protocol 1 "1.2", # Original protocol 1
"1.3", # Protocol 1 with BINFLOAT added "1.3", # Protocol 1 with BINFLOAT added
"2.0", # Protocol 2 "2.0", # Protocol 2
"3.0", # Protocol 3 "3.0", # Protocol 3
"4.0", # Protocol 4
] # Old format versions we can read ] # Old format versions we can read
# This is the highest protocol number we know how to read. # This is the highest protocol number we know how to read.
HIGHEST_PROTOCOL = 3 HIGHEST_PROTOCOL = 4
# The protocol we write by default. May be less than HIGHEST_PROTOCOL. # The protocol we write by default. May be less than HIGHEST_PROTOCOL.
# We intentionally write a protocol that Python 2.x cannot read; # We intentionally write a protocol that Python 2.x cannot read;
...@@ -164,7 +165,196 @@ _tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3] ...@@ -164,7 +165,196 @@ _tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3]
BINBYTES = b'B' # push bytes; counted binary string argument BINBYTES = b'B' # push bytes; counted binary string argument
SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes
__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)]) # Protocol 4
SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
BINUNICODE8 = b'\x8d' # push very long string
BINBYTES8 = b'\x8e' # push very long bytes string
EMPTY_SET = b'\x8f' # push empty set on the stack
ADDITEMS = b'\x90' # modify set by adding topmost stack items
FROZENSET = b'\x91' # build frozenset from topmost stack items
NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments
STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks
MEMOIZE = b'\x94' # store top of the stack in memo
FRAME = b'\x95' # indicate the beginning of a new frame
__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)])
class _Framer:
_FRAME_SIZE_TARGET = 64 * 1024
def __init__(self, file_write):
self.file_write = file_write
self.current_frame = None
def _commit_frame(self):
f = self.current_frame
with f.getbuffer() as data:
n = len(data)
write = self.file_write
write(FRAME)
write(pack("<Q", n))
write(data)
f.seek(0)
f.truncate()
def start_framing(self):
self.current_frame = io.BytesIO()
def end_framing(self):
if self.current_frame is not None:
self._commit_frame()
self.current_frame = None
def write(self, data):
f = self.current_frame
if f is None:
return self.file_write(data)
else:
n = len(data)
if f.tell() >= self._FRAME_SIZE_TARGET:
self._commit_frame()
return f.write(data)
class _Unframer:
def __init__(self, file_read, file_readline, file_tell=None):
self.file_read = file_read
self.file_readline = file_readline
self.file_tell = file_tell
self.framing_enabled = False
self.current_frame = None
self.frame_start = None
def read(self, n):
if n == 0:
return b''
_file_read = self.file_read
if not self.framing_enabled:
return _file_read(n)
f = self.current_frame
if f is not None:
data = f.read(n)
if data:
if len(data) < n:
raise UnpicklingError(
"pickle exhausted before end of frame")
return data
frame_opcode = _file_read(1)
if frame_opcode != FRAME:
raise UnpicklingError(
"expected a FRAME opcode, got {} instead".format(frame_opcode))
frame_size, = unpack("<Q", _file_read(8))
if frame_size > sys.maxsize:
raise ValueError("frame size > sys.maxsize: %d" % frame_size)
if self.file_tell is not None:
self.frame_start = self.file_tell()
f = self.current_frame = io.BytesIO(_file_read(frame_size))
self.readline = f.readline
data = f.read(n)
assert len(data) == n, (len(data), n)
return data
def readline(self):
if not self.framing_enabled:
return self.file_readline()
else:
return self.current_frame.readline()
def tell(self):
if self.file_tell is None:
return None
elif self.current_frame is None:
return self.file_tell()
else:
return self.frame_start + self.current_frame.tell()
# Tools used for pickling.
def _getattribute(obj, name, allow_qualname=False):
dotted_path = name.split(".")
if not allow_qualname and len(dotted_path) > 1:
raise AttributeError("Can't get qualified attribute {!r} on {!r}; " +
"use protocols >= 4 to enable support"
.format(name, obj))
for subpath in dotted_path:
if subpath == '<locals>':
raise AttributeError("Can't get local attribute {!r} on {!r}"
.format(name, obj))
try:
obj = getattr(obj, subpath)
except AttributeError:
raise AttributeError("Can't get attribute {!r} on {!r}"
.format(name, obj))
return obj
def whichmodule(obj, name, allow_qualname=False):
"""Find the module an object belong to."""
module_name = getattr(obj, '__module__', None)
if module_name is not None:
return module_name
for module_name, module in sys.modules.items():
if module_name == '__main__' or module is None:
continue
try:
if _getattribute(module, name, allow_qualname) is obj:
return module_name
except AttributeError:
pass
return '__main__'
def encode_long(x):
r"""Encode a long to a two's complement little-endian binary string.
Note that 0 is a special case, returning an empty string, to save a
byte in the LONG1 pickling context.
>>> encode_long(0)
b''
>>> encode_long(255)
b'\xff\x00'
>>> encode_long(32767)
b'\xff\x7f'
>>> encode_long(-256)
b'\x00\xff'
>>> encode_long(-32768)
b'\x00\x80'
>>> encode_long(-128)
b'\x80'
>>> encode_long(127)
b'\x7f'
>>>
"""
if x == 0:
return b''
nbytes = (x.bit_length() >> 3) + 1
result = x.to_bytes(nbytes, byteorder='little', signed=True)
if x < 0 and nbytes > 1:
if result[-1] == 0xff and (result[-2] & 0x80) != 0:
result = result[:-1]
return result
def decode_long(data):
r"""Decode a long from a two's complement little-endian binary string.
>>> decode_long(b'')
0
>>> decode_long(b"\xff\x00")
255
>>> decode_long(b"\xff\x7f")
32767
>>> decode_long(b"\x00\xff")
-256
>>> decode_long(b"\x00\x80")
-32768
>>> decode_long(b"\x80")
-128
>>> decode_long(b"\x7f")
127
"""
return int.from_bytes(data, byteorder='little', signed=True)
# Pickling machinery # Pickling machinery
...@@ -174,9 +364,9 @@ class _Pickler: ...@@ -174,9 +364,9 @@ class _Pickler:
"""This takes a binary file for writing a pickle data stream. """This takes a binary file for writing a pickle data stream.
The optional protocol argument tells the pickler to use the The optional protocol argument tells the pickler to use the
given protocol; supported protocols are 0, 1, 2, 3. The default given protocol; supported protocols are 0, 1, 2, 3 and 4. The
protocol is 3; a backward-incompatible protocol designed for default protocol is 3; a backward-incompatible protocol designed for
Python 3.0. Python 3.
Specifying a negative protocol version selects the highest Specifying a negative protocol version selects the highest
protocol version supported. The higher the protocol used, the protocol version supported. The higher the protocol used, the
...@@ -189,8 +379,8 @@ class _Pickler: ...@@ -189,8 +379,8 @@ class _Pickler:
meets this interface. meets this interface.
If fix_imports is True and protocol is less than 3, pickle will try to If fix_imports is True and protocol is less than 3, pickle will try to
map the new Python 3.x names to the old module names used in Python map the new Python 3 names to the old module names used in Python 2,
2.x, so that the pickle data stream is readable with Python 2.x. so that the pickle data stream is readable with Python 2.
""" """
if protocol is None: if protocol is None:
protocol = DEFAULT_PROTOCOL protocol = DEFAULT_PROTOCOL
...@@ -199,7 +389,7 @@ class _Pickler: ...@@ -199,7 +389,7 @@ class _Pickler:
elif not 0 <= protocol <= HIGHEST_PROTOCOL: elif not 0 <= protocol <= HIGHEST_PROTOCOL:
raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
try: try:
self.write = file.write self._file_write = file.write
except AttributeError: except AttributeError:
raise TypeError("file must have a 'write' attribute") raise TypeError("file must have a 'write' attribute")
self.memo = {} self.memo = {}
...@@ -223,13 +413,22 @@ class _Pickler: ...@@ -223,13 +413,22 @@ class _Pickler:
"""Write a pickled representation of obj to the open file.""" """Write a pickled representation of obj to the open file."""
# Check whether Pickler was initialized correctly. This is # Check whether Pickler was initialized correctly. This is
# only needed to mimic the behavior of _pickle.Pickler.dump(). # only needed to mimic the behavior of _pickle.Pickler.dump().
if not hasattr(self, "write"): if not hasattr(self, "_file_write"):
raise PicklingError("Pickler.__init__() was not called by " raise PicklingError("Pickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,)) "%s.__init__()" % (self.__class__.__name__,))
if self.proto >= 2: if self.proto >= 2:
self.write(PROTO + pack("<B", self.proto)) self._file_write(PROTO + pack("<B", self.proto))
if self.proto >= 4:
framer = _Framer(self._file_write)
framer.start_framing()
self.write = framer.write
else:
framer = None
self.write = self._file_write
self.save(obj) self.save(obj)
self.write(STOP) self.write(STOP)
if framer is not None:
framer.end_framing()
def memoize(self, obj): def memoize(self, obj):
"""Store an object in the memo.""" """Store an object in the memo."""
...@@ -249,19 +448,21 @@ class _Pickler: ...@@ -249,19 +448,21 @@ class _Pickler:
if self.fast: if self.fast:
return return
assert id(obj) not in self.memo assert id(obj) not in self.memo
memo_len = len(self.memo) idx = len(self.memo)
self.write(self.put(memo_len)) self.write(self.put(idx))
self.memo[id(obj)] = memo_len, obj self.memo[id(obj)] = idx, obj
# Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
def put(self, i): def put(self, idx):
if self.bin: if self.proto >= 4:
if i < 256: return MEMOIZE
return BINPUT + pack("<B", i) elif self.bin:
if idx < 256:
return BINPUT + pack("<B", idx)
else: else:
return LONG_BINPUT + pack("<I", i) return LONG_BINPUT + pack("<I", idx)
else:
return PUT + repr(i).encode("ascii") + b'\n' return PUT + repr(idx).encode("ascii") + b'\n'
# Return a GET (BINGET, LONG_BINGET) opcode string, with argument i. # Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
def get(self, i): def get(self, i):
...@@ -349,24 +550,33 @@ class _Pickler: ...@@ -349,24 +550,33 @@ class _Pickler:
else: else:
self.write(PERSID + str(pid).encode("ascii") + b'\n') self.write(PERSID + str(pid).encode("ascii") + b'\n')
def save_reduce(self, func, args, state=None, def save_reduce(self, func, args, state=None, listitems=None,
listitems=None, dictitems=None, obj=None): dictitems=None, obj=None):
# This API is called by some subclasses # This API is called by some subclasses
# Assert that args is a tuple
if not isinstance(args, tuple): if not isinstance(args, tuple):
raise PicklingError("args from save_reduce() should be a tuple") raise PicklingError("args from save_reduce() must be a tuple")
# Assert that func is callable
if not callable(func): if not callable(func):
raise PicklingError("func from save_reduce() should be callable") raise PicklingError("func from save_reduce() must be callable")
save = self.save save = self.save
write = self.write write = self.write
# Protocol 2 special case: if func's name is __newobj__, use NEWOBJ func_name = getattr(func, "__name__", "")
if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": if self.proto >= 4 and func_name == "__newobj_ex__":
# A __reduce__ implementation can direct protocol 2 to cls, args, kwargs = args
if not hasattr(cls, "__new__"):
raise PicklingError("args[0] from {} args has no __new__"
.format(func_name))
if obj is not None and cls is not obj.__class__:
raise PicklingError("args[0] from {} args has the wrong class"
.format(func_name))
save(cls)
save(args)
save(kwargs)
write(NEWOBJ_EX)
elif self.proto >= 2 and func_name == "__newobj__":
# A __reduce__ implementation can direct protocol 2 or newer to
# use the more efficient NEWOBJ opcode, while still # use the more efficient NEWOBJ opcode, while still
# allowing protocol 0 and 1 to work normally. For this to # allowing protocol 0 and 1 to work normally. For this to
# work, the function returned by __reduce__ should be # work, the function returned by __reduce__ should be
...@@ -409,6 +619,12 @@ class _Pickler: ...@@ -409,6 +619,12 @@ class _Pickler:
write(REDUCE) write(REDUCE)
if obj is not None: if obj is not None:
# If the object is already in the memo, this means it is
# recursive. In this case, throw away everything we put on the
# stack, and fetch the object back from the memo.
if id(obj) in self.memo:
write(POP + self.get(self.memo[id(obj)][0]))
else:
self.memoize(obj) self.memoize(obj)
# More new special cases (that work with older protocols as # More new special cases (that work with older protocols as
...@@ -493,8 +709,10 @@ class _Pickler: ...@@ -493,8 +709,10 @@ class _Pickler:
(str(obj, 'latin1'), 'latin1'), obj=obj) (str(obj, 'latin1'), 'latin1'), obj=obj)
return return
n = len(obj) n = len(obj)
if n < 256: if n <= 0xff:
self.write(SHORT_BINBYTES + pack("<B", n) + obj) self.write(SHORT_BINBYTES + pack("<B", n) + obj)
elif n > 0xffffffff and self.proto >= 4:
self.write(BINBYTES8 + pack("<Q", n) + obj)
else: else:
self.write(BINBYTES + pack("<I", n) + obj) self.write(BINBYTES + pack("<I", n) + obj)
self.memoize(obj) self.memoize(obj)
...@@ -504,11 +722,17 @@ class _Pickler: ...@@ -504,11 +722,17 @@ class _Pickler:
if self.bin: if self.bin:
encoded = obj.encode('utf-8', 'surrogatepass') encoded = obj.encode('utf-8', 'surrogatepass')
n = len(encoded) n = len(encoded)
if n <= 0xff and self.proto >= 4:
self.write(SHORT_BINUNICODE + pack("<B", n) + encoded)
elif n > 0xffffffff and self.proto >= 4:
self.write(BINUNICODE8 + pack("<Q", n) + encoded)
else:
self.write(BINUNICODE + pack("<I", n) + encoded) self.write(BINUNICODE + pack("<I", n) + encoded)
else: else:
obj = obj.replace("\\", "\\u005c") obj = obj.replace("\\", "\\u005c")
obj = obj.replace("\n", "\\u000a") obj = obj.replace("\n", "\\u000a")
self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n') self.write(UNICODE + obj.encode('raw-unicode-escape') +
b'\n')
self.memoize(obj) self.memoize(obj)
dispatch[str] = save_str dispatch[str] = save_str
...@@ -647,33 +871,79 @@ class _Pickler: ...@@ -647,33 +871,79 @@ class _Pickler:
if n < self._BATCHSIZE: if n < self._BATCHSIZE:
return return
def save_set(self, obj):
save = self.save
write = self.write
if self.proto < 4:
self.save_reduce(set, (list(obj),), obj=obj)
return
write(EMPTY_SET)
self.memoize(obj)
it = iter(obj)
while True:
batch = list(islice(it, self._BATCHSIZE))
n = len(batch)
if n > 0:
write(MARK)
for item in batch:
save(item)
write(ADDITEMS)
if n < self._BATCHSIZE:
return
dispatch[set] = save_set
def save_frozenset(self, obj):
save = self.save
write = self.write
if self.proto < 4:
self.save_reduce(frozenset, (list(obj),), obj=obj)
return
write(MARK)
for item in obj:
save(item)
if id(obj) in self.memo:
# If the object is already in the memo, this means it is
# recursive. In this case, throw away everything we put on the
# stack, and fetch the object back from the memo.
write(POP_MARK + self.get(self.memo[id(obj)][0]))
return
write(FROZENSET)
self.memoize(obj)
dispatch[frozenset] = save_frozenset
def save_global(self, obj, name=None): def save_global(self, obj, name=None):
write = self.write write = self.write
memo = self.memo memo = self.memo
if name is None and self.proto >= 4:
name = getattr(obj, '__qualname__', None)
if name is None: if name is None:
name = obj.__name__ name = obj.__name__
module = getattr(obj, "__module__", None) module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4)
if module is None:
module = whichmodule(obj, name)
try: try:
__import__(module, level=0) __import__(module_name, level=0)
mod = sys.modules[module] module = sys.modules[module_name]
klass = getattr(mod, name) obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4)
except (ImportError, KeyError, AttributeError): except (ImportError, KeyError, AttributeError):
raise PicklingError( raise PicklingError(
"Can't pickle %r: it's not found as %s.%s" % "Can't pickle %r: it's not found as %s.%s" %
(obj, module, name)) (obj, module_name, name))
else: else:
if klass is not obj: if obj2 is not obj:
raise PicklingError( raise PicklingError(
"Can't pickle %r: it's not the same object as %s.%s" % "Can't pickle %r: it's not the same object as %s.%s" %
(obj, module, name)) (obj, module_name, name))
if self.proto >= 2: if self.proto >= 2:
code = _extension_registry.get((module, name)) code = _extension_registry.get((module_name, name))
if code: if code:
assert code > 0 assert code > 0
if code <= 0xff: if code <= 0xff:
...@@ -684,17 +954,23 @@ class _Pickler: ...@@ -684,17 +954,23 @@ class _Pickler:
write(EXT4 + pack("<i", code)) write(EXT4 + pack("<i", code))
return return
# Non-ASCII identifiers are supported only with protocols >= 3. # Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 3: if self.proto >= 4:
write(GLOBAL + bytes(module, "utf-8") + b'\n' + self.save(module_name)
self.save(name)
write(STACK_GLOBAL)
elif self.proto >= 3:
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n') bytes(name, "utf-8") + b'\n')
else: else:
if self.fix_imports: if self.fix_imports:
if (module, name) in _compat_pickle.REVERSE_NAME_MAPPING: r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
module, name = _compat_pickle.REVERSE_NAME_MAPPING[(module, name)] r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
if module in _compat_pickle.REVERSE_IMPORT_MAPPING: if (module_name, name) in r_name_mapping:
module = _compat_pickle.REVERSE_IMPORT_MAPPING[module] module_name, name = r_name_mapping[(module_name, name)]
if module_name in r_import_mapping:
module_name = r_import_mapping[module_name]
try: try:
write(GLOBAL + bytes(module, "ascii") + b'\n' + write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n') bytes(name, "ascii") + b'\n')
except UnicodeEncodeError: except UnicodeEncodeError:
raise PicklingError( raise PicklingError(
...@@ -703,40 +979,16 @@ class _Pickler: ...@@ -703,40 +979,16 @@ class _Pickler:
self.memoize(obj) self.memoize(obj)
def save_method(self, obj):
if obj.__self__ is None or type(obj.__self__) is ModuleType:
self.save_global(obj)
else:
self.save_reduce(getattr, (obj.__self__, obj.__name__), obj=obj)
dispatch[FunctionType] = save_global dispatch[FunctionType] = save_global
dispatch[BuiltinFunctionType] = save_global dispatch[BuiltinFunctionType] = save_method
dispatch[type] = save_global dispatch[type] = save_global
# A cache for whichmodule(), mapping a function object to the name of
# the module in which the function was found.
classmap = {} # called classmap for backwards compatibility
def whichmodule(func, funcname):
"""Figure out the module in which a function occurs.
Search sys.modules for the module.
Cache in classmap.
Return a module name.
If the function cannot be found, return "__main__".
"""
# Python functions should always get an __module__ from their globals.
mod = getattr(func, "__module__", None)
if mod is not None:
return mod
if func in classmap:
return classmap[func]
for name, module in list(sys.modules.items()):
if module is None:
continue # skip dummy package entries
if name != '__main__' and getattr(module, funcname, None) is func:
break
else:
name = '__main__'
classmap[func] = name
return name
# Unpickling machinery # Unpickling machinery
...@@ -764,8 +1016,8 @@ class _Unpickler: ...@@ -764,8 +1016,8 @@ class _Unpickler:
instances pickled by Python 2.x; these default to 'ASCII' and instances pickled by Python 2.x; these default to 'ASCII' and
'strict', respectively. 'strict', respectively.
""" """
self.readline = file.readline self._file_readline = file.readline
self.read = file.read self._file_read = file.read
self.memo = {} self.memo = {}
self.encoding = encoding self.encoding = encoding
self.errors = errors self.errors = errors
...@@ -779,12 +1031,16 @@ class _Unpickler: ...@@ -779,12 +1031,16 @@ class _Unpickler:
""" """
# Check whether Unpickler was initialized correctly. This is # Check whether Unpickler was initialized correctly. This is
# only needed to mimic the behavior of _pickle.Unpickler.dump(). # only needed to mimic the behavior of _pickle.Unpickler.dump().
if not hasattr(self, "read"): if not hasattr(self, "_file_read"):
raise UnpicklingError("Unpickler.__init__() was not called by " raise UnpicklingError("Unpickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,)) "%s.__init__()" % (self.__class__.__name__,))
self._unframer = _Unframer(self._file_read, self._file_readline)
self.read = self._unframer.read
self.readline = self._unframer.readline
self.mark = object() # any new unique object self.mark = object() # any new unique object
self.stack = [] self.stack = []
self.append = self.stack.append self.append = self.stack.append
self.proto = 0
read = self.read read = self.read
dispatch = self.dispatch dispatch = self.dispatch
try: try:
...@@ -822,6 +1078,8 @@ class _Unpickler: ...@@ -822,6 +1078,8 @@ class _Unpickler:
if not 0 <= proto <= HIGHEST_PROTOCOL: if not 0 <= proto <= HIGHEST_PROTOCOL:
raise ValueError("unsupported pickle protocol: %d" % proto) raise ValueError("unsupported pickle protocol: %d" % proto)
self.proto = proto self.proto = proto
if proto >= 4:
self._unframer.framing_enabled = True
dispatch[PROTO[0]] = load_proto dispatch[PROTO[0]] = load_proto
def load_persid(self): def load_persid(self):
...@@ -940,6 +1198,14 @@ class _Unpickler: ...@@ -940,6 +1198,14 @@ class _Unpickler:
self.append(str(self.read(len), 'utf-8', 'surrogatepass')) self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
dispatch[BINUNICODE[0]] = load_binunicode dispatch[BINUNICODE[0]] = load_binunicode
def load_binunicode8(self):
len, = unpack('<Q', self.read(8))
if len > maxsize:
raise UnpicklingError("BINUNICODE8 exceeds system's maximum size "
"of %d bytes" % maxsize)
self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
dispatch[BINUNICODE8[0]] = load_binunicode8
def load_short_binstring(self): def load_short_binstring(self):
len = self.read(1)[0] len = self.read(1)[0]
data = self.read(len) data = self.read(len)
...@@ -952,6 +1218,11 @@ class _Unpickler: ...@@ -952,6 +1218,11 @@ class _Unpickler:
self.append(self.read(len)) self.append(self.read(len))
dispatch[SHORT_BINBYTES[0]] = load_short_binbytes dispatch[SHORT_BINBYTES[0]] = load_short_binbytes
def load_short_binunicode(self):
len = self.read(1)[0]
self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode
def load_tuple(self): def load_tuple(self):
k = self.marker() k = self.marker()
self.stack[k:] = [tuple(self.stack[k+1:])] self.stack[k:] = [tuple(self.stack[k+1:])]
...@@ -981,6 +1252,15 @@ class _Unpickler: ...@@ -981,6 +1252,15 @@ class _Unpickler:
self.append({}) self.append({})
dispatch[EMPTY_DICT[0]] = load_empty_dictionary dispatch[EMPTY_DICT[0]] = load_empty_dictionary
def load_empty_set(self):
self.append(set())
dispatch[EMPTY_SET[0]] = load_empty_set
def load_frozenset(self):
k = self.marker()
self.stack[k:] = [frozenset(self.stack[k+1:])]
dispatch[FROZENSET[0]] = load_frozenset
def load_list(self): def load_list(self):
k = self.marker() k = self.marker()
self.stack[k:] = [self.stack[k+1:]] self.stack[k:] = [self.stack[k+1:]]
...@@ -1029,11 +1309,19 @@ class _Unpickler: ...@@ -1029,11 +1309,19 @@ class _Unpickler:
def load_newobj(self): def load_newobj(self):
args = self.stack.pop() args = self.stack.pop()
cls = self.stack[-1] cls = self.stack.pop()
obj = cls.__new__(cls, *args) obj = cls.__new__(cls, *args)
self.stack[-1] = obj self.append(obj)
dispatch[NEWOBJ[0]] = load_newobj dispatch[NEWOBJ[0]] = load_newobj
def load_newobj_ex(self):
kwargs = self.stack.pop()
args = self.stack.pop()
cls = self.stack.pop()
obj = cls.__new__(cls, *args, **kwargs)
self.append(obj)
dispatch[NEWOBJ_EX[0]] = load_newobj_ex
def load_global(self): def load_global(self):
module = self.readline()[:-1].decode("utf-8") module = self.readline()[:-1].decode("utf-8")
name = self.readline()[:-1].decode("utf-8") name = self.readline()[:-1].decode("utf-8")
...@@ -1041,6 +1329,14 @@ class _Unpickler: ...@@ -1041,6 +1329,14 @@ class _Unpickler:
self.append(klass) self.append(klass)
dispatch[GLOBAL[0]] = load_global dispatch[GLOBAL[0]] = load_global
def load_stack_global(self):
name = self.stack.pop()
module = self.stack.pop()
if type(name) is not str or type(module) is not str:
raise UnpicklingError("STACK_GLOBAL requires str")
self.append(self.find_class(module, name))
dispatch[STACK_GLOBAL[0]] = load_stack_global
def load_ext1(self): def load_ext1(self):
code = self.read(1)[0] code = self.read(1)[0]
self.get_extension(code) self.get_extension(code)
...@@ -1080,9 +1376,8 @@ class _Unpickler: ...@@ -1080,9 +1376,8 @@ class _Unpickler:
if module in _compat_pickle.IMPORT_MAPPING: if module in _compat_pickle.IMPORT_MAPPING:
module = _compat_pickle.IMPORT_MAPPING[module] module = _compat_pickle.IMPORT_MAPPING[module]
__import__(module, level=0) __import__(module, level=0)
mod = sys.modules[module] return _getattribute(sys.modules[module], name,
klass = getattr(mod, name) allow_qualname=self.proto >= 4)
return klass
def load_reduce(self): def load_reduce(self):
stack = self.stack stack = self.stack
...@@ -1146,6 +1441,11 @@ class _Unpickler: ...@@ -1146,6 +1441,11 @@ class _Unpickler:
self.memo[i] = self.stack[-1] self.memo[i] = self.stack[-1]
dispatch[LONG_BINPUT[0]] = load_long_binput dispatch[LONG_BINPUT[0]] = load_long_binput
def load_memoize(self):
memo = self.memo
memo[len(memo)] = self.stack[-1]
dispatch[MEMOIZE[0]] = load_memoize
def load_append(self): def load_append(self):
stack = self.stack stack = self.stack
value = stack.pop() value = stack.pop()
...@@ -1185,6 +1485,20 @@ class _Unpickler: ...@@ -1185,6 +1485,20 @@ class _Unpickler:
del stack[mark:] del stack[mark:]
dispatch[SETITEMS[0]] = load_setitems dispatch[SETITEMS[0]] = load_setitems
def load_additems(self):
stack = self.stack
mark = self.marker()
set_obj = stack[mark - 1]
items = stack[mark + 1:]
if isinstance(set_obj, set):
set_obj.update(items)
else:
add = set_obj.add
for item in items:
add(item)
del stack[mark:]
dispatch[ADDITEMS[0]] = load_additems
def load_build(self): def load_build(self):
stack = self.stack stack = self.stack
state = stack.pop() state = stack.pop()
...@@ -1218,86 +1532,46 @@ class _Unpickler: ...@@ -1218,86 +1532,46 @@ class _Unpickler:
raise _Stop(value) raise _Stop(value)
dispatch[STOP[0]] = load_stop dispatch[STOP[0]] = load_stop
# Encode/decode ints.
def encode_long(x):
r"""Encode a long to a two's complement little-endian binary string.
Note that 0 is a special case, returning an empty string, to save a
byte in the LONG1 pickling context.
>>> encode_long(0)
b''
>>> encode_long(255)
b'\xff\x00'
>>> encode_long(32767)
b'\xff\x7f'
>>> encode_long(-256)
b'\x00\xff'
>>> encode_long(-32768)
b'\x00\x80'
>>> encode_long(-128)
b'\x80'
>>> encode_long(127)
b'\x7f'
>>>
"""
if x == 0:
return b''
nbytes = (x.bit_length() >> 3) + 1
result = x.to_bytes(nbytes, byteorder='little', signed=True)
if x < 0 and nbytes > 1:
if result[-1] == 0xff and (result[-2] & 0x80) != 0:
result = result[:-1]
return result
def decode_long(data):
r"""Decode an int from a two's complement little-endian binary string.
>>> decode_long(b'')
0
>>> decode_long(b"\xff\x00")
255
>>> decode_long(b"\xff\x7f")
32767
>>> decode_long(b"\x00\xff")
-256
>>> decode_long(b"\x00\x80")
-32768
>>> decode_long(b"\x80")
-128
>>> decode_long(b"\x7f")
127
"""
return int.from_bytes(data, byteorder='little', signed=True)
# Shorthands # Shorthands
def dump(obj, file, protocol=None, *, fix_imports=True): def _dump(obj, file, protocol=None, *, fix_imports=True):
Pickler(file, protocol, fix_imports=fix_imports).dump(obj) _Pickler(file, protocol, fix_imports=fix_imports).dump(obj)
def dumps(obj, protocol=None, *, fix_imports=True): def _dumps(obj, protocol=None, *, fix_imports=True):
f = io.BytesIO() f = io.BytesIO()
Pickler(f, protocol, fix_imports=fix_imports).dump(obj) _Pickler(f, protocol, fix_imports=fix_imports).dump(obj)
res = f.getvalue() res = f.getvalue()
assert isinstance(res, bytes_types) assert isinstance(res, bytes_types)
return res return res
def load(file, *, fix_imports=True, encoding="ASCII", errors="strict"): def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
return Unpickler(file, fix_imports=fix_imports, return _Unpickler(file, fix_imports=fix_imports,
encoding=encoding, errors=errors).load() encoding=encoding, errors=errors).load()
def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"): def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"):
if isinstance(s, str): if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string") raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s) file = io.BytesIO(s)
return Unpickler(file, fix_imports=fix_imports, return _Unpickler(file, fix_imports=fix_imports,
encoding=encoding, errors=errors).load() encoding=encoding, errors=errors).load()
# Use the faster _pickle if possible # Use the faster _pickle if possible
try: try:
from _pickle import * from _pickle import (
PickleError,
PicklingError,
UnpicklingError,
Pickler,
Unpickler,
dump,
dumps,
load,
loads
)
except ImportError: except ImportError:
Pickler, Unpickler = _Pickler, _Unpickler Pickler, Unpickler = _Pickler, _Unpickler
dump, dumps, load, loads = _dump, _dumps, _load, _loads
# Doctest # Doctest
def _test(): def _test():
......
...@@ -11,6 +11,7 @@ dis(pickle, out=None, memo=None, indentlevel=4) ...@@ -11,6 +11,7 @@ dis(pickle, out=None, memo=None, indentlevel=4)
''' '''
import codecs import codecs
import io
import pickle import pickle
import re import re
import sys import sys
...@@ -168,6 +169,7 @@ UP_TO_NEWLINE = -1 ...@@ -168,6 +169,7 @@ UP_TO_NEWLINE = -1
TAKEN_FROM_ARGUMENT1 = -2 # num bytes is 1-byte unsigned int TAKEN_FROM_ARGUMENT1 = -2 # num bytes is 1-byte unsigned int
TAKEN_FROM_ARGUMENT4 = -3 # num bytes is 4-byte signed little-endian int TAKEN_FROM_ARGUMENT4 = -3 # num bytes is 4-byte signed little-endian int
TAKEN_FROM_ARGUMENT4U = -4 # num bytes is 4-byte unsigned little-endian int TAKEN_FROM_ARGUMENT4U = -4 # num bytes is 4-byte unsigned little-endian int
TAKEN_FROM_ARGUMENT8U = -5 # num bytes is 8-byte unsigned little-endian int
class ArgumentDescriptor(object): class ArgumentDescriptor(object):
__slots__ = ( __slots__ = (
...@@ -175,7 +177,7 @@ class ArgumentDescriptor(object): ...@@ -175,7 +177,7 @@ class ArgumentDescriptor(object):
'name', 'name',
# length of argument, in bytes; an int; UP_TO_NEWLINE and # length of argument, in bytes; an int; UP_TO_NEWLINE and
# TAKEN_FROM_ARGUMENT{1,4} are negative values for variable-length # TAKEN_FROM_ARGUMENT{1,4,8} are negative values for variable-length
# cases # cases
'n', 'n',
...@@ -196,7 +198,8 @@ class ArgumentDescriptor(object): ...@@ -196,7 +198,8 @@ class ArgumentDescriptor(object):
n in (UP_TO_NEWLINE, n in (UP_TO_NEWLINE,
TAKEN_FROM_ARGUMENT1, TAKEN_FROM_ARGUMENT1,
TAKEN_FROM_ARGUMENT4, TAKEN_FROM_ARGUMENT4,
TAKEN_FROM_ARGUMENT4U)) TAKEN_FROM_ARGUMENT4U,
TAKEN_FROM_ARGUMENT8U))
self.n = n self.n = n
self.reader = reader self.reader = reader
...@@ -288,6 +291,27 @@ uint4 = ArgumentDescriptor( ...@@ -288,6 +291,27 @@ uint4 = ArgumentDescriptor(
doc="Four-byte unsigned integer, little-endian.") doc="Four-byte unsigned integer, little-endian.")
def read_uint8(f):
r"""
>>> import io
>>> read_uint8(io.BytesIO(b'\xff\x00\x00\x00\x00\x00\x00\x00'))
255
>>> read_uint8(io.BytesIO(b'\xff' * 8)) == 2**64-1
True
"""
data = f.read(8)
if len(data) == 8:
return _unpack("<Q", data)[0]
raise ValueError("not enough data in stream to read uint8")
uint8 = ArgumentDescriptor(
name='uint8',
n=8,
reader=read_uint8,
doc="Eight-byte unsigned integer, little-endian.")
def read_stringnl(f, decode=True, stripquotes=True): def read_stringnl(f, decode=True, stripquotes=True):
r""" r"""
>>> import io >>> import io
...@@ -381,6 +405,36 @@ stringnl_noescape_pair = ArgumentDescriptor( ...@@ -381,6 +405,36 @@ stringnl_noescape_pair = ArgumentDescriptor(
a single blank separating the two strings. a single blank separating the two strings.
""") """)
def read_string1(f):
r"""
>>> import io
>>> read_string1(io.BytesIO(b"\x00"))
''
>>> read_string1(io.BytesIO(b"\x03abcdef"))
'abc'
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return data.decode("latin-1")
raise ValueError("expected %d bytes in a string1, but only %d remain" %
(n, len(data)))
string1 = ArgumentDescriptor(
name="string1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_string1,
doc="""A counted string.
The first argument is a 1-byte unsigned int giving the number
of bytes in the string, and the second argument is that many
bytes.
""")
def read_string4(f): def read_string4(f):
r""" r"""
>>> import io >>> import io
...@@ -415,28 +469,28 @@ string4 = ArgumentDescriptor( ...@@ -415,28 +469,28 @@ string4 = ArgumentDescriptor(
""") """)
def read_string1(f): def read_bytes1(f):
r""" r"""
>>> import io >>> import io
>>> read_string1(io.BytesIO(b"\x00")) >>> read_bytes1(io.BytesIO(b"\x00"))
'' b''
>>> read_string1(io.BytesIO(b"\x03abcdef")) >>> read_bytes1(io.BytesIO(b"\x03abcdef"))
'abc' b'abc'
""" """
n = read_uint1(f) n = read_uint1(f)
assert n >= 0 assert n >= 0
data = f.read(n) data = f.read(n)
if len(data) == n: if len(data) == n:
return data.decode("latin-1") return data
raise ValueError("expected %d bytes in a string1, but only %d remain" % raise ValueError("expected %d bytes in a bytes1, but only %d remain" %
(n, len(data))) (n, len(data)))
string1 = ArgumentDescriptor( bytes1 = ArgumentDescriptor(
name="string1", name="bytes1",
n=TAKEN_FROM_ARGUMENT1, n=TAKEN_FROM_ARGUMENT1,
reader=read_string1, reader=read_bytes1,
doc="""A counted string. doc="""A counted bytes string.
The first argument is a 1-byte unsigned int giving the number The first argument is a 1-byte unsigned int giving the number
of bytes in the string, and the second argument is that many of bytes in the string, and the second argument is that many
...@@ -486,6 +540,7 @@ def read_bytes4(f): ...@@ -486,6 +540,7 @@ def read_bytes4(f):
""" """
n = read_uint4(f) n = read_uint4(f)
assert n >= 0
if n > sys.maxsize: if n > sys.maxsize:
raise ValueError("bytes4 byte count > sys.maxsize: %d" % n) raise ValueError("bytes4 byte count > sys.maxsize: %d" % n)
data = f.read(n) data = f.read(n)
...@@ -505,6 +560,39 @@ bytes4 = ArgumentDescriptor( ...@@ -505,6 +560,39 @@ bytes4 = ArgumentDescriptor(
""") """)
def read_bytes8(f):
r"""
>>> import io
>>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x00abc"))
b''
>>> read_bytes8(io.BytesIO(b"\x03\x00\x00\x00\x00\x00\x00\x00abcdef"))
b'abc'
>>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x03\x00abcdef"))
Traceback (most recent call last):
...
ValueError: expected 844424930131968 bytes in a bytes8, but only 6 remain
"""
n = read_uint8(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("bytes8 byte count > sys.maxsize: %d" % n)
data = f.read(n)
if len(data) == n:
return data
raise ValueError("expected %d bytes in a bytes8, but only %d remain" %
(n, len(data)))
bytes8 = ArgumentDescriptor(
name="bytes8",
n=TAKEN_FROM_ARGUMENT8U,
reader=read_bytes8,
doc="""A counted bytes string.
The first argument is a 8-byte little-endian unsigned int giving
the number of bytes, and the second argument is that many bytes.
""")
def read_unicodestringnl(f): def read_unicodestringnl(f):
r""" r"""
>>> import io >>> import io
...@@ -530,6 +618,46 @@ unicodestringnl = ArgumentDescriptor( ...@@ -530,6 +618,46 @@ unicodestringnl = ArgumentDescriptor(
escape sequences. escape sequences.
""") """)
def read_unicodestring1(f):
r"""
>>> import io
>>> s = 'abcd\uabcd'
>>> enc = s.encode('utf-8')
>>> enc
b'abcd\xea\xaf\x8d'
>>> n = bytes([len(enc)]) # little-endian 1-byte length
>>> t = read_unicodestring1(io.BytesIO(n + enc + b'junk'))
>>> s == t
True
>>> read_unicodestring1(io.BytesIO(n + enc[:-1]))
Traceback (most recent call last):
...
ValueError: expected 7 bytes in a unicodestring1, but only 6 remain
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return str(data, 'utf-8', 'surrogatepass')
raise ValueError("expected %d bytes in a unicodestring1, but only %d "
"remain" % (n, len(data)))
unicodestring1 = ArgumentDescriptor(
name="unicodestring1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_unicodestring1,
doc="""A counted Unicode string.
The first argument is a 1-byte little-endian signed int
giving the number of bytes in the string, and the second
argument-- the UTF-8 encoding of the Unicode string --
contains that many bytes.
""")
def read_unicodestring4(f): def read_unicodestring4(f):
r""" r"""
>>> import io >>> import io
...@@ -549,6 +677,7 @@ def read_unicodestring4(f): ...@@ -549,6 +677,7 @@ def read_unicodestring4(f):
""" """
n = read_uint4(f) n = read_uint4(f)
assert n >= 0
if n > sys.maxsize: if n > sys.maxsize:
raise ValueError("unicodestring4 byte count > sys.maxsize: %d" % n) raise ValueError("unicodestring4 byte count > sys.maxsize: %d" % n)
data = f.read(n) data = f.read(n)
...@@ -570,6 +699,47 @@ unicodestring4 = ArgumentDescriptor( ...@@ -570,6 +699,47 @@ unicodestring4 = ArgumentDescriptor(
""") """)
def read_unicodestring8(f):
r"""
>>> import io
>>> s = 'abcd\uabcd'
>>> enc = s.encode('utf-8')
>>> enc
b'abcd\xea\xaf\x8d'
>>> n = bytes([len(enc)]) + bytes(7) # little-endian 8-byte length
>>> t = read_unicodestring8(io.BytesIO(n + enc + b'junk'))
>>> s == t
True
>>> read_unicodestring8(io.BytesIO(n + enc[:-1]))
Traceback (most recent call last):
...
ValueError: expected 7 bytes in a unicodestring8, but only 6 remain
"""
n = read_uint8(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("unicodestring8 byte count > sys.maxsize: %d" % n)
data = f.read(n)
if len(data) == n:
return str(data, 'utf-8', 'surrogatepass')
raise ValueError("expected %d bytes in a unicodestring8, but only %d "
"remain" % (n, len(data)))
unicodestring8 = ArgumentDescriptor(
name="unicodestring8",
n=TAKEN_FROM_ARGUMENT8U,
reader=read_unicodestring8,
doc="""A counted Unicode string.
The first argument is a 8-byte little-endian signed int
giving the number of bytes in the string, and the second
argument-- the UTF-8 encoding of the Unicode string --
contains that many bytes.
""")
def read_decimalnl_short(f): def read_decimalnl_short(f):
r""" r"""
>>> import io >>> import io
...@@ -859,6 +1029,16 @@ pydict = StackObject( ...@@ -859,6 +1029,16 @@ pydict = StackObject(
obtype=dict, obtype=dict,
doc="A Python dict object.") doc="A Python dict object.")
pyset = StackObject(
name="set",
obtype=set,
doc="A Python set object.")
pyfrozenset = StackObject(
name="frozenset",
obtype=set,
doc="A Python frozenset object.")
anyobject = StackObject( anyobject = StackObject(
name='any', name='any',
obtype=object, obtype=object,
...@@ -1142,6 +1322,19 @@ opcodes = [ ...@@ -1142,6 +1322,19 @@ opcodes = [
literally as the string content. literally as the string content.
"""), """),
I(name='BINBYTES8',
code='\x8e',
arg=bytes8,
stack_before=[],
stack_after=[pybytes],
proto=4,
doc="""Push a Python bytes object.
There are two arguments: the first is a 8-byte unsigned int giving
the number of bytes in the string, and the second is that many bytes,
which are taken literally as the string content.
"""),
# Ways to spell None. # Ways to spell None.
I(name='NONE', I(name='NONE',
...@@ -1190,6 +1383,19 @@ opcodes = [ ...@@ -1190,6 +1383,19 @@ opcodes = [
until the next newline character. until the next newline character.
"""), """),
I(name='SHORT_BINUNICODE',
code='\x8c',
arg=unicodestring1,
stack_before=[],
stack_after=[pyunicode],
proto=4,
doc="""Push a Python Unicode string object.
There are two arguments: the first is a 1-byte little-endian signed int
giving the number of bytes in the string. The second is that many
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
I(name='BINUNICODE', I(name='BINUNICODE',
code='X', code='X',
arg=unicodestring4, arg=unicodestring4,
...@@ -1203,6 +1409,19 @@ opcodes = [ ...@@ -1203,6 +1409,19 @@ opcodes = [
bytes, and is the UTF-8 encoding of the Unicode string. bytes, and is the UTF-8 encoding of the Unicode string.
"""), """),
I(name='BINUNICODE8',
code='\x8d',
arg=unicodestring8,
stack_before=[],
stack_after=[pyunicode],
proto=4,
doc="""Push a Python Unicode string object.
There are two arguments: the first is a 8-byte little-endian signed int
giving the number of bytes in the string. The second is that many
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
# Ways to spell floats. # Ways to spell floats.
I(name='FLOAT', I(name='FLOAT',
...@@ -1428,6 +1647,54 @@ opcodes = [ ...@@ -1428,6 +1647,54 @@ opcodes = [
1, 2, ..., n, and in that order. 1, 2, ..., n, and in that order.
"""), """),
# Ways to build sets
I(name='EMPTY_SET',
code='\x8f',
arg=None,
stack_before=[],
stack_after=[pyset],
proto=4,
doc="Push an empty set."),
I(name='ADDITEMS',
code='\x90',
arg=None,
stack_before=[pyset, markobject, stackslice],
stack_after=[pyset],
proto=4,
doc="""Add an arbitrary number of items to an existing set.
The slice of the stack following the topmost markobject is taken as
a sequence of items, added to the set immediately under the topmost
markobject. Everything at and after the topmost markobject is popped,
leaving the mutated set at the top of the stack.
Stack before: ... pyset markobject item_1 ... item_n
Stack after: ... pyset
where pyset has been modified via pyset.add(item_i) = item_i for i in
1, 2, ..., n, and in that order.
"""),
# Way to build frozensets
I(name='FROZENSET',
code='\x91',
arg=None,
stack_before=[markobject, stackslice],
stack_after=[pyfrozenset],
proto=4,
doc="""Build a frozenset out of the topmost slice, after markobject.
All the stack entries following the topmost markobject are placed into
a single Python frozenset, which single frozenset object replaces all
of the stack from the topmost markobject onward. For example,
Stack before: ... markobject 1 2 3
Stack after: ... frozenset({1, 2, 3})
"""),
# Stack manipulation. # Stack manipulation.
I(name='POP', I(name='POP',
...@@ -1549,6 +1816,18 @@ opcodes = [ ...@@ -1549,6 +1816,18 @@ opcodes = [
unsigned little-endian integer following. unsigned little-endian integer following.
"""), """),
I(name='MEMOIZE',
code='\x94',
arg=None,
stack_before=[anyobject],
stack_after=[anyobject],
proto=4,
doc="""Store the stack top into the memo. The stack is not popped.
The index of the memo location to write is the number of
elements currently present in the memo.
"""),
# Access the extension registry (predefined objects). Akin to the GET # Access the extension registry (predefined objects). Akin to the GET
# family. # family.
...@@ -1614,6 +1893,15 @@ opcodes = [ ...@@ -1614,6 +1893,15 @@ opcodes = [
stack, so unpickling subclasses can override this form of lookup. stack, so unpickling subclasses can override this form of lookup.
"""), """),
I(name='STACK_GLOBAL',
code='\x93',
arg=None,
stack_before=[pyunicode, pyunicode],
stack_after=[anyobject],
proto=0,
doc="""Push a global object (module.attr) on the stack.
"""),
# Ways to build objects of classes pickle doesn't know about directly # Ways to build objects of classes pickle doesn't know about directly
# (user-defined classes). I despair of documenting this accurately # (user-defined classes). I despair of documenting this accurately
# and comprehensibly -- you really have to read the pickle code to # and comprehensibly -- you really have to read the pickle code to
...@@ -1770,6 +2058,21 @@ opcodes = [ ...@@ -1770,6 +2058,21 @@ opcodes = [
onto the stack. onto the stack.
"""), """),
I(name='NEWOBJ_EX',
code='\x92',
arg=None,
stack_before=[anyobject, anyobject, anyobject],
stack_after=[anyobject],
proto=4,
doc="""Build an object instance.
The stack before should be thought of as containing a class
object followed by an argument tuple and by a keyword argument dict
(the dict being the stack top). Call these cls and args. They are
popped off the stack, and the value returned by
cls.__new__(cls, *args, *kwargs) is pushed back onto the stack.
"""),
# Machine control. # Machine control.
I(name='PROTO', I(name='PROTO',
...@@ -1797,6 +2100,20 @@ opcodes = [ ...@@ -1797,6 +2100,20 @@ opcodes = [
empty then. empty then.
"""), """),
# Framing support.
I(name='FRAME',
code='\x95',
arg=uint8,
stack_before=[],
stack_after=[],
proto=4,
doc="""Indicate the beginning of a new frame.
The unpickler may use this opcode to safely prefetch data from its
underlying stream.
"""),
# Ways to deal with persistent IDs. # Ways to deal with persistent IDs.
I(name='PERSID', I(name='PERSID',
...@@ -1903,6 +2220,38 @@ del assure_pickle_consistency ...@@ -1903,6 +2220,38 @@ del assure_pickle_consistency
############################################################################## ##############################################################################
# A pickle opcode generator. # A pickle opcode generator.
def _genops(data, yield_end_pos=False):
if isinstance(data, bytes_types):
data = io.BytesIO(data)
if hasattr(data, "tell"):
getpos = data.tell
else:
getpos = lambda: None
while True:
pos = getpos()
code = data.read(1)
opcode = code2op.get(code.decode("latin-1"))
if opcode is None:
if code == b"":
raise ValueError("pickle exhausted before seeing STOP")
else:
raise ValueError("at position %s, opcode %r unknown" % (
"<unknown>" if pos is None else pos,
code))
if opcode.arg is None:
arg = None
else:
arg = opcode.arg.reader(data)
if yield_end_pos:
yield opcode, arg, pos, getpos()
else:
yield opcode, arg, pos
if code == b'.':
assert opcode.name == 'STOP'
break
def genops(pickle): def genops(pickle):
"""Generate all the opcodes in a pickle. """Generate all the opcodes in a pickle.
...@@ -1926,62 +2275,47 @@ def genops(pickle): ...@@ -1926,62 +2275,47 @@ def genops(pickle):
used. Else (the pickle doesn't have a tell(), and it's not obvious how used. Else (the pickle doesn't have a tell(), and it's not obvious how
to query its current position) pos is None. to query its current position) pos is None.
""" """
return _genops(pickle)
if isinstance(pickle, bytes_types):
import io
pickle = io.BytesIO(pickle)
if hasattr(pickle, "tell"):
getpos = pickle.tell
else:
getpos = lambda: None
while True:
pos = getpos()
code = pickle.read(1)
opcode = code2op.get(code.decode("latin-1"))
if opcode is None:
if code == b"":
raise ValueError("pickle exhausted before seeing STOP")
else:
raise ValueError("at position %s, opcode %r unknown" % (
pos is None and "<unknown>" or pos,
code))
if opcode.arg is None:
arg = None
else:
arg = opcode.arg.reader(pickle)
yield opcode, arg, pos
if code == b'.':
assert opcode.name == 'STOP'
break
############################################################################## ##############################################################################
# A pickle optimizer. # A pickle optimizer.
def optimize(p): def optimize(p):
'Optimize a pickle string by removing unused PUT opcodes' 'Optimize a pickle string by removing unused PUT opcodes'
gets = set() # set of args used by a GET opcode not_a_put = object()
puts = [] # (arg, startpos, stoppos) for the PUT opcodes gets = { not_a_put } # set of args used by a GET opcode
prevpos = None # set to pos if previous opcode was a PUT opcodes = [] # (startpos, stoppos, putid)
for opcode, arg, pos in genops(p): proto = 0
if prevpos is not None: for opcode, arg, pos, end_pos in _genops(p, yield_end_pos=True):
puts.append((prevarg, prevpos, pos))
prevpos = None
if 'PUT' in opcode.name: if 'PUT' in opcode.name:
prevarg, prevpos = arg, pos opcodes.append((pos, end_pos, arg))
elif 'GET' in opcode.name: elif 'FRAME' in opcode.name:
pass
else:
if 'GET' in opcode.name:
gets.add(arg) gets.add(arg)
elif opcode.name == 'PROTO':
# Copy the pickle string except for PUTS without a corresponding GET assert pos == 0, pos
s = [] proto = arg
i = 0 opcodes.append((pos, end_pos, not_a_put))
for arg, start, stop in puts: prevpos, prevarg = pos, None
j = stop if (arg in gets) else start
s.append(p[i:j]) # Copy the opcodes except for PUTS without a corresponding GET
i = stop out = io.BytesIO()
s.append(p[i:]) opcodes = iter(opcodes)
return b''.join(s) if proto >= 2:
# Write the PROTO header before any framing
start, stop, _ = next(opcodes)
out.write(p[start:stop])
buf = pickle._Framer(out.write)
if proto >= 4:
buf.start_framing()
for start, stop, putid in opcodes:
if putid in gets:
buf.write(p[start:stop])
if proto >= 4:
buf.end_framing()
return out.getvalue()
############################################################################## ##############################################################################
# A symbolic pickle disassembler. # A symbolic pickle disassembler.
...@@ -2081,17 +2415,20 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0): ...@@ -2081,17 +2415,20 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0):
errormsg = markmsg = "no MARK exists on stack" errormsg = markmsg = "no MARK exists on stack"
# Check for correct memo usage. # Check for correct memo usage.
if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT"): if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT", "MEMOIZE"):
if opcode.name == "MEMOIZE":
memo_idx = len(memo)
else:
assert arg is not None assert arg is not None
if arg in memo: memo_idx = arg
if memo_idx in memo:
errormsg = "memo key %r already defined" % arg errormsg = "memo key %r already defined" % arg
elif not stack: elif not stack:
errormsg = "stack is empty -- can't store into memo" errormsg = "stack is empty -- can't store into memo"
elif stack[-1] is markobject: elif stack[-1] is markobject:
errormsg = "can't store markobject in the memo" errormsg = "can't store markobject in the memo"
else: else:
memo[arg] = stack[-1] memo[memo_idx] = stack[-1]
elif opcode.name in ("GET", "BINGET", "LONG_BINGET"): elif opcode.name in ("GET", "BINGET", "LONG_BINGET"):
if arg in memo: if arg in memo:
assert len(after) == 1 assert len(after) == 1
......
import copyreg
import io import io
import unittest
import pickle import pickle
import pickletools import pickletools
import random
import sys import sys
import copyreg import unittest
import weakref import weakref
from http.cookies import SimpleCookie from http.cookies import SimpleCookie
...@@ -95,6 +96,9 @@ class E(C): ...@@ -95,6 +96,9 @@ class E(C):
def __getinitargs__(self): def __getinitargs__(self):
return () return ()
class H(object):
pass
import __main__ import __main__
__main__.C = C __main__.C = C
C.__module__ = "__main__" C.__module__ = "__main__"
...@@ -102,6 +106,8 @@ __main__.D = D ...@@ -102,6 +106,8 @@ __main__.D = D
D.__module__ = "__main__" D.__module__ = "__main__"
__main__.E = E __main__.E = E
E.__module__ = "__main__" E.__module__ = "__main__"
__main__.H = H
H.__module__ = "__main__"
class myint(int): class myint(int):
def __init__(self, x): def __init__(self, x):
...@@ -428,6 +434,7 @@ def create_data(): ...@@ -428,6 +434,7 @@ def create_data():
x.append(5) x.append(5)
return x return x
class AbstractPickleTests(unittest.TestCase): class AbstractPickleTests(unittest.TestCase):
# Subclass must define self.dumps, self.loads. # Subclass must define self.dumps, self.loads.
...@@ -436,23 +443,41 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -436,23 +443,41 @@ class AbstractPickleTests(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
def assert_is_copy(self, obj, objcopy, msg=None):
"""Utility method to verify if two objects are copies of each others.
"""
if msg is None:
msg = "{!r} is not a copy of {!r}".format(obj, objcopy)
self.assertEqual(obj, objcopy, msg=msg)
self.assertIs(type(obj), type(objcopy), msg=msg)
if hasattr(obj, '__dict__'):
self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg)
self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg)
if hasattr(obj, '__slots__'):
self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg)
for slot in obj.__slots__:
self.assertEqual(
hasattr(obj, slot), hasattr(objcopy, slot), msg=msg)
self.assertEqual(getattr(obj, slot, None),
getattr(objcopy, slot, None), msg=msg)
def test_misc(self): def test_misc(self):
# test various datatypes not tested by testdata # test various datatypes not tested by testdata
for proto in protocols: for proto in protocols:
x = myint(4) x = myint(4)
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
x = (1, ()) x = (1, ())
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
x = initarg(1, x) x = initarg(1, x)
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
# XXX test __reduce__ protocol? # XXX test __reduce__ protocol?
...@@ -461,16 +486,16 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -461,16 +486,16 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(expected, proto) s = self.dumps(expected, proto)
got = self.loads(s) got = self.loads(s)
self.assertEqual(expected, got) self.assert_is_copy(expected, got)
def test_load_from_data0(self): def test_load_from_data0(self):
self.assertEqual(self._testdata, self.loads(DATA0)) self.assert_is_copy(self._testdata, self.loads(DATA0))
def test_load_from_data1(self): def test_load_from_data1(self):
self.assertEqual(self._testdata, self.loads(DATA1)) self.assert_is_copy(self._testdata, self.loads(DATA1))
def test_load_from_data2(self): def test_load_from_data2(self):
self.assertEqual(self._testdata, self.loads(DATA2)) self.assert_is_copy(self._testdata, self.loads(DATA2))
def test_load_classic_instance(self): def test_load_classic_instance(self):
# See issue5180. Test loading 2.x pickles that # See issue5180. Test loading 2.x pickles that
...@@ -492,7 +517,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -492,7 +517,7 @@ class AbstractPickleTests(unittest.TestCase):
b"X\n" b"X\n"
b"p0\n" b"p0\n"
b"(dp1\nb.").replace(b'X', xname) b"(dp1\nb.").replace(b'X', xname)
self.assertEqual(X(*args), self.loads(pickle0)) self.assert_is_copy(X(*args), self.loads(pickle0))
# Protocol 1 (binary mode pickle) # Protocol 1 (binary mode pickle)
""" """
...@@ -509,7 +534,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -509,7 +534,7 @@ class AbstractPickleTests(unittest.TestCase):
pickle1 = (b'(c__main__\n' pickle1 = (b'(c__main__\n'
b'X\n' b'X\n'
b'q\x00oq\x01}q\x02b.').replace(b'X', xname) b'q\x00oq\x01}q\x02b.').replace(b'X', xname)
self.assertEqual(X(*args), self.loads(pickle1)) self.assert_is_copy(X(*args), self.loads(pickle1))
# Protocol 2 (pickle2 = b'\x80\x02' + pickle1) # Protocol 2 (pickle2 = b'\x80\x02' + pickle1)
""" """
...@@ -527,7 +552,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -527,7 +552,7 @@ class AbstractPickleTests(unittest.TestCase):
pickle2 = (b'\x80\x02(c__main__\n' pickle2 = (b'\x80\x02(c__main__\n'
b'X\n' b'X\n'
b'q\x00oq\x01}q\x02b.').replace(b'X', xname) b'q\x00oq\x01}q\x02b.').replace(b'X', xname)
self.assertEqual(X(*args), self.loads(pickle2)) self.assert_is_copy(X(*args), self.loads(pickle2))
# There are gratuitous differences between pickles produced by # There are gratuitous differences between pickles produced by
# pickle and cPickle, largely because cPickle starts PUT indices at # pickle and cPickle, largely because cPickle starts PUT indices at
...@@ -552,6 +577,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -552,6 +577,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(l, proto) s = self.dumps(l, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, list)
self.assertEqual(len(x), 1) self.assertEqual(len(x), 1)
self.assertTrue(x is x[0]) self.assertTrue(x is x[0])
...@@ -561,6 +587,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -561,6 +587,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(t, proto) s = self.dumps(t, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1) self.assertEqual(len(x), 1)
self.assertEqual(len(x[0]), 1) self.assertEqual(len(x[0]), 1)
self.assertTrue(x is x[0][0]) self.assertTrue(x is x[0][0])
...@@ -571,15 +598,39 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -571,15 +598,39 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(d, proto) s = self.dumps(d, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, dict)
self.assertEqual(list(x.keys()), [1]) self.assertEqual(list(x.keys()), [1])
self.assertTrue(x[1] is x) self.assertTrue(x[1] is x)
def test_recursive_set(self):
h = H()
y = set({h})
h.attr = y
for proto in protocols:
s = self.dumps(y, proto)
x = self.loads(s)
self.assertIsInstance(x, set)
self.assertIs(list(x)[0].attr, x)
self.assertEqual(len(x), 1)
def test_recursive_frozenset(self):
h = H()
y = frozenset({h})
h.attr = y
for proto in protocols:
s = self.dumps(y, proto)
x = self.loads(s)
self.assertIsInstance(x, frozenset)
self.assertIs(list(x)[0].attr, x)
self.assertEqual(len(x), 1)
def test_recursive_inst(self): def test_recursive_inst(self):
i = C() i = C()
i.attr = i i.attr = i
for proto in protocols: for proto in protocols:
s = self.dumps(i, proto) s = self.dumps(i, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, C)
self.assertEqual(dir(x), dir(i)) self.assertEqual(dir(x), dir(i))
self.assertIs(x.attr, x) self.assertIs(x.attr, x)
...@@ -592,6 +643,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -592,6 +643,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(l, proto) s = self.dumps(l, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, list)
self.assertEqual(len(x), 1) self.assertEqual(len(x), 1)
self.assertEqual(dir(x[0]), dir(i)) self.assertEqual(dir(x[0]), dir(i))
self.assertEqual(list(x[0].attr.keys()), [1]) self.assertEqual(list(x[0].attr.keys()), [1])
...@@ -599,7 +651,8 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -599,7 +651,8 @@ class AbstractPickleTests(unittest.TestCase):
def test_get(self): def test_get(self):
self.assertRaises(KeyError, self.loads, b'g0\np0') self.assertRaises(KeyError, self.loads, b'g0\np0')
self.assertEqual(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)]) self.assert_is_copy([(100,), (100,)],
self.loads(b'((Kdtp0\nh\x00l.))'))
def test_unicode(self): def test_unicode(self):
endcases = ['', '<\\u>', '<\\\u1234>', '<\n>', endcases = ['', '<\\u>', '<\\\u1234>', '<\n>',
...@@ -610,26 +663,26 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -610,26 +663,26 @@ class AbstractPickleTests(unittest.TestCase):
for u in endcases: for u in endcases:
p = self.dumps(u, proto) p = self.dumps(u, proto)
u2 = self.loads(p) u2 = self.loads(p)
self.assertEqual(u2, u) self.assert_is_copy(u, u2)
def test_unicode_high_plane(self): def test_unicode_high_plane(self):
t = '\U00012345' t = '\U00012345'
for proto in protocols: for proto in protocols:
p = self.dumps(t, proto) p = self.dumps(t, proto)
t2 = self.loads(p) t2 = self.loads(p)
self.assertEqual(t2, t) self.assert_is_copy(t, t2)
def test_bytes(self): def test_bytes(self):
for proto in protocols: for proto in protocols:
for s in b'', b'xyz', b'xyz'*100: for s in b'', b'xyz', b'xyz'*100:
p = self.dumps(s, proto) p = self.dumps(s, proto)
self.assertEqual(self.loads(p), s) self.assert_is_copy(s, self.loads(p))
for s in [bytes([i]) for i in range(256)]: for s in [bytes([i]) for i in range(256)]:
p = self.dumps(s, proto) p = self.dumps(s, proto)
self.assertEqual(self.loads(p), s) self.assert_is_copy(s, self.loads(p))
for s in [bytes([i, i]) for i in range(256)]: for s in [bytes([i, i]) for i in range(256)]:
p = self.dumps(s, proto) p = self.dumps(s, proto)
self.assertEqual(self.loads(p), s) self.assert_is_copy(s, self.loads(p))
def test_ints(self): def test_ints(self):
import sys import sys
...@@ -639,14 +692,14 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -639,14 +692,14 @@ class AbstractPickleTests(unittest.TestCase):
for expected in (-n, n): for expected in (-n, n):
s = self.dumps(expected, proto) s = self.dumps(expected, proto)
n2 = self.loads(s) n2 = self.loads(s)
self.assertEqual(expected, n2) self.assert_is_copy(expected, n2)
n = n >> 1 n = n >> 1
def test_maxint64(self): def test_maxint64(self):
maxint64 = (1 << 63) - 1 maxint64 = (1 << 63) - 1
data = b'I' + str(maxint64).encode("ascii") + b'\n.' data = b'I' + str(maxint64).encode("ascii") + b'\n.'
got = self.loads(data) got = self.loads(data)
self.assertEqual(got, maxint64) self.assert_is_copy(maxint64, got)
# Try too with a bogus literal. # Try too with a bogus literal.
data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.' data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.'
...@@ -661,7 +714,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -661,7 +714,7 @@ class AbstractPickleTests(unittest.TestCase):
for n in npos, -npos: for n in npos, -npos:
pickle = self.dumps(n, proto) pickle = self.dumps(n, proto)
got = self.loads(pickle) got = self.loads(pickle)
self.assertEqual(n, got) self.assert_is_copy(n, got)
# Try a monster. This is quadratic-time in protos 0 & 1, so don't # Try a monster. This is quadratic-time in protos 0 & 1, so don't
# bother with those. # bother with those.
nbase = int("deadbeeffeedface", 16) nbase = int("deadbeeffeedface", 16)
...@@ -669,7 +722,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -669,7 +722,7 @@ class AbstractPickleTests(unittest.TestCase):
for n in nbase, -nbase: for n in nbase, -nbase:
p = self.dumps(n, 2) p = self.dumps(n, 2)
got = self.loads(p) got = self.loads(p)
self.assertEqual(n, got) self.assert_is_copy(n, got)
def test_float(self): def test_float(self):
test_values = [0.0, 4.94e-324, 1e-310, 7e-308, 6.626e-34, 0.1, 0.5, test_values = [0.0, 4.94e-324, 1e-310, 7e-308, 6.626e-34, 0.1, 0.5,
...@@ -679,7 +732,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -679,7 +732,7 @@ class AbstractPickleTests(unittest.TestCase):
for value in test_values: for value in test_values:
pickle = self.dumps(value, proto) pickle = self.dumps(value, proto)
got = self.loads(pickle) got = self.loads(pickle)
self.assertEqual(value, got) self.assert_is_copy(value, got)
@run_with_locale('LC_ALL', 'de_DE', 'fr_FR') @run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
def test_float_format(self): def test_float_format(self):
...@@ -711,6 +764,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -711,6 +764,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(a, proto) s = self.dumps(a, proto)
b = self.loads(s) b = self.loads(s)
self.assertEqual(a, b) self.assertEqual(a, b)
self.assertIs(type(a), type(b))
def test_structseq(self): def test_structseq(self):
import time import time
...@@ -720,48 +774,48 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -720,48 +774,48 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(t, proto) s = self.dumps(t, proto)
u = self.loads(s) u = self.loads(s)
self.assertEqual(t, u) self.assert_is_copy(t, u)
if hasattr(os, "stat"): if hasattr(os, "stat"):
t = os.stat(os.curdir) t = os.stat(os.curdir)
s = self.dumps(t, proto) s = self.dumps(t, proto)
u = self.loads(s) u = self.loads(s)
self.assertEqual(t, u) self.assert_is_copy(t, u)
if hasattr(os, "statvfs"): if hasattr(os, "statvfs"):
t = os.statvfs(os.curdir) t = os.statvfs(os.curdir)
s = self.dumps(t, proto) s = self.dumps(t, proto)
u = self.loads(s) u = self.loads(s)
self.assertEqual(t, u) self.assert_is_copy(t, u)
def test_ellipsis(self): def test_ellipsis(self):
for proto in protocols: for proto in protocols:
s = self.dumps(..., proto) s = self.dumps(..., proto)
u = self.loads(s) u = self.loads(s)
self.assertEqual(..., u) self.assertIs(..., u)
def test_notimplemented(self): def test_notimplemented(self):
for proto in protocols: for proto in protocols:
s = self.dumps(NotImplemented, proto) s = self.dumps(NotImplemented, proto)
u = self.loads(s) u = self.loads(s)
self.assertEqual(NotImplemented, u) self.assertIs(NotImplemented, u)
# Tests for protocol 2 # Tests for protocol 2
def test_proto(self): def test_proto(self):
build_none = pickle.NONE + pickle.STOP
for proto in protocols: for proto in protocols:
expected = build_none pickled = self.dumps(None, proto)
if proto >= 2: if proto >= 2:
expected = pickle.PROTO + bytes([proto]) + expected proto_header = pickle.PROTO + bytes([proto])
p = self.dumps(None, proto) self.assertTrue(pickled.startswith(proto_header))
self.assertEqual(p, expected) else:
self.assertEqual(count_opcode(pickle.PROTO, pickled), 0)
oob = protocols[-1] + 1 # a future protocol oob = protocols[-1] + 1 # a future protocol
build_none = pickle.NONE + pickle.STOP
badpickle = pickle.PROTO + bytes([oob]) + build_none badpickle = pickle.PROTO + bytes([oob]) + build_none
try: try:
self.loads(badpickle) self.loads(badpickle)
except ValueError as detail: except ValueError as err:
self.assertTrue(str(detail).startswith( self.assertIn("unsupported pickle protocol", str(err))
"unsupported pickle protocol"))
else: else:
self.fail("expected bad protocol number to raise ValueError") self.fail("expected bad protocol number to raise ValueError")
...@@ -770,7 +824,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -770,7 +824,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2) self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2)
def test_long4(self): def test_long4(self):
...@@ -778,7 +832,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -778,7 +832,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2) self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2)
def test_short_tuples(self): def test_short_tuples(self):
...@@ -816,9 +870,9 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -816,9 +870,9 @@ class AbstractPickleTests(unittest.TestCase):
for x in a, b, c, d, e: for x in a, b, c, d, e:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y, (proto, x, s, y)) self.assert_is_copy(x, y)
expected = expected_opcode[proto, len(x)] expected = expected_opcode[min(proto, 3), len(x)]
self.assertEqual(opcode_in_pickle(expected, s), True) self.assertTrue(opcode_in_pickle(expected, s))
def test_singletons(self): def test_singletons(self):
# Map (proto, singleton) to expected opcode. # Map (proto, singleton) to expected opcode.
...@@ -842,8 +896,8 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -842,8 +896,8 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertTrue(x is y, (proto, x, s, y)) self.assertTrue(x is y, (proto, x, s, y))
expected = expected_opcode[proto, x] expected = expected_opcode[min(proto, 3), x]
self.assertEqual(opcode_in_pickle(expected, s), True) self.assertTrue(opcode_in_pickle(expected, s))
def test_newobj_tuple(self): def test_newobj_tuple(self):
x = MyTuple([1, 2, 3]) x = MyTuple([1, 2, 3])
...@@ -852,8 +906,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -852,8 +906,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(tuple(x), tuple(y)) self.assert_is_copy(x, y)
self.assertEqual(x.__dict__, y.__dict__)
def test_newobj_list(self): def test_newobj_list(self):
x = MyList([1, 2, 3]) x = MyList([1, 2, 3])
...@@ -862,8 +915,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -862,8 +915,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(list(x), list(y)) self.assert_is_copy(x, y)
self.assertEqual(x.__dict__, y.__dict__)
def test_newobj_generic(self): def test_newobj_generic(self):
for proto in protocols: for proto in protocols:
...@@ -874,6 +926,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -874,6 +926,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
detail = (proto, C, B, x, y, type(y)) detail = (proto, C, B, x, y, type(y))
self.assert_is_copy(x, y) # XXX revisit
self.assertEqual(B(x), B(y), detail) self.assertEqual(B(x), B(y), detail)
self.assertEqual(x.__dict__, y.__dict__, detail) self.assertEqual(x.__dict__, y.__dict__, detail)
...@@ -912,11 +965,10 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -912,11 +965,10 @@ class AbstractPickleTests(unittest.TestCase):
s1 = self.dumps(x, 1) s1 = self.dumps(x, 1)
self.assertIn(__name__.encode("utf-8"), s1) self.assertIn(__name__.encode("utf-8"), s1)
self.assertIn(b"MyList", s1) self.assertIn(b"MyList", s1)
self.assertEqual(opcode_in_pickle(opcode, s1), False) self.assertFalse(opcode_in_pickle(opcode, s1))
y = self.loads(s1) y = self.loads(s1)
self.assertEqual(list(x), list(y)) self.assert_is_copy(x, y)
self.assertEqual(x.__dict__, y.__dict__)
# Dump using protocol 2 for test. # Dump using protocol 2 for test.
s2 = self.dumps(x, 2) s2 = self.dumps(x, 2)
...@@ -925,9 +977,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -925,9 +977,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2)) self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2))
y = self.loads(s2) y = self.loads(s2)
self.assertEqual(list(x), list(y)) self.assert_is_copy(x, y)
self.assertEqual(x.__dict__, y.__dict__)
finally: finally:
e.restore() e.restore()
...@@ -951,7 +1001,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -951,7 +1001,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
num_appends = count_opcode(pickle.APPENDS, s) num_appends = count_opcode(pickle.APPENDS, s)
self.assertEqual(num_appends, proto > 0) self.assertEqual(num_appends, proto > 0)
...@@ -960,7 +1010,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -960,7 +1010,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
num_appends = count_opcode(pickle.APPENDS, s) num_appends = count_opcode(pickle.APPENDS, s)
if proto == 0: if proto == 0:
self.assertEqual(num_appends, 0) self.assertEqual(num_appends, 0)
...@@ -974,7 +1024,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -974,7 +1024,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto) s = self.dumps(x, proto)
self.assertIsInstance(s, bytes_types) self.assertIsInstance(s, bytes_types)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s) num_setitems = count_opcode(pickle.SETITEMS, s)
self.assertEqual(num_setitems, proto > 0) self.assertEqual(num_setitems, proto > 0)
...@@ -983,22 +1033,49 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -983,22 +1033,49 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
y = self.loads(s) y = self.loads(s)
self.assertEqual(x, y) self.assert_is_copy(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s) num_setitems = count_opcode(pickle.SETITEMS, s)
if proto == 0: if proto == 0:
self.assertEqual(num_setitems, 0) self.assertEqual(num_setitems, 0)
else: else:
self.assertTrue(num_setitems >= 2) self.assertTrue(num_setitems >= 2)
def test_set_chunking(self):
n = 10 # too small to chunk
x = set(range(n))
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assert_is_copy(x, y)
num_additems = count_opcode(pickle.ADDITEMS, s)
if proto < 4:
self.assertEqual(num_additems, 0)
else:
self.assertEqual(num_additems, 1)
n = 2500 # expect at least two chunks when proto >= 4
x = set(range(n))
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assert_is_copy(x, y)
num_additems = count_opcode(pickle.ADDITEMS, s)
if proto < 4:
self.assertEqual(num_additems, 0)
else:
self.assertGreaterEqual(num_additems, 2)
def test_simple_newobj(self): def test_simple_newobj(self):
x = object.__new__(SimpleNewObj) # avoid __init__ x = object.__new__(SimpleNewObj) # avoid __init__
x.abc = 666 x.abc = 666
for proto in protocols: for proto in protocols:
s = self.dumps(x, proto) s = self.dumps(x, proto)
self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2) self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s),
2 <= proto < 4)
self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s),
proto >= 4)
y = self.loads(s) # will raise TypeError if __init__ called y = self.loads(s) # will raise TypeError if __init__ called
self.assertEqual(y.abc, 666) self.assert_is_copy(x, y)
self.assertEqual(x.__dict__, y.__dict__)
def test_newobj_list_slots(self): def test_newobj_list_slots(self):
x = SlotList([1, 2, 3]) x = SlotList([1, 2, 3])
...@@ -1006,10 +1083,7 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -1006,10 +1083,7 @@ class AbstractPickleTests(unittest.TestCase):
x.bar = "hello" x.bar = "hello"
s = self.dumps(x, 2) s = self.dumps(x, 2)
y = self.loads(s) y = self.loads(s)
self.assertEqual(list(x), list(y)) self.assert_is_copy(x, y)
self.assertEqual(x.__dict__, y.__dict__)
self.assertEqual(x.foo, y.foo)
self.assertEqual(x.bar, y.bar)
def test_reduce_overrides_default_reduce_ex(self): def test_reduce_overrides_default_reduce_ex(self):
for proto in protocols: for proto in protocols:
...@@ -1058,11 +1132,10 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -1058,11 +1132,10 @@ class AbstractPickleTests(unittest.TestCase):
@no_tracing @no_tracing
def test_bad_getattr(self): def test_bad_getattr(self):
# Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr() x = BadGetattr()
for proto in 0, 1: for proto in protocols:
self.assertRaises(RuntimeError, self.dumps, x, proto) self.assertRaises(RuntimeError, self.dumps, x, proto)
# protocol 2 don't raise a RuntimeError.
d = self.dumps(x, 2)
def test_reduce_bad_iterator(self): def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__() # Issue4176: crash when 4th and 5th items of __reduce__()
...@@ -1095,11 +1168,10 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -1095,11 +1168,10 @@ class AbstractPickleTests(unittest.TestCase):
obj = [dict(large_dict), dict(large_dict), dict(large_dict)] obj = [dict(large_dict), dict(large_dict), dict(large_dict)]
for proto in protocols: for proto in protocols:
with self.subTest(proto=proto):
dumped = self.dumps(obj, proto) dumped = self.dumps(obj, proto)
loaded = self.loads(dumped) loaded = self.loads(dumped)
self.assertEqual(loaded, obj, self.assert_is_copy(obj, loaded)
"Failed protocol %d: %r != %r"
% (proto, obj, loaded))
def test_attribute_name_interning(self): def test_attribute_name_interning(self):
# Test that attribute names of pickled objects are interned when # Test that attribute names of pickled objects are interned when
...@@ -1155,11 +1227,14 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -1155,11 +1227,14 @@ class AbstractPickleTests(unittest.TestCase):
def test_int_pickling_efficiency(self): def test_int_pickling_efficiency(self):
# Test compacity of int representation (see issue #12744) # Test compacity of int representation (see issue #12744)
for proto in protocols: for proto in protocols:
sizes = [len(self.dumps(2**n, proto)) for n in range(70)] with self.subTest(proto=proto):
pickles = [self.dumps(2**n, proto) for n in range(70)]
sizes = list(map(len, pickles))
# the size function is monotonic # the size function is monotonic
self.assertEqual(sorted(sizes), sizes) self.assertEqual(sorted(sizes), sizes)
if proto >= 2: if proto >= 2:
self.assertLessEqual(sizes[-1], 14) for p in pickles:
self.assertFalse(opcode_in_pickle(pickle.LONG, p))
def check_negative_32b_binXXX(self, dumped): def check_negative_32b_binXXX(self, dumped):
if sys.maxsize > 2**32: if sys.maxsize > 2**32:
...@@ -1242,6 +1317,137 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -1242,6 +1317,137 @@ class AbstractPickleTests(unittest.TestCase):
else: else:
self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto) self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto)
# Exercise framing (proto >= 4) for significant workloads
FRAME_SIZE_TARGET = 64 * 1024
def test_framing_many_objects(self):
obj = list(range(10**5))
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
pickled = self.dumps(obj, proto)
unpickled = self.loads(pickled)
self.assertEqual(obj, unpickled)
# Test the framing heuristic is sane,
# assuming a given frame size target.
bytes_per_frame = (len(pickled) /
pickled.count(b'\x00\x00\x00\x00\x00'))
self.assertGreater(bytes_per_frame,
self.FRAME_SIZE_TARGET / 2)
self.assertLessEqual(bytes_per_frame,
self.FRAME_SIZE_TARGET * 1)
def test_framing_large_objects(self):
N = 1024 * 1024
obj = [b'x' * N, b'y' * N, b'z' * N]
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
pickled = self.dumps(obj, proto)
unpickled = self.loads(pickled)
self.assertEqual(obj, unpickled)
# At least one frame was emitted per large bytes object.
n_frames = pickled.count(b'\x00\x00\x00\x00\x00')
self.assertGreaterEqual(n_frames, len(obj))
def test_nested_names(self):
global Nested
class Nested:
class A:
class B:
class C:
pass
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
for obj in [Nested.A, Nested.A.B, Nested.A.B.C]:
with self.subTest(proto=proto, obj=obj):
unpickled = self.loads(self.dumps(obj, proto))
self.assertIs(obj, unpickled)
def test_py_methods(self):
global PyMethodsTest
class PyMethodsTest:
@staticmethod
def cheese():
return "cheese"
@classmethod
def wine(cls):
assert cls is PyMethodsTest
return "wine"
def biscuits(self):
assert isinstance(self, PyMethodsTest)
return "biscuits"
class Nested:
"Nested class"
@staticmethod
def ketchup():
return "ketchup"
@classmethod
def maple(cls):
assert cls is PyMethodsTest.Nested
return "maple"
def pie(self):
assert isinstance(self, PyMethodsTest.Nested)
return "pie"
py_methods = (
PyMethodsTest.cheese,
PyMethodsTest.wine,
PyMethodsTest().biscuits,
PyMethodsTest.Nested.ketchup,
PyMethodsTest.Nested.maple,
PyMethodsTest.Nested().pie
)
py_unbound_methods = (
(PyMethodsTest.biscuits, PyMethodsTest),
(PyMethodsTest.Nested.pie, PyMethodsTest.Nested)
)
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
for method in py_methods:
with self.subTest(proto=proto, method=method):
unpickled = self.loads(self.dumps(method, proto))
self.assertEqual(method(), unpickled())
for method, cls in py_unbound_methods:
obj = cls()
with self.subTest(proto=proto, method=method):
unpickled = self.loads(self.dumps(method, proto))
self.assertEqual(method(obj), unpickled(obj))
def test_c_methods(self):
global Subclass
class Subclass(tuple):
class Nested(str):
pass
c_methods = (
# bound built-in method
("abcd".index, ("c",)),
# unbound built-in method
(str.index, ("abcd", "c")),
# bound "slot" method
([1, 2, 3].__len__, ()),
# unbound "slot" method
(list.__len__, ([1, 2, 3],)),
# bound "coexist" method
({1, 2}.__contains__, (2,)),
# unbound "coexist" method
(set.__contains__, ({1, 2}, 2)),
# built-in class method
(dict.fromkeys, (("a", 1), ("b", 2))),
# built-in static method
(bytearray.maketrans, (b"abc", b"xyz")),
# subclass methods
(Subclass([1,2,2]).count, (2,)),
(Subclass.count, (Subclass([1,2,2]), 2)),
(Subclass.Nested("sweet").count, ("e",)),
(Subclass.Nested.count, (Subclass.Nested("sweet"), "e")),
)
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
for method, args in c_methods:
with self.subTest(proto=proto, method=method):
unpickled = self.loads(self.dumps(method, proto))
self.assertEqual(method(*args), unpickled(*args))
class BigmemPickleTests(unittest.TestCase): class BigmemPickleTests(unittest.TestCase):
...@@ -1252,6 +1458,7 @@ class BigmemPickleTests(unittest.TestCase): ...@@ -1252,6 +1458,7 @@ class BigmemPickleTests(unittest.TestCase):
data = 1 << (8 * size) data = 1 << (8 * size)
try: try:
for proto in protocols: for proto in protocols:
with self.subTest(proto=proto):
if proto < 2: if proto < 2:
continue continue
with self.assertRaises((ValueError, OverflowError)): with self.assertRaises((ValueError, OverflowError)):
...@@ -1268,12 +1475,13 @@ class BigmemPickleTests(unittest.TestCase): ...@@ -1268,12 +1475,13 @@ class BigmemPickleTests(unittest.TestCase):
data = b"abcd" * (size // 4) data = b"abcd" * (size // 4)
try: try:
for proto in protocols: for proto in protocols:
with self.subTest(proto=proto):
if proto < 3: if proto < 3:
continue continue
try: try:
pickled = self.dumps(data, protocol=proto) pickled = self.dumps(data, protocol=proto)
self.assertTrue(b"abcd" in pickled[:15]) self.assertTrue(b"abcd" in pickled[:19])
self.assertTrue(b"abcd" in pickled[-15:]) self.assertTrue(b"abcd" in pickled[-18:])
finally: finally:
pickled = None pickled = None
finally: finally:
...@@ -1284,6 +1492,7 @@ class BigmemPickleTests(unittest.TestCase): ...@@ -1284,6 +1492,7 @@ class BigmemPickleTests(unittest.TestCase):
data = b"a" * size data = b"a" * size
try: try:
for proto in protocols: for proto in protocols:
with self.subTest(proto=proto):
if proto < 3: if proto < 3:
continue continue
with self.assertRaises((ValueError, OverflowError)): with self.assertRaises((ValueError, OverflowError)):
...@@ -1299,27 +1508,38 @@ class BigmemPickleTests(unittest.TestCase): ...@@ -1299,27 +1508,38 @@ class BigmemPickleTests(unittest.TestCase):
data = "abcd" * (size // 4) data = "abcd" * (size // 4)
try: try:
for proto in protocols: for proto in protocols:
with self.subTest(proto=proto):
try: try:
pickled = self.dumps(data, protocol=proto) pickled = self.dumps(data, protocol=proto)
self.assertTrue(b"abcd" in pickled[:15]) self.assertTrue(b"abcd" in pickled[:19])
self.assertTrue(b"abcd" in pickled[-15:]) self.assertTrue(b"abcd" in pickled[-18:])
finally: finally:
pickled = None pickled = None
finally: finally:
data = None data = None
# BINUNICODE (protocols 1, 2 and 3) cannot carry more than # BINUNICODE (protocols 1, 2 and 3) cannot carry more than 2**32 - 1 bytes
# 2**32 - 1 bytes of utf-8 encoded unicode. # of utf-8 encoded unicode. BINUNICODE8 (protocol 4) supports these huge
# unicode strings however.
@bigmemtest(size=_4G, memuse=1 + ascii_char_size, dry_run=False) @bigmemtest(size=_4G, memuse=2 + ascii_char_size, dry_run=False)
def test_huge_str_64b(self, size): def test_huge_str_64b(self, size):
data = "a" * size data = "abcd" * (size // 4)
try: try:
for proto in protocols: for proto in protocols:
with self.subTest(proto=proto):
if proto == 0: if proto == 0:
continue continue
if proto < 4:
with self.assertRaises((ValueError, OverflowError)): with self.assertRaises((ValueError, OverflowError)):
self.dumps(data, protocol=proto) self.dumps(data, protocol=proto)
else:
try:
pickled = self.dumps(data, protocol=proto)
self.assertTrue(b"abcd" in pickled[:19])
self.assertTrue(b"abcd" in pickled[-18:])
finally:
pickled = None
finally: finally:
data = None data = None
...@@ -1363,8 +1583,8 @@ class REX_five(object): ...@@ -1363,8 +1583,8 @@ class REX_five(object):
return object.__reduce__(self) return object.__reduce__(self)
class REX_six(object): class REX_six(object):
"""This class is used to check the 4th argument (list iterator) of the reduce """This class is used to check the 4th argument (list iterator) of
protocol. the reduce protocol.
""" """
def __init__(self, items=None): def __init__(self, items=None):
self.items = items if items is not None else [] self.items = items if items is not None else []
...@@ -1376,8 +1596,8 @@ class REX_six(object): ...@@ -1376,8 +1596,8 @@ class REX_six(object):
return type(self), (), None, iter(self.items), None return type(self), (), None, iter(self.items), None
class REX_seven(object): class REX_seven(object):
"""This class is used to check the 5th argument (dict iterator) of the reduce """This class is used to check the 5th argument (dict iterator) of
protocol. the reduce protocol.
""" """
def __init__(self, table=None): def __init__(self, table=None):
self.table = table if table is not None else {} self.table = table if table is not None else {}
...@@ -1415,10 +1635,16 @@ class MyList(list): ...@@ -1415,10 +1635,16 @@ class MyList(list):
class MyDict(dict): class MyDict(dict):
sample = {"a": 1, "b": 2} sample = {"a": 1, "b": 2}
class MySet(set):
sample = {"a", "b"}
class MyFrozenSet(frozenset):
sample = frozenset({"a", "b"})
myclasses = [MyInt, MyFloat, myclasses = [MyInt, MyFloat,
MyComplex, MyComplex,
MyStr, MyUnicode, MyStr, MyUnicode,
MyTuple, MyList, MyDict] MyTuple, MyList, MyDict, MySet, MyFrozenSet]
class SlotList(MyList): class SlotList(MyList):
...@@ -1428,6 +1654,8 @@ class SimpleNewObj(object): ...@@ -1428,6 +1654,8 @@ class SimpleNewObj(object):
def __init__(self, a, b, c): def __init__(self, a, b, c):
# raise an error, to make sure this isn't called # raise an error, to make sure this isn't called
raise TypeError("SimpleNewObj.__init__() didn't expect to get called") raise TypeError("SimpleNewObj.__init__() didn't expect to get called")
def __eq__(self, other):
return self.__dict__ == other.__dict__
class BadGetattr: class BadGetattr:
def __getattr__(self, key): def __getattr__(self, key):
...@@ -1464,7 +1692,7 @@ class AbstractPickleModuleTests(unittest.TestCase): ...@@ -1464,7 +1692,7 @@ class AbstractPickleModuleTests(unittest.TestCase):
def test_highest_protocol(self): def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes. # Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(pickle.HIGHEST_PROTOCOL, 3) self.assertEqual(pickle.HIGHEST_PROTOCOL, 4)
def test_callapi(self): def test_callapi(self):
f = io.BytesIO() f = io.BytesIO()
...@@ -1645,6 +1873,7 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase): ...@@ -1645,6 +1873,7 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
def _check_multiple_unpicklings(self, ioclass): def _check_multiple_unpicklings(self, ioclass):
for proto in protocols: for proto in protocols:
with self.subTest(proto=proto):
data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len]
f = ioclass() f = ioclass()
pickler = self.pickler_class(f, protocol=proto) pickler = self.pickler_class(f, protocol=proto)
......
import builtins import builtins
import copyreg
import gc import gc
import itertools
import math
import pickle
import sys import sys
import types import types
import math
import unittest import unittest
import weakref import weakref
...@@ -3153,176 +3156,6 @@ order (MRO) for bases """ ...@@ -3153,176 +3156,6 @@ order (MRO) for bases """
self.assertEqual(e.a, 1) self.assertEqual(e.a, 1)
self.assertEqual(can_delete_dict(e), can_delete_dict(ValueError())) self.assertEqual(can_delete_dict(e), can_delete_dict(ValueError()))
def test_pickles(self):
# Testing pickling and copying new-style classes and objects...
import pickle
def sorteditems(d):
L = list(d.items())
L.sort()
return L
global C
class C(object):
def __init__(self, a, b):
super(C, self).__init__()
self.a = a
self.b = b
def __repr__(self):
return "C(%r, %r)" % (self.a, self.b)
global C1
class C1(list):
def __new__(cls, a, b):
return super(C1, cls).__new__(cls)
def __getnewargs__(self):
return (self.a, self.b)
def __init__(self, a, b):
self.a = a
self.b = b
def __repr__(self):
return "C1(%r, %r)<%r>" % (self.a, self.b, list(self))
global C2
class C2(int):
def __new__(cls, a, b, val=0):
return super(C2, cls).__new__(cls, val)
def __getnewargs__(self):
return (self.a, self.b, int(self))
def __init__(self, a, b, val=0):
self.a = a
self.b = b
def __repr__(self):
return "C2(%r, %r)<%r>" % (self.a, self.b, int(self))
global C3
class C3(object):
def __init__(self, foo):
self.foo = foo
def __getstate__(self):
return self.foo
def __setstate__(self, foo):
self.foo = foo
global C4classic, C4
class C4classic: # classic
pass
class C4(C4classic, object): # mixed inheritance
pass
for bin in 0, 1:
for cls in C, C1, C2:
s = pickle.dumps(cls, bin)
cls2 = pickle.loads(s)
self.assertIs(cls2, cls)
a = C1(1, 2); a.append(42); a.append(24)
b = C2("hello", "world", 42)
s = pickle.dumps((a, b), bin)
x, y = pickle.loads(s)
self.assertEqual(x.__class__, a.__class__)
self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__))
self.assertEqual(y.__class__, b.__class__)
self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__))
self.assertEqual(repr(x), repr(a))
self.assertEqual(repr(y), repr(b))
# Test for __getstate__ and __setstate__ on new style class
u = C3(42)
s = pickle.dumps(u, bin)
v = pickle.loads(s)
self.assertEqual(u.__class__, v.__class__)
self.assertEqual(u.foo, v.foo)
# Test for picklability of hybrid class
u = C4()
u.foo = 42
s = pickle.dumps(u, bin)
v = pickle.loads(s)
self.assertEqual(u.__class__, v.__class__)
self.assertEqual(u.foo, v.foo)
# Testing copy.deepcopy()
import copy
for cls in C, C1, C2:
cls2 = copy.deepcopy(cls)
self.assertIs(cls2, cls)
a = C1(1, 2); a.append(42); a.append(24)
b = C2("hello", "world", 42)
x, y = copy.deepcopy((a, b))
self.assertEqual(x.__class__, a.__class__)
self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__))
self.assertEqual(y.__class__, b.__class__)
self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__))
self.assertEqual(repr(x), repr(a))
self.assertEqual(repr(y), repr(b))
def test_pickle_slots(self):
# Testing pickling of classes with __slots__ ...
import pickle
# Pickling of classes with __slots__ but without __getstate__ should fail
# (if using protocol 0 or 1)
global B, C, D, E
class B(object):
pass
for base in [object, B]:
class C(base):
__slots__ = ['a']
class D(C):
pass
try:
pickle.dumps(C(), 0)
except TypeError:
pass
else:
self.fail("should fail: pickle C instance - %s" % base)
try:
pickle.dumps(C(), 0)
except TypeError:
pass
else:
self.fail("should fail: pickle D instance - %s" % base)
# Give C a nice generic __getstate__ and __setstate__
class C(base):
__slots__ = ['a']
def __getstate__(self):
try:
d = self.__dict__.copy()
except AttributeError:
d = {}
for cls in self.__class__.__mro__:
for sn in cls.__dict__.get('__slots__', ()):
try:
d[sn] = getattr(self, sn)
except AttributeError:
pass
return d
def __setstate__(self, d):
for k, v in list(d.items()):
setattr(self, k, v)
class D(C):
pass
# Now it should work
x = C()
y = pickle.loads(pickle.dumps(x))
self.assertNotHasAttr(y, 'a')
x.a = 42
y = pickle.loads(pickle.dumps(x))
self.assertEqual(y.a, 42)
x = D()
x.a = 42
x.b = 100
y = pickle.loads(pickle.dumps(x))
self.assertEqual(y.a + y.b, 142)
# A subclass that adds a slot should also work
class E(C):
__slots__ = ['b']
x = E()
x.a = 42
x.b = "foo"
y = pickle.loads(pickle.dumps(x))
self.assertEqual(y.a, x.a)
self.assertEqual(y.b, x.b)
def test_binary_operator_override(self): def test_binary_operator_override(self):
# Testing overrides of binary operations... # Testing overrides of binary operations...
class I(int): class I(int):
...@@ -4690,11 +4523,439 @@ class MiscTests(unittest.TestCase): ...@@ -4690,11 +4523,439 @@ class MiscTests(unittest.TestCase):
self.assertEqual(X.mykey2, 'from Base2') self.assertEqual(X.mykey2, 'from Base2')
class PicklingTests(unittest.TestCase):
def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None,
listitems=None, dictitems=None):
if proto >= 4:
reduce_value = obj.__reduce_ex__(proto)
self.assertEqual(reduce_value[:3],
(copyreg.__newobj_ex__,
(type(obj), args, kwargs),
state))
if listitems is not None:
self.assertListEqual(list(reduce_value[3]), listitems)
else:
self.assertIsNone(reduce_value[3])
if dictitems is not None:
self.assertDictEqual(dict(reduce_value[4]), dictitems)
else:
self.assertIsNone(reduce_value[4])
elif proto >= 2:
reduce_value = obj.__reduce_ex__(proto)
self.assertEqual(reduce_value[:3],
(copyreg.__newobj__,
(type(obj),) + args,
state))
if listitems is not None:
self.assertListEqual(list(reduce_value[3]), listitems)
else:
self.assertIsNone(reduce_value[3])
if dictitems is not None:
self.assertDictEqual(dict(reduce_value[4]), dictitems)
else:
self.assertIsNone(reduce_value[4])
else:
base_type = type(obj).__base__
reduce_value = (copyreg._reconstructor,
(type(obj),
base_type,
None if base_type is object else base_type(obj)))
if state is not None:
reduce_value += (state,)
self.assertEqual(obj.__reduce_ex__(proto), reduce_value)
self.assertEqual(obj.__reduce__(), reduce_value)
def test_reduce(self):
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
args = (-101, "spam")
kwargs = {'bacon': -201, 'fish': -301}
state = {'cheese': -401}
class C1:
def __getnewargs__(self):
return args
obj = C1()
for proto in protocols:
self._check_reduce(proto, obj, args)
for name, value in state.items():
setattr(obj, name, value)
for proto in protocols:
self._check_reduce(proto, obj, args, state=state)
class C2:
def __getnewargs__(self):
return "bad args"
obj = C2()
for proto in protocols:
if proto >= 2:
with self.assertRaises(TypeError):
obj.__reduce_ex__(proto)
class C3:
def __getnewargs_ex__(self):
return (args, kwargs)
obj = C3()
for proto in protocols:
if proto >= 4:
self._check_reduce(proto, obj, args, kwargs)
elif proto >= 2:
with self.assertRaises(ValueError):
obj.__reduce_ex__(proto)
class C4:
def __getnewargs_ex__(self):
return (args, "bad dict")
class C5:
def __getnewargs_ex__(self):
return ("bad tuple", kwargs)
class C6:
def __getnewargs_ex__(self):
return ()
class C7:
def __getnewargs_ex__(self):
return "bad args"
for proto in protocols:
for cls in C4, C5, C6, C7:
obj = cls()
if proto >= 2:
with self.assertRaises((TypeError, ValueError)):
obj.__reduce_ex__(proto)
class C8:
def __getnewargs_ex__(self):
return (args, kwargs)
obj = C8()
for proto in protocols:
if 2 <= proto < 4:
with self.assertRaises(ValueError):
obj.__reduce_ex__(proto)
class C9:
def __getnewargs_ex__(self):
return (args, {})
obj = C9()
for proto in protocols:
self._check_reduce(proto, obj, args)
class C10:
def __getnewargs_ex__(self):
raise IndexError
obj = C10()
for proto in protocols:
if proto >= 2:
with self.assertRaises(IndexError):
obj.__reduce_ex__(proto)
class C11:
def __getstate__(self):
return state
obj = C11()
for proto in protocols:
self._check_reduce(proto, obj, state=state)
class C12:
def __getstate__(self):
return "not dict"
obj = C12()
for proto in protocols:
self._check_reduce(proto, obj, state="not dict")
class C13:
def __getstate__(self):
raise IndexError
obj = C13()
for proto in protocols:
with self.assertRaises(IndexError):
obj.__reduce_ex__(proto)
if proto < 2:
with self.assertRaises(IndexError):
obj.__reduce__()
class C14:
__slots__ = tuple(state)
def __init__(self):
for name, value in state.items():
setattr(self, name, value)
obj = C14()
for proto in protocols:
if proto >= 2:
self._check_reduce(proto, obj, state=(None, state))
else:
with self.assertRaises(TypeError):
obj.__reduce_ex__(proto)
with self.assertRaises(TypeError):
obj.__reduce__()
class C15(dict):
pass
obj = C15({"quebec": -601})
for proto in protocols:
self._check_reduce(proto, obj, dictitems=dict(obj))
class C16(list):
pass
obj = C16(["yukon"])
for proto in protocols:
self._check_reduce(proto, obj, listitems=list(obj))
def _assert_is_copy(self, obj, objcopy, msg=None):
"""Utility method to verify if two objects are copies of each others.
"""
if msg is None:
msg = "{!r} is not a copy of {!r}".format(obj, objcopy)
if type(obj).__repr__ is object.__repr__:
# We have this limitation for now because we use the object's repr
# to help us verify that the two objects are copies. This allows
# us to delegate the non-generic verification logic to the objects
# themselves.
raise ValueError("object passed to _assert_is_copy must " +
"override the __repr__ method.")
self.assertIsNot(obj, objcopy, msg=msg)
self.assertIs(type(obj), type(objcopy), msg=msg)
if hasattr(obj, '__dict__'):
self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg)
self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg)
if hasattr(obj, '__slots__'):
self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg)
for slot in obj.__slots__:
self.assertEqual(
hasattr(obj, slot), hasattr(objcopy, slot), msg=msg)
self.assertEqual(getattr(obj, slot, None),
getattr(objcopy, slot, None), msg=msg)
self.assertEqual(repr(obj), repr(objcopy), msg=msg)
@staticmethod
def _generate_pickle_copiers():
"""Utility method to generate the many possible pickle configurations.
"""
class PickleCopier:
"This class copies object using pickle."
def __init__(self, proto, dumps, loads):
self.proto = proto
self.dumps = dumps
self.loads = loads
def copy(self, obj):
return self.loads(self.dumps(obj, self.proto))
def __repr__(self):
# We try to be as descriptive as possible here since this is
# the string which we will allow us to tell the pickle
# configuration we are using during debugging.
return ("PickleCopier(proto={}, dumps={}.{}, loads={}.{})"
.format(self.proto,
self.dumps.__module__, self.dumps.__qualname__,
self.loads.__module__, self.loads.__qualname__))
return (PickleCopier(*args) for args in
itertools.product(range(pickle.HIGHEST_PROTOCOL + 1),
{pickle.dumps, pickle._dumps},
{pickle.loads, pickle._loads}))
def test_pickle_slots(self):
# Tests pickling of classes with __slots__.
# Pickling of classes with __slots__ but without __getstate__ should
# fail (if using protocol 0 or 1)
global C
class C:
__slots__ = ['a']
with self.assertRaises(TypeError):
pickle.dumps(C(), 0)
global D
class D(C):
pass
with self.assertRaises(TypeError):
pickle.dumps(D(), 0)
class C:
"A class with __getstate__ and __setstate__ implemented."
__slots__ = ['a']
def __getstate__(self):
state = getattr(self, '__dict__', {}).copy()
for cls in type(self).__mro__:
for slot in cls.__dict__.get('__slots__', ()):
try:
state[slot] = getattr(self, slot)
except AttributeError:
pass
return state
def __setstate__(self, state):
for k, v in state.items():
setattr(self, k, v)
def __repr__(self):
return "%s()<%r>" % (type(self).__name__, self.__getstate__())
class D(C):
"A subclass of a class with slots."
pass
global E
class E(C):
"A subclass with an extra slot."
__slots__ = ['b']
# Now it should work
for pickle_copier in self._generate_pickle_copiers():
with self.subTest(pickle_copier=pickle_copier):
x = C()
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
x.a = 42
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
x = D()
x.a = 42
x.b = 100
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
x = E()
x.a = 42
x.b = "foo"
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
def test_reduce_copying(self):
# Tests pickling and copying new-style classes and objects.
global C1
class C1:
"The state of this class is copyable via its instance dict."
ARGS = (1, 2)
NEED_DICT_COPYING = True
def __init__(self, a, b):
super().__init__()
self.a = a
self.b = b
def __repr__(self):
return "C1(%r, %r)" % (self.a, self.b)
global C2
class C2(list):
"A list subclass copyable via __getnewargs__."
ARGS = (1, 2)
NEED_DICT_COPYING = False
def __new__(cls, a, b):
self = super().__new__(cls)
self.a = a
self.b = b
return self
def __init__(self, *args):
super().__init__()
# This helps testing that __init__ is not called during the
# unpickling process, which would cause extra appends.
self.append("cheese")
@classmethod
def __getnewargs__(cls):
return cls.ARGS
def __repr__(self):
return "C2(%r, %r)<%r>" % (self.a, self.b, list(self))
global C3
class C3(list):
"A list subclass copyable via __getstate__."
ARGS = (1, 2)
NEED_DICT_COPYING = False
def __init__(self, a, b):
self.a = a
self.b = b
# This helps testing that __init__ is not called during the
# unpickling process, which would cause extra appends.
self.append("cheese")
@classmethod
def __getstate__(cls):
return cls.ARGS
def __setstate__(self, state):
a, b = state
self.a = a
self.b = b
def __repr__(self):
return "C3(%r, %r)<%r>" % (self.a, self.b, list(self))
global C4
class C4(int):
"An int subclass copyable via __getnewargs__."
ARGS = ("hello", "world", 1)
NEED_DICT_COPYING = False
def __new__(cls, a, b, value):
self = super().__new__(cls, value)
self.a = a
self.b = b
return self
@classmethod
def __getnewargs__(cls):
return cls.ARGS
def __repr__(self):
return "C4(%r, %r)<%r>" % (self.a, self.b, int(self))
global C5
class C5(int):
"An int subclass copyable via __getnewargs_ex__."
ARGS = (1, 2)
KWARGS = {'value': 3}
NEED_DICT_COPYING = False
def __new__(cls, a, b, *, value=0):
self = super().__new__(cls, value)
self.a = a
self.b = b
return self
@classmethod
def __getnewargs_ex__(cls):
return (cls.ARGS, cls.KWARGS)
def __repr__(self):
return "C5(%r, %r)<%r>" % (self.a, self.b, int(self))
test_classes = (C1, C2, C3, C4, C5)
# Testing copying through pickle
pickle_copiers = self._generate_pickle_copiers()
for cls, pickle_copier in itertools.product(test_classes, pickle_copiers):
with self.subTest(cls=cls, pickle_copier=pickle_copier):
kwargs = getattr(cls, 'KWARGS', {})
obj = cls(*cls.ARGS, **kwargs)
proto = pickle_copier.proto
if 2 <= proto < 4 and hasattr(cls, '__getnewargs_ex__'):
with self.assertRaises(ValueError):
pickle_copier.dumps(obj, proto)
continue
objcopy = pickle_copier.copy(obj)
self._assert_is_copy(obj, objcopy)
# For test classes that supports this, make sure we didn't go
# around the reduce protocol by simply copying the attribute
# dictionary. We clear attributes using the previous copy to
# not mutate the original argument.
if proto >= 2 and not cls.NEED_DICT_COPYING:
objcopy.__dict__.clear()
objcopy2 = pickle_copier.copy(objcopy)
self._assert_is_copy(obj, objcopy2)
# Testing copying through copy.deepcopy()
for cls in test_classes:
with self.subTest(cls=cls):
kwargs = getattr(cls, 'KWARGS', {})
obj = cls(*cls.ARGS, **kwargs)
# XXX: We need to modify the copy module to support PEP 3154's
# reduce protocol 4.
if hasattr(cls, '__getnewargs_ex__'):
continue
objcopy = deepcopy(obj)
self._assert_is_copy(obj, objcopy)
# For test classes that supports this, make sure we didn't go
# around the reduce protocol by simply copying the attribute
# dictionary. We clear attributes using the previous copy to
# not mutate the original argument.
if not cls.NEED_DICT_COPYING:
objcopy.__dict__.clear()
objcopy2 = deepcopy(objcopy)
self._assert_is_copy(obj, objcopy2)
def test_main(): def test_main():
# Run all local test cases, with PTypesLongInitTest first. # Run all local test cases, with PTypesLongInitTest first.
support.run_unittest(PTypesLongInitTest, OperatorsTest, support.run_unittest(PTypesLongInitTest, OperatorsTest,
ClassPropertiesAndMethods, DictProxyTests, ClassPropertiesAndMethods, DictProxyTests,
MiscTests) MiscTests, PicklingTests)
if __name__ == "__main__": if __name__ == "__main__":
test_main() test_main()
...@@ -68,6 +68,8 @@ Core and Builtins ...@@ -68,6 +68,8 @@ Core and Builtins
Library Library
------- -------
- Issue #17810: Implement PEP 3154, pickle protocol 4.
- Issue #19668: Added support for the cp1125 encoding. - Issue #19668: Added support for the cp1125 encoding.
- Issue #19689: Add ssl.create_default_context() factory function. It creates - Issue #19689: Add ssl.create_default_context() factory function. It creates
......
...@@ -6,7 +6,7 @@ PyDoc_STRVAR(pickle_module_doc, ...@@ -6,7 +6,7 @@ PyDoc_STRVAR(pickle_module_doc,
/* Bump this when new opcodes are added to the pickle protocol. */ /* Bump this when new opcodes are added to the pickle protocol. */
enum { enum {
HIGHEST_PROTOCOL = 3, HIGHEST_PROTOCOL = 4,
DEFAULT_PROTOCOL = 3 DEFAULT_PROTOCOL = 3
}; };
...@@ -71,7 +71,19 @@ enum opcode { ...@@ -71,7 +71,19 @@ enum opcode {
/* Protocol 3 (Python 3.x) */ /* Protocol 3 (Python 3.x) */
BINBYTES = 'B', BINBYTES = 'B',
SHORT_BINBYTES = 'C' SHORT_BINBYTES = 'C',
/* Protocol 4 */
SHORT_BINUNICODE = '\x8c',
BINUNICODE8 = '\x8d',
BINBYTES8 = '\x8e',
EMPTY_SET = '\x8f',
ADDITEMS = '\x90',
FROZENSET = '\x91',
NEWOBJ_EX = '\x92',
STACK_GLOBAL = '\x93',
MEMOIZE = '\x94',
FRAME = '\x95'
}; };
/* These aren't opcodes -- they're ways to pickle bools before protocol 2 /* These aren't opcodes -- they're ways to pickle bools before protocol 2
...@@ -103,7 +115,11 @@ enum { ...@@ -103,7 +115,11 @@ enum {
MAX_WRITE_BUF_SIZE = 64 * 1024, MAX_WRITE_BUF_SIZE = 64 * 1024,
/* Prefetch size when unpickling (disabled on unpeekable streams) */ /* Prefetch size when unpickling (disabled on unpeekable streams) */
PREFETCH = 8192 * 16 PREFETCH = 8192 * 16,
FRAME_SIZE_TARGET = 64 * 1024,
FRAME_HEADER_SIZE = 9
}; };
/* Exception classes for pickle. These should override the ones defined in /* Exception classes for pickle. These should override the ones defined in
...@@ -136,9 +152,6 @@ static PyObject *empty_tuple = NULL; ...@@ -136,9 +152,6 @@ static PyObject *empty_tuple = NULL;
/* For looking up name pairs in copyreg._extension_registry. */ /* For looking up name pairs in copyreg._extension_registry. */
static PyObject *two_tuple = NULL; static PyObject *two_tuple = NULL;
_Py_IDENTIFIER(__name__);
_Py_IDENTIFIER(modules);
static int static int
stack_underflow(void) stack_underflow(void)
{ {
...@@ -332,6 +345,11 @@ typedef struct PicklerObject { ...@@ -332,6 +345,11 @@ typedef struct PicklerObject {
Py_ssize_t max_output_len; /* Allocation size of output_buffer. */ Py_ssize_t max_output_len; /* Allocation size of output_buffer. */
int proto; /* Pickle protocol number, >= 0 */ int proto; /* Pickle protocol number, >= 0 */
int bin; /* Boolean, true if proto > 0 */ int bin; /* Boolean, true if proto > 0 */
int framing; /* True when framing is enabled, proto >= 4 */
Py_ssize_t frame_start; /* Position in output_buffer where the
where the current frame begins. -1 if there
is no frame currently open. */
Py_ssize_t buf_size; /* Size of the current buffered pickle data */ Py_ssize_t buf_size; /* Size of the current buffered pickle data */
int fast; /* Enable fast mode if set to a true value. int fast; /* Enable fast mode if set to a true value.
The fast mode disable the usage of memo, The fast mode disable the usage of memo,
...@@ -352,7 +370,8 @@ typedef struct UnpicklerObject { ...@@ -352,7 +370,8 @@ typedef struct UnpicklerObject {
/* The unpickler memo is just an array of PyObject *s. Using a dict /* The unpickler memo is just an array of PyObject *s. Using a dict
is unnecessary, since the keys are contiguous ints. */ is unnecessary, since the keys are contiguous ints. */
PyObject **memo; PyObject **memo;
Py_ssize_t memo_size; Py_ssize_t memo_size; /* Capacity of the memo array */
Py_ssize_t memo_len; /* Number of objects in the memo */
PyObject *arg; PyObject *arg;
PyObject *pers_func; /* persistent_load() method, can be NULL. */ PyObject *pers_func; /* persistent_load() method, can be NULL. */
...@@ -362,7 +381,9 @@ typedef struct UnpicklerObject { ...@@ -362,7 +381,9 @@ typedef struct UnpicklerObject {
char *input_line; char *input_line;
Py_ssize_t input_len; Py_ssize_t input_len;
Py_ssize_t next_read_idx; Py_ssize_t next_read_idx;
Py_ssize_t frame_end_idx;
Py_ssize_t prefetched_idx; /* index of first prefetched byte */ Py_ssize_t prefetched_idx; /* index of first prefetched byte */
PyObject *read; /* read() method of the input stream. */ PyObject *read; /* read() method of the input stream. */
PyObject *readline; /* readline() method of the input stream. */ PyObject *readline; /* readline() method of the input stream. */
PyObject *peek; /* peek() method of the input stream, or NULL */ PyObject *peek; /* peek() method of the input stream, or NULL */
...@@ -380,6 +401,7 @@ typedef struct UnpicklerObject { ...@@ -380,6 +401,7 @@ typedef struct UnpicklerObject {
int proto; /* Protocol of the pickle loaded. */ int proto; /* Protocol of the pickle loaded. */
int fix_imports; /* Indicate whether Unpickler should fix int fix_imports; /* Indicate whether Unpickler should fix
the name of globals pickled by Python 2.x. */ the name of globals pickled by Python 2.x. */
int framing; /* True when framing is enabled, proto >= 4 */
} UnpicklerObject; } UnpicklerObject;
/* Forward declarations */ /* Forward declarations */
...@@ -673,6 +695,50 @@ _Pickler_ClearBuffer(PicklerObject *self) ...@@ -673,6 +695,50 @@ _Pickler_ClearBuffer(PicklerObject *self)
if (self->output_buffer == NULL) if (self->output_buffer == NULL)
return -1; return -1;
self->output_len = 0; self->output_len = 0;
self->frame_start = -1;
return 0;
}
static void
_Pickler_WriteFrameHeader(PicklerObject *self, char *qdata, size_t frame_len)
{
qdata[0] = (unsigned char)FRAME;
qdata[1] = (unsigned char)(frame_len & 0xff);
qdata[2] = (unsigned char)((frame_len >> 8) & 0xff);
qdata[3] = (unsigned char)((frame_len >> 16) & 0xff);
qdata[4] = (unsigned char)((frame_len >> 24) & 0xff);
qdata[5] = (unsigned char)((frame_len >> 32) & 0xff);
qdata[6] = (unsigned char)((frame_len >> 40) & 0xff);
qdata[7] = (unsigned char)((frame_len >> 48) & 0xff);
qdata[8] = (unsigned char)((frame_len >> 56) & 0xff);
}
static int
_Pickler_CommitFrame(PicklerObject *self)
{
size_t frame_len;
char *qdata;
if (!self->framing || self->frame_start == -1)
return 0;
frame_len = self->output_len - self->frame_start - FRAME_HEADER_SIZE;
qdata = PyBytes_AS_STRING(self->output_buffer) + self->frame_start;
_Pickler_WriteFrameHeader(self, qdata, frame_len);
self->frame_start = -1;
return 0;
}
static int
_Pickler_OpcodeBoundary(PicklerObject *self)
{
Py_ssize_t frame_len;
if (!self->framing || self->frame_start == -1)
return 0;
frame_len = self->output_len - self->frame_start - FRAME_HEADER_SIZE;
if (frame_len >= FRAME_SIZE_TARGET)
return _Pickler_CommitFrame(self);
else
return 0; return 0;
} }
...@@ -682,6 +748,10 @@ _Pickler_GetString(PicklerObject *self) ...@@ -682,6 +748,10 @@ _Pickler_GetString(PicklerObject *self)
PyObject *output_buffer = self->output_buffer; PyObject *output_buffer = self->output_buffer;
assert(self->output_buffer != NULL); assert(self->output_buffer != NULL);
if (_Pickler_CommitFrame(self))
return NULL;
self->output_buffer = NULL; self->output_buffer = NULL;
/* Resize down to exact size */ /* Resize down to exact size */
if (_PyBytes_Resize(&output_buffer, self->output_len) < 0) if (_PyBytes_Resize(&output_buffer, self->output_len) < 0)
...@@ -696,6 +766,7 @@ _Pickler_FlushToFile(PicklerObject *self) ...@@ -696,6 +766,7 @@ _Pickler_FlushToFile(PicklerObject *self)
assert(self->write != NULL); assert(self->write != NULL);
/* This will commit the frame first */
output = _Pickler_GetString(self); output = _Pickler_GetString(self);
if (output == NULL) if (output == NULL)
return -1; return -1;
...@@ -706,15 +777,21 @@ _Pickler_FlushToFile(PicklerObject *self) ...@@ -706,15 +777,21 @@ _Pickler_FlushToFile(PicklerObject *self)
} }
static Py_ssize_t static Py_ssize_t
_Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n) _Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t data_len)
{ {
Py_ssize_t i, required; Py_ssize_t i, n, required;
char *buffer; char *buffer;
int need_new_frame;
assert(s != NULL); assert(s != NULL);
need_new_frame = (self->framing && self->frame_start == -1);
if (need_new_frame)
n = data_len + FRAME_HEADER_SIZE;
else
n = data_len;
required = self->output_len + n; required = self->output_len + n;
if (required > self->max_output_len) {
if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) { if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) {
/* XXX This reallocates a new buffer every time, which is a bit /* XXX This reallocates a new buffer every time, which is a bit
wasteful. */ wasteful. */
...@@ -722,20 +799,41 @@ _Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n) ...@@ -722,20 +799,41 @@ _Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n)
return -1; return -1;
if (_Pickler_ClearBuffer(self) < 0) if (_Pickler_ClearBuffer(self) < 0)
return -1; return -1;
/* The previous frame was just committed by _Pickler_FlushToFile */
need_new_frame = self->framing;
if (need_new_frame)
n = data_len + FRAME_HEADER_SIZE;
else
n = data_len;
required = self->output_len + n;
} }
if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) { if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) {
/* we already flushed above, so the buffer is empty */ /* For large pickle chunks, we write directly to the output
PyObject *result; file instead of buffering. Note the buffer is empty at this
point (it was flushed above, since required >= n). */
PyObject *output, *result;
if (need_new_frame) {
char frame_header[FRAME_HEADER_SIZE];
_Pickler_WriteFrameHeader(self, frame_header, (size_t) data_len);
output = PyBytes_FromStringAndSize(frame_header, FRAME_HEADER_SIZE);
if (output == NULL)
return -1;
result = _Pickler_FastCall(self, self->write, output);
Py_XDECREF(result);
if (result == NULL)
return -1;
}
/* XXX we could spare an intermediate copy and pass /* XXX we could spare an intermediate copy and pass
a memoryview instead */ a memoryview instead */
PyObject *output = PyBytes_FromStringAndSize(s, n); output = PyBytes_FromStringAndSize(s, data_len);
if (s == NULL) if (output == NULL)
return -1; return -1;
result = _Pickler_FastCall(self, self->write, output); result = _Pickler_FastCall(self, self->write, output);
Py_XDECREF(result); Py_XDECREF(result);
return (result == NULL) ? -1 : 0; return (result == NULL) ? -1 : 0;
} }
else { if (required > self->max_output_len) {
/* Make place in buffer for the pickle chunk */
if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) { if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) {
PyErr_NoMemory(); PyErr_NoMemory();
return -1; return -1;
...@@ -744,19 +842,28 @@ _Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n) ...@@ -744,19 +842,28 @@ _Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n)
if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0) if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0)
return -1; return -1;
} }
}
buffer = PyBytes_AS_STRING(self->output_buffer); buffer = PyBytes_AS_STRING(self->output_buffer);
if (n < 8) { if (need_new_frame) {
/* Setup new frame */
Py_ssize_t frame_start = self->output_len;
self->frame_start = frame_start;
for (i = 0; i < FRAME_HEADER_SIZE; i++) {
/* Write an invalid value, for debugging */
buffer[frame_start + i] = 0xFE;
}
self->output_len += FRAME_HEADER_SIZE;
}
if (data_len < 8) {
/* This is faster than memcpy when the string is short. */ /* This is faster than memcpy when the string is short. */
for (i = 0; i < n; i++) { for (i = 0; i < data_len; i++) {
buffer[self->output_len + i] = s[i]; buffer[self->output_len + i] = s[i];
} }
} }
else { else {
memcpy(buffer + self->output_len, s, n); memcpy(buffer + self->output_len, s, data_len);
} }
self->output_len += n; self->output_len += data_len;
return n; return data_len;
} }
static PicklerObject * static PicklerObject *
...@@ -774,6 +881,8 @@ _Pickler_New(void) ...@@ -774,6 +881,8 @@ _Pickler_New(void)
self->write = NULL; self->write = NULL;
self->proto = 0; self->proto = 0;
self->bin = 0; self->bin = 0;
self->framing = 0;
self->frame_start = -1;
self->fast = 0; self->fast = 0;
self->fast_nesting = 0; self->fast_nesting = 0;
self->fix_imports = 0; self->fix_imports = 0;
...@@ -868,6 +977,7 @@ _Unpickler_SetStringInput(UnpicklerObject *self, PyObject *input) ...@@ -868,6 +977,7 @@ _Unpickler_SetStringInput(UnpicklerObject *self, PyObject *input)
self->input_buffer = self->buffer.buf; self->input_buffer = self->buffer.buf;
self->input_len = self->buffer.len; self->input_len = self->buffer.len;
self->next_read_idx = 0; self->next_read_idx = 0;
self->frame_end_idx = -1;
self->prefetched_idx = self->input_len; self->prefetched_idx = self->input_len;
return self->input_len; return self->input_len;
} }
...@@ -932,7 +1042,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n) ...@@ -932,7 +1042,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n)
return -1; return -1;
/* Prefetch some data without advancing the file pointer, if possible */ /* Prefetch some data without advancing the file pointer, if possible */
if (self->peek) { if (self->peek && !self->framing) {
PyObject *len, *prefetched; PyObject *len, *prefetched;
len = PyLong_FromSsize_t(PREFETCH); len = PyLong_FromSsize_t(PREFETCH);
if (len == NULL) { if (len == NULL) {
...@@ -980,7 +1090,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n) ...@@ -980,7 +1090,7 @@ _Unpickler_ReadFromFile(UnpicklerObject *self, Py_ssize_t n)
Returns -1 (with an exception set) on failure. On success, return the Returns -1 (with an exception set) on failure. On success, return the
number of chars read. */ number of chars read. */
static Py_ssize_t static Py_ssize_t
_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) _Unpickler_ReadUnframed(UnpicklerObject *self, char **s, Py_ssize_t n)
{ {
Py_ssize_t num_read; Py_ssize_t num_read;
...@@ -1005,6 +1115,67 @@ _Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) ...@@ -1005,6 +1115,67 @@ _Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n)
return n; return n;
} }
static Py_ssize_t
_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n)
{
if (self->framing &&
(self->frame_end_idx == -1 ||
self->frame_end_idx <= self->next_read_idx)) {
/* Need to read new frame */
char *dummy;
unsigned char *frame_start;
size_t frame_len;
if (_Unpickler_ReadUnframed(self, &dummy, FRAME_HEADER_SIZE) < 0)
return -1;
frame_start = (unsigned char *) dummy;
if (frame_start[0] != (unsigned char)FRAME) {
PyErr_Format(UnpicklingError,
"expected FRAME opcode, got 0x%x instead",
frame_start[0]);
return -1;
}
frame_len = (size_t) frame_start[1];
frame_len |= (size_t) frame_start[2] << 8;
frame_len |= (size_t) frame_start[3] << 16;
frame_len |= (size_t) frame_start[4] << 24;
#if SIZEOF_SIZE_T >= 8
frame_len |= (size_t) frame_start[5] << 32;
frame_len |= (size_t) frame_start[6] << 40;
frame_len |= (size_t) frame_start[7] << 48;
frame_len |= (size_t) frame_start[8] << 56;
#else
if (frame_start[5] || frame_start[6] ||
frame_start[7] || frame_start[8]) {
PyErr_Format(PyExc_OverflowError,
"Frame size too large for 32-bit build");
return -1;
}
#endif
if (frame_len > PY_SSIZE_T_MAX) {
PyErr_Format(UnpicklingError, "Invalid frame length");
return -1;
}
if (frame_len < n) {
PyErr_Format(UnpicklingError, "Bad framing");
return -1;
}
if (_Unpickler_ReadUnframed(self, &dummy /* unused */,
frame_len) < 0)
return -1;
/* Rewind to start of frame */
self->frame_end_idx = self->next_read_idx;
self->next_read_idx -= frame_len;
}
if (self->framing) {
/* Check for bad input */
if (n + self->next_read_idx > self->frame_end_idx) {
PyErr_Format(UnpicklingError, "Bad framing");
return -1;
}
}
return _Unpickler_ReadUnframed(self, s, n);
}
static Py_ssize_t static Py_ssize_t
_Unpickler_CopyLine(UnpicklerObject *self, char *line, Py_ssize_t len, _Unpickler_CopyLine(UnpicklerObject *self, char *line, Py_ssize_t len,
char **result) char **result)
...@@ -1102,7 +1273,12 @@ _Unpickler_MemoPut(UnpicklerObject *self, Py_ssize_t idx, PyObject *value) ...@@ -1102,7 +1273,12 @@ _Unpickler_MemoPut(UnpicklerObject *self, Py_ssize_t idx, PyObject *value)
Py_INCREF(value); Py_INCREF(value);
old_item = self->memo[idx]; old_item = self->memo[idx];
self->memo[idx] = value; self->memo[idx] = value;
Py_XDECREF(old_item); if (old_item != NULL) {
Py_DECREF(old_item);
}
else {
self->memo_len++;
}
return 0; return 0;
} }
...@@ -1150,6 +1326,7 @@ _Unpickler_New(void) ...@@ -1150,6 +1326,7 @@ _Unpickler_New(void)
self->input_line = NULL; self->input_line = NULL;
self->input_len = 0; self->input_len = 0;
self->next_read_idx = 0; self->next_read_idx = 0;
self->frame_end_idx = -1;
self->prefetched_idx = 0; self->prefetched_idx = 0;
self->read = NULL; self->read = NULL;
self->readline = NULL; self->readline = NULL;
...@@ -1160,9 +1337,11 @@ _Unpickler_New(void) ...@@ -1160,9 +1337,11 @@ _Unpickler_New(void)
self->num_marks = 0; self->num_marks = 0;
self->marks_size = 0; self->marks_size = 0;
self->proto = 0; self->proto = 0;
self->framing = 0;
self->fix_imports = 0; self->fix_imports = 0;
memset(&self->buffer, 0, sizeof(Py_buffer)); memset(&self->buffer, 0, sizeof(Py_buffer));
self->memo_size = 32; self->memo_size = 32;
self->memo_len = 0;
self->memo = _Unpickler_NewMemo(self->memo_size); self->memo = _Unpickler_NewMemo(self->memo_size);
self->stack = (Pdata *)Pdata_New(); self->stack = (Pdata *)Pdata_New();
...@@ -1277,36 +1456,44 @@ memo_get(PicklerObject *self, PyObject *key) ...@@ -1277,36 +1456,44 @@ memo_get(PicklerObject *self, PyObject *key)
static int static int
memo_put(PicklerObject *self, PyObject *obj) memo_put(PicklerObject *self, PyObject *obj)
{ {
Py_ssize_t x;
char pdata[30]; char pdata[30];
Py_ssize_t len; Py_ssize_t len;
int status = 0; Py_ssize_t idx;
const char memoize_op = MEMOIZE;
if (self->fast) if (self->fast)
return 0; return 0;
if (_Pickler_OpcodeBoundary(self))
return -1;
x = PyMemoTable_Size(self->memo); idx = PyMemoTable_Size(self->memo);
if (PyMemoTable_Set(self->memo, obj, x) < 0) if (PyMemoTable_Set(self->memo, obj, idx) < 0)
goto error; return -1;
if (!self->bin) { if (self->proto >= 4) {
if (_Pickler_Write(self, &memoize_op, 1) < 0)
return -1;
return 0;
}
else if (!self->bin) {
pdata[0] = PUT; pdata[0] = PUT;
PyOS_snprintf(pdata + 1, sizeof(pdata) - 1, PyOS_snprintf(pdata + 1, sizeof(pdata) - 1,
"%" PY_FORMAT_SIZE_T "d\n", x); "%" PY_FORMAT_SIZE_T "d\n", idx);
len = strlen(pdata); len = strlen(pdata);
} }
else { else {
if (x < 256) { if (idx < 256) {
pdata[0] = BINPUT; pdata[0] = BINPUT;
pdata[1] = (unsigned char)x; pdata[1] = (unsigned char)idx;
len = 2; len = 2;
} }
else if (x <= 0xffffffffL) { else if (idx <= 0xffffffffL) {
pdata[0] = LONG_BINPUT; pdata[0] = LONG_BINPUT;
pdata[1] = (unsigned char)(x & 0xff); pdata[1] = (unsigned char)(idx & 0xff);
pdata[2] = (unsigned char)((x >> 8) & 0xff); pdata[2] = (unsigned char)((idx >> 8) & 0xff);
pdata[3] = (unsigned char)((x >> 16) & 0xff); pdata[3] = (unsigned char)((idx >> 16) & 0xff);
pdata[4] = (unsigned char)((x >> 24) & 0xff); pdata[4] = (unsigned char)((idx >> 24) & 0xff);
len = 5; len = 5;
} }
else { /* unlikely */ else { /* unlikely */
...@@ -1315,57 +1502,94 @@ memo_put(PicklerObject *self, PyObject *obj) ...@@ -1315,57 +1502,94 @@ memo_put(PicklerObject *self, PyObject *obj)
return -1; return -1;
} }
} }
if (_Pickler_Write(self, pdata, len) < 0) if (_Pickler_Write(self, pdata, len) < 0)
goto error; return -1;
if (0) { return 0;
error: }
status = -1;
}
return status; static PyObject *
getattribute(PyObject *obj, PyObject *name, int allow_qualname) {
PyObject *dotted_path;
Py_ssize_t i;
_Py_static_string(PyId_dot, ".");
_Py_static_string(PyId_locals, "<locals>");
dotted_path = PyUnicode_Split(name, _PyUnicode_FromId(&PyId_dot), -1);
if (dotted_path == NULL) {
return NULL;
}
assert(Py_SIZE(dotted_path) >= 1);
if (!allow_qualname && Py_SIZE(dotted_path) > 1) {
PyErr_Format(PyExc_AttributeError,
"Can't get qualified attribute %R on %R;"
"use protocols >= 4 to enable support",
name, obj);
Py_DECREF(dotted_path);
return NULL;
}
Py_INCREF(obj);
for (i = 0; i < Py_SIZE(dotted_path); i++) {
PyObject *subpath = PyList_GET_ITEM(dotted_path, i);
PyObject *tmp;
PyObject *result = PyUnicode_RichCompare(
subpath, _PyUnicode_FromId(&PyId_locals), Py_EQ);
int is_equal = (result == Py_True);
assert(PyBool_Check(result));
Py_DECREF(result);
if (is_equal) {
PyErr_Format(PyExc_AttributeError,
"Can't get local attribute %R on %R", name, obj);
Py_DECREF(dotted_path);
Py_DECREF(obj);
return NULL;
}
tmp = PyObject_GetAttr(obj, subpath);
Py_DECREF(obj);
if (tmp == NULL) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyErr_Clear();
PyErr_Format(PyExc_AttributeError,
"Can't get attribute %R on %R", name, obj);
}
Py_DECREF(dotted_path);
return NULL;
}
obj = tmp;
}
Py_DECREF(dotted_path);
return obj;
} }
static PyObject * static PyObject *
whichmodule(PyObject *global, PyObject *global_name) whichmodule(PyObject *global, PyObject *global_name, int allow_qualname)
{ {
Py_ssize_t i, j;
static PyObject *module_str = NULL;
static PyObject *main_str = NULL;
PyObject *module_name; PyObject *module_name;
PyObject *modules_dict; PyObject *modules_dict;
PyObject *module; PyObject *module;
PyObject *obj; PyObject *obj;
Py_ssize_t i, j;
_Py_IDENTIFIER(__module__);
_Py_IDENTIFIER(modules);
_Py_IDENTIFIER(__main__);
if (module_str == NULL) { module_name = _PyObject_GetAttrId(global, &PyId___module__);
module_str = PyUnicode_InternFromString("__module__");
if (module_str == NULL) if (module_name == NULL) {
return NULL; if (!PyErr_ExceptionMatches(PyExc_AttributeError))
main_str = PyUnicode_InternFromString("__main__");
if (main_str == NULL)
return NULL; return NULL;
PyErr_Clear();
} }
else {
module_name = PyObject_GetAttr(global, module_str);
/* In some rare cases (e.g., bound methods of extension types), /* In some rare cases (e.g., bound methods of extension types),
__module__ can be None. If it is so, then search sys.modules __module__ can be None. If it is so, then search sys.modules for
for the module of global. */ the module of global. */
if (module_name == Py_None) { if (module_name != Py_None)
Py_DECREF(module_name);
goto search;
}
if (module_name) {
return module_name; return module_name;
Py_CLEAR(module_name);
} }
if (PyErr_ExceptionMatches(PyExc_AttributeError)) assert(module_name == NULL);
PyErr_Clear();
else
return NULL;
search:
modules_dict = _PySys_GetObjectId(&PyId_modules); modules_dict = _PySys_GetObjectId(&PyId_modules);
if (modules_dict == NULL) { if (modules_dict == NULL) {
PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules");
...@@ -1373,34 +1597,35 @@ whichmodule(PyObject *global, PyObject *global_name) ...@@ -1373,34 +1597,35 @@ whichmodule(PyObject *global, PyObject *global_name)
} }
i = 0; i = 0;
module_name = NULL;
while ((j = PyDict_Next(modules_dict, &i, &module_name, &module))) { while ((j = PyDict_Next(modules_dict, &i, &module_name, &module))) {
if (PyObject_RichCompareBool(module_name, main_str, Py_EQ) == 1) PyObject *result = PyUnicode_RichCompare(
module_name, _PyUnicode_FromId(&PyId___main__), Py_EQ);
int is_equal = (result == Py_True);
assert(PyBool_Check(result));
Py_DECREF(result);
if (is_equal)
continue;
if (module == Py_None)
continue; continue;
obj = PyObject_GetAttr(module, global_name); obj = getattribute(module, global_name, allow_qualname);
if (obj == NULL) { if (obj == NULL) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) if (!PyErr_ExceptionMatches(PyExc_AttributeError))
PyErr_Clear();
else
return NULL; return NULL;
PyErr_Clear();
continue; continue;
} }
if (obj != global) { if (obj == global) {
Py_DECREF(obj); Py_DECREF(obj);
continue; Py_INCREF(module_name);
return module_name;
} }
Py_DECREF(obj); Py_DECREF(obj);
break;
} }
/* If no module is found, use __main__. */ /* If no module is found, use __main__. */
if (!j) { module_name = _PyUnicode_FromId(&PyId___main__);
module_name = main_str;
}
Py_INCREF(module_name); Py_INCREF(module_name);
return module_name; return module_name;
} }
...@@ -1744,22 +1969,17 @@ save_bytes(PicklerObject *self, PyObject *obj) ...@@ -1744,22 +1969,17 @@ save_bytes(PicklerObject *self, PyObject *obj)
reduce_value = Py_BuildValue("(O())", (PyObject*)&PyBytes_Type); reduce_value = Py_BuildValue("(O())", (PyObject*)&PyBytes_Type);
} }
else { else {
static PyObject *latin1 = NULL;
PyObject *unicode_str = PyObject *unicode_str =
PyUnicode_DecodeLatin1(PyBytes_AS_STRING(obj), PyUnicode_DecodeLatin1(PyBytes_AS_STRING(obj),
PyBytes_GET_SIZE(obj), PyBytes_GET_SIZE(obj),
"strict"); "strict");
_Py_IDENTIFIER(latin1);
if (unicode_str == NULL) if (unicode_str == NULL)
return -1; return -1;
if (latin1 == NULL) {
latin1 = PyUnicode_InternFromString("latin1");
if (latin1 == NULL) {
Py_DECREF(unicode_str);
return -1;
}
}
reduce_value = Py_BuildValue("(O(OO))", reduce_value = Py_BuildValue("(O(OO))",
codecs_encode, unicode_str, latin1); codecs_encode, unicode_str,
_PyUnicode_FromId(&PyId_latin1));
Py_DECREF(unicode_str); Py_DECREF(unicode_str);
} }
...@@ -1773,14 +1993,14 @@ save_bytes(PicklerObject *self, PyObject *obj) ...@@ -1773,14 +1993,14 @@ save_bytes(PicklerObject *self, PyObject *obj)
} }
else { else {
Py_ssize_t size; Py_ssize_t size;
char header[5]; char header[9];
Py_ssize_t len; Py_ssize_t len;
size = PyBytes_GET_SIZE(obj); size = PyBytes_GET_SIZE(obj);
if (size < 0) if (size < 0)
return -1; return -1;
if (size < 256) { if (size <= 0xff) {
header[0] = SHORT_BINBYTES; header[0] = SHORT_BINBYTES;
header[1] = (unsigned char)size; header[1] = (unsigned char)size;
len = 2; len = 2;
...@@ -1793,6 +2013,14 @@ save_bytes(PicklerObject *self, PyObject *obj) ...@@ -1793,6 +2013,14 @@ save_bytes(PicklerObject *self, PyObject *obj)
header[4] = (unsigned char)((size >> 24) & 0xff); header[4] = (unsigned char)((size >> 24) & 0xff);
len = 5; len = 5;
} }
else if (self->proto >= 4) {
int i;
header[0] = BINBYTES8;
for (i = 0; i < 8; i++) {
header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff);
}
len = 8;
}
else { else {
PyErr_SetString(PyExc_OverflowError, PyErr_SetString(PyExc_OverflowError,
"cannot serialize a bytes object larger than 4 GiB"); "cannot serialize a bytes object larger than 4 GiB");
...@@ -1882,26 +2110,39 @@ done: ...@@ -1882,26 +2110,39 @@ done:
static int static int
write_utf8(PicklerObject *self, char *data, Py_ssize_t size) write_utf8(PicklerObject *self, char *data, Py_ssize_t size)
{ {
char pdata[5]; char header[9];
Py_ssize_t len;
if (size <= 0xff && self->proto >= 4) {
header[0] = SHORT_BINUNICODE;
header[1] = (unsigned char)(size & 0xff);
len = 2;
}
else if (size <= 0xffffffffUL) {
header[0] = BINUNICODE;
header[1] = (unsigned char)(size & 0xff);
header[2] = (unsigned char)((size >> 8) & 0xff);
header[3] = (unsigned char)((size >> 16) & 0xff);
header[4] = (unsigned char)((size >> 24) & 0xff);
len = 5;
}
else if (self->proto >= 4) {
int i;
#if SIZEOF_SIZE_T > 4 header[0] = BINUNICODE8;
if (size > 0xffffffffUL) { for (i = 0; i < 8; i++) {
/* string too large */ header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff);
}
len = 9;
}
else {
PyErr_SetString(PyExc_OverflowError, PyErr_SetString(PyExc_OverflowError,
"cannot serialize a string larger than 4GiB"); "cannot serialize a string larger than 4GiB");
return -1; return -1;
} }
#endif
pdata[0] = BINUNICODE;
pdata[1] = (unsigned char)(size & 0xff);
pdata[2] = (unsigned char)((size >> 8) & 0xff);
pdata[3] = (unsigned char)((size >> 16) & 0xff);
pdata[4] = (unsigned char)((size >> 24) & 0xff);
if (_Pickler_Write(self, pdata, sizeof(pdata)) < 0) if (_Pickler_Write(self, header, len) < 0)
return -1; return -1;
if (_Pickler_Write(self, data, size) < 0) if (_Pickler_Write(self, data, size) < 0)
return -1; return -1;
...@@ -2597,6 +2838,214 @@ save_dict(PicklerObject *self, PyObject *obj) ...@@ -2597,6 +2838,214 @@ save_dict(PicklerObject *self, PyObject *obj)
return status; return status;
} }
static int
save_set(PicklerObject *self, PyObject *obj)
{
PyObject *item;
int i;
Py_ssize_t set_size, ppos = 0;
Py_hash_t hash;
const char empty_set_op = EMPTY_SET;
const char mark_op = MARK;
const char additems_op = ADDITEMS;
if (self->proto < 4) {
PyObject *items;
PyObject *reduce_value;
int status;
items = PySequence_List(obj);
if (items == NULL) {
return -1;
}
reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PySet_Type, items);
Py_DECREF(items);
if (reduce_value == NULL) {
return -1;
}
/* save_reduce() will memoize the object automatically. */
status = save_reduce(self, reduce_value, obj);
Py_DECREF(reduce_value);
return status;
}
if (_Pickler_Write(self, &empty_set_op, 1) < 0)
return -1;
if (memo_put(self, obj) < 0)
return -1;
set_size = PySet_GET_SIZE(obj);
if (set_size == 0)
return 0; /* nothing to do */
/* Write in batches of BATCHSIZE. */
do {
i = 0;
if (_Pickler_Write(self, &mark_op, 1) < 0)
return -1;
while (_PySet_NextEntry(obj, &ppos, &item, &hash)) {
if (save(self, item, 0) < 0)
return -1;
if (++i == BATCHSIZE)
break;
}
if (_Pickler_Write(self, &additems_op, 1) < 0)
return -1;
if (PySet_GET_SIZE(obj) != set_size) {
PyErr_Format(
PyExc_RuntimeError,
"set changed size during iteration");
return -1;
}
} while (i == BATCHSIZE);
return 0;
}
static int
save_frozenset(PicklerObject *self, PyObject *obj)
{
PyObject *iter;
const char mark_op = MARK;
const char frozenset_op = FROZENSET;
if (self->fast && !fast_save_enter(self, obj))
return -1;
if (self->proto < 4) {
PyObject *items;
PyObject *reduce_value;
int status;
items = PySequence_List(obj);
if (items == NULL) {
return -1;
}
reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PyFrozenSet_Type,
items);
Py_DECREF(items);
if (reduce_value == NULL) {
return -1;
}
/* save_reduce() will memoize the object automatically. */
status = save_reduce(self, reduce_value, obj);
Py_DECREF(reduce_value);
return status;
}
if (_Pickler_Write(self, &mark_op, 1) < 0)
return -1;
iter = PyObject_GetIter(obj);
for (;;) {
PyObject *item;
item = PyIter_Next(iter);
if (item == NULL) {
if (PyErr_Occurred()) {
Py_DECREF(iter);
return -1;
}
break;
}
if (save(self, item, 0) < 0) {
Py_DECREF(item);
Py_DECREF(iter);
return -1;
}
Py_DECREF(item);
}
Py_DECREF(iter);
/* If the object is already in the memo, this means it is
recursive. In this case, throw away everything we put on the
stack, and fetch the object back from the memo. */
if (PyMemoTable_Get(self->memo, obj)) {
const char pop_mark_op = POP_MARK;
if (_Pickler_Write(self, &pop_mark_op, 1) < 0)
return -1;
if (memo_get(self, obj) < 0)
return -1;
return 0;
}
if (_Pickler_Write(self, &frozenset_op, 1) < 0)
return -1;
if (memo_put(self, obj) < 0)
return -1;
return 0;
}
static int
fix_imports(PyObject **module_name, PyObject **global_name)
{
PyObject *key;
PyObject *item;
key = PyTuple_Pack(2, *module_name, *global_name);
if (key == NULL)
return -1;
item = PyDict_GetItemWithError(name_mapping_3to2, key);
Py_DECREF(key);
if (item) {
PyObject *fixed_module_name;
PyObject *fixed_global_name;
if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) {
PyErr_Format(PyExc_RuntimeError,
"_compat_pickle.REVERSE_NAME_MAPPING values "
"should be 2-tuples, not %.200s",
Py_TYPE(item)->tp_name);
return -1;
}
fixed_module_name = PyTuple_GET_ITEM(item, 0);
fixed_global_name = PyTuple_GET_ITEM(item, 1);
if (!PyUnicode_Check(fixed_module_name) ||
!PyUnicode_Check(fixed_global_name)) {
PyErr_Format(PyExc_RuntimeError,
"_compat_pickle.REVERSE_NAME_MAPPING values "
"should be pairs of str, not (%.200s, %.200s)",
Py_TYPE(fixed_module_name)->tp_name,
Py_TYPE(fixed_global_name)->tp_name);
return -1;
}
Py_CLEAR(*module_name);
Py_CLEAR(*global_name);
Py_INCREF(fixed_module_name);
Py_INCREF(fixed_global_name);
*module_name = fixed_module_name;
*global_name = fixed_global_name;
}
else if (PyErr_Occurred()) {
return -1;
}
item = PyDict_GetItemWithError(import_mapping_3to2, *module_name);
if (item) {
if (!PyUnicode_Check(item)) {
PyErr_Format(PyExc_RuntimeError,
"_compat_pickle.REVERSE_IMPORT_MAPPING values "
"should be strings, not %.200s",
Py_TYPE(item)->tp_name);
return -1;
}
Py_CLEAR(*module_name);
Py_INCREF(item);
*module_name = item;
}
else if (PyErr_Occurred()) {
return -1;
}
return 0;
}
static int static int
save_global(PicklerObject *self, PyObject *obj, PyObject *name) save_global(PicklerObject *self, PyObject *obj, PyObject *name)
{ {
...@@ -2605,20 +3054,32 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) ...@@ -2605,20 +3054,32 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
PyObject *module = NULL; PyObject *module = NULL;
PyObject *cls; PyObject *cls;
int status = 0; int status = 0;
_Py_IDENTIFIER(__name__);
_Py_IDENTIFIER(__qualname__);
const char global_op = GLOBAL; const char global_op = GLOBAL;
if (name) { if (name) {
Py_INCREF(name);
global_name = name; global_name = name;
Py_INCREF(global_name);
} }
else { else {
if (self->proto >= 4) {
global_name = _PyObject_GetAttrId(obj, &PyId___qualname__);
if (global_name == NULL) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
goto error;
PyErr_Clear();
}
}
if (global_name == NULL) {
global_name = _PyObject_GetAttrId(obj, &PyId___name__); global_name = _PyObject_GetAttrId(obj, &PyId___name__);
if (global_name == NULL) if (global_name == NULL)
goto error; goto error;
} }
}
module_name = whichmodule(obj, global_name); module_name = whichmodule(obj, global_name, self->proto >= 4);
if (module_name == NULL) if (module_name == NULL)
goto error; goto error;
...@@ -2637,11 +3098,11 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) ...@@ -2637,11 +3098,11 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
obj, module_name); obj, module_name);
goto error; goto error;
} }
cls = PyObject_GetAttr(module, global_name); cls = getattribute(module, global_name, self->proto >= 4);
if (cls == NULL) { if (cls == NULL) {
PyErr_Format(PicklingError, PyErr_Format(PicklingError,
"Can't pickle %R: attribute lookup %S.%S failed", "Can't pickle %R: attribute lookup %S on %S failed",
obj, module_name, global_name); obj, global_name, module_name);
goto error; goto error;
} }
if (cls != obj) { if (cls != obj) {
...@@ -2714,92 +3175,53 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) ...@@ -2714,92 +3175,53 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
if (_Pickler_Write(self, pdata, n) < 0) if (_Pickler_Write(self, pdata, n) < 0)
goto error; goto error;
} }
else {
gen_global:
if (self->proto >= 4) {
const char stack_global_op = STACK_GLOBAL;
save(self, module_name, 0);
save(self, global_name, 0);
if (_Pickler_Write(self, &stack_global_op, 1) < 0)
goto error;
}
else { else {
/* Generate a normal global opcode if we are using a pickle /* Generate a normal global opcode if we are using a pickle
protocol <= 2, or if the object is not registered in the protocol < 4, or if the object is not registered in the
extension registry. */ extension registry. */
PyObject *encoded; PyObject *encoded;
PyObject *(*unicode_encoder)(PyObject *); PyObject *(*unicode_encoder)(PyObject *);
gen_global:
if (_Pickler_Write(self, &global_op, 1) < 0) if (_Pickler_Write(self, &global_op, 1) < 0)
goto error; goto error;
/* Since Python 3.0 now supports non-ASCII identifiers, we encode both /* For protocol < 3 and if the user didn't request against doing
the module name and the global name using UTF-8. We do so only when so, we convert module names to the old 2.x module names. */
we are using the pickle protocol newer than version 3. This is to if (self->proto < 3 && self->fix_imports) {
ensure compatibility with older Unpickler running on Python 2.x. */ if (fix_imports(&module_name, &global_name) < 0) {
if (self->proto >= 3) {
unicode_encoder = PyUnicode_AsUTF8String;
}
else {
unicode_encoder = PyUnicode_AsASCIIString;
}
/* For protocol < 3 and if the user didn't request against doing so,
we convert module names to the old 2.x module names. */
if (self->fix_imports) {
PyObject *key;
PyObject *item;
key = PyTuple_Pack(2, module_name, global_name);
if (key == NULL)
goto error;
item = PyDict_GetItemWithError(name_mapping_3to2, key);
Py_DECREF(key);
if (item) {
if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) {
PyErr_Format(PyExc_RuntimeError,
"_compat_pickle.REVERSE_NAME_MAPPING values "
"should be 2-tuples, not %.200s",
Py_TYPE(item)->tp_name);
goto error;
}
Py_CLEAR(module_name);
Py_CLEAR(global_name);
module_name = PyTuple_GET_ITEM(item, 0);
global_name = PyTuple_GET_ITEM(item, 1);
if (!PyUnicode_Check(module_name) ||
!PyUnicode_Check(global_name)) {
PyErr_Format(PyExc_RuntimeError,
"_compat_pickle.REVERSE_NAME_MAPPING values "
"should be pairs of str, not (%.200s, %.200s)",
Py_TYPE(module_name)->tp_name,
Py_TYPE(global_name)->tp_name);
goto error;
}
Py_INCREF(module_name);
Py_INCREF(global_name);
}
else if (PyErr_Occurred()) {
goto error;
}
item = PyDict_GetItemWithError(import_mapping_3to2, module_name);
if (item) {
if (!PyUnicode_Check(item)) {
PyErr_Format(PyExc_RuntimeError,
"_compat_pickle.REVERSE_IMPORT_MAPPING values "
"should be strings, not %.200s",
Py_TYPE(item)->tp_name);
goto error;
}
Py_CLEAR(module_name);
module_name = item;
Py_INCREF(module_name);
}
else if (PyErr_Occurred()) {
goto error; goto error;
} }
} }
/* Save the name of the module. */ /* Since Python 3.0 now supports non-ASCII identifiers, we encode
both the module name and the global name using UTF-8. We do so
only when we are using the pickle protocol newer than version
3. This is to ensure compatibility with older Unpickler running
on Python 2.x. */
if (self->proto == 3) {
unicode_encoder = PyUnicode_AsUTF8String;
}
else {
unicode_encoder = PyUnicode_AsASCIIString;
}
encoded = unicode_encoder(module_name); encoded = unicode_encoder(module_name);
if (encoded == NULL) { if (encoded == NULL) {
if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
PyErr_Format(PicklingError, PyErr_Format(PicklingError,
"can't pickle module identifier '%S' using " "can't pickle module identifier '%S' using "
"pickle protocol %i", module_name, self->proto); "pickle protocol %i",
module_name, self->proto);
goto error; goto error;
} }
if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
...@@ -2817,7 +3239,8 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) ...@@ -2817,7 +3239,8 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
PyErr_Format(PicklingError, PyErr_Format(PicklingError,
"can't pickle global identifier '%S' using " "can't pickle global identifier '%S' using "
"pickle protocol %i", global_name, self->proto); "pickle protocol %i",
global_name, self->proto);
goto error; goto error;
} }
if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
...@@ -2826,9 +3249,9 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name) ...@@ -2826,9 +3249,9 @@ save_global(PicklerObject *self, PyObject *obj, PyObject *name)
goto error; goto error;
} }
Py_DECREF(encoded); Py_DECREF(encoded);
if(_Pickler_Write(self, "\n", 1) < 0) if (_Pickler_Write(self, "\n", 1) < 0)
goto error; goto error;
}
/* Memoize the object. */ /* Memoize the object. */
if (memo_put(self, obj) < 0) if (memo_put(self, obj) < 0)
goto error; goto error;
...@@ -2927,14 +3350,9 @@ static PyObject * ...@@ -2927,14 +3350,9 @@ static PyObject *
get_class(PyObject *obj) get_class(PyObject *obj)
{ {
PyObject *cls; PyObject *cls;
static PyObject *str_class; _Py_IDENTIFIER(__class__);
if (str_class == NULL) { cls = _PyObject_GetAttrId(obj, &PyId___class__);
str_class = PyUnicode_InternFromString("__class__");
if (str_class == NULL)
return NULL;
}
cls = PyObject_GetAttr(obj, str_class);
if (cls == NULL) { if (cls == NULL) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) { if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyErr_Clear(); PyErr_Clear();
...@@ -2957,12 +3375,12 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) ...@@ -2957,12 +3375,12 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
PyObject *listitems = Py_None; PyObject *listitems = Py_None;
PyObject *dictitems = Py_None; PyObject *dictitems = Py_None;
Py_ssize_t size; Py_ssize_t size;
int use_newobj = 0, use_newobj_ex = 0;
int use_newobj = self->proto >= 2;
const char reduce_op = REDUCE; const char reduce_op = REDUCE;
const char build_op = BUILD; const char build_op = BUILD;
const char newobj_op = NEWOBJ; const char newobj_op = NEWOBJ;
const char newobj_ex_op = NEWOBJ_EX;
size = PyTuple_Size(args); size = PyTuple_Size(args);
if (size < 2 || size > 5) { if (size < 2 || size > 5) {
...@@ -3007,33 +3425,75 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) ...@@ -3007,33 +3425,75 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
return -1; return -1;
} }
/* Protocol 2 special case: if callable's name is __newobj__, use if (self->proto >= 2) {
NEWOBJ. */
if (use_newobj) {
static PyObject *newobj_str = NULL;
PyObject *name; PyObject *name;
_Py_IDENTIFIER(__name__);
if (newobj_str == NULL) {
newobj_str = PyUnicode_InternFromString("__newobj__");
if (newobj_str == NULL)
return -1;
}
name = _PyObject_GetAttrId(callable, &PyId___name__); name = _PyObject_GetAttrId(callable, &PyId___name__);
if (name == NULL) { if (name == NULL) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyErr_Clear();
else
return -1; return -1;
use_newobj = 0; }
PyErr_Clear();
}
else if (self->proto >= 4) {
_Py_IDENTIFIER(__newobj_ex__);
use_newobj_ex = PyUnicode_Check(name) &&
PyUnicode_Compare(
name, _PyUnicode_FromId(&PyId___newobj_ex__)) == 0;
Py_DECREF(name);
} }
else { else {
_Py_IDENTIFIER(__newobj__);
use_newobj = PyUnicode_Check(name) && use_newobj = PyUnicode_Check(name) &&
PyUnicode_Compare(name, newobj_str) == 0; PyUnicode_Compare(
name, _PyUnicode_FromId(&PyId___newobj__)) == 0;
Py_DECREF(name); Py_DECREF(name);
} }
} }
if (use_newobj) {
if (use_newobj_ex) {
PyObject *cls;
PyObject *args;
PyObject *kwargs;
if (Py_SIZE(argtup) != 3) {
PyErr_Format(PicklingError,
"length of the NEWOBJ_EX argument tuple must be "
"exactly 3, not %zd", Py_SIZE(argtup));
return -1;
}
cls = PyTuple_GET_ITEM(argtup, 0);
if (!PyType_Check(cls)) {
PyErr_Format(PicklingError,
"first item from NEWOBJ_EX argument tuple must "
"be a class, not %.200s", Py_TYPE(cls)->tp_name);
return -1;
}
args = PyTuple_GET_ITEM(argtup, 1);
if (!PyTuple_Check(args)) {
PyErr_Format(PicklingError,
"second item from NEWOBJ_EX argument tuple must "
"be a tuple, not %.200s", Py_TYPE(args)->tp_name);
return -1;
}
kwargs = PyTuple_GET_ITEM(argtup, 2);
if (!PyDict_Check(kwargs)) {
PyErr_Format(PicklingError,
"third item from NEWOBJ_EX argument tuple must "
"be a dict, not %.200s", Py_TYPE(kwargs)->tp_name);
return -1;
}
if (save(self, cls, 0) < 0 ||
save(self, args, 0) < 0 ||
save(self, kwargs, 0) < 0 ||
_Pickler_Write(self, &newobj_ex_op, 1) < 0) {
return -1;
}
}
else if (use_newobj) {
PyObject *cls; PyObject *cls;
PyObject *newargtup; PyObject *newargtup;
PyObject *obj_class; PyObject *obj_class;
...@@ -3117,8 +3577,23 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) ...@@ -3117,8 +3577,23 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
the caller do not want to memoize the object. Not particularly useful, the caller do not want to memoize the object. Not particularly useful,
but that is to mimic the behavior save_reduce() in pickle.py when but that is to mimic the behavior save_reduce() in pickle.py when
obj is None. */ obj is None. */
if (obj && memo_put(self, obj) < 0) if (obj != NULL) {
/* If the object is already in the memo, this means it is
recursive. In this case, throw away everything we put on the
stack, and fetch the object back from the memo. */
if (PyMemoTable_Get(self->memo, obj)) {
const char pop_op = POP;
if (_Pickler_Write(self, &pop_op, 1) < 0)
return -1;
if (memo_get(self, obj) < 0)
return -1;
return 0;
}
else if (memo_put(self, obj) < 0)
return -1; return -1;
}
if (listitems && batch_list(self, listitems) < 0) if (listitems && batch_list(self, listitems) < 0)
return -1; return -1;
...@@ -3135,6 +3610,34 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj) ...@@ -3135,6 +3610,34 @@ save_reduce(PicklerObject *self, PyObject *args, PyObject *obj)
return 0; return 0;
} }
static int
save_method(PicklerObject *self, PyObject *obj)
{
PyObject *method_self = PyCFunction_GET_SELF(obj);
if (method_self == NULL || PyModule_Check(method_self)) {
return save_global(self, obj, NULL);
}
else {
PyObject *builtins;
PyObject *getattr;
PyObject *reduce_value;
int status = -1;
_Py_IDENTIFIER(getattr);
builtins = PyEval_GetBuiltins();
getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
reduce_value = \
Py_BuildValue("O(Os)", getattr, method_self,
((PyCFunctionObject *)obj)->m_ml->ml_name);
if (reduce_value != NULL) {
status = save_reduce(self, reduce_value, obj);
Py_DECREF(reduce_value);
}
return status;
}
}
static int static int
save(PicklerObject *self, PyObject *obj, int pers_save) save(PicklerObject *self, PyObject *obj, int pers_save)
{ {
...@@ -3213,6 +3716,14 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3213,6 +3716,14 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
status = save_dict(self, obj); status = save_dict(self, obj);
goto done; goto done;
} }
else if (type == &PySet_Type) {
status = save_set(self, obj);
goto done;
}
else if (type == &PyFrozenSet_Type) {
status = save_frozenset(self, obj);
goto done;
}
else if (type == &PyList_Type) { else if (type == &PyList_Type) {
status = save_list(self, obj); status = save_list(self, obj);
goto done; goto done;
...@@ -3236,7 +3747,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3236,7 +3747,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
} }
} }
else if (type == &PyCFunction_Type) { else if (type == &PyCFunction_Type) {
status = save_global(self, obj, NULL); status = save_method(self, obj);
goto done; goto done;
} }
...@@ -3269,18 +3780,9 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3269,18 +3780,9 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
goto done; goto done;
} }
else { else {
static PyObject *reduce_str = NULL; _Py_IDENTIFIER(__reduce__);
static PyObject *reduce_ex_str = NULL; _Py_IDENTIFIER(__reduce_ex__);
/* Cache the name of the reduce methods. */
if (reduce_str == NULL) {
reduce_str = PyUnicode_InternFromString("__reduce__");
if (reduce_str == NULL)
goto error;
reduce_ex_str = PyUnicode_InternFromString("__reduce_ex__");
if (reduce_ex_str == NULL)
goto error;
}
/* XXX: If the __reduce__ method is defined, __reduce_ex__ is /* XXX: If the __reduce__ method is defined, __reduce_ex__ is
automatically defined as __reduce__. While this is convenient, this automatically defined as __reduce__. While this is convenient, this
...@@ -3291,7 +3793,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3291,7 +3793,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
don't actually have to check for a __reduce__ method. */ don't actually have to check for a __reduce__ method. */
/* Check for a __reduce_ex__ method. */ /* Check for a __reduce_ex__ method. */
reduce_func = PyObject_GetAttr(obj, reduce_ex_str); reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce_ex__);
if (reduce_func != NULL) { if (reduce_func != NULL) {
PyObject *proto; PyObject *proto;
proto = PyLong_FromLong(self->proto); proto = PyLong_FromLong(self->proto);
...@@ -3305,7 +3807,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3305,7 +3807,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
else else
goto error; goto error;
/* Check for a __reduce__ method. */ /* Check for a __reduce__ method. */
reduce_func = PyObject_GetAttr(obj, reduce_str); reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce__);
if (reduce_func != NULL) { if (reduce_func != NULL) {
reduce_value = PyObject_Call(reduce_func, empty_tuple, NULL); reduce_value = PyObject_Call(reduce_func, empty_tuple, NULL);
} }
...@@ -3338,6 +3840,8 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3338,6 +3840,8 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
status = -1; status = -1;
} }
done: done:
if (status == 0)
status = _Pickler_OpcodeBoundary(self);
Py_LeaveRecursiveCall(); Py_LeaveRecursiveCall();
Py_XDECREF(reduce_func); Py_XDECREF(reduce_func);
Py_XDECREF(reduce_value); Py_XDECREF(reduce_value);
...@@ -3358,6 +3862,8 @@ dump(PicklerObject *self, PyObject *obj) ...@@ -3358,6 +3862,8 @@ dump(PicklerObject *self, PyObject *obj)
header[1] = (unsigned char)self->proto; header[1] = (unsigned char)self->proto;
if (_Pickler_Write(self, header, 2) < 0) if (_Pickler_Write(self, header, 2) < 0)
return -1; return -1;
if (self->proto >= 4)
self->framing = 1;
} }
if (save(self, obj, 0) < 0 || if (save(self, obj, 0) < 0 ||
...@@ -3478,9 +3984,9 @@ PyDoc_STRVAR(Pickler_doc, ...@@ -3478,9 +3984,9 @@ PyDoc_STRVAR(Pickler_doc,
"This takes a binary file for writing a pickle data stream.\n" "This takes a binary file for writing a pickle data stream.\n"
"\n" "\n"
"The optional protocol argument tells the pickler to use the\n" "The optional protocol argument tells the pickler to use the\n"
"given protocol; supported protocols are 0, 1, 2, 3. The default\n" "given protocol; supported protocols are 0, 1, 2, 3 and 4. The\n"
"protocol is 3; a backward-incompatible protocol designed for\n" "default protocol is 3; a backward-incompatible protocol designed for\n"
"Python 3.0.\n" "Python 3.\n"
"\n" "\n"
"Specifying a negative protocol version selects the highest\n" "Specifying a negative protocol version selects the highest\n"
"protocol version supported. The higher the protocol used, the\n" "protocol version supported. The higher the protocol used, the\n"
...@@ -3493,8 +3999,8 @@ PyDoc_STRVAR(Pickler_doc, ...@@ -3493,8 +3999,8 @@ PyDoc_STRVAR(Pickler_doc,
"meets this interface.\n" "meets this interface.\n"
"\n" "\n"
"If fix_imports is True and protocol is less than 3, pickle will try to\n" "If fix_imports is True and protocol is less than 3, pickle will try to\n"
"map the new Python 3.x names to the old module names used in Python\n" "map the new Python 3 names to the old module names used in Python 2,\n"
"2.x, so that the pickle data stream is readable with Python 2.x.\n"); "so that the pickle data stream is readable with Python 2.\n");
static int static int
Pickler_init(PicklerObject *self, PyObject *args, PyObject *kwds) Pickler_init(PicklerObject *self, PyObject *args, PyObject *kwds)
...@@ -3987,17 +4493,15 @@ load_bool(UnpicklerObject *self, PyObject *boolean) ...@@ -3987,17 +4493,15 @@ load_bool(UnpicklerObject *self, PyObject *boolean)
* as a C Py_ssize_t, or -1 if it's higher than PY_SSIZE_T_MAX. * as a C Py_ssize_t, or -1 if it's higher than PY_SSIZE_T_MAX.
*/ */
static Py_ssize_t static Py_ssize_t
calc_binsize(char *bytes, int size) calc_binsize(char *bytes, int nbytes)
{ {
unsigned char *s = (unsigned char *)bytes; unsigned char *s = (unsigned char *)bytes;
int i;
size_t x = 0; size_t x = 0;
assert(size == 4); for (i = 0; i < nbytes; i++) {
x |= (size_t) s[i] << (8 * i);
x = (size_t) s[0]; }
x |= (size_t) s[1] << 8;
x |= (size_t) s[2] << 16;
x |= (size_t) s[3] << 24;
if (x > PY_SSIZE_T_MAX) if (x > PY_SSIZE_T_MAX)
return -1; return -1;
...@@ -4011,21 +4515,21 @@ calc_binsize(char *bytes, int size) ...@@ -4011,21 +4515,21 @@ calc_binsize(char *bytes, int size)
* of x-platform bugs. * of x-platform bugs.
*/ */
static long static long
calc_binint(char *bytes, int size) calc_binint(char *bytes, int nbytes)
{ {
unsigned char *s = (unsigned char *)bytes; unsigned char *s = (unsigned char *)bytes;
int i = size; int i;
long x = 0; long x = 0;
for (i = 0; i < size; i++) { for (i = 0; i < nbytes; i++) {
x |= (long)s[i] << (i * 8); x |= (long)s[i] << (8 * i);
} }
/* Unlike BININT1 and BININT2, BININT (more accurately BININT4) /* Unlike BININT1 and BININT2, BININT (more accurately BININT4)
* is signed, so on a box with longs bigger than 4 bytes we need * is signed, so on a box with longs bigger than 4 bytes we need
* to extend a BININT's sign bit to the full width. * to extend a BININT's sign bit to the full width.
*/ */
if (SIZEOF_LONG > 4 && size == 4) { if (SIZEOF_LONG > 4 && nbytes == 4) {
x |= -(x & (1L << 31)); x |= -(x & (1L << 31));
} }
...@@ -4233,49 +4737,27 @@ load_string(UnpicklerObject *self) ...@@ -4233,49 +4737,27 @@ load_string(UnpicklerObject *self)
} }
static int static int
load_binbytes(UnpicklerObject *self) load_counted_binbytes(UnpicklerObject *self, int nbytes)
{ {
PyObject *bytes; PyObject *bytes;
Py_ssize_t x; Py_ssize_t size;
char *s; char *s;
if (_Unpickler_Read(self, &s, 4) < 0) if (_Unpickler_Read(self, &s, nbytes) < 0)
return -1; return -1;
x = calc_binsize(s, 4); size = calc_binsize(s, nbytes);
if (x < 0) { if (size < 0) {
PyErr_Format(PyExc_OverflowError, PyErr_Format(PyExc_OverflowError,
"BINBYTES exceeds system's maximum size of %zd bytes", "BINBYTES exceeds system's maximum size of %zd bytes",
PY_SSIZE_T_MAX); PY_SSIZE_T_MAX);
return -1; return -1;
} }
if (_Unpickler_Read(self, &s, x) < 0) if (_Unpickler_Read(self, &s, size) < 0)
return -1;
bytes = PyBytes_FromStringAndSize(s, x);
if (bytes == NULL)
return -1;
PDATA_PUSH(self->stack, bytes, -1);
return 0;
}
static int
load_short_binbytes(UnpicklerObject *self)
{
PyObject *bytes;
Py_ssize_t x;
char *s;
if (_Unpickler_Read(self, &s, 1) < 0)
return -1;
x = (unsigned char)s[0];
if (_Unpickler_Read(self, &s, x) < 0)
return -1; return -1;
bytes = PyBytes_FromStringAndSize(s, x); bytes = PyBytes_FromStringAndSize(s, size);
if (bytes == NULL) if (bytes == NULL)
return -1; return -1;
...@@ -4284,51 +4766,27 @@ load_short_binbytes(UnpicklerObject *self) ...@@ -4284,51 +4766,27 @@ load_short_binbytes(UnpicklerObject *self)
} }
static int static int
load_binstring(UnpicklerObject *self) load_counted_binstring(UnpicklerObject *self, int nbytes)
{ {
PyObject *str; PyObject *str;
Py_ssize_t x; Py_ssize_t size;
char *s; char *s;
if (_Unpickler_Read(self, &s, 4) < 0) if (_Unpickler_Read(self, &s, nbytes) < 0)
return -1; return -1;
x = calc_binint(s, 4); size = calc_binsize(s, nbytes);
if (x < 0) { if (size < 0) {
PyErr_SetString(UnpicklingError, PyErr_Format(UnpicklingError,
"BINSTRING pickle has negative byte count"); "BINSTRING exceeds system's maximum size of %zd bytes",
PY_SSIZE_T_MAX);
return -1; return -1;
} }
if (_Unpickler_Read(self, &s, x) < 0) if (_Unpickler_Read(self, &s, size) < 0)
return -1;
/* Convert Python 2.x strings to unicode. */
str = PyUnicode_Decode(s, x, self->encoding, self->errors);
if (str == NULL)
return -1;
PDATA_PUSH(self->stack, str, -1);
return 0;
}
static int
load_short_binstring(UnpicklerObject *self)
{
PyObject *str;
Py_ssize_t x;
char *s;
if (_Unpickler_Read(self, &s, 1) < 0)
return -1;
x = (unsigned char)s[0];
if (_Unpickler_Read(self, &s, x) < 0)
return -1; return -1;
/* Convert Python 2.x strings to unicode. */ /* Convert Python 2.x strings to unicode. */
str = PyUnicode_Decode(s, x, self->encoding, self->errors); str = PyUnicode_Decode(s, size, self->encoding, self->errors);
if (str == NULL) if (str == NULL)
return -1; return -1;
...@@ -4357,16 +4815,16 @@ load_unicode(UnpicklerObject *self) ...@@ -4357,16 +4815,16 @@ load_unicode(UnpicklerObject *self)
} }
static int static int
load_binunicode(UnpicklerObject *self) load_counted_binunicode(UnpicklerObject *self, int nbytes)
{ {
PyObject *str; PyObject *str;
Py_ssize_t size; Py_ssize_t size;
char *s; char *s;
if (_Unpickler_Read(self, &s, 4) < 0) if (_Unpickler_Read(self, &s, nbytes) < 0)
return -1; return -1;
size = calc_binsize(s, 4); size = calc_binsize(s, nbytes);
if (size < 0) { if (size < 0) {
PyErr_Format(PyExc_OverflowError, PyErr_Format(PyExc_OverflowError,
"BINUNICODE exceeds system's maximum size of %zd bytes", "BINUNICODE exceeds system's maximum size of %zd bytes",
...@@ -4374,7 +4832,6 @@ load_binunicode(UnpicklerObject *self) ...@@ -4374,7 +4832,6 @@ load_binunicode(UnpicklerObject *self)
return -1; return -1;
} }
if (_Unpickler_Read(self, &s, size) < 0) if (_Unpickler_Read(self, &s, size) < 0)
return -1; return -1;
...@@ -4445,6 +4902,17 @@ load_empty_dict(UnpicklerObject *self) ...@@ -4445,6 +4902,17 @@ load_empty_dict(UnpicklerObject *self)
return 0; return 0;
} }
static int
load_empty_set(UnpicklerObject *self)
{
PyObject *set;
if ((set = PySet_New(NULL)) == NULL)
return -1;
PDATA_PUSH(self->stack, set, -1);
return 0;
}
static int static int
load_list(UnpicklerObject *self) load_list(UnpicklerObject *self)
{ {
...@@ -4487,6 +4955,29 @@ load_dict(UnpicklerObject *self) ...@@ -4487,6 +4955,29 @@ load_dict(UnpicklerObject *self)
return 0; return 0;
} }
static int
load_frozenset(UnpicklerObject *self)
{
PyObject *items;
PyObject *frozenset;
Py_ssize_t i;
if ((i = marker(self)) < 0)
return -1;
items = Pdata_poptuple(self->stack, i);
if (items == NULL)
return -1;
frozenset = PyFrozenSet_New(items);
Py_DECREF(items);
if (frozenset == NULL)
return -1;
PDATA_PUSH(self->stack, frozenset, -1);
return 0;
}
static PyObject * static PyObject *
instantiate(PyObject *cls, PyObject *args) instantiate(PyObject *cls, PyObject *args)
{ {
...@@ -4637,6 +5128,57 @@ load_newobj(UnpicklerObject *self) ...@@ -4637,6 +5128,57 @@ load_newobj(UnpicklerObject *self)
return -1; return -1;
} }
static int
load_newobj_ex(UnpicklerObject *self)
{
PyObject *cls, *args, *kwargs;
PyObject *obj;
PDATA_POP(self->stack, kwargs);
if (kwargs == NULL) {
return -1;
}
PDATA_POP(self->stack, args);
if (args == NULL) {
Py_DECREF(kwargs);
return -1;
}
PDATA_POP(self->stack, cls);
if (cls == NULL) {
Py_DECREF(kwargs);
Py_DECREF(args);
return -1;
}
if (!PyType_Check(cls)) {
Py_DECREF(kwargs);
Py_DECREF(args);
Py_DECREF(cls);
PyErr_Format(UnpicklingError,
"NEWOBJ_EX class argument must be a type, not %.200s",
Py_TYPE(cls)->tp_name);
return -1;
}
if (((PyTypeObject *)cls)->tp_new == NULL) {
Py_DECREF(kwargs);
Py_DECREF(args);
Py_DECREF(cls);
PyErr_SetString(UnpicklingError,
"NEWOBJ_EX class argument doesn't have __new__");
return -1;
}
obj = ((PyTypeObject *)cls)->tp_new((PyTypeObject *)cls, args, kwargs);
Py_DECREF(kwargs);
Py_DECREF(args);
Py_DECREF(cls);
if (obj == NULL) {
return -1;
}
PDATA_PUSH(self->stack, obj, -1);
return 0;
}
static int static int
load_global(UnpicklerObject *self) load_global(UnpicklerObject *self)
{ {
...@@ -4673,6 +5215,31 @@ load_global(UnpicklerObject *self) ...@@ -4673,6 +5215,31 @@ load_global(UnpicklerObject *self)
return 0; return 0;
} }
static int
load_stack_global(UnpicklerObject *self)
{
PyObject *global;
PyObject *module_name;
PyObject *global_name;
PDATA_POP(self->stack, global_name);
PDATA_POP(self->stack, module_name);
if (module_name == NULL || !PyUnicode_CheckExact(module_name) ||
global_name == NULL || !PyUnicode_CheckExact(global_name)) {
PyErr_SetString(UnpicklingError, "STACK_GLOBAL requires str");
Py_XDECREF(global_name);
Py_XDECREF(module_name);
return -1;
}
global = find_class(self, module_name, global_name);
Py_DECREF(global_name);
Py_DECREF(module_name);
if (global == NULL)
return -1;
PDATA_PUSH(self->stack, global, -1);
return 0;
}
static int static int
load_persid(UnpicklerObject *self) load_persid(UnpicklerObject *self)
{ {
...@@ -5016,6 +5583,18 @@ load_long_binput(UnpicklerObject *self) ...@@ -5016,6 +5583,18 @@ load_long_binput(UnpicklerObject *self)
return _Unpickler_MemoPut(self, idx, value); return _Unpickler_MemoPut(self, idx, value);
} }
static int
load_memoize(UnpicklerObject *self)
{
PyObject *value;
if (Py_SIZE(self->stack) <= 0)
return stack_underflow();
value = self->stack->data[Py_SIZE(self->stack) - 1];
return _Unpickler_MemoPut(self, self->memo_len, value);
}
static int static int
do_append(UnpicklerObject *self, Py_ssize_t x) do_append(UnpicklerObject *self, Py_ssize_t x)
{ {
...@@ -5131,6 +5710,59 @@ load_setitems(UnpicklerObject *self) ...@@ -5131,6 +5710,59 @@ load_setitems(UnpicklerObject *self)
return do_setitems(self, marker(self)); return do_setitems(self, marker(self));
} }
static int
load_additems(UnpicklerObject *self)
{
PyObject *set;
Py_ssize_t mark, len, i;
mark = marker(self);
len = Py_SIZE(self->stack);
if (mark > len || mark <= 0)
return stack_underflow();
if (len == mark) /* nothing to do */
return 0;
set = self->stack->data[mark - 1];
if (PySet_Check(set)) {
PyObject *items;
int status;
items = Pdata_poptuple(self->stack, mark);
if (items == NULL)
return -1;
status = _PySet_Update(set, items);
Py_DECREF(items);
return status;
}
else {
PyObject *add_func;
_Py_IDENTIFIER(add);
add_func = _PyObject_GetAttrId(set, &PyId_add);
if (add_func == NULL)
return -1;
for (i = mark; i < len; i++) {
PyObject *result;
PyObject *item;
item = self->stack->data[i];
result = _Unpickler_FastCall(self, add_func, item);
if (result == NULL) {
Pdata_clear(self->stack, i + 1);
Py_SIZE(self->stack) = mark;
return -1;
}
Py_DECREF(result);
}
Py_SIZE(self->stack) = mark;
}
return 0;
}
static int static int
load_build(UnpicklerObject *self) load_build(UnpicklerObject *self)
{ {
...@@ -5325,6 +5957,7 @@ load_proto(UnpicklerObject *self) ...@@ -5325,6 +5957,7 @@ load_proto(UnpicklerObject *self)
i = (unsigned char)s[0]; i = (unsigned char)s[0];
if (i <= HIGHEST_PROTOCOL) { if (i <= HIGHEST_PROTOCOL) {
self->proto = i; self->proto = i;
self->framing = (self->proto >= 4);
return 0; return 0;
} }
...@@ -5340,6 +5973,8 @@ load(UnpicklerObject *self) ...@@ -5340,6 +5973,8 @@ load(UnpicklerObject *self)
char *s; char *s;
self->num_marks = 0; self->num_marks = 0;
self->proto = 0;
self->framing = 0;
if (Py_SIZE(self->stack)) if (Py_SIZE(self->stack))
Pdata_clear(self->stack, 0); Pdata_clear(self->stack, 0);
...@@ -5365,13 +6000,16 @@ load(UnpicklerObject *self) ...@@ -5365,13 +6000,16 @@ load(UnpicklerObject *self)
OP_ARG(LONG4, load_counted_long, 4) OP_ARG(LONG4, load_counted_long, 4)
OP(FLOAT, load_float) OP(FLOAT, load_float)
OP(BINFLOAT, load_binfloat) OP(BINFLOAT, load_binfloat)
OP(BINBYTES, load_binbytes) OP_ARG(SHORT_BINBYTES, load_counted_binbytes, 1)
OP(SHORT_BINBYTES, load_short_binbytes) OP_ARG(BINBYTES, load_counted_binbytes, 4)
OP(BINSTRING, load_binstring) OP_ARG(BINBYTES8, load_counted_binbytes, 8)
OP(SHORT_BINSTRING, load_short_binstring) OP_ARG(SHORT_BINSTRING, load_counted_binstring, 1)
OP_ARG(BINSTRING, load_counted_binstring, 4)
OP(STRING, load_string) OP(STRING, load_string)
OP(UNICODE, load_unicode) OP(UNICODE, load_unicode)
OP(BINUNICODE, load_binunicode) OP_ARG(SHORT_BINUNICODE, load_counted_binunicode, 1)
OP_ARG(BINUNICODE, load_counted_binunicode, 4)
OP_ARG(BINUNICODE8, load_counted_binunicode, 8)
OP_ARG(EMPTY_TUPLE, load_counted_tuple, 0) OP_ARG(EMPTY_TUPLE, load_counted_tuple, 0)
OP_ARG(TUPLE1, load_counted_tuple, 1) OP_ARG(TUPLE1, load_counted_tuple, 1)
OP_ARG(TUPLE2, load_counted_tuple, 2) OP_ARG(TUPLE2, load_counted_tuple, 2)
...@@ -5381,10 +6019,15 @@ load(UnpicklerObject *self) ...@@ -5381,10 +6019,15 @@ load(UnpicklerObject *self)
OP(LIST, load_list) OP(LIST, load_list)
OP(EMPTY_DICT, load_empty_dict) OP(EMPTY_DICT, load_empty_dict)
OP(DICT, load_dict) OP(DICT, load_dict)
OP(EMPTY_SET, load_empty_set)
OP(ADDITEMS, load_additems)
OP(FROZENSET, load_frozenset)
OP(OBJ, load_obj) OP(OBJ, load_obj)
OP(INST, load_inst) OP(INST, load_inst)
OP(NEWOBJ, load_newobj) OP(NEWOBJ, load_newobj)
OP(NEWOBJ_EX, load_newobj_ex)
OP(GLOBAL, load_global) OP(GLOBAL, load_global)
OP(STACK_GLOBAL, load_stack_global)
OP(APPEND, load_append) OP(APPEND, load_append)
OP(APPENDS, load_appends) OP(APPENDS, load_appends)
OP(BUILD, load_build) OP(BUILD, load_build)
...@@ -5396,6 +6039,7 @@ load(UnpicklerObject *self) ...@@ -5396,6 +6039,7 @@ load(UnpicklerObject *self)
OP(BINPUT, load_binput) OP(BINPUT, load_binput)
OP(LONG_BINPUT, load_long_binput) OP(LONG_BINPUT, load_long_binput)
OP(PUT, load_put) OP(PUT, load_put)
OP(MEMOIZE, load_memoize)
OP(POP, load_pop) OP(POP, load_pop)
OP(POP_MARK, load_pop_mark) OP(POP_MARK, load_pop_mark)
OP(SETITEM, load_setitem) OP(SETITEM, load_setitem)
...@@ -5485,6 +6129,7 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args) ...@@ -5485,6 +6129,7 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args)
PyObject *modules_dict; PyObject *modules_dict;
PyObject *module; PyObject *module;
PyObject *module_name, *global_name; PyObject *module_name, *global_name;
_Py_IDENTIFIER(modules);
if (!PyArg_UnpackTuple(args, "find_class", 2, 2, if (!PyArg_UnpackTuple(args, "find_class", 2, 2,
&module_name, &global_name)) &module_name, &global_name))
...@@ -5556,11 +6201,11 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args) ...@@ -5556,11 +6201,11 @@ Unpickler_find_class(UnpicklerObject *self, PyObject *args)
module = PyImport_Import(module_name); module = PyImport_Import(module_name);
if (module == NULL) if (module == NULL)
return NULL; return NULL;
global = PyObject_GetAttr(module, global_name); global = getattribute(module, global_name, self->proto >= 4);
Py_DECREF(module); Py_DECREF(module);
} }
else { else {
global = PyObject_GetAttr(module, global_name); global = getattribute(module, global_name, self->proto >= 4);
} }
return global; return global;
} }
...@@ -5723,6 +6368,7 @@ Unpickler_init(UnpicklerObject *self, PyObject *args, PyObject *kwds) ...@@ -5723,6 +6368,7 @@ Unpickler_init(UnpicklerObject *self, PyObject *args, PyObject *kwds)
self->arg = NULL; self->arg = NULL;
self->proto = 0; self->proto = 0;
self->framing = 0;
return 0; return 0;
} }
......
...@@ -69,6 +69,30 @@ PyMethod_New(PyObject *func, PyObject *self) ...@@ -69,6 +69,30 @@ PyMethod_New(PyObject *func, PyObject *self)
return (PyObject *)im; return (PyObject *)im;
} }
static PyObject *
method_reduce(PyMethodObject *im)
{
PyObject *self = PyMethod_GET_SELF(im);
PyObject *func = PyMethod_GET_FUNCTION(im);
PyObject *builtins;
PyObject *getattr;
PyObject *funcname;
_Py_IDENTIFIER(getattr);
funcname = _PyObject_GetAttrId(func, &PyId___name__);
if (funcname == NULL) {
return NULL;
}
builtins = PyEval_GetBuiltins();
getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
return Py_BuildValue("O(ON)", getattr, self, funcname);
}
static PyMethodDef method_methods[] = {
{"__reduce__", (PyCFunction)method_reduce, METH_NOARGS, NULL},
{NULL, NULL}
};
/* Descriptors for PyMethod attributes */ /* Descriptors for PyMethod attributes */
/* im_func and im_self are stored in the PyMethod object */ /* im_func and im_self are stored in the PyMethod object */
...@@ -367,7 +391,7 @@ PyTypeObject PyMethod_Type = { ...@@ -367,7 +391,7 @@ PyTypeObject PyMethod_Type = {
offsetof(PyMethodObject, im_weakreflist), /* tp_weaklistoffset */ offsetof(PyMethodObject, im_weakreflist), /* tp_weaklistoffset */
0, /* tp_iter */ 0, /* tp_iter */
0, /* tp_iternext */ 0, /* tp_iternext */
0, /* tp_methods */ method_methods, /* tp_methods */
method_memberlist, /* tp_members */ method_memberlist, /* tp_members */
method_getset, /* tp_getset */ method_getset, /* tp_getset */
0, /* tp_base */ 0, /* tp_base */
......
...@@ -398,6 +398,24 @@ descr_get_qualname(PyDescrObject *descr) ...@@ -398,6 +398,24 @@ descr_get_qualname(PyDescrObject *descr)
return descr->d_qualname; return descr->d_qualname;
} }
static PyObject *
descr_reduce(PyDescrObject *descr)
{
PyObject *builtins;
PyObject *getattr;
_Py_IDENTIFIER(getattr);
builtins = PyEval_GetBuiltins();
getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
return Py_BuildValue("O(OO)", getattr, PyDescr_TYPE(descr),
PyDescr_NAME(descr));
}
static PyMethodDef descr_methods[] = {
{"__reduce__", (PyCFunction)descr_reduce, METH_NOARGS, NULL},
{NULL, NULL}
};
static PyMemberDef descr_members[] = { static PyMemberDef descr_members[] = {
{"__objclass__", T_OBJECT, offsetof(PyDescrObject, d_type), READONLY}, {"__objclass__", T_OBJECT, offsetof(PyDescrObject, d_type), READONLY},
{"__name__", T_OBJECT, offsetof(PyDescrObject, d_name), READONLY}, {"__name__", T_OBJECT, offsetof(PyDescrObject, d_name), READONLY},
...@@ -494,7 +512,7 @@ PyTypeObject PyMethodDescr_Type = { ...@@ -494,7 +512,7 @@ PyTypeObject PyMethodDescr_Type = {
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
0, /* tp_iter */ 0, /* tp_iter */
0, /* tp_iternext */ 0, /* tp_iternext */
0, /* tp_methods */ descr_methods, /* tp_methods */
descr_members, /* tp_members */ descr_members, /* tp_members */
method_getset, /* tp_getset */ method_getset, /* tp_getset */
0, /* tp_base */ 0, /* tp_base */
...@@ -532,7 +550,7 @@ PyTypeObject PyClassMethodDescr_Type = { ...@@ -532,7 +550,7 @@ PyTypeObject PyClassMethodDescr_Type = {
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
0, /* tp_iter */ 0, /* tp_iter */
0, /* tp_iternext */ 0, /* tp_iternext */
0, /* tp_methods */ descr_methods, /* tp_methods */
descr_members, /* tp_members */ descr_members, /* tp_members */
method_getset, /* tp_getset */ method_getset, /* tp_getset */
0, /* tp_base */ 0, /* tp_base */
...@@ -569,7 +587,7 @@ PyTypeObject PyMemberDescr_Type = { ...@@ -569,7 +587,7 @@ PyTypeObject PyMemberDescr_Type = {
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
0, /* tp_iter */ 0, /* tp_iter */
0, /* tp_iternext */ 0, /* tp_iternext */
0, /* tp_methods */ descr_methods, /* tp_methods */
descr_members, /* tp_members */ descr_members, /* tp_members */
member_getset, /* tp_getset */ member_getset, /* tp_getset */
0, /* tp_base */ 0, /* tp_base */
...@@ -643,7 +661,7 @@ PyTypeObject PyWrapperDescr_Type = { ...@@ -643,7 +661,7 @@ PyTypeObject PyWrapperDescr_Type = {
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
0, /* tp_iter */ 0, /* tp_iter */
0, /* tp_iternext */ 0, /* tp_iternext */
0, /* tp_methods */ descr_methods, /* tp_methods */
descr_members, /* tp_members */ descr_members, /* tp_members */
wrapperdescr_getset, /* tp_getset */ wrapperdescr_getset, /* tp_getset */
0, /* tp_base */ 0, /* tp_base */
...@@ -1085,6 +1103,23 @@ wrapper_repr(wrapperobject *wp) ...@@ -1085,6 +1103,23 @@ wrapper_repr(wrapperobject *wp)
wp->self); wp->self);
} }
static PyObject *
wrapper_reduce(wrapperobject *wp)
{
PyObject *builtins;
PyObject *getattr;
_Py_IDENTIFIER(getattr);
builtins = PyEval_GetBuiltins();
getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
return Py_BuildValue("O(OO)", getattr, wp->self, PyDescr_NAME(wp->descr));
}
static PyMethodDef wrapper_methods[] = {
{"__reduce__", (PyCFunction)wrapper_reduce, METH_NOARGS, NULL},
{NULL, NULL}
};
static PyMemberDef wrapper_members[] = { static PyMemberDef wrapper_members[] = {
{"__self__", T_OBJECT, offsetof(wrapperobject, self), READONLY}, {"__self__", T_OBJECT, offsetof(wrapperobject, self), READONLY},
{0} {0}
...@@ -1193,7 +1228,7 @@ PyTypeObject _PyMethodWrapper_Type = { ...@@ -1193,7 +1228,7 @@ PyTypeObject _PyMethodWrapper_Type = {
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
0, /* tp_iter */ 0, /* tp_iter */
0, /* tp_iternext */ 0, /* tp_iternext */
0, /* tp_methods */ wrapper_methods, /* tp_methods */
wrapper_members, /* tp_members */ wrapper_members, /* tp_members */
wrapper_getsets, /* tp_getset */ wrapper_getsets, /* tp_getset */
0, /* tp_base */ 0, /* tp_base */
......
...@@ -3405,150 +3405,429 @@ import_copyreg(void) ...@@ -3405,150 +3405,429 @@ import_copyreg(void)
return cached_copyreg_module; return cached_copyreg_module;
} }
static PyObject * Py_LOCAL(PyObject *)
slotnames(PyObject *cls) _PyType_GetSlotNames(PyTypeObject *cls)
{ {
PyObject *clsdict;
PyObject *copyreg; PyObject *copyreg;
PyObject *slotnames; PyObject *slotnames;
_Py_IDENTIFIER(__slotnames__); _Py_IDENTIFIER(__slotnames__);
_Py_IDENTIFIER(_slotnames); _Py_IDENTIFIER(_slotnames);
clsdict = ((PyTypeObject *)cls)->tp_dict; assert(PyType_Check(cls));
slotnames = _PyDict_GetItemId(clsdict, &PyId___slotnames__);
if (slotnames != NULL && PyList_Check(slotnames)) { /* Get the slot names from the cache in the class if possible. */
slotnames = _PyDict_GetItemIdWithError(cls->tp_dict, &PyId___slotnames__);
if (slotnames != NULL) {
if (slotnames != Py_None && !PyList_Check(slotnames)) {
PyErr_Format(PyExc_TypeError,
"%.200s.__slotnames__ should be a list or None, "
"not %.200s",
cls->tp_name, Py_TYPE(slotnames)->tp_name);
return NULL;
}
Py_INCREF(slotnames); Py_INCREF(slotnames);
return slotnames; return slotnames;
} }
else {
if (PyErr_Occurred()) {
return NULL;
}
/* The class does not have the slot names cached yet. */
}
copyreg = import_copyreg(); copyreg = import_copyreg();
if (copyreg == NULL) if (copyreg == NULL)
return NULL; return NULL;
slotnames = _PyObject_CallMethodId(copyreg, &PyId__slotnames, "O", cls); /* Use _slotnames function from the copyreg module to find the slots
by this class and its bases. This function will cache the result
in __slotnames__. */
slotnames = _PyObject_CallMethodIdObjArgs(copyreg, &PyId__slotnames,
cls, NULL);
Py_DECREF(copyreg); Py_DECREF(copyreg);
if (slotnames != NULL && if (slotnames == NULL)
slotnames != Py_None && return NULL;
!PyList_Check(slotnames))
{ if (slotnames != Py_None && !PyList_Check(slotnames)) {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"copyreg._slotnames didn't return a list or None"); "copyreg._slotnames didn't return a list or None");
Py_DECREF(slotnames); Py_DECREF(slotnames);
slotnames = NULL; return NULL;
} }
return slotnames; return slotnames;
} }
static PyObject * Py_LOCAL(PyObject *)
reduce_2(PyObject *obj) _PyObject_GetState(PyObject *obj)
{ {
PyObject *cls, *getnewargs; PyObject *state;
PyObject *args = NULL, *args2 = NULL; PyObject *getstate;
PyObject *getstate = NULL, *state = NULL, *names = NULL;
PyObject *slots = NULL, *listitems = NULL, *dictitems = NULL;
PyObject *copyreg = NULL, *newobj = NULL, *res = NULL;
Py_ssize_t i, n;
_Py_IDENTIFIER(__getnewargs__);
_Py_IDENTIFIER(__getstate__); _Py_IDENTIFIER(__getstate__);
_Py_IDENTIFIER(__newobj__);
cls = (PyObject *) Py_TYPE(obj); getstate = _PyObject_GetAttrId(obj, &PyId___getstate__);
if (getstate == NULL) {
PyObject *slotnames;
getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__); if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
if (getnewargs != NULL) { return NULL;
args = PyObject_CallObject(getnewargs, NULL);
Py_DECREF(getnewargs);
if (args != NULL && !PyTuple_Check(args)) {
PyErr_Format(PyExc_TypeError,
"__getnewargs__ should return a tuple, "
"not '%.200s'", Py_TYPE(args)->tp_name);
goto end;
}
} }
else {
PyErr_Clear(); PyErr_Clear();
args = PyTuple_New(0);
}
if (args == NULL)
goto end;
getstate = _PyObject_GetAttrId(obj, &PyId___getstate__); {
if (getstate != NULL) {
state = PyObject_CallObject(getstate, NULL);
Py_DECREF(getstate);
if (state == NULL)
goto end;
}
else {
PyObject **dict; PyObject **dict;
PyErr_Clear();
dict = _PyObject_GetDictPtr(obj); dict = _PyObject_GetDictPtr(obj);
if (dict && *dict) /* It is possible that the object's dict is not initialized
yet. In this case, we will return None for the state.
We also return None if the dict is empty to make the behavior
consistent regardless whether the dict was initialized or not.
This make unit testing easier. */
if (dict != NULL && *dict != NULL && PyDict_Size(*dict) > 0) {
state = *dict; state = *dict;
else }
else {
state = Py_None; state = Py_None;
}
Py_INCREF(state); Py_INCREF(state);
names = slotnames(cls); }
if (names == NULL)
goto end; slotnames = _PyType_GetSlotNames(Py_TYPE(obj));
if (names != Py_None && PyList_GET_SIZE(names) > 0) { if (slotnames == NULL) {
assert(PyList_Check(names)); Py_DECREF(state);
return NULL;
}
assert(slotnames == Py_None || PyList_Check(slotnames));
if (slotnames != Py_None && Py_SIZE(slotnames) > 0) {
PyObject *slots;
Py_ssize_t slotnames_size, i;
slots = PyDict_New(); slots = PyDict_New();
if (slots == NULL) if (slots == NULL) {
goto end; Py_DECREF(slotnames);
n = 0; Py_DECREF(state);
/* Can't pre-compute the list size; the list return NULL;
is stored on the class so accessible to other }
threads, which may be run by DECREF */
for (i = 0; i < PyList_GET_SIZE(names); i++) { slotnames_size = Py_SIZE(slotnames);
for (i = 0; i < slotnames_size; i++) {
PyObject *name, *value; PyObject *name, *value;
name = PyList_GET_ITEM(names, i);
name = PyList_GET_ITEM(slotnames, i);
value = PyObject_GetAttr(obj, name); value = PyObject_GetAttr(obj, name);
if (value == NULL) if (value == NULL) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
goto error;
}
/* It is not an error if the attribute is not present. */
PyErr_Clear(); PyErr_Clear();
}
else { else {
int err = PyDict_SetItem(slots, name, int err = PyDict_SetItem(slots, name, value);
value);
Py_DECREF(value); Py_DECREF(value);
if (err) if (err) {
goto end; goto error;
n++; }
} }
/* The list is stored on the class so it may mutates while we
iterate over it */
if (slotnames_size != Py_SIZE(slotnames)) {
PyErr_Format(PyExc_RuntimeError,
"__slotsname__ changed size during iteration");
goto error;
}
/* We handle errors within the loop here. */
if (0) {
error:
Py_DECREF(slotnames);
Py_DECREF(slots);
Py_DECREF(state);
return NULL;
}
}
/* If we found some slot attributes, pack them in a tuple along
the orginal attribute dictionary. */
if (PyDict_Size(slots) > 0) {
PyObject *state2;
state2 = PyTuple_Pack(2, state, slots);
Py_DECREF(state);
if (state2 == NULL) {
Py_DECREF(slotnames);
Py_DECREF(slots);
return NULL;
} }
if (n) { state = state2;
state = Py_BuildValue("(NO)", state, slots); }
Py_DECREF(slots);
}
Py_DECREF(slotnames);
}
else { /* getstate != NULL */
state = PyObject_CallObject(getstate, NULL);
Py_DECREF(getstate);
if (state == NULL) if (state == NULL)
goto end; return NULL;
}
return state;
}
Py_LOCAL(int)
_PyObject_GetNewArguments(PyObject *obj, PyObject **args, PyObject **kwargs)
{
PyObject *getnewargs, *getnewargs_ex;
_Py_IDENTIFIER(__getnewargs_ex__);
_Py_IDENTIFIER(__getnewargs__);
if (args == NULL || kwargs == NULL) {
PyErr_BadInternalCall();
return -1;
} }
/* We first attempt to fetch the arguments for __new__ by calling
__getnewargs_ex__ on the object. */
getnewargs_ex = _PyObject_GetAttrId(obj, &PyId___getnewargs_ex__);
if (getnewargs_ex != NULL) {
PyObject *newargs = PyObject_CallObject(getnewargs_ex, NULL);
Py_DECREF(getnewargs_ex);
if (newargs == NULL) {
return -1;
}
if (!PyTuple_Check(newargs)) {
PyErr_Format(PyExc_TypeError,
"__getnewargs_ex__ should return a tuple, "
"not '%.200s'", Py_TYPE(newargs)->tp_name);
Py_DECREF(newargs);
return -1;
} }
if (Py_SIZE(newargs) != 2) {
PyErr_Format(PyExc_ValueError,
"__getnewargs_ex__ should return a tuple of "
"length 2, not %zd", Py_SIZE(newargs));
Py_DECREF(newargs);
return -1;
}
*args = PyTuple_GET_ITEM(newargs, 0);
Py_INCREF(*args);
*kwargs = PyTuple_GET_ITEM(newargs, 1);
Py_INCREF(*kwargs);
Py_DECREF(newargs);
/* XXX We should perhaps allow None to be passed here. */
if (!PyTuple_Check(*args)) {
PyErr_Format(PyExc_TypeError,
"first item of the tuple returned by "
"__getnewargs_ex__ must be a tuple, not '%.200s'",
Py_TYPE(*args)->tp_name);
Py_CLEAR(*args);
Py_CLEAR(*kwargs);
return -1;
}
if (!PyDict_Check(*kwargs)) {
PyErr_Format(PyExc_TypeError,
"second item of the tuple returned by "
"__getnewargs_ex__ must be a dict, not '%.200s'",
Py_TYPE(*kwargs)->tp_name);
Py_CLEAR(*args);
Py_CLEAR(*kwargs);
return -1;
}
return 0;
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1;
}
PyErr_Clear();
}
/* The object does not have __getnewargs_ex__ so we fallback on using
__getnewargs__ instead. */
getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__);
if (getnewargs != NULL) {
*args = PyObject_CallObject(getnewargs, NULL);
Py_DECREF(getnewargs);
if (*args == NULL) {
return -1;
}
if (!PyTuple_Check(*args)) {
PyErr_Format(PyExc_TypeError,
"__getnewargs__ should return a tuple, "
"not '%.200s'", Py_TYPE(*args)->tp_name);
Py_CLEAR(*args);
return -1;
}
*kwargs = NULL;
return 0;
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1;
}
PyErr_Clear();
}
/* The object does not have __getnewargs_ex__ and __getnewargs__. This may
means __new__ does not takes any arguments on this object, or that the
object does not implement the reduce protocol for pickling or
copying. */
*args = NULL;
*kwargs = NULL;
return 0;
}
Py_LOCAL(int)
_PyObject_GetItemsIter(PyObject *obj, PyObject **listitems,
PyObject **dictitems)
{
if (listitems == NULL || dictitems == NULL) {
PyErr_BadInternalCall();
return -1;
} }
if (!PyList_Check(obj)) { if (!PyList_Check(obj)) {
listitems = Py_None; *listitems = Py_None;
Py_INCREF(listitems); Py_INCREF(*listitems);
} }
else { else {
listitems = PyObject_GetIter(obj); *listitems = PyObject_GetIter(obj);
if (listitems == NULL) if (listitems == NULL)
goto end; return -1;
} }
if (!PyDict_Check(obj)) { if (!PyDict_Check(obj)) {
dictitems = Py_None; *dictitems = Py_None;
Py_INCREF(dictitems); Py_INCREF(*dictitems);
} }
else { else {
PyObject *items;
_Py_IDENTIFIER(items); _Py_IDENTIFIER(items);
PyObject *items = _PyObject_CallMethodId(obj, &PyId_items, "");
if (items == NULL) items = _PyObject_CallMethodIdObjArgs(obj, &PyId_items, NULL);
goto end; if (items == NULL) {
dictitems = PyObject_GetIter(items); Py_CLEAR(*listitems);
return -1;
}
*dictitems = PyObject_GetIter(items);
Py_DECREF(items); Py_DECREF(items);
if (dictitems == NULL) if (*dictitems == NULL) {
goto end; Py_CLEAR(*listitems);
return -1;
}
}
assert(*listitems != NULL && *dictitems != NULL);
return 0;
}
static PyObject *
reduce_4(PyObject *obj)
{
PyObject *args = NULL, *kwargs = NULL;
PyObject *copyreg;
PyObject *newobj, *newargs, *state, *listitems, *dictitems;
PyObject *result;
_Py_IDENTIFIER(__newobj_ex__);
if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) {
return NULL;
}
if (args == NULL) {
args = PyTuple_New(0);
if (args == NULL)
return NULL;
}
if (kwargs == NULL) {
kwargs = PyDict_New();
if (kwargs == NULL)
return NULL;
} }
copyreg = import_copyreg();
if (copyreg == NULL) {
Py_DECREF(args);
Py_DECREF(kwargs);
return NULL;
}
newobj = _PyObject_GetAttrId(copyreg, &PyId___newobj_ex__);
Py_DECREF(copyreg);
if (newobj == NULL) {
Py_DECREF(args);
Py_DECREF(kwargs);
return NULL;
}
newargs = PyTuple_Pack(3, Py_TYPE(obj), args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);
if (newargs == NULL) {
Py_DECREF(newobj);
return NULL;
}
state = _PyObject_GetState(obj);
if (state == NULL) {
Py_DECREF(newobj);
Py_DECREF(newargs);
return NULL;
}
if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0) {
Py_DECREF(newobj);
Py_DECREF(newargs);
Py_DECREF(state);
return NULL;
}
result = PyTuple_Pack(5, newobj, newargs, state, listitems, dictitems);
Py_DECREF(newobj);
Py_DECREF(newargs);
Py_DECREF(state);
Py_DECREF(listitems);
Py_DECREF(dictitems);
return result;
}
static PyObject *
reduce_2(PyObject *obj)
{
PyObject *cls;
PyObject *args = NULL, *args2 = NULL, *kwargs = NULL;
PyObject *state = NULL, *listitems = NULL, *dictitems = NULL;
PyObject *copyreg = NULL, *newobj = NULL, *res = NULL;
Py_ssize_t i, n;
_Py_IDENTIFIER(__newobj__);
if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) {
return NULL;
}
if (args == NULL) {
assert(kwargs == NULL);
args = PyTuple_New(0);
if (args == NULL) {
return NULL;
}
}
else if (kwargs != NULL) {
if (PyDict_Size(kwargs) > 0) {
PyErr_SetString(PyExc_ValueError,
"must use protocol 4 or greater to copy this "
"object; since __getnewargs_ex__ returned "
"keyword arguments.");
Py_DECREF(args);
Py_DECREF(kwargs);
return NULL;
}
Py_CLEAR(kwargs);
}
state = _PyObject_GetState(obj);
if (state == NULL)
goto end;
if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0)
goto end;
copyreg = import_copyreg(); copyreg = import_copyreg();
if (copyreg == NULL) if (copyreg == NULL)
goto end; goto end;
...@@ -3560,6 +3839,7 @@ reduce_2(PyObject *obj) ...@@ -3560,6 +3839,7 @@ reduce_2(PyObject *obj)
args2 = PyTuple_New(n+1); args2 = PyTuple_New(n+1);
if (args2 == NULL) if (args2 == NULL)
goto end; goto end;
cls = (PyObject *) Py_TYPE(obj);
Py_INCREF(cls); Py_INCREF(cls);
PyTuple_SET_ITEM(args2, 0, cls); PyTuple_SET_ITEM(args2, 0, cls);
for (i = 0; i < n; i++) { for (i = 0; i < n; i++) {
...@@ -3573,9 +3853,7 @@ reduce_2(PyObject *obj) ...@@ -3573,9 +3853,7 @@ reduce_2(PyObject *obj)
end: end:
Py_XDECREF(args); Py_XDECREF(args);
Py_XDECREF(args2); Py_XDECREF(args2);
Py_XDECREF(slots);
Py_XDECREF(state); Py_XDECREF(state);
Py_XDECREF(names);
Py_XDECREF(listitems); Py_XDECREF(listitems);
Py_XDECREF(dictitems); Py_XDECREF(dictitems);
Py_XDECREF(copyreg); Py_XDECREF(copyreg);
...@@ -3603,7 +3881,9 @@ _common_reduce(PyObject *self, int proto) ...@@ -3603,7 +3881,9 @@ _common_reduce(PyObject *self, int proto)
{ {
PyObject *copyreg, *res; PyObject *copyreg, *res;
if (proto >= 2) if (proto >= 4)
return reduce_4(self);
else if (proto >= 2)
return reduce_2(self); return reduce_2(self);
copyreg = import_copyreg(); copyreg = import_copyreg();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment