Commit 9c47ac05 authored by Antoine Pitrou's avatar Antoine Pitrou

Fix some set algebra methods of WeakSet objects.

parent de89d4b0
...@@ -121,26 +121,14 @@ class WeakSet: ...@@ -121,26 +121,14 @@ class WeakSet:
self.update(other) self.update(other)
return self return self
# Helper functions for simple delegating methods.
def _apply(self, other, method):
if not isinstance(other, self.__class__):
other = self.__class__(other)
newdata = method(other.data)
newset = self.__class__()
newset.data = newdata
return newset
def difference(self, other): def difference(self, other):
return self._apply(other, self.data.difference) newset = self.copy()
newset.difference_update(other)
return newset
__sub__ = difference __sub__ = difference
def difference_update(self, other): def difference_update(self, other):
if self._pending_removals: self.__isub__(other)
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
def __isub__(self, other): def __isub__(self, other):
if self._pending_removals: if self._pending_removals:
self._commit_removals() self._commit_removals()
...@@ -151,13 +139,11 @@ class WeakSet: ...@@ -151,13 +139,11 @@ class WeakSet:
return self return self
def intersection(self, other): def intersection(self, other):
return self._apply(other, self.data.intersection) return self.__class__(item for item in other if item in self)
__and__ = intersection __and__ = intersection
def intersection_update(self, other): def intersection_update(self, other):
if self._pending_removals: self.__iand__(other)
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
def __iand__(self, other): def __iand__(self, other):
if self._pending_removals: if self._pending_removals:
self._commit_removals() self._commit_removals()
...@@ -184,27 +170,24 @@ class WeakSet: ...@@ -184,27 +170,24 @@ class WeakSet:
return self.data == set(ref(item) for item in other) return self.data == set(ref(item) for item in other)
def symmetric_difference(self, other): def symmetric_difference(self, other):
return self._apply(other, self.data.symmetric_difference) newset = self.copy()
newset.symmetric_difference_update(other)
return newset
__xor__ = symmetric_difference __xor__ = symmetric_difference
def symmetric_difference_update(self, other): def symmetric_difference_update(self, other):
if self._pending_removals: self.__ixor__(other)
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
def __ixor__(self, other): def __ixor__(self, other):
if self._pending_removals: if self._pending_removals:
self._commit_removals() self._commit_removals()
if self is other: if self is other:
self.data.clear() self.data.clear()
else: else:
self.data.symmetric_difference_update(ref(item) for item in other) self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
return self return self
def union(self, other): def union(self, other):
return self._apply(other, self.data.union) return self.__class__(e for s in (self, other) for e in s)
__or__ = union __or__ = union
def isdisjoint(self, other): def isdisjoint(self, other):
......
...@@ -71,6 +71,11 @@ class TestWeakSet(unittest.TestCase): ...@@ -71,6 +71,11 @@ class TestWeakSet(unittest.TestCase):
x = WeakSet(self.items + self.items2) x = WeakSet(self.items + self.items2)
c = C(self.items2) c = C(self.items2)
self.assertEqual(self.s.union(c), x) self.assertEqual(self.s.union(c), x)
del c
self.assertEqual(len(u), len(self.items) + len(self.items2))
self.items2.pop()
gc.collect()
self.assertEqual(len(u), len(self.items) + len(self.items2))
def test_or(self): def test_or(self):
i = self.s.union(self.items2) i = self.s.union(self.items2)
...@@ -78,14 +83,19 @@ class TestWeakSet(unittest.TestCase): ...@@ -78,14 +83,19 @@ class TestWeakSet(unittest.TestCase):
self.assertEqual(self.s | frozenset(self.items2), i) self.assertEqual(self.s | frozenset(self.items2), i)
def test_intersection(self): def test_intersection(self):
i = self.s.intersection(self.items2) s = WeakSet(self.letters)
i = s.intersection(self.items2)
for c in self.letters: for c in self.letters:
self.assertEqual(c in i, c in self.d and c in self.items2) self.assertEqual(c in i, c in self.items2 and c in self.letters)
self.assertEqual(self.s, WeakSet(self.items)) self.assertEqual(s, WeakSet(self.letters))
self.assertEqual(type(i), WeakSet) self.assertEqual(type(i), WeakSet)
for C in set, frozenset, dict.fromkeys, list, tuple: for C in set, frozenset, dict.fromkeys, list, tuple:
x = WeakSet([]) x = WeakSet([])
self.assertEqual(self.s.intersection(C(self.items2)), x) self.assertEqual(i.intersection(C(self.items)), x)
self.assertEqual(len(i), len(self.items2))
self.items2.pop()
gc.collect()
self.assertEqual(len(i), len(self.items2))
def test_isdisjoint(self): def test_isdisjoint(self):
self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
...@@ -116,6 +126,10 @@ class TestWeakSet(unittest.TestCase): ...@@ -116,6 +126,10 @@ class TestWeakSet(unittest.TestCase):
self.assertEqual(self.s, WeakSet(self.items)) self.assertEqual(self.s, WeakSet(self.items))
self.assertEqual(type(i), WeakSet) self.assertEqual(type(i), WeakSet)
self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
self.assertEqual(len(i), len(self.items) + len(self.items2))
self.items2.pop()
gc.collect()
self.assertEqual(len(i), len(self.items) + len(self.items2))
def test_xor(self): def test_xor(self):
i = self.s.symmetric_difference(self.items2) i = self.s.symmetric_difference(self.items2)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment