Commit 42218c87 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Make set more subclassing-friendly

And add set.difference_update() and set.isdisjoint()
parent 063f08fe
......@@ -57,22 +57,22 @@ extern "C" void setIteratorGCHandler(GCVisitor* v, Box* b) {
}
Box* setiteratorHasnext(BoxedSetIterator* self) {
assert(self->cls == set_iterator_cls);
RELEASE_ASSERT(self->cls == set_iterator_cls, "");
return boxBool(self->hasNext());
}
Box* setiteratorNext(BoxedSetIterator* self) {
assert(self->cls == set_iterator_cls);
RELEASE_ASSERT(self->cls == set_iterator_cls, "");
return self->next();
}
Box* setiteratorIter(BoxedSetIterator* self) {
assert(self->cls == set_iterator_cls);
RELEASE_ASSERT(self->cls == set_iterator_cls, "");
return self;
}
Box* setAdd2(Box* _self, Box* b) {
assert(PyAnySet_Check(_self));
RELEASE_ASSERT(isSubclass(_self->cls, set_cls), "");
BoxedSet* self = static_cast<BoxedSet*>(_self);
self->s.insert(b);
......@@ -80,17 +80,17 @@ Box* setAdd2(Box* _self, Box* b) {
}
Box* setNew(Box* _cls, Box* container) {
assert(_cls->cls == type_cls);
RELEASE_ASSERT(_cls->cls == type_cls, "");
BoxedClass* cls = static_cast<BoxedClass*>(_cls);
assert(isSubclass(cls, set_cls) || isSubclass(cls, frozenset_cls));
RELEASE_ASSERT(isSubclass(cls, set_cls) || isSubclass(cls, frozenset_cls), "");
Box* rtn = new (cls) BoxedSet();
BoxedSet* rtn = new (cls) BoxedSet();
if (container == None)
return rtn;
for (Box* e : container->pyElements()) {
setAdd2(rtn, e);
rtn->s.insert(e);
}
return rtn;
......@@ -114,18 +114,18 @@ static Box* _setRepr(BoxedSet* self, const char* type_name) {
}
Box* setRepr(BoxedSet* self) {
assert(self->cls == set_cls);
RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
return _setRepr(self, "set");
}
Box* frozensetRepr(BoxedSet* self) {
assert(self->cls == frozenset_cls);
RELEASE_ASSERT(isSubclass(self->cls, frozenset_cls), "");
return _setRepr(self, "frozenset");
}
Box* setOrSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs));
assert(PyAnySet_Check(rhs));
RELEASE_ASSERT(PyAnySet_Check(lhs), "");
RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet();
......@@ -139,8 +139,8 @@ Box* setOrSet(BoxedSet* lhs, BoxedSet* rhs) {
}
Box* setAndSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs));
assert(PyAnySet_Check(rhs));
RELEASE_ASSERT(PyAnySet_Check(lhs), "");
RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet();
......@@ -152,8 +152,8 @@ Box* setAndSet(BoxedSet* lhs, BoxedSet* rhs) {
}
Box* setSubSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs));
assert(PyAnySet_Check(rhs));
RELEASE_ASSERT(PyAnySet_Check(lhs), "");
RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet();
......@@ -167,8 +167,8 @@ Box* setSubSet(BoxedSet* lhs, BoxedSet* rhs) {
}
Box* setXorSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs));
assert(PyAnySet_Check(rhs));
RELEASE_ASSERT(PyAnySet_Check(lhs), "");
RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet();
......@@ -186,23 +186,42 @@ Box* setXorSet(BoxedSet* lhs, BoxedSet* rhs) {
}
Box* setIter(BoxedSet* self) {
assert(PyAnySet_Check(self));
RELEASE_ASSERT(PyAnySet_Check(self), "");
return new BoxedSetIterator(self);
}
Box* setLen(BoxedSet* self) {
assert(PyAnySet_Check(self));
RELEASE_ASSERT(PyAnySet_Check(self), "");
return boxInt(self->s.size());
}
Box* setAdd(BoxedSet* self, Box* v) {
assert(PyAnySet_Check(self));
RELEASE_ASSERT(isSubclass(self->cls, set_cls), "%s", self->cls->tp_name);
self->s.insert(v);
return None;
}
// Note: PySet_Add is allowed to apply to frozenset objects, though CPython has
// an check to make sure the refcount is 1.
// for example, the marshal library uses this to construct frozenset objects.
extern "C" int PySet_Add(PyObject* set, PyObject* key) noexcept {
if (!PyAnySet_Check(set)) {
PyErr_BadInternalCall();
return -1;
}
try {
static_cast<BoxedSet*>(set)->s.insert(key);
return 0;
} catch (ExcInfo e) {
setCAPIException(e);
return -1;
}
}
Box* setRemove(BoxedSet* self, Box* v) {
assert(self->cls == set_cls);
RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
auto it = self->s.find(v);
if (it == self->s.end()) {
......@@ -214,7 +233,7 @@ Box* setRemove(BoxedSet* self, Box* v) {
}
Box* setDiscard(BoxedSet* self, Box* v) {
assert(self->cls == set_cls);
RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
auto it = self->s.find(v);
if (it != self->s.end())
......@@ -224,13 +243,15 @@ Box* setDiscard(BoxedSet* self, Box* v) {
}
Box* setClear(BoxedSet* self, Box* v) {
assert(self->cls == set_cls);
RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
self->s.clear();
return None;
}
Box* setUpdate(BoxedSet* self, BoxedTuple* args) {
assert(self->cls == set_cls);
RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
assert(args->cls == tuple_cls);
for (auto l : *args) {
......@@ -248,7 +269,7 @@ Box* setUpdate(BoxedSet* self, BoxedTuple* args) {
}
Box* setUnion(BoxedSet* self, BoxedTuple* args) {
if (!isSubclass(self->cls, set_cls))
if (!PyAnySet_Check(self))
raiseExcHelper(TypeError, "descriptor 'union' requires a 'set' object but received a '%s'", getTypeName(self));
BoxedSet* rtn = new BoxedSet();
......@@ -279,24 +300,27 @@ Box* setDifference(BoxedSet* self, BoxedTuple* args) {
return rtn;
}
static BoxedSet* setIntersection2(BoxedSet* self, Box* container) {
assert(self->cls == set_cls);
Box* setDifferenceUpdate(BoxedSet* self, BoxedTuple* args) {
if (!PySet_Check(self))
raiseExcHelper(TypeError, "descriptor 'difference' requires a 'set' object but received a '%s'",
getTypeName(self));
BoxedSet* rtn = new BoxedSet();
for (auto elt : container->pyElements()) {
if (self->s.count(elt))
rtn->s.insert(elt);
for (auto container : args->pyElements()) {
for (auto elt : container->pyElements()) {
self->s.erase(elt);
}
}
return rtn;
return None;
}
static Box* setIssubset(BoxedSet* self, Box* container) {
assert(self->cls == set_cls);
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (container->cls != set_cls && container->cls != frozenset_cls) {
if (!PyAnySet_Check(container)) {
container = setNew(set_cls, container);
}
assert(container->cls == set_cls || container->cls == frozenset_cls);
assert(PyAnySet_Check(container));
BoxedSet* rhs = static_cast<BoxedSet*>(container);
for (auto e : self->s) {
......@@ -307,12 +331,12 @@ static Box* setIssubset(BoxedSet* self, Box* container) {
}
static Box* setIssuperset(BoxedSet* self, Box* container) {
assert(self->cls == set_cls);
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (container->cls != set_cls && container->cls != frozenset_cls) {
if (!PyAnySet_Check(container)) {
container = setNew(set_cls, container);
}
assert(container->cls == set_cls || container->cls == frozenset_cls);
assert(PyAnySet_Check(container));
BoxedSet* rhs = static_cast<BoxedSet*>(container);
for (auto e : rhs->s) {
......@@ -322,7 +346,28 @@ static Box* setIssuperset(BoxedSet* self, Box* container) {
return True;
}
Box* setIntersection(BoxedSet* self, BoxedTuple* args) {
static Box* setIsdisjoint(BoxedSet* self, Box* container) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
for (auto e : container->pyElements()) {
if (self->s.find(e) != self->s.end())
return False;
}
return True;
}
static BoxedSet* setIntersection2(BoxedSet* self, Box* container) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
BoxedSet* rtn = new BoxedSet();
for (auto elt : container->pyElements()) {
if (self->s.count(elt))
rtn->s.insert(elt);
}
return rtn;
}
static Box* setIntersection(BoxedSet* self, BoxedTuple* args) {
if (!PyAnySet_Check(self))
raiseExcHelper(TypeError, "descriptor 'intersection' requires a 'set' object but received a '%s'",
getTypeName(self));
......@@ -335,7 +380,7 @@ Box* setIntersection(BoxedSet* self, BoxedTuple* args) {
}
Box* setCopy(BoxedSet* self) {
assert(self->cls == set_cls);
RELEASE_ASSERT(PyAnySet_Check(self), "");
BoxedSet* rtn = new BoxedSet();
rtn->s.insert(self->s.begin(), self->s.end());
......@@ -343,7 +388,7 @@ Box* setCopy(BoxedSet* self) {
}
Box* setPop(BoxedSet* self) {
assert(self->cls == set_cls);
RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
if (!self->s.size())
raiseExcHelper(KeyError, "pop from an empty set");
......@@ -355,13 +400,13 @@ Box* setPop(BoxedSet* self) {
}
Box* setContains(BoxedSet* self, Box* v) {
assert(PyAnySet_Check(self));
RELEASE_ASSERT(PyAnySet_Check(self), "");
return boxBool(self->s.count(v) != 0);
}
Box* setEq(BoxedSet* self, BoxedSet* rhs) {
assert(PyAnySet_Check(self));
if (rhs->cls != set_cls && rhs->cls != frozenset_cls)
RELEASE_ASSERT(PyAnySet_Check(self), "");
if (!PyAnySet_Check(rhs))
return NotImplemented;
if (self->s.size() != rhs->s.size())
......@@ -383,6 +428,7 @@ Box* setNe(BoxedSet* self, BoxedSet* rhs) {
}
Box* setNonzero(BoxedSet* self) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
return boxBool(self->s.size());
}
......@@ -420,21 +466,6 @@ extern "C" PyObject* PyFrozenSet_New(PyObject* iterable) noexcept {
}
}
extern "C" int PySet_Add(PyObject* set, PyObject* key) noexcept {
if (!PyAnySet_Check(set)) {
PyErr_BadInternalCall();
return -1;
}
try {
setAdd((BoxedSet*)set, key);
return 0;
} catch (ExcInfo e) {
setCAPIException(e);
return -1;
}
}
} // namespace set
using namespace pyston::set;
......@@ -517,12 +548,17 @@ void setupSet() {
set_cls->giveAttr("clear", new BoxedFunction(boxRTFunction((void*)setClear, NONE, 1)));
set_cls->giveAttr("update", new BoxedFunction(boxRTFunction((void*)setUpdate, NONE, 1, 0, true, false)));
set_cls->giveAttr("union", new BoxedFunction(boxRTFunction((void*)setUnion, UNKNOWN, 1, 0, true, false)));
frozenset_cls->giveAttr("union", set_cls->getattr("union"));
set_cls->giveAttr("intersection",
new BoxedFunction(boxRTFunction((void*)setIntersection, UNKNOWN, 1, 0, true, false)));
frozenset_cls->giveAttr("intersection", set_cls->getattr("intersection"));
set_cls->giveAttr("difference", new BoxedFunction(boxRTFunction((void*)setDifference, UNKNOWN, 1, 0, true, false)));
frozenset_cls->giveAttr("difference", set_cls->getattr("difference"));
set_cls->giveAttr("difference_update",
new BoxedFunction(boxRTFunction((void*)setDifferenceUpdate, UNKNOWN, 1, 0, true, false)));
set_cls->giveAttr("issubset", new BoxedFunction(boxRTFunction((void*)setIssubset, UNKNOWN, 2)));
set_cls->giveAttr("issuperset", new BoxedFunction(boxRTFunction((void*)setIssuperset, UNKNOWN, 2)));
set_cls->giveAttr("isdisjoint", new BoxedFunction(boxRTFunction((void*)setIsdisjoint, UNKNOWN, 2)));
set_cls->giveAttr("copy", new BoxedFunction(boxRTFunction((void*)setCopy, UNKNOWN, 1)));
set_cls->giveAttr("pop", new BoxedFunction(boxRTFunction((void*)setPop, UNKNOWN, 1)));
......
......@@ -104,7 +104,28 @@ print s
s.discard(1)
print s
s = set(range(5))
s = set(range(10))
print s.difference_update(range(-3, 2), range(7, 23))
print sorted(s)
# Check set subclassing:
class MySet(set):
pass
class MyFrozenset(frozenset):
pass
compare_to = []
for i in xrange(10):
s2 = set(range(i))
print s.issubset(s2), s.issuperset(s2), s == s2, s != s2, s.difference(s2), s.issubset(range(i)), s.issuperset(range(i))
compare_to.append(set(range(i)))
compare_to.append(frozenset(range(i)))
compare_to.append(MySet(range(i)))
compare_to.append(MyFrozenset(range(i)))
compare_to.append(range(i))
compare_to.append(range(i, 10))
for s1 in set(range(5)), frozenset(range(5)):
for s2 in compare_to:
print type(s2), sorted(s2), s.issubset(s2), s.issuperset(s2), s == s2, s != s2, s.difference(s2), s.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment