Commit 1c30ab2f authored by Serhiy Storchaka's avatar Serhiy Storchaka

Issue #17711: Fixed unpickling by the persistent ID with protocol 0.

Original patch by Alexandre Vassalotti.
parent c89e255f
...@@ -529,7 +529,11 @@ class _Pickler: ...@@ -529,7 +529,11 @@ class _Pickler:
self.save(pid, save_persistent_id=False) self.save(pid, save_persistent_id=False)
self.write(BINPERSID) self.write(BINPERSID)
else: else:
try:
self.write(PERSID + str(pid).encode("ascii") + b'\n') self.write(PERSID + str(pid).encode("ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
"persistent IDs in protocol 0 must be ASCII strings")
def save_reduce(self, func, args, state=None, listitems=None, def save_reduce(self, func, args, state=None, listitems=None,
dictitems=None, obj=None): dictitems=None, obj=None):
...@@ -1075,7 +1079,11 @@ class _Unpickler: ...@@ -1075,7 +1079,11 @@ class _Unpickler:
dispatch[FRAME[0]] = load_frame dispatch[FRAME[0]] = load_frame
def load_persid(self): def load_persid(self):
try:
pid = self.readline()[:-1].decode("ascii") pid = self.readline()[:-1].decode("ascii")
except UnicodeDecodeError:
raise UnpicklingError(
"persistent IDs in protocol 0 must be ASCII strings")
self.append(self.persistent_load(pid)) self.append(self.persistent_load(pid))
dispatch[PERSID[0]] = load_persid dispatch[PERSID[0]] = load_persid
......
...@@ -2629,6 +2629,35 @@ class AbstractPersistentPicklerTests(unittest.TestCase): ...@@ -2629,6 +2629,35 @@ class AbstractPersistentPicklerTests(unittest.TestCase):
self.assertEqual(self.load_false_count, 1) self.assertEqual(self.load_false_count, 1)
class AbstractIdentityPersistentPicklerTests(unittest.TestCase):
def persistent_id(self, obj):
return obj
def persistent_load(self, pid):
return pid
def _check_return_correct_type(self, obj, proto):
unpickled = self.loads(self.dumps(obj, proto))
self.assertIsInstance(unpickled, type(obj))
self.assertEqual(unpickled, obj)
def test_return_correct_type(self):
for proto in protocols:
# Protocol 0 supports only ASCII strings.
if proto == 0:
self._check_return_correct_type("abc", 0)
else:
for obj in [b"abc\n", "abc\n", -1, -1.1 * 0.1, str]:
self._check_return_correct_type(obj, proto)
def test_protocol0_is_ascii_only(self):
non_ascii_str = "\N{EMPTY SET}"
self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0)
pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.'
self.assertRaises(pickle.UnpicklingError, self.loads, pickled)
class AbstractPicklerUnpicklerObjectTests(unittest.TestCase): class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
pickler_class = None pickler_class = None
......
...@@ -14,6 +14,7 @@ from test.pickletester import AbstractUnpickleTests ...@@ -14,6 +14,7 @@ from test.pickletester import AbstractUnpickleTests
from test.pickletester import AbstractPickleTests from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests from test.pickletester import AbstractPersistentPicklerTests
from test.pickletester import AbstractIdentityPersistentPicklerTests
from test.pickletester import AbstractPicklerUnpicklerObjectTests from test.pickletester import AbstractPicklerUnpicklerObjectTests
from test.pickletester import AbstractDispatchTableTests from test.pickletester import AbstractDispatchTableTests
from test.pickletester import BigmemPickleTests from test.pickletester import BigmemPickleTests
...@@ -82,10 +83,7 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, ...@@ -82,10 +83,7 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
return pickle.loads(buf, **kwds) return pickle.loads(buf, **kwds)
class PyPersPicklerTests(AbstractPersistentPicklerTests): class PersistentPicklerUnpicklerMixin(object):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
def dumps(self, arg, proto=None): def dumps(self, arg, proto=None):
class PersPickler(self.pickler): class PersPickler(self.pickler):
...@@ -94,8 +92,7 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests): ...@@ -94,8 +92,7 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
f = io.BytesIO() f = io.BytesIO()
p = PersPickler(f, proto) p = PersPickler(f, proto)
p.dump(arg) p.dump(arg)
f.seek(0) return f.getvalue()
return f.read()
def loads(self, buf, **kwds): def loads(self, buf, **kwds):
class PersUnpickler(self.unpickler): class PersUnpickler(self.unpickler):
...@@ -106,6 +103,20 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests): ...@@ -106,6 +103,20 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
return u.load() return u.load()
class PyPersPicklerTests(AbstractPersistentPicklerTests,
PersistentPicklerUnpicklerMixin):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
PersistentPicklerUnpicklerMixin):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
pickler_class = pickle._Pickler pickler_class = pickle._Pickler
...@@ -144,6 +155,10 @@ if has_c_implementation: ...@@ -144,6 +155,10 @@ if has_c_implementation:
pickler = _pickle.Pickler pickler = _pickle.Pickler
unpickler = _pickle.Unpickler unpickler = _pickle.Unpickler
class CIdPersPicklerTests(PyIdPersPicklerTests):
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
class CDumpPickle_LoadPickle(PyPicklerTests): class CDumpPickle_LoadPickle(PyPicklerTests):
pickler = _pickle.Pickler pickler = _pickle.Pickler
unpickler = pickle._Unpickler unpickler = pickle._Unpickler
...@@ -409,11 +424,13 @@ class CompatPickleTests(unittest.TestCase): ...@@ -409,11 +424,13 @@ class CompatPickleTests(unittest.TestCase):
def test_main(): def test_main():
tests = [PickleTests, PyUnpicklerTests, PyPicklerTests, PyPersPicklerTests, tests = [PickleTests, PyUnpicklerTests, PyPicklerTests,
PyPersPicklerTests, PyIdPersPicklerTests,
PyDispatchTableTests, PyChainDispatchTableTests, PyDispatchTableTests, PyChainDispatchTableTests,
CompatPickleTests] CompatPickleTests]
if has_c_implementation: if has_c_implementation:
tests.extend([CUnpicklerTests, CPicklerTests, CPersPicklerTests, tests.extend([CUnpicklerTests, CPicklerTests,
CPersPicklerTests, CIdPersPicklerTests,
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle, CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
PyPicklerUnpicklerObjectTests, PyPicklerUnpicklerObjectTests,
CPicklerUnpicklerObjectTests, CPicklerUnpicklerObjectTests,
......
...@@ -24,6 +24,9 @@ Core and Builtins ...@@ -24,6 +24,9 @@ Core and Builtins
Library Library
------- -------
- Issue #17711: Fixed unpickling by the persistent ID with protocol 0.
Original patch by Alexandre Vassalotti.
- Issue #27522: Avoid an unintentional reference cycle in email.feedparser. - Issue #27522: Avoid an unintentional reference cycle in email.feedparser.
- Issue #26844: Fix error message for imp.find_module() to refer to 'path' - Issue #26844: Fix error message for imp.find_module() to refer to 'path'
......
...@@ -3406,27 +3406,31 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func) ...@@ -3406,27 +3406,31 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
goto error; goto error;
} }
else { else {
PyObject *pid_str = NULL; PyObject *pid_str;
char *pid_ascii_bytes;
Py_ssize_t size;
pid_str = PyObject_Str(pid); pid_str = PyObject_Str(pid);
if (pid_str == NULL) if (pid_str == NULL)
goto error; goto error;
/* XXX: Should it check whether the persistent id only contains /* XXX: Should it check whether the pid contains embedded
ASCII characters? And what if the pid contains embedded
newlines? */ newlines? */
pid_ascii_bytes = _PyUnicode_AsStringAndSize(pid_str, &size); if (!PyUnicode_IS_ASCII(pid_str)) {
PyErr_SetString(_Pickle_GetGlobalState()->PicklingError,
"persistent IDs in protocol 0 must be "
"ASCII strings");
Py_DECREF(pid_str); Py_DECREF(pid_str);
if (pid_ascii_bytes == NULL)
goto error; goto error;
}
if (_Pickler_Write(self, &persid_op, 1) < 0 || if (_Pickler_Write(self, &persid_op, 1) < 0 ||
_Pickler_Write(self, pid_ascii_bytes, size) < 0 || _Pickler_Write(self, PyUnicode_DATA(pid_str),
_Pickler_Write(self, "\n", 1) < 0) PyUnicode_GET_LENGTH(pid_str)) < 0 ||
_Pickler_Write(self, "\n", 1) < 0) {
Py_DECREF(pid_str);
goto error; goto error;
} }
Py_DECREF(pid_str);
}
status = 1; status = 1;
} }
...@@ -5389,9 +5393,15 @@ load_persid(UnpicklerObject *self) ...@@ -5389,9 +5393,15 @@ load_persid(UnpicklerObject *self)
if (len < 1) if (len < 1)
return bad_readline(); return bad_readline();
pid = PyBytes_FromStringAndSize(s, len - 1); pid = PyUnicode_DecodeASCII(s, len - 1, "strict");
if (pid == NULL) if (pid == NULL) {
if (PyErr_ExceptionMatches(PyExc_UnicodeDecodeError)) {
PyErr_SetString(_Pickle_GetGlobalState()->UnpicklingError,
"persistent IDs in protocol 0 must be "
"ASCII strings");
}
return -1; return -1;
}
/* This does not leak since _Pickle_FastCall() steals the reference /* This does not leak since _Pickle_FastCall() steals the reference
to pid first. */ to pid first. */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment