Commit f8b3405e authored by Robert Bradshaw's avatar Robert Bradshaw

Don't override unknown inherited __reduce__.

parent 0ef81256
...@@ -2807,6 +2807,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2807,6 +2807,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
weakref_entry.cname)) weakref_entry.cname))
else: else:
error(weakref_entry.pos, "__weakref__ slot must be of type 'object'") error(weakref_entry.pos, "__weakref__ slot must be of type 'object'")
if scope.lookup_here("__reduce_cython__"):
# Unfortunately, we cannot reliably detect whether a
# superclass defined __reduce__ at compile time, so we must
# do so at runtime.
code.globalstate.use_utility_code(
UtilityCode.load_cached('SetupReduce', 'ExtensionTypes.c'))
code.putln('if (__Pyx_setup_reduce((PyObject*)&%s) < 0) %s' % (
typeobj_cname,
code.error_goto(entry.pos)))
def generate_exttype_vtable_init_code(self, entry, code): def generate_exttype_vtable_init_code(self, entry, code):
# Generate code to initialise the C method table of an # Generate code to initialise the C method table of an
......
...@@ -1579,12 +1579,18 @@ if VALUE is not None: ...@@ -1579,12 +1579,18 @@ if VALUE is not None:
all_members = [] all_members = []
cls = node.entry.type cls = node.entry.type
cinit = None cinit = None
inherited_reduce = None
while cls is not None: while cls is not None:
all_members.extend(e for e in cls.scope.var_entries if e.name not in ('__weakref__', '__dict__')) all_members.extend(e for e in cls.scope.var_entries if e.name not in ('__weakref__', '__dict__'))
cinit = cinit or cls.scope.lookup('__cinit__') cinit = cinit or cls.scope.lookup('__cinit__')
inherited_reduce = inherited_reduce or cls.scope.lookup('__reduce__') or cls.scope.lookup('__reduce_ex__')
cls = cls.base_type cls = cls.base_type
all_members.sort(key=lambda e: e.name) all_members.sort(key=lambda e: e.name)
if inherited_reduce:
# This is not failsafe, as we may not know whether a cimported class defines a __reduce__.
return
non_py = [ non_py = [
e for e in all_members e for e in all_members
if not e.type.is_pyobject and (not e.type.create_from_py_utility_code(env) if not e.type.is_pyobject and (not e.type.create_from_py_utility_code(env)
...@@ -1601,7 +1607,7 @@ if VALUE is not None: ...@@ -1601,7 +1607,7 @@ if VALUE is not None:
error(node.pos, msg) error(node.pos, msg)
pickle_func = TreeFragment(u""" pickle_func = TreeFragment(u"""
def __reduce__(self): def __reduce_cython__(self):
raise TypeError("%s") raise TypeError("%s")
""" % msg, """ % msg,
level='c_class', pipeline=[NormalizeTree(None)]).substitute({}) level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
...@@ -1643,7 +1649,7 @@ if VALUE is not None: ...@@ -1643,7 +1649,7 @@ if VALUE is not None:
self.extra_module_declarations.append(unpickle_func) self.extra_module_declarations.append(unpickle_func)
pickle_func = TreeFragment(u""" pickle_func = TreeFragment(u"""
def __reduce__(self): def __reduce_cython__(self):
if hasattr(self, '__getstate__'): if hasattr(self, '__getstate__'):
state = self.__getstate__() state = self.__getstate__()
elif hasattr(self, '__dict__'): elif hasattr(self, '__dict__'):
......
...@@ -51,3 +51,57 @@ static void __Pyx_call_next_tp_clear(PyObject* obj, inquiry current_tp_clear) { ...@@ -51,3 +51,57 @@ static void __Pyx_call_next_tp_clear(PyObject* obj, inquiry current_tp_clear) {
if (type && type->tp_clear) if (type && type->tp_clear)
type->tp_clear(obj); type->tp_clear(obj);
} }
/////////////// SetupReduce.proto ///////////////
static int __Pyx_setup_reduce(PyObject* type_obj);
/////////////// SetupReduce ///////////////
#define __Pyx_setup_reduce_GET_ATTR_OR_BAD(res, obj, name) res = PyObject_GetAttrString(obj, name); if (res == NULL) goto BAD;
static int __Pyx_setup_reduce(PyObject* type_obj) {
int ret = 0;
PyObject* builtin_object = NULL;
static PyObject *object_reduce = NULL;
static PyObject *object_reduce_ex = NULL;
PyObject *reduce = NULL;
PyObject *reduce_ex = NULL;
PyObject *reduce_cython = NULL;
PyObject *reduce_name = NULL;
PyObject *reduce_cython_name = NULL;
PyObject *same_name = NULL;
if (object_reduce_ex == NULL) {
__Pyx_setup_reduce_GET_ATTR_OR_BAD(builtin_object, __pyx_b, "object");
__Pyx_setup_reduce_GET_ATTR_OR_BAD(object_reduce, builtin_object, "__reduce__");
__Pyx_setup_reduce_GET_ATTR_OR_BAD(object_reduce_ex, builtin_object, "__reduce_ex__");
}
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce_ex, type_obj, "__reduce_ex__");
if (reduce_ex == object_reduce_ex) {
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce, type_obj, "__reduce__");
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce_cython, type_obj, "__reduce_cython__");
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce_name, reduce, "__name__");
__Pyx_setup_reduce_GET_ATTR_OR_BAD(reduce_cython_name, reduce_cython, "__name__");
same_name = PyObject_RichCompare(reduce_name, reduce_cython_name, Py_EQ); if (same_name == NULL) goto BAD;
if (object_reduce == reduce || PyObject_IsTrue(same_name)) {
ret = PyDict_SetItemString(((PyTypeObject*)type_obj)->tp_dict, "__reduce__", reduce_cython); if (ret < 0) goto BAD;
ret = PyDict_DelItemString(((PyTypeObject*)type_obj)->tp_dict, "__reduce_cython__");
PyType_Modified((PyTypeObject*)type_obj);
}
}
goto GOOD;
BAD:
ret = -1;
GOOD:
Py_XDECREF(builtin_object);
Py_XDECREF(reduce);
Py_XDECREF(reduce_ex);
Py_XDECREF(reduce_cython);
Py_XDECREF(reduce_name);
Py_XDECREF(reduce_cython_name);
Py_XDECREF(same_name);
return ret;
}
...@@ -46,18 +46,30 @@ cdef class B: ...@@ -46,18 +46,30 @@ cdef class B:
cdef int x, y cdef int x, y
def __cinit__(self):
self.x = self.y = -1
def __init__(self, x=0, y=0): def __init__(self, x=0, y=0):
self.x = x self.x = x
self.y = y self.y = y
def __repr__(self): def __repr__(self):
return "B(x=%s, y=%s)" % (self.x, self.y) return "%s(x=%s, y=%s)" % (self.__class__.__name__, self.x, self.y)
def __reduce__(self): def __reduce__(self):
return makeB, ({'x': self.x, 'y': self.y},) return makeObj, (type(self), {'x': self.x, 'y': self.y})
def makeObj(obj_type, kwds):
return obj_type(**kwds)
def makeB(kwds): cdef class C(B):
return B(**kwds) """
>>> import pickle
>>> pickle.loads(pickle.dumps(C(x=37, y=389)))
C(x=37, y=389)
"""
pass
@cython.auto_pickle(True) # Not needed, just to test the directive. @cython.auto_pickle(True) # Not needed, just to test the directive.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment