Commit a4ebd314 authored by Robert Bradshaw's avatar Robert Bradshaw

Allow self-referential members in default pickling.

parent 0860c918
......@@ -1639,8 +1639,10 @@ if VALUE is not None:
pickle_func = TreeFragment(u"""
def __reduce_cython__(self):
raise TypeError("%s")
""" % msg,
raise TypeError("%(msg)s")
def __setstate_cython__(self, __pyx_state):
raise TypeError("%(msg)s")
""" % {'msg': msg},
level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
pickle_func.analyse_declarations(node.scope)
self.visit(pickle_func)
......@@ -1658,28 +1660,28 @@ if VALUE is not None:
# TODO(robertwb): Move the state into the third argument
# so it can be pickled *after* self is memoized.
unpickle_func = TreeFragment(u"""
def %(unpickle_func_name)s(__pyx_type, long __pyx_checksum, __pyx_state, %(args)s):
def %(unpickle_func_name)s(__pyx_type, long __pyx_checksum, __pyx_state):
if __pyx_checksum != %(checksum)s:
from pickle import PickleError
raise PickleError("Incompatible checksums (%%s vs %(checksum)s = (%(members)s))" %% __pyx_checksum)
cdef %(class_name)s result
result = %(class_name)s.__new__(__pyx_type)
%(assignments)s
if hasattr(result, '__setstate__'):
result.__setstate__(__pyx_state)
elif hasattr(result, '__dict__'):
result.__dict__.update(__pyx_state)
elif __pyx_state is not None:
from pickle import PickleError
raise PickleError("Unexpected state: %%s" %% __pyx_state)
if __pyx_state is not None:
%(unpickle_func_name)s__set_state(<%(class_name)s> result, __pyx_state)
return result
cdef %(unpickle_func_name)s__set_state(%(class_name)s result, tuple __pyx_state):
%(assignments)s
if hasattr(result, '__dict__'):
result.__dict__.update(__pyx_state[%(num_members)s])
""" % {
'unpickle_func_name': unpickle_func_name,
'checksum': checksum,
'members': ', '.join(all_members_names),
'class_name': node.class_name,
'assignments': '; '.join('result.%s = __pyx_arg_%s' % (v, v) for v in all_members_names),
'args': ','.join('__pyx_arg_%s' % v for v in all_members_names),
'assignments': '; '.join(
'result.%s = __pyx_state[%s]' % (v, ix)
for ix, v in enumerate(all_members_names)),
'num_members': len(all_members_names),
}, level='module', pipeline=[NormalizeTree(None)]).substitute({})
unpickle_func.analyse_declarations(node.entry.scope)
self.visit(unpickle_func)
......@@ -1687,14 +1689,28 @@ if VALUE is not None:
pickle_func = TreeFragment(u"""
def __reduce_cython__(self):
if hasattr(self, '__getstate__'):
state = self.__getstate__()
elif hasattr(self, '__dict__'):
state = self.__dict__
cdef bint use_setstate
state = (%(members)s)
_dict = getattr(self, '__dict__', None)
if _dict is not None:
state += _dict,
use_setstate = True
else:
use_setstate = %(any_notnone_members)s
if use_setstate:
return %(unpickle_func_name)s, (type(self), %(checksum)s, None), state
else:
state = None
return %s, (type(self), %s, state, %s)
""" % (unpickle_func_name, checksum, ', '.join('self.%s' % v for v in all_members_names)),
return %(unpickle_func_name)s, (type(self), %(checksum)s, state)
def __setstate_cython__(self, __pyx_state):
%(unpickle_func_name)s__set_state(self, __pyx_state)
""" % {
'unpickle_func_name': unpickle_func_name,
'checksum': checksum,
'members': ', '.join('self.%s' % v for v in all_members_names) + (',' if len(all_members_names) == 1 else ''),
# Even better, we could check PyType_IS_GC.
'any_notnone_members' : ' or '.join(['self.%s is not None' % e.name for e in all_members if e.type.is_pyobject] or ['False']),
},
level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
pickle_func.analyse_declarations(node.scope)
self.visit(pickle_func)
......
......@@ -68,6 +68,8 @@ static int __Pyx_setup_reduce(PyObject* type_obj) {
PyObject *reduce = NULL;
PyObject *reduce_ex = NULL;
PyObject *reduce_cython = NULL;
PyObject *setstate = NULL;
PyObject *setstate_cython = NULL;
if (PyObject_HasAttrString(type_obj, "__getstate__")) goto GOOD;
......@@ -80,23 +82,36 @@ static int __Pyx_setup_reduce(PyObject* type_obj) {
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce_ex, type_obj, "__reduce_ex__");
if (reduce_ex == object_reduce_ex) {
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce, type_obj, "__reduce__");
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce_cython, type_obj, "__reduce_cython__");
if (object_reduce == reduce
|| (strcmp(reduce->ob_type->tp_name, "method_descriptor") == 0
&& strcmp(((PyMethodDescrObject*)reduce)->d_method->ml_name, "__reduce_cython__") == 0)) {
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce_cython, type_obj, "__reduce_cython__");
ret = PyDict_SetItemString(((PyTypeObject*)type_obj)->tp_dict, "__reduce__", reduce_cython); if (ret < 0) goto BAD;
ret = PyDict_DelItemString(((PyTypeObject*)type_obj)->tp_dict, "__reduce_cython__");
ret = PyDict_DelItemString(((PyTypeObject*)type_obj)->tp_dict, "__reduce_cython__"); if (ret < 0) goto BAD;
setstate = PyObject_GetAttrString(type_obj, "__setstate__");
if (!setstate) PyErr_Clear();
if (!setstate
|| (strcmp(setstate->ob_type->tp_name, "method_descriptor") == 0
&& strcmp(((PyMethodDescrObject*)setstate)->d_method->ml_name, "__setstate_cython__") == 0)) {
__Pyx_setup_reduce_GET_ATTR_OR_BAD(setstate_cython, type_obj, "__setstate_cython__");
ret = PyDict_SetItemString(((PyTypeObject*)type_obj)->tp_dict, "__setstate__", setstate_cython); if (ret < 0) goto BAD;
ret = PyDict_DelItemString(((PyTypeObject*)type_obj)->tp_dict, "__setstate_cython__"); if (ret < 0) goto BAD;
}
PyType_Modified((PyTypeObject*)type_obj);
}
}
goto GOOD;
BAD:
if (!PyErr_Occurred()) PyErr_Format(PyExc_RuntimeError, "Unable to initialize pickling for %s", ((PyTypeObject*)type_obj)->tp_name);
ret = -1;
GOOD:
Py_XDECREF(builtin_object);
Py_XDECREF(reduce);
Py_XDECREF(reduce_ex);
Py_XDECREF(reduce_cython);
Py_XDECREF(setstate);
Py_XDECREF(setstate_cython);
return ret;
}
......@@ -80,6 +80,8 @@ cdef class DefaultReduce(object):
>>> import pickle
>>> pickle.loads(pickle.dumps(a))
DefaultReduce(i=11, s='abc')
>>> pickle.loads(pickle.dumps(DefaultReduce(i=11, s=None)))
DefaultReduce(i=11, s=None)
"""
cdef readonly int i
......@@ -119,6 +121,11 @@ class DefaultReducePySubclass(DefaultReduce):
>>> import pickle
>>> pickle.loads(pickle.dumps(a))
DefaultReducePySubclass(i=11, s='abc', x=1.5)
>>> a.self_reference = a
>>> a2 = pickle.loads(pickle.dumps(a))
>>> a2.self_reference is a2
True
"""
def __init__(self, **kwargs):
self.x = kwargs.pop('x', 0)
......@@ -148,3 +155,91 @@ cdef class NoReduceDueToNontrivialCInit(object):
"""
def __cinit__(self, arg):
pass
cdef class NoMembers(object):
"""
>>> import pickle
>>> pickle.loads(pickle.dumps(NoMembers()))
NoMembers()
"""
def __repr__(self):
return "NoMembers()"
cdef struct MyStruct:
int i
double x
cdef class NoPyMembers(object):
"""
>>> import pickle
>>> pickle.loads(pickle.dumps(NoPyMembers(2, 1.75)))
NoPyMembers(ii=[2, 4, 8], x=1.75, my_struct=(3, 2.75))
"""
cdef int[3] ii
cdef double x
cdef MyStruct my_struct
def __init__(self, i, x):
self.ii[0] = i
self.ii[1] = i * i
self.ii[2] = i * i * i
self.x = x
self.my_struct = MyStruct(i+1, x+1)
def __repr__(self):
return "NoPyMembers(ii=%s, x=%s, my_struct=(%s, %s))" % (
self.ii, self.x, self.my_struct.i, self.my_struct.x)
class NoPyMembersPySubclass(NoPyMembers):
"""
>>> import pickle
>>> pickle.loads(pickle.dumps(NoPyMembersPySubclass(2, 1.75, 'xyz')))
NoPyMembersPySubclass(ii=[2, 4, 8], x=1.75, my_struct=(3, 2.75), s='xyz')
"""
def __init__(self, i, x, s):
super(NoPyMembersPySubclass, self).__init__(i, x)
self.s = s
def __repr__(self):
return super(NoPyMembersPySubclass, self).__repr__().replace(
'NoPyMembers', 'NoPyMembersPySubclass')[:-1] + ', s=%r)' % self.s
cdef _unset = object()
# Test cyclic references.
cdef class Wrapper(object):
"""
>>> import pickle
>>> w = Wrapper(); w
Wrapper(...)
>>> w2 = pickle.loads(pickle.dumps(w)); w2
Wrapper(...)
>>> w2.ref is w2
True
>>> pickle.loads(pickle.dumps(Wrapper(DefaultReduce(1, 'xyz'))))
Wrapper(DefaultReduce(i=1, s='xyz'))
>>> L = [None]
>>> L[0] = L
>>> w = Wrapper(L)
>>> pickle.loads(pickle.dumps(Wrapper(L)))
Wrapper([[...]])
>>> L[0] = w # Don't print this one out...
>>> w2 = pickle.loads(pickle.dumps(w))
>>> w2.ref[0] is w2
True
"""
cdef public object ref
def __init__(self, ref=_unset):
if ref is _unset:
self.ref = self
else:
self.ref = ref
def __repr__(self):
if self.ref is self:
return "Wrapper(...)"
else:
return "Wrapper(%r)" % self.ref
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment