Commit 986375eb authored by Serhiy Storchaka's avatar Serhiy Storchaka Committed by GitHub

bpo-28416: Break reference cycles in Pickler and Unpickler subclasses (#4080)

with the persistent_id() and persistent_load() methods.
parent bc8ac6b0
...@@ -6,6 +6,7 @@ import io ...@@ -6,6 +6,7 @@ import io
import collections import collections
import struct import struct
import sys import sys
import weakref
import unittest import unittest
from test import support from test import support
...@@ -117,6 +118,66 @@ class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, ...@@ -117,6 +118,66 @@ class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
pickler = pickle._Pickler pickler = pickle._Pickler
unpickler = pickle._Unpickler unpickler = pickle._Unpickler
@support.cpython_only
def test_pickler_reference_cycle(self):
def check(Pickler):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
f = io.BytesIO()
pickler = Pickler(f, proto)
pickler.dump('abc')
self.assertEqual(self.loads(f.getvalue()), 'abc')
pickler = Pickler(io.BytesIO())
self.assertEqual(pickler.persistent_id('def'), 'def')
r = weakref.ref(pickler)
del pickler
self.assertIsNone(r())
class PersPickler(self.pickler):
def persistent_id(subself, obj):
return obj
check(PersPickler)
class PersPickler(self.pickler):
@classmethod
def persistent_id(cls, obj):
return obj
check(PersPickler)
class PersPickler(self.pickler):
@staticmethod
def persistent_id(obj):
return obj
check(PersPickler)
@support.cpython_only
def test_unpickler_reference_cycle(self):
def check(Unpickler):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
self.assertEqual(unpickler.load(), 'abc')
unpickler = Unpickler(io.BytesIO())
self.assertEqual(unpickler.persistent_load('def'), 'def')
r = weakref.ref(unpickler)
del unpickler
self.assertIsNone(r())
class PersUnpickler(self.unpickler):
def persistent_load(subself, pid):
return pid
check(PersUnpickler)
class PersUnpickler(self.unpickler):
@classmethod
def persistent_load(cls, pid):
return pid
check(PersUnpickler)
class PersUnpickler(self.unpickler):
@staticmethod
def persistent_load(pid):
return pid
check(PersUnpickler)
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
...@@ -197,7 +258,7 @@ if has_c_implementation: ...@@ -197,7 +258,7 @@ if has_c_implementation:
check_sizeof = support.check_sizeof check_sizeof = support.check_sizeof
def test_pickler(self): def test_pickler(self):
basesize = support.calcobjsize('5P2n3i2n3iP') basesize = support.calcobjsize('6P2n3i2n3iP')
p = _pickle.Pickler(io.BytesIO()) p = _pickle.Pickler(io.BytesIO())
self.assertEqual(object.__sizeof__(p), basesize) self.assertEqual(object.__sizeof__(p), basesize)
MT_size = struct.calcsize('3nP0n') MT_size = struct.calcsize('3nP0n')
...@@ -214,7 +275,7 @@ if has_c_implementation: ...@@ -214,7 +275,7 @@ if has_c_implementation:
0) # Write buffer is cleared after every dump(). 0) # Write buffer is cleared after every dump().
def test_unpickler(self): def test_unpickler(self):
basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i') basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i')
unpickler = _pickle.Unpickler unpickler = _pickle.Unpickler
P = struct.calcsize('P') # Size of memo table entry. P = struct.calcsize('P') # Size of memo table entry.
n = struct.calcsize('n') # Size of mark table entry. n = struct.calcsize('n') # Size of mark table entry.
......
Instances of pickle.Pickler subclass with the persistent_id() method and
pickle.Unpickler subclass with the persistent_load() method no longer create
reference cycles.
...@@ -360,6 +360,69 @@ _Pickle_FastCall(PyObject *func, PyObject *obj) ...@@ -360,6 +360,69 @@ _Pickle_FastCall(PyObject *func, PyObject *obj)
/*************************************************************************/ /*************************************************************************/
/* Retrieve and deconstruct a method for avoiding a reference cycle
(pickler -> bound method of pickler -> pickler) */
static int
init_method_ref(PyObject *self, _Py_Identifier *name,
PyObject **method_func, PyObject **method_self)
{
PyObject *func, *func2;
/* *method_func and *method_self should be consistent. All refcount decrements
should be occurred after setting *method_self and *method_func. */
func = _PyObject_GetAttrId(self, name);
if (func == NULL) {
*method_self = NULL;
Py_CLEAR(*method_func);
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1;
}
PyErr_Clear();
return 0;
}
if (PyMethod_Check(func) && PyMethod_GET_SELF(func) == self) {
/* Deconstruct a bound Python method */
func2 = PyMethod_GET_FUNCTION(func);
Py_INCREF(func2);
*method_self = self; /* borrowed */
Py_XSETREF(*method_func, func2);
Py_DECREF(func);
return 0;
}
else {
*method_self = NULL;
Py_XSETREF(*method_func, func);
return 0;
}
}
/* Bind a method if it was deconstructed */
static PyObject *
reconstruct_method(PyObject *func, PyObject *self)
{
if (self) {
return PyMethod_New(func, self);
}
else {
Py_INCREF(func);
return func;
}
}
static PyObject *
call_method(PyObject *func, PyObject *self, PyObject *obj)
{
if (self) {
return PyObject_CallFunctionObjArgs(func, self, obj, NULL);
}
else {
return PyObject_CallFunctionObjArgs(func, obj, NULL);
}
}
/*************************************************************************/
/* Internal data type used as the unpickling stack. */ /* Internal data type used as the unpickling stack. */
typedef struct { typedef struct {
PyObject_VAR_HEAD PyObject_VAR_HEAD
...@@ -552,6 +615,8 @@ typedef struct PicklerObject { ...@@ -552,6 +615,8 @@ typedef struct PicklerObject {
objects to support self-referential objects objects to support self-referential objects
pickling. */ pickling. */
PyObject *pers_func; /* persistent_id() method, can be NULL */ PyObject *pers_func; /* persistent_id() method, can be NULL */
PyObject *pers_func_self; /* borrowed reference to self if pers_func
is an unbound method, NULL otherwise */
PyObject *dispatch_table; /* private dispatch_table, can be NULL */ PyObject *dispatch_table; /* private dispatch_table, can be NULL */
PyObject *write; /* write() method of the output stream. */ PyObject *write; /* write() method of the output stream. */
...@@ -590,6 +655,8 @@ typedef struct UnpicklerObject { ...@@ -590,6 +655,8 @@ typedef struct UnpicklerObject {
Py_ssize_t memo_len; /* Number of objects in the memo */ Py_ssize_t memo_len; /* Number of objects in the memo */
PyObject *pers_func; /* persistent_load() method, can be NULL. */ PyObject *pers_func; /* persistent_load() method, can be NULL. */
PyObject *pers_func_self; /* borrowed reference to self if pers_func
is an unbound method, NULL otherwise */
Py_buffer buffer; Py_buffer buffer;
char *input_buffer; char *input_buffer;
...@@ -3444,7 +3511,7 @@ save_type(PicklerObject *self, PyObject *obj) ...@@ -3444,7 +3511,7 @@ save_type(PicklerObject *self, PyObject *obj)
} }
static int static int
save_pers(PicklerObject *self, PyObject *obj, PyObject *func) save_pers(PicklerObject *self, PyObject *obj)
{ {
PyObject *pid = NULL; PyObject *pid = NULL;
int status = 0; int status = 0;
...@@ -3452,8 +3519,7 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func) ...@@ -3452,8 +3519,7 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
const char persid_op = PERSID; const char persid_op = PERSID;
const char binpersid_op = BINPERSID; const char binpersid_op = BINPERSID;
Py_INCREF(obj); pid = call_method(self->pers_func, self->pers_func_self, obj);
pid = _Pickle_FastCall(func, obj);
if (pid == NULL) if (pid == NULL)
return -1; return -1;
...@@ -3831,7 +3897,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save) ...@@ -3831,7 +3897,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
0 if it did nothing successfully; 0 if it did nothing successfully;
1 if a persistent id was saved. 1 if a persistent id was saved.
*/ */
if ((status = save_pers(self, obj, self->pers_func)) != 0) if ((status = save_pers(self, obj)) != 0)
goto done; goto done;
} }
...@@ -4246,14 +4312,11 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file, ...@@ -4246,14 +4312,11 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file,
self->fast_nesting = 0; self->fast_nesting = 0;
self->fast_memo = NULL; self->fast_memo = NULL;
self->pers_func = _PyObject_GetAttrId((PyObject *)self, if (init_method_ref((PyObject *)self, &PyId_persistent_id,
&PyId_persistent_id); &self->pers_func, &self->pers_func_self) < 0)
if (self->pers_func == NULL) { {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1; return -1;
} }
PyErr_Clear();
}
self->dispatch_table = _PyObject_GetAttrId((PyObject *)self, self->dispatch_table = _PyObject_GetAttrId((PyObject *)self,
&PyId_dispatch_table); &PyId_dispatch_table);
...@@ -4519,11 +4582,11 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj) ...@@ -4519,11 +4582,11 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj)
static PyObject * static PyObject *
Pickler_get_persid(PicklerObject *self) Pickler_get_persid(PicklerObject *self)
{ {
if (self->pers_func == NULL) if (self->pers_func == NULL) {
PyErr_SetString(PyExc_AttributeError, "persistent_id"); PyErr_SetString(PyExc_AttributeError, "persistent_id");
else return NULL;
Py_INCREF(self->pers_func); }
return self->pers_func; return reconstruct_method(self->pers_func, self->pers_func_self);
} }
static int static int
...@@ -4540,6 +4603,7 @@ Pickler_set_persid(PicklerObject *self, PyObject *value) ...@@ -4540,6 +4603,7 @@ Pickler_set_persid(PicklerObject *self, PyObject *value)
return -1; return -1;
} }
self->pers_func_self = NULL;
Py_INCREF(value); Py_INCREF(value);
Py_XSETREF(self->pers_func, value); Py_XSETREF(self->pers_func, value);
...@@ -5489,7 +5553,7 @@ load_stack_global(UnpicklerObject *self) ...@@ -5489,7 +5553,7 @@ load_stack_global(UnpicklerObject *self)
static int static int
load_persid(UnpicklerObject *self) load_persid(UnpicklerObject *self)
{ {
PyObject *pid; PyObject *pid, *obj;
Py_ssize_t len; Py_ssize_t len;
char *s; char *s;
...@@ -5509,13 +5573,12 @@ load_persid(UnpicklerObject *self) ...@@ -5509,13 +5573,12 @@ load_persid(UnpicklerObject *self)
return -1; return -1;
} }
/* This does not leak since _Pickle_FastCall() steals the reference obj = call_method(self->pers_func, self->pers_func_self, pid);
to pid first. */ Py_DECREF(pid);
pid = _Pickle_FastCall(self->pers_func, pid); if (obj == NULL)
if (pid == NULL)
return -1; return -1;
PDATA_PUSH(self->stack, pid, -1); PDATA_PUSH(self->stack, obj, -1);
return 0; return 0;
} }
else { else {
...@@ -5530,20 +5593,19 @@ load_persid(UnpicklerObject *self) ...@@ -5530,20 +5593,19 @@ load_persid(UnpicklerObject *self)
static int static int
load_binpersid(UnpicklerObject *self) load_binpersid(UnpicklerObject *self)
{ {
PyObject *pid; PyObject *pid, *obj;
if (self->pers_func) { if (self->pers_func) {
PDATA_POP(self->stack, pid); PDATA_POP(self->stack, pid);
if (pid == NULL) if (pid == NULL)
return -1; return -1;
/* This does not leak since _Pickle_FastCall() steals the obj = call_method(self->pers_func, self->pers_func_self, pid);
reference to pid first. */ Py_DECREF(pid);
pid = _Pickle_FastCall(self->pers_func, pid); if (obj == NULL)
if (pid == NULL)
return -1; return -1;
PDATA_PUSH(self->stack, pid, -1); PDATA_PUSH(self->stack, obj, -1);
return 0; return 0;
} }
else { else {
...@@ -6690,14 +6752,11 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file, ...@@ -6690,14 +6752,11 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file,
self->fix_imports = fix_imports; self->fix_imports = fix_imports;
self->pers_func = _PyObject_GetAttrId((PyObject *)self, if (init_method_ref((PyObject *)self, &PyId_persistent_load,
&PyId_persistent_load); &self->pers_func, &self->pers_func_self) < 0)
if (self->pers_func == NULL) { {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1; return -1;
} }
PyErr_Clear();
}
self->stack = (Pdata *)Pdata_New(); self->stack = (Pdata *)Pdata_New();
if (self->stack == NULL) if (self->stack == NULL)
...@@ -6983,11 +7042,11 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj) ...@@ -6983,11 +7042,11 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj)
static PyObject * static PyObject *
Unpickler_get_persload(UnpicklerObject *self) Unpickler_get_persload(UnpicklerObject *self)
{ {
if (self->pers_func == NULL) if (self->pers_func == NULL) {
PyErr_SetString(PyExc_AttributeError, "persistent_load"); PyErr_SetString(PyExc_AttributeError, "persistent_load");
else return NULL;
Py_INCREF(self->pers_func); }
return self->pers_func; return reconstruct_method(self->pers_func, self->pers_func_self);
} }
static int static int
...@@ -7005,6 +7064,7 @@ Unpickler_set_persload(UnpicklerObject *self, PyObject *value) ...@@ -7005,6 +7064,7 @@ Unpickler_set_persload(UnpicklerObject *self, PyObject *value)
return -1; return -1;
} }
self->pers_func_self = NULL;
Py_INCREF(value); Py_INCREF(value);
Py_XSETREF(self->pers_func, value); Py_XSETREF(self->pers_func, value);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment