Commit 0d9eedb9 authored by Marius Wachtler's avatar Marius Wachtler

Allow changing class.__bases__

parent c6788275
......@@ -2647,6 +2647,246 @@ static void remove_subclass(PyTypeObject* base, PyTypeObject* type) noexcept {
}
}
static int equiv_structs(PyTypeObject* a, PyTypeObject* b) noexcept {
// Pyston change: added attrs_offset equality check
// return a == b || (a != NULL && b != NULL && a->tp_basicsize == b->tp_basicsize
// && a->tp_itemsize == b->tp_itemsize
// && a->tp_dictoffset == b->tp_dictoffset && a->tp_weaklistoffset == b->tp_weaklistoffset
// && ((a->tp_flags & Py_TPFLAGS_HAVE_GC) == (b->tp_flags & Py_TPFLAGS_HAVE_GC)));
return a == b || (a != NULL && b != NULL && a->tp_basicsize == b->tp_basicsize && a->tp_itemsize == b->tp_itemsize
&& a->tp_dictoffset == b->tp_dictoffset && a->tp_weaklistoffset == b->tp_weaklistoffset
&& a->attrs_offset == b->attrs_offset
&& ((a->tp_flags & Py_TPFLAGS_HAVE_GC) == (b->tp_flags & Py_TPFLAGS_HAVE_GC)));
}
static void update_all_slots(PyTypeObject* type) noexcept {
slotdef* p;
init_slotdefs();
for (p = slotdefs; p->name; p++) {
/* update_slot returns int but can't actually fail */
update_slot(type, p->name);
}
}
static int same_slots_added(PyTypeObject* a, PyTypeObject* b) noexcept {
PyTypeObject* base = a->tp_base;
Py_ssize_t size;
PyObject* slots_a, *slots_b;
assert(base == b->tp_base);
size = base->tp_basicsize;
if (a->tp_dictoffset == size && b->tp_dictoffset == size)
size += sizeof(PyObject*);
// Pyston change: have to check attrs_offset
if (a->attrs_offset == size && b->attrs_offset == size)
size += sizeof(HCAttrs);
if (a->tp_weaklistoffset == size && b->tp_weaklistoffset == size)
size += sizeof(PyObject*);
/* Check slots compliance */
slots_a = ((PyHeapTypeObject*)a)->ht_slots;
slots_b = ((PyHeapTypeObject*)b)->ht_slots;
if (slots_a && slots_b) {
if (PyObject_Compare(slots_a, slots_b) != 0)
return 0;
size += sizeof(PyObject*) * PyTuple_GET_SIZE(slots_a);
}
return size == a->tp_basicsize && size == b->tp_basicsize;
}
static int compatible_for_assignment(PyTypeObject* oldto, PyTypeObject* newto, const char* attr) noexcept {
PyTypeObject* newbase, *oldbase;
if (newto->tp_dealloc != oldto->tp_dealloc || newto->tp_free != oldto->tp_free) {
PyErr_Format(PyExc_TypeError, "%s assignment: "
"'%s' deallocator differs from '%s'",
attr, newto->tp_name, oldto->tp_name);
return 0;
}
newbase = newto;
oldbase = oldto;
while (equiv_structs(newbase, newbase->tp_base))
newbase = newbase->tp_base;
while (equiv_structs(oldbase, oldbase->tp_base))
oldbase = oldbase->tp_base;
if (newbase != oldbase && (newbase->tp_base != oldbase->tp_base || !same_slots_added(newbase, oldbase))) {
PyErr_Format(PyExc_TypeError, "%s assignment: "
"'%s' object layout differs from '%s'",
attr, newto->tp_name, oldto->tp_name);
return 0;
}
return 1;
}
static int mro_subclasses(PyTypeObject* type, PyObject* temp) noexcept {
PyTypeObject* subclass;
PyObject* ref, *subclasses, *old_mro;
Py_ssize_t i, n;
subclasses = type->tp_subclasses;
if (subclasses == NULL)
return 0;
assert(PyList_Check(subclasses));
n = PyList_GET_SIZE(subclasses);
for (i = 0; i < n; i++) {
ref = PyList_GET_ITEM(subclasses, i);
assert(PyWeakref_CheckRef(ref));
subclass = (PyTypeObject*)PyWeakref_GET_OBJECT(ref);
assert(subclass != NULL);
if ((PyObject*)subclass == Py_None)
continue;
assert(PyType_Check(subclass));
old_mro = subclass->tp_mro;
if (mro_internal(subclass) < 0) {
subclass->tp_mro = old_mro;
return -1;
} else {
PyObject* tuple;
tuple = PyTuple_Pack(2, subclass, old_mro);
Py_DECREF(old_mro);
if (!tuple)
return -1;
if (PyList_Append(temp, tuple) < 0)
return -1;
Py_DECREF(tuple);
}
if (mro_subclasses(subclass, temp) < 0)
return -1;
}
return 0;
}
int type_set_bases(PyTypeObject* type, PyObject* value, void* context) noexcept {
Py_ssize_t i;
int r = 0;
PyObject* ob, *temp;
PyTypeObject* new_base, *old_base;
PyObject* old_bases, *old_mro;
if (!(type->tp_flags & Py_TPFLAGS_HEAPTYPE)) {
PyErr_Format(PyExc_TypeError, "can't set %s.__bases__", type->tp_name);
return -1;
}
if (!value) {
PyErr_Format(PyExc_TypeError, "can't delete %s.__bases__", type->tp_name);
return -1;
}
if (!PyTuple_Check(value)) {
PyErr_Format(PyExc_TypeError, "can only assign tuple to %s.__bases__, not %s", type->tp_name,
Py_TYPE(value)->tp_name);
return -1;
}
if (PyTuple_GET_SIZE(value) == 0) {
PyErr_Format(PyExc_TypeError, "can only assign non-empty tuple to %s.__bases__, not ()", type->tp_name);
return -1;
}
for (i = 0; i < PyTuple_GET_SIZE(value); i++) {
ob = PyTuple_GET_ITEM(value, i);
if (!PyClass_Check(ob) && !PyType_Check(ob)) {
PyErr_Format(PyExc_TypeError, "%s.__bases__ must be tuple of old- or new-style classes, not '%s'",
type->tp_name, Py_TYPE(ob)->tp_name);
return -1;
}
if (PyType_Check(ob)) {
if (PyType_IsSubtype((PyTypeObject*)ob, type)) {
PyErr_SetString(PyExc_TypeError, "a __bases__ item causes an inheritance cycle");
return -1;
}
}
}
new_base = best_base(value);
if (!new_base) {
return -1;
}
if (!compatible_for_assignment(type->tp_base, new_base, "__bases__"))
return -1;
Py_INCREF(new_base);
Py_INCREF(value);
old_bases = type->tp_bases;
old_base = type->tp_base;
old_mro = type->tp_mro;
type->tp_bases = value;
type->tp_base = new_base;
if (mro_internal(type) < 0) {
goto bail;
}
temp = PyList_New(0);
if (!temp)
goto bail;
r = mro_subclasses(type, temp);
if (r < 0) {
for (i = 0; i < PyList_Size(temp); i++) {
PyTypeObject* cls;
PyObject* mro;
PyArg_UnpackTuple(PyList_GET_ITEM(temp, i), "", 2, 2, &cls, &mro);
Py_INCREF(mro);
ob = cls->tp_mro;
cls->tp_mro = mro;
Py_DECREF(ob);
}
Py_DECREF(temp);
goto bail;
}
Py_DECREF(temp);
/* any base that was in __bases__ but now isn't, we
need to remove |type| from its tp_subclasses.
conversely, any class now in __bases__ that wasn't
needs to have |type| added to its subclasses. */
/* for now, sod that: just remove from all old_bases,
add to all new_bases */
for (i = PyTuple_GET_SIZE(old_bases) - 1; i >= 0; i--) {
ob = PyTuple_GET_ITEM(old_bases, i);
if (PyType_Check(ob)) {
remove_subclass((PyTypeObject*)ob, type);
}
}
for (i = PyTuple_GET_SIZE(value) - 1; i >= 0; i--) {
ob = PyTuple_GET_ITEM(value, i);
if (PyType_Check(ob)) {
if (add_subclass((PyTypeObject*)ob, type) < 0)
r = -1;
}
}
update_all_slots(type);
Py_DECREF(old_bases);
Py_DECREF(old_base);
Py_DECREF(old_mro);
return r;
bail:
Py_DECREF(type->tp_bases);
Py_DECREF(type->tp_base);
if (type->tp_mro != old_mro) {
Py_DECREF(type->tp_mro);
}
type->tp_bases = old_bases;
type->tp_base = old_base;
type->tp_mro = old_mro;
return -1;
}
// commonClassSetup is for the common code between PyType_Ready (which is just for extension classes)
// and our internal type-creation endpoints (BoxedClass::BoxedClass()).
// TODO: Move more of the duplicated logic into here.
......
......@@ -32,6 +32,7 @@ void commonClassSetup(BoxedClass* cls);
// We could probably unify things more but that's for later.
PyTypeObject* best_base(PyObject* bases) noexcept;
PyObject* mro_external(PyObject* self) noexcept;
int type_set_bases(PyTypeObject* type, PyObject* value, void* context) noexcept;
}
#endif
......@@ -2082,8 +2082,11 @@ static Box* typeBases(Box* b, void*) {
return type->tp_bases;
}
static void typeSetBases(Box* b, Box* v, void*) {
Py_FatalError("unimplemented");
static void typeSetBases(Box* b, Box* v, void* c) {
RELEASE_ASSERT(isSubclass(b->cls, type_cls), "");
BoxedClass* type = static_cast<BoxedClass*>(b);
if (type_set_bases(type, v, c) == -1)
throwCAPIException();
}
// cls should be obj->cls.
......
class A(object):
def foo(self):
print "foo"
class B(object):
def bar(self):
print "bar"
class C(object):
def baz(self):
print "baz"
class D(C):
pass
print D.__bases__ == (C,)
print hasattr(D, "bar")
try:
D.__bases__ += (A, B, C)
except Exception as e:
print e # duplicate base class C
D.__bases__ += (A, B)
print D.__bases__ == (C, A, B, C)
print D.__base__ == (C)
D().foo(), D().bar(), D().baz()
D.__bases__ = (C,)
print D.__bases__ == (C,)
print hasattr(D, "foo"), hasattr(D, "bar"), hasattr(D, "baz")
D().baz()
# inheritance circle:
try:
C.__bases__ = (D,)
except TypeError as e:
print e
class Slots(object):
__slots__ = ["a", "b", "c"]
def __init__(self):
self.a = 1
self.b = 2
self.c = 3
try:
Slots.__bases__ = (A,)
except TypeError:
print "cought TypeError exception" # pyston and cpython throw a different exception
# This tests are copied from CPython an can be removed when we support running test/cpython/test_descr.py
class CPythonTests(object):
def assertEqual(self, x, y):
assert x == y
def fail(self, msg):
print "Error", msg
def test_mutable_bases(self):
# Testing mutable bases...
# stuff that should work:
class C(object):
pass
class C2(object):
def __getattribute__(self, attr):
if attr == 'a':
return 2
else:
return super(C2, self).__getattribute__(attr)
def meth(self):
return 1
class D(C):
pass
class E(D):
pass
d = D()
e = E()
D.__bases__ = (C,)
D.__bases__ = (C2,)
self.assertEqual(d.meth(), 1)
self.assertEqual(e.meth(), 1)
self.assertEqual(d.a, 2)
self.assertEqual(e.a, 2)
self.assertEqual(C2.__subclasses__(), [D])
try:
del D.__bases__
except (TypeError, AttributeError):
pass
else:
self.fail("shouldn't be able to delete .__bases__")
try:
D.__bases__ = ()
except TypeError, msg:
if str(msg) == "a new-style class can't have only classic bases":
self.fail("wrong error message for .__bases__ = ()")
else:
self.fail("shouldn't be able to set .__bases__ to ()")
try:
D.__bases__ = (D,)
except TypeError:
pass
else:
# actually, we'll have crashed by here...
self.fail("shouldn't be able to create inheritance cycles")
try:
D.__bases__ = (C, C)
except TypeError:
pass
else:
self.fail("didn't detect repeated base classes")
try:
D.__bases__ = (E,)
except TypeError:
pass
else:
self.fail("shouldn't be able to create inheritance cycles")
# let's throw a classic class into the mix:
class Classic:
def meth2(self):
return 3
D.__bases__ = (C, Classic)
self.assertEqual(d.meth2(), 3)
self.assertEqual(e.meth2(), 3)
try:
d.a
except AttributeError:
pass
else:
self.fail("attribute should have vanished")
try:
D.__bases__ = (Classic,)
except TypeError:
pass
else:
self.fail("new-style class must have a new-style base")
def test_mutable_bases_catch_mro_conflict(self):
# Testing mutable bases catch mro conflict...
class A(object):
pass
class B(object):
pass
class C(A, B):
pass
class D(A, B):
pass
class E(C, D):
pass
try:
C.__bases__ = (B, A)
except TypeError:
pass
else:
self.fail("didn't catch MRO conflict")
tests = CPythonTests()
tests.test_mutable_bases()
tests.test_mutable_bases_catch_mro_conflict()
print
print "Finished"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment