Commit f643b9a9 authored by Raymond Hettinger's avatar Raymond Hettinger

Issue 8743: Improve interoperability between sets and the collections.Set abstract base class.

parent 92df7529
...@@ -165,12 +165,17 @@ class Set(Sized, Iterable, Container): ...@@ -165,12 +165,17 @@ class Set(Sized, Iterable, Container):
def __gt__(self, other): def __gt__(self, other):
if not isinstance(other, Set): if not isinstance(other, Set):
return NotImplemented return NotImplemented
return other.__lt__(self) return len(self) > len(other) and self.__ge__(other)
def __ge__(self, other): def __ge__(self, other):
if not isinstance(other, Set): if not isinstance(other, Set):
return NotImplemented return NotImplemented
return other.__le__(self) if len(self) < len(other):
return False
for elem in other:
if elem not in self:
return False
return True
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Set): if not isinstance(other, Set):
...@@ -194,6 +199,8 @@ class Set(Sized, Iterable, Container): ...@@ -194,6 +199,8 @@ class Set(Sized, Iterable, Container):
return NotImplemented return NotImplemented
return self._from_iterable(value for value in other if value in self) return self._from_iterable(value for value in other if value in self)
__rand__ = __and__
def isdisjoint(self, other): def isdisjoint(self, other):
'Return True if two sets have a null intersection.' 'Return True if two sets have a null intersection.'
for value in other: for value in other:
...@@ -207,6 +214,8 @@ class Set(Sized, Iterable, Container): ...@@ -207,6 +214,8 @@ class Set(Sized, Iterable, Container):
chain = (e for s in (self, other) for e in s) chain = (e for s in (self, other) for e in s)
return self._from_iterable(chain) return self._from_iterable(chain)
__ror__ = __or__
def __sub__(self, other): def __sub__(self, other):
if not isinstance(other, Set): if not isinstance(other, Set):
if not isinstance(other, Iterable): if not isinstance(other, Iterable):
...@@ -215,6 +224,14 @@ class Set(Sized, Iterable, Container): ...@@ -215,6 +224,14 @@ class Set(Sized, Iterable, Container):
return self._from_iterable(value for value in self return self._from_iterable(value for value in self
if value not in other) if value not in other)
def __rsub__(self, other):
if not isinstance(other, Set):
if not isinstance(other, Iterable):
return NotImplemented
other = self._from_iterable(other)
return self._from_iterable(value for value in other
if value not in self)
def __xor__(self, other): def __xor__(self, other):
if not isinstance(other, Set): if not isinstance(other, Set):
if not isinstance(other, Iterable): if not isinstance(other, Iterable):
...@@ -222,6 +239,8 @@ class Set(Sized, Iterable, Container): ...@@ -222,6 +239,8 @@ class Set(Sized, Iterable, Container):
other = self._from_iterable(other) other = self._from_iterable(other)
return (self - other) | (other - self) return (self - other) | (other - self)
__rxor__ = __xor__
# Sets are not hashable by default, but subclasses can change this # Sets are not hashable by default, but subclasses can change this
__hash__ = None __hash__ = None
......
...@@ -8,6 +8,7 @@ import pickle, cPickle, copy ...@@ -8,6 +8,7 @@ import pickle, cPickle, copy
from random import randrange, shuffle from random import randrange, shuffle
import keyword import keyword
import re import re
import sets
import sys import sys
from collections import Hashable, Iterable, Iterator from collections import Hashable, Iterable, Iterator
from collections import Sized, Container, Callable from collections import Sized, Container, Callable
...@@ -618,10 +619,173 @@ class TestCollectionABCs(ABCTestCase): ...@@ -618,10 +619,173 @@ class TestCollectionABCs(ABCTestCase):
cs = MyComparableSet() cs = MyComparableSet()
ncs = MyNonComparableSet() ncs = MyNonComparableSet()
self.assertFalse(ncs < cs)
self.assertFalse(ncs <= cs) # Run all the variants to make sure they don't mutually recurse
self.assertFalse(cs > ncs) ncs < cs
self.assertFalse(cs >= ncs) ncs <= cs
ncs > cs
ncs >= cs
cs < ncs
cs <= ncs
cs > ncs
cs >= ncs
def assertSameSet(self, s1, s2):
# coerce both to a real set then check equality
self.assertEqual(set(s1), set(s2))
def test_Set_interoperability_with_real_sets(self):
# Issue: 8743
class ListSet(Set):
def __init__(self, elements=()):
self.data = []
for elem in elements:
if elem not in self.data:
self.data.append(elem)
def __contains__(self, elem):
return elem in self.data
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)
def __repr__(self):
return 'Set({!r})'.format(self.data)
r1 = set('abc')
r2 = set('bcd')
r3 = set('abcde')
f1 = ListSet('abc')
f2 = ListSet('bcd')
f3 = ListSet('abcde')
l1 = list('abccba')
l2 = list('bcddcb')
l3 = list('abcdeedcba')
p1 = sets.Set('abc')
p2 = sets.Set('bcd')
p3 = sets.Set('abcde')
target = r1 & r2
self.assertSameSet(f1 & f2, target)
self.assertSameSet(f1 & r2, target)
self.assertSameSet(r2 & f1, target)
self.assertSameSet(f1 & p2, target)
self.assertSameSet(p2 & f1, target)
self.assertSameSet(f1 & l2, target)
target = r1 | r2
self.assertSameSet(f1 | f2, target)
self.assertSameSet(f1 | r2, target)
self.assertSameSet(r2 | f1, target)
self.assertSameSet(f1 | p2, target)
self.assertSameSet(p2 | f1, target)
self.assertSameSet(f1 | l2, target)
fwd_target = r1 - r2
rev_target = r2 - r1
self.assertSameSet(f1 - f2, fwd_target)
self.assertSameSet(f2 - f1, rev_target)
self.assertSameSet(f1 - r2, fwd_target)
self.assertSameSet(f2 - r1, rev_target)
self.assertSameSet(r1 - f2, fwd_target)
self.assertSameSet(r2 - f1, rev_target)
self.assertSameSet(f1 - p2, fwd_target)
self.assertSameSet(f2 - p1, rev_target)
self.assertSameSet(p1 - f2, fwd_target)
self.assertSameSet(p2 - f1, rev_target)
self.assertSameSet(f1 - l2, fwd_target)
self.assertSameSet(f2 - l1, rev_target)
target = r1 ^ r2
self.assertSameSet(f1 ^ f2, target)
self.assertSameSet(f1 ^ r2, target)
self.assertSameSet(r2 ^ f1, target)
self.assertSameSet(f1 ^ p2, target)
self.assertSameSet(p2 ^ f1, target)
self.assertSameSet(f1 ^ l2, target)
# proper subset
self.assertTrue(f1 < f3)
self.assertFalse(f1 < f1)
self.assertFalse(f1 < f2)
self.assertTrue(r1 < f3)
self.assertFalse(r1 < f1)
self.assertFalse(r1 < f2)
self.assertTrue(r1 < r3)
self.assertFalse(r1 < r1)
self.assertFalse(r1 < r2)
# python 2 only, cross-type compares will succeed
f1 < l3
f1 < l1
f1 < l2
# any subset
self.assertTrue(f1 <= f3)
self.assertTrue(f1 <= f1)
self.assertFalse(f1 <= f2)
self.assertTrue(r1 <= f3)
self.assertTrue(r1 <= f1)
self.assertFalse(r1 <= f2)
self.assertTrue(r1 <= r3)
self.assertTrue(r1 <= r1)
self.assertFalse(r1 <= r2)
# python 2 only, cross-type compares will succeed
f1 <= l3
f1 <= l1
f1 <= l2
# proper superset
self.assertTrue(f3 > f1)
self.assertFalse(f1 > f1)
self.assertFalse(f2 > f1)
self.assertTrue(r3 > r1)
self.assertFalse(f1 > r1)
self.assertFalse(f2 > r1)
self.assertTrue(r3 > r1)
self.assertFalse(r1 > r1)
self.assertFalse(r2 > r1)
# python 2 only, cross-type compares will succeed
f1 > l3
f1 > l1
f1 > l2
# any superset
self.assertTrue(f3 >= f1)
self.assertTrue(f1 >= f1)
self.assertFalse(f2 >= f1)
self.assertTrue(r3 >= r1)
self.assertTrue(f1 >= r1)
self.assertFalse(f2 >= r1)
self.assertTrue(r3 >= r1)
self.assertTrue(r1 >= r1)
self.assertFalse(r2 >= r1)
# python 2 only, cross-type compares will succeed
f1 >= l3
f1 >=l1
f1 >= l2
# equality
self.assertTrue(f1 == f1)
self.assertTrue(r1 == f1)
self.assertTrue(f1 == r1)
self.assertFalse(f1 == f3)
self.assertFalse(r1 == f3)
self.assertFalse(f1 == r3)
# python 2 only, cross-type compares will succeed
f1 == l3
f1 == l1
f1 == l2
# inequality
self.assertFalse(f1 != f1)
self.assertFalse(r1 != f1)
self.assertFalse(f1 != r1)
self.assertTrue(f1 != f3)
self.assertTrue(r1 != f3)
self.assertTrue(f1 != r3)
# python 2 only, cross-type compares will succeed
f1 != l3
f1 != l1
f1 != l2
def test_Mapping(self): def test_Mapping(self):
for sample in [dict]: for sample in [dict]:
......
...@@ -1017,8 +1017,6 @@ class TestBinaryOps(unittest.TestCase): ...@@ -1017,8 +1017,6 @@ class TestBinaryOps(unittest.TestCase):
# without calling __cmp__. # without calling __cmp__.
self.assertEqual(cmp(a, a), 0) self.assertEqual(cmp(a, a), 0)
self.assertRaises(TypeError, cmp, a, 12)
self.assertRaises(TypeError, cmp, "abc", a)
#============================================================================== #==============================================================================
...@@ -1269,17 +1267,6 @@ class TestOnlySetsInBinaryOps(unittest.TestCase): ...@@ -1269,17 +1267,6 @@ class TestOnlySetsInBinaryOps(unittest.TestCase):
self.assertEqual(self.other != self.set, True) self.assertEqual(self.other != self.set, True)
self.assertEqual(self.set != self.other, True) self.assertEqual(self.set != self.other, True)
def test_ge_gt_le_lt(self):
self.assertRaises(TypeError, lambda: self.set < self.other)
self.assertRaises(TypeError, lambda: self.set <= self.other)
self.assertRaises(TypeError, lambda: self.set > self.other)
self.assertRaises(TypeError, lambda: self.set >= self.other)
self.assertRaises(TypeError, lambda: self.other < self.set)
self.assertRaises(TypeError, lambda: self.other <= self.set)
self.assertRaises(TypeError, lambda: self.other > self.set)
self.assertRaises(TypeError, lambda: self.other >= self.set)
def test_update_operator(self): def test_update_operator(self):
try: try:
self.set |= self.other self.set |= self.other
...@@ -1392,18 +1379,6 @@ class TestOnlySetsDict(TestOnlySetsInBinaryOps): ...@@ -1392,18 +1379,6 @@ class TestOnlySetsDict(TestOnlySetsInBinaryOps):
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
class TestOnlySetsOperator(TestOnlySetsInBinaryOps):
def setUp(self):
self.set = set((1, 2, 3))
self.other = operator.add
self.otherIsIterable = False
def test_ge_gt_le_lt(self):
with test_support.check_py3k_warnings():
super(TestOnlySetsOperator, self).test_ge_gt_le_lt()
#------------------------------------------------------------------------------
class TestOnlySetsTuple(TestOnlySetsInBinaryOps): class TestOnlySetsTuple(TestOnlySetsInBinaryOps):
def setUp(self): def setUp(self):
self.set = set((1, 2, 3)) self.set = set((1, 2, 3))
...@@ -1801,7 +1776,6 @@ def test_main(verbose=None): ...@@ -1801,7 +1776,6 @@ def test_main(verbose=None):
TestSubsetNonOverlap, TestSubsetNonOverlap,
TestOnlySetsNumeric, TestOnlySetsNumeric,
TestOnlySetsDict, TestOnlySetsDict,
TestOnlySetsOperator,
TestOnlySetsTuple, TestOnlySetsTuple,
TestOnlySetsString, TestOnlySetsString,
TestOnlySetsGenerator, TestOnlySetsGenerator,
......
...@@ -18,6 +18,9 @@ Core and Builtins ...@@ -18,6 +18,9 @@ Core and Builtins
Library Library
------- -------
- Issue #8743: Fix interoperability between set objects and the
collections.Set() abstract base class.
Tests Tests
----- -----
......
...@@ -1796,12 +1796,8 @@ set_richcompare(PySetObject *v, PyObject *w, int op) ...@@ -1796,12 +1796,8 @@ set_richcompare(PySetObject *v, PyObject *w, int op)
PyObject *r1, *r2; PyObject *r1, *r2;
if(!PyAnySet_Check(w)) { if(!PyAnySet_Check(w)) {
if (op == Py_EQ) Py_INCREF(Py_NotImplemented);
Py_RETURN_FALSE; return Py_NotImplemented;
if (op == Py_NE)
Py_RETURN_TRUE;
PyErr_SetString(PyExc_TypeError, "can only compare to a set");
return NULL;
} }
switch (op) { switch (op) {
case Py_EQ: case Py_EQ:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment