Commit 42218c87 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Make set more subclassing-friendly

And add set.difference_update() and set.isdisjoint()
parent 063f08fe
...@@ -57,22 +57,22 @@ extern "C" void setIteratorGCHandler(GCVisitor* v, Box* b) { ...@@ -57,22 +57,22 @@ extern "C" void setIteratorGCHandler(GCVisitor* v, Box* b) {
} }
Box* setiteratorHasnext(BoxedSetIterator* self) { Box* setiteratorHasnext(BoxedSetIterator* self) {
assert(self->cls == set_iterator_cls); RELEASE_ASSERT(self->cls == set_iterator_cls, "");
return boxBool(self->hasNext()); return boxBool(self->hasNext());
} }
Box* setiteratorNext(BoxedSetIterator* self) { Box* setiteratorNext(BoxedSetIterator* self) {
assert(self->cls == set_iterator_cls); RELEASE_ASSERT(self->cls == set_iterator_cls, "");
return self->next(); return self->next();
} }
Box* setiteratorIter(BoxedSetIterator* self) { Box* setiteratorIter(BoxedSetIterator* self) {
assert(self->cls == set_iterator_cls); RELEASE_ASSERT(self->cls == set_iterator_cls, "");
return self; return self;
} }
Box* setAdd2(Box* _self, Box* b) { Box* setAdd2(Box* _self, Box* b) {
assert(PyAnySet_Check(_self)); RELEASE_ASSERT(isSubclass(_self->cls, set_cls), "");
BoxedSet* self = static_cast<BoxedSet*>(_self); BoxedSet* self = static_cast<BoxedSet*>(_self);
self->s.insert(b); self->s.insert(b);
...@@ -80,17 +80,17 @@ Box* setAdd2(Box* _self, Box* b) { ...@@ -80,17 +80,17 @@ Box* setAdd2(Box* _self, Box* b) {
} }
Box* setNew(Box* _cls, Box* container) { Box* setNew(Box* _cls, Box* container) {
assert(_cls->cls == type_cls); RELEASE_ASSERT(_cls->cls == type_cls, "");
BoxedClass* cls = static_cast<BoxedClass*>(_cls); BoxedClass* cls = static_cast<BoxedClass*>(_cls);
assert(isSubclass(cls, set_cls) || isSubclass(cls, frozenset_cls)); RELEASE_ASSERT(isSubclass(cls, set_cls) || isSubclass(cls, frozenset_cls), "");
Box* rtn = new (cls) BoxedSet(); BoxedSet* rtn = new (cls) BoxedSet();
if (container == None) if (container == None)
return rtn; return rtn;
for (Box* e : container->pyElements()) { for (Box* e : container->pyElements()) {
setAdd2(rtn, e); rtn->s.insert(e);
} }
return rtn; return rtn;
...@@ -114,18 +114,18 @@ static Box* _setRepr(BoxedSet* self, const char* type_name) { ...@@ -114,18 +114,18 @@ static Box* _setRepr(BoxedSet* self, const char* type_name) {
} }
Box* setRepr(BoxedSet* self) { Box* setRepr(BoxedSet* self) {
assert(self->cls == set_cls); RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
return _setRepr(self, "set"); return _setRepr(self, "set");
} }
Box* frozensetRepr(BoxedSet* self) { Box* frozensetRepr(BoxedSet* self) {
assert(self->cls == frozenset_cls); RELEASE_ASSERT(isSubclass(self->cls, frozenset_cls), "");
return _setRepr(self, "frozenset"); return _setRepr(self, "frozenset");
} }
Box* setOrSet(BoxedSet* lhs, BoxedSet* rhs) { Box* setOrSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs)); RELEASE_ASSERT(PyAnySet_Check(lhs), "");
assert(PyAnySet_Check(rhs)); RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet(); BoxedSet* rtn = new (lhs->cls) BoxedSet();
...@@ -139,8 +139,8 @@ Box* setOrSet(BoxedSet* lhs, BoxedSet* rhs) { ...@@ -139,8 +139,8 @@ Box* setOrSet(BoxedSet* lhs, BoxedSet* rhs) {
} }
Box* setAndSet(BoxedSet* lhs, BoxedSet* rhs) { Box* setAndSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs)); RELEASE_ASSERT(PyAnySet_Check(lhs), "");
assert(PyAnySet_Check(rhs)); RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet(); BoxedSet* rtn = new (lhs->cls) BoxedSet();
...@@ -152,8 +152,8 @@ Box* setAndSet(BoxedSet* lhs, BoxedSet* rhs) { ...@@ -152,8 +152,8 @@ Box* setAndSet(BoxedSet* lhs, BoxedSet* rhs) {
} }
Box* setSubSet(BoxedSet* lhs, BoxedSet* rhs) { Box* setSubSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs)); RELEASE_ASSERT(PyAnySet_Check(lhs), "");
assert(PyAnySet_Check(rhs)); RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet(); BoxedSet* rtn = new (lhs->cls) BoxedSet();
...@@ -167,8 +167,8 @@ Box* setSubSet(BoxedSet* lhs, BoxedSet* rhs) { ...@@ -167,8 +167,8 @@ Box* setSubSet(BoxedSet* lhs, BoxedSet* rhs) {
} }
Box* setXorSet(BoxedSet* lhs, BoxedSet* rhs) { Box* setXorSet(BoxedSet* lhs, BoxedSet* rhs) {
assert(PyAnySet_Check(lhs)); RELEASE_ASSERT(PyAnySet_Check(lhs), "");
assert(PyAnySet_Check(rhs)); RELEASE_ASSERT(PyAnySet_Check(rhs), "");
BoxedSet* rtn = new (lhs->cls) BoxedSet(); BoxedSet* rtn = new (lhs->cls) BoxedSet();
...@@ -186,23 +186,42 @@ Box* setXorSet(BoxedSet* lhs, BoxedSet* rhs) { ...@@ -186,23 +186,42 @@ Box* setXorSet(BoxedSet* lhs, BoxedSet* rhs) {
} }
Box* setIter(BoxedSet* self) { Box* setIter(BoxedSet* self) {
assert(PyAnySet_Check(self)); RELEASE_ASSERT(PyAnySet_Check(self), "");
return new BoxedSetIterator(self); return new BoxedSetIterator(self);
} }
Box* setLen(BoxedSet* self) { Box* setLen(BoxedSet* self) {
assert(PyAnySet_Check(self)); RELEASE_ASSERT(PyAnySet_Check(self), "");
return boxInt(self->s.size()); return boxInt(self->s.size());
} }
Box* setAdd(BoxedSet* self, Box* v) { Box* setAdd(BoxedSet* self, Box* v) {
assert(PyAnySet_Check(self)); RELEASE_ASSERT(isSubclass(self->cls, set_cls), "%s", self->cls->tp_name);
self->s.insert(v); self->s.insert(v);
return None; return None;
} }
// Note: PySet_Add is allowed to apply to frozenset objects, though CPython has
// an check to make sure the refcount is 1.
// for example, the marshal library uses this to construct frozenset objects.
extern "C" int PySet_Add(PyObject* set, PyObject* key) noexcept {
if (!PyAnySet_Check(set)) {
PyErr_BadInternalCall();
return -1;
}
try {
static_cast<BoxedSet*>(set)->s.insert(key);
return 0;
} catch (ExcInfo e) {
setCAPIException(e);
return -1;
}
}
Box* setRemove(BoxedSet* self, Box* v) { Box* setRemove(BoxedSet* self, Box* v) {
assert(self->cls == set_cls); RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
auto it = self->s.find(v); auto it = self->s.find(v);
if (it == self->s.end()) { if (it == self->s.end()) {
...@@ -214,7 +233,7 @@ Box* setRemove(BoxedSet* self, Box* v) { ...@@ -214,7 +233,7 @@ Box* setRemove(BoxedSet* self, Box* v) {
} }
Box* setDiscard(BoxedSet* self, Box* v) { Box* setDiscard(BoxedSet* self, Box* v) {
assert(self->cls == set_cls); RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
auto it = self->s.find(v); auto it = self->s.find(v);
if (it != self->s.end()) if (it != self->s.end())
...@@ -224,13 +243,15 @@ Box* setDiscard(BoxedSet* self, Box* v) { ...@@ -224,13 +243,15 @@ Box* setDiscard(BoxedSet* self, Box* v) {
} }
Box* setClear(BoxedSet* self, Box* v) { Box* setClear(BoxedSet* self, Box* v) {
assert(self->cls == set_cls); RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
self->s.clear(); self->s.clear();
return None; return None;
} }
Box* setUpdate(BoxedSet* self, BoxedTuple* args) { Box* setUpdate(BoxedSet* self, BoxedTuple* args) {
assert(self->cls == set_cls); RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
assert(args->cls == tuple_cls); assert(args->cls == tuple_cls);
for (auto l : *args) { for (auto l : *args) {
...@@ -248,7 +269,7 @@ Box* setUpdate(BoxedSet* self, BoxedTuple* args) { ...@@ -248,7 +269,7 @@ Box* setUpdate(BoxedSet* self, BoxedTuple* args) {
} }
Box* setUnion(BoxedSet* self, BoxedTuple* args) { Box* setUnion(BoxedSet* self, BoxedTuple* args) {
if (!isSubclass(self->cls, set_cls)) if (!PyAnySet_Check(self))
raiseExcHelper(TypeError, "descriptor 'union' requires a 'set' object but received a '%s'", getTypeName(self)); raiseExcHelper(TypeError, "descriptor 'union' requires a 'set' object but received a '%s'", getTypeName(self));
BoxedSet* rtn = new BoxedSet(); BoxedSet* rtn = new BoxedSet();
...@@ -279,24 +300,27 @@ Box* setDifference(BoxedSet* self, BoxedTuple* args) { ...@@ -279,24 +300,27 @@ Box* setDifference(BoxedSet* self, BoxedTuple* args) {
return rtn; return rtn;
} }
static BoxedSet* setIntersection2(BoxedSet* self, Box* container) { Box* setDifferenceUpdate(BoxedSet* self, BoxedTuple* args) {
assert(self->cls == set_cls); if (!PySet_Check(self))
raiseExcHelper(TypeError, "descriptor 'difference' requires a 'set' object but received a '%s'",
getTypeName(self));
BoxedSet* rtn = new BoxedSet(); for (auto container : args->pyElements()) {
for (auto elt : container->pyElements()) { for (auto elt : container->pyElements()) {
if (self->s.count(elt)) self->s.erase(elt);
rtn->s.insert(elt); }
} }
return rtn;
return None;
} }
static Box* setIssubset(BoxedSet* self, Box* container) { static Box* setIssubset(BoxedSet* self, Box* container) {
assert(self->cls == set_cls); RELEASE_ASSERT(PyAnySet_Check(self), "");
if (container->cls != set_cls && container->cls != frozenset_cls) { if (!PyAnySet_Check(container)) {
container = setNew(set_cls, container); container = setNew(set_cls, container);
} }
assert(container->cls == set_cls || container->cls == frozenset_cls); assert(PyAnySet_Check(container));
BoxedSet* rhs = static_cast<BoxedSet*>(container); BoxedSet* rhs = static_cast<BoxedSet*>(container);
for (auto e : self->s) { for (auto e : self->s) {
...@@ -307,12 +331,12 @@ static Box* setIssubset(BoxedSet* self, Box* container) { ...@@ -307,12 +331,12 @@ static Box* setIssubset(BoxedSet* self, Box* container) {
} }
static Box* setIssuperset(BoxedSet* self, Box* container) { static Box* setIssuperset(BoxedSet* self, Box* container) {
assert(self->cls == set_cls); RELEASE_ASSERT(PyAnySet_Check(self), "");
if (container->cls != set_cls && container->cls != frozenset_cls) { if (!PyAnySet_Check(container)) {
container = setNew(set_cls, container); container = setNew(set_cls, container);
} }
assert(container->cls == set_cls || container->cls == frozenset_cls); assert(PyAnySet_Check(container));
BoxedSet* rhs = static_cast<BoxedSet*>(container); BoxedSet* rhs = static_cast<BoxedSet*>(container);
for (auto e : rhs->s) { for (auto e : rhs->s) {
...@@ -322,7 +346,28 @@ static Box* setIssuperset(BoxedSet* self, Box* container) { ...@@ -322,7 +346,28 @@ static Box* setIssuperset(BoxedSet* self, Box* container) {
return True; return True;
} }
Box* setIntersection(BoxedSet* self, BoxedTuple* args) { static Box* setIsdisjoint(BoxedSet* self, Box* container) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
for (auto e : container->pyElements()) {
if (self->s.find(e) != self->s.end())
return False;
}
return True;
}
static BoxedSet* setIntersection2(BoxedSet* self, Box* container) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
BoxedSet* rtn = new BoxedSet();
for (auto elt : container->pyElements()) {
if (self->s.count(elt))
rtn->s.insert(elt);
}
return rtn;
}
static Box* setIntersection(BoxedSet* self, BoxedTuple* args) {
if (!PyAnySet_Check(self)) if (!PyAnySet_Check(self))
raiseExcHelper(TypeError, "descriptor 'intersection' requires a 'set' object but received a '%s'", raiseExcHelper(TypeError, "descriptor 'intersection' requires a 'set' object but received a '%s'",
getTypeName(self)); getTypeName(self));
...@@ -335,7 +380,7 @@ Box* setIntersection(BoxedSet* self, BoxedTuple* args) { ...@@ -335,7 +380,7 @@ Box* setIntersection(BoxedSet* self, BoxedTuple* args) {
} }
Box* setCopy(BoxedSet* self) { Box* setCopy(BoxedSet* self) {
assert(self->cls == set_cls); RELEASE_ASSERT(PyAnySet_Check(self), "");
BoxedSet* rtn = new BoxedSet(); BoxedSet* rtn = new BoxedSet();
rtn->s.insert(self->s.begin(), self->s.end()); rtn->s.insert(self->s.begin(), self->s.end());
...@@ -343,7 +388,7 @@ Box* setCopy(BoxedSet* self) { ...@@ -343,7 +388,7 @@ Box* setCopy(BoxedSet* self) {
} }
Box* setPop(BoxedSet* self) { Box* setPop(BoxedSet* self) {
assert(self->cls == set_cls); RELEASE_ASSERT(isSubclass(self->cls, set_cls), "");
if (!self->s.size()) if (!self->s.size())
raiseExcHelper(KeyError, "pop from an empty set"); raiseExcHelper(KeyError, "pop from an empty set");
...@@ -355,13 +400,13 @@ Box* setPop(BoxedSet* self) { ...@@ -355,13 +400,13 @@ Box* setPop(BoxedSet* self) {
} }
Box* setContains(BoxedSet* self, Box* v) { Box* setContains(BoxedSet* self, Box* v) {
assert(PyAnySet_Check(self)); RELEASE_ASSERT(PyAnySet_Check(self), "");
return boxBool(self->s.count(v) != 0); return boxBool(self->s.count(v) != 0);
} }
Box* setEq(BoxedSet* self, BoxedSet* rhs) { Box* setEq(BoxedSet* self, BoxedSet* rhs) {
assert(PyAnySet_Check(self)); RELEASE_ASSERT(PyAnySet_Check(self), "");
if (rhs->cls != set_cls && rhs->cls != frozenset_cls) if (!PyAnySet_Check(rhs))
return NotImplemented; return NotImplemented;
if (self->s.size() != rhs->s.size()) if (self->s.size() != rhs->s.size())
...@@ -383,6 +428,7 @@ Box* setNe(BoxedSet* self, BoxedSet* rhs) { ...@@ -383,6 +428,7 @@ Box* setNe(BoxedSet* self, BoxedSet* rhs) {
} }
Box* setNonzero(BoxedSet* self) { Box* setNonzero(BoxedSet* self) {
RELEASE_ASSERT(PyAnySet_Check(self), "");
return boxBool(self->s.size()); return boxBool(self->s.size());
} }
...@@ -420,21 +466,6 @@ extern "C" PyObject* PyFrozenSet_New(PyObject* iterable) noexcept { ...@@ -420,21 +466,6 @@ extern "C" PyObject* PyFrozenSet_New(PyObject* iterable) noexcept {
} }
} }
extern "C" int PySet_Add(PyObject* set, PyObject* key) noexcept {
if (!PyAnySet_Check(set)) {
PyErr_BadInternalCall();
return -1;
}
try {
setAdd((BoxedSet*)set, key);
return 0;
} catch (ExcInfo e) {
setCAPIException(e);
return -1;
}
}
} // namespace set } // namespace set
using namespace pyston::set; using namespace pyston::set;
...@@ -517,12 +548,17 @@ void setupSet() { ...@@ -517,12 +548,17 @@ void setupSet() {
set_cls->giveAttr("clear", new BoxedFunction(boxRTFunction((void*)setClear, NONE, 1))); set_cls->giveAttr("clear", new BoxedFunction(boxRTFunction((void*)setClear, NONE, 1)));
set_cls->giveAttr("update", new BoxedFunction(boxRTFunction((void*)setUpdate, NONE, 1, 0, true, false))); set_cls->giveAttr("update", new BoxedFunction(boxRTFunction((void*)setUpdate, NONE, 1, 0, true, false)));
set_cls->giveAttr("union", new BoxedFunction(boxRTFunction((void*)setUnion, UNKNOWN, 1, 0, true, false))); set_cls->giveAttr("union", new BoxedFunction(boxRTFunction((void*)setUnion, UNKNOWN, 1, 0, true, false)));
frozenset_cls->giveAttr("union", set_cls->getattr("union"));
set_cls->giveAttr("intersection", set_cls->giveAttr("intersection",
new BoxedFunction(boxRTFunction((void*)setIntersection, UNKNOWN, 1, 0, true, false))); new BoxedFunction(boxRTFunction((void*)setIntersection, UNKNOWN, 1, 0, true, false)));
frozenset_cls->giveAttr("intersection", set_cls->getattr("intersection"));
set_cls->giveAttr("difference", new BoxedFunction(boxRTFunction((void*)setDifference, UNKNOWN, 1, 0, true, false))); set_cls->giveAttr("difference", new BoxedFunction(boxRTFunction((void*)setDifference, UNKNOWN, 1, 0, true, false)));
frozenset_cls->giveAttr("difference", set_cls->getattr("difference")); frozenset_cls->giveAttr("difference", set_cls->getattr("difference"));
set_cls->giveAttr("difference_update",
new BoxedFunction(boxRTFunction((void*)setDifferenceUpdate, UNKNOWN, 1, 0, true, false)));
set_cls->giveAttr("issubset", new BoxedFunction(boxRTFunction((void*)setIssubset, UNKNOWN, 2))); set_cls->giveAttr("issubset", new BoxedFunction(boxRTFunction((void*)setIssubset, UNKNOWN, 2)));
set_cls->giveAttr("issuperset", new BoxedFunction(boxRTFunction((void*)setIssuperset, UNKNOWN, 2))); set_cls->giveAttr("issuperset", new BoxedFunction(boxRTFunction((void*)setIssuperset, UNKNOWN, 2)));
set_cls->giveAttr("isdisjoint", new BoxedFunction(boxRTFunction((void*)setIsdisjoint, UNKNOWN, 2)));
set_cls->giveAttr("copy", new BoxedFunction(boxRTFunction((void*)setCopy, UNKNOWN, 1))); set_cls->giveAttr("copy", new BoxedFunction(boxRTFunction((void*)setCopy, UNKNOWN, 1)));
set_cls->giveAttr("pop", new BoxedFunction(boxRTFunction((void*)setPop, UNKNOWN, 1))); set_cls->giveAttr("pop", new BoxedFunction(boxRTFunction((void*)setPop, UNKNOWN, 1)));
......
...@@ -104,7 +104,28 @@ print s ...@@ -104,7 +104,28 @@ print s
s.discard(1) s.discard(1)
print s print s
s = set(range(5))
s = set(range(10))
print s.difference_update(range(-3, 2), range(7, 23))
print sorted(s)
# Check set subclassing:
class MySet(set):
pass
class MyFrozenset(frozenset):
pass
compare_to = []
for i in xrange(10): for i in xrange(10):
s2 = set(range(i)) compare_to.append(set(range(i)))
print s.issubset(s2), s.issuperset(s2), s == s2, s != s2, s.difference(s2), s.issubset(range(i)), s.issuperset(range(i)) compare_to.append(frozenset(range(i)))
compare_to.append(MySet(range(i)))
compare_to.append(MyFrozenset(range(i)))
compare_to.append(range(i))
compare_to.append(range(i, 10))
for s1 in set(range(5)), frozenset(range(5)):
for s2 in compare_to:
print type(s2), sorted(s2), s.issubset(s2), s.issuperset(s2), s == s2, s != s2, s.difference(s2), s.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment