Commit 458265c2 authored by Raymond Hettinger's avatar Raymond Hettinger

1. Removed module self test in favor of unittests -- Timbot's suggestion.

2. Replaced calls to Set([]) with Set() -- Timbot's suggestion
3. Fixed subtle bug in sets of sets:

The following code did not work (will add to test suite):
    d = Set('d')
    s = Set([d])  # Stores inner set as an ImmutableSet
    s.remove(d)   # For comparison, wraps d in _TemporarilyImmutableSet

The comparison proceeds by computing the hash of the
_TemporarilyImmutableSet and finding it in the dictionary.
It then verifies equality by calling ImmutableSet.__eq__()
and crashes from the binary sanity check.

The problem is that the code assumed equality would be checked
with _TemporarilyImmutableSet.__eq__().

The solution is to let _TemporarilyImmutableSet derive from BaseSet
so it will pass the sanity check and then to provide it with the
._data element from the wrapped set so that ImmutableSet.__eq__()
will find ._data where it expects.

Since ._data is now provided and because BaseSet is the base class,
_TemporarilyImmutableSet no longer needs .__eq__() or .__ne__().

Note that inheriting all of BaseSet's methods is harmless because
none of those methods (except ones starting with an underscore)
can mutate the .data element.  Also _TemporarilyImmutableSet is only
used internally as is not otherwise visible.
parent e6ddaab6
...@@ -133,7 +133,7 @@ class BaseSet(object): ...@@ -133,7 +133,7 @@ class BaseSet(object):
def copy(self): def copy(self):
"""Return a shallow copy of a set.""" """Return a shallow copy of a set."""
result = self.__class__([]) result = self.__class__()
result._data.update(self._data) result._data.update(self._data)
return result return result
...@@ -147,7 +147,7 @@ class BaseSet(object): ...@@ -147,7 +147,7 @@ class BaseSet(object):
# it can certainly contain an object that has a reference to # it can certainly contain an object that has a reference to
# itself. # itself.
from copy import deepcopy from copy import deepcopy
result = self.__class__([]) result = self.__class__()
memo[id(self)] = result memo[id(self)] = result
data = result._data data = result._data
value = True value = True
...@@ -188,7 +188,7 @@ class BaseSet(object): ...@@ -188,7 +188,7 @@ class BaseSet(object):
little, big = self, other little, big = self, other
else: else:
little, big = other, self little, big = other, self
result = self.__class__([]) result = self.__class__()
data = result._data data = result._data
value = True value = True
for elt in little: for elt in little:
...@@ -210,7 +210,7 @@ class BaseSet(object): ...@@ -210,7 +210,7 @@ class BaseSet(object):
""" """
if not isinstance(other, BaseSet): if not isinstance(other, BaseSet):
return NotImplemented return NotImplemented
result = self.__class__([]) result = self.__class__()
data = result._data data = result._data
value = True value = True
for elt in self: for elt in self:
...@@ -235,7 +235,7 @@ class BaseSet(object): ...@@ -235,7 +235,7 @@ class BaseSet(object):
""" """
if not isinstance(other, BaseSet): if not isinstance(other, BaseSet):
return NotImplemented return NotImplemented
result = self.__class__([]) result = self.__class__()
data = result._data data = result._data
value = True value = True
for elt in self: for elt in self:
...@@ -467,7 +467,7 @@ class Set(BaseSet): ...@@ -467,7 +467,7 @@ class Set(BaseSet):
return _TemporarilyImmutableSet(self) return _TemporarilyImmutableSet(self)
class _TemporarilyImmutableSet(object): class _TemporarilyImmutableSet(BaseSet):
# Wrap a mutable set as if it was temporarily immutable. # Wrap a mutable set as if it was temporarily immutable.
# This only supplies hashing and equality comparisons. # This only supplies hashing and equality comparisons.
...@@ -475,111 +475,9 @@ class _TemporarilyImmutableSet(object): ...@@ -475,111 +475,9 @@ class _TemporarilyImmutableSet(object):
def __init__(self, set): def __init__(self, set):
self._set = set self._set = set
self._data = set._data # Needed by ImmutableSet.__eq__()
def __hash__(self): def __hash__(self):
if self._hashcode is None: if self._hashcode is None:
self._hashcode = self._set._compute_hash() self._hashcode = self._set._compute_hash()
return self._hashcode return self._hashcode
def __eq__(self, other):
return self._set == other
def __ne__(self, other):
return self._set != other
# Rudimentary self-tests
def _test():
# Empty set
red = Set()
assert `red` == "Set([])", "Empty set: %s" % `red`
# Unit set
green = Set((0,))
assert `green` == "Set([0])", "Unit set: %s" % `green`
# 3-element set
blue = Set([0, 1, 2])
assert blue._repr(True) == "Set([0, 1, 2])", "3-element set: %s" % `blue`
# 2-element set with other values
black = Set([0, 5])
assert black._repr(True) == "Set([0, 5])", "2-element set: %s" % `black`
# All elements from all sets
white = Set([0, 1, 2, 5])
assert white._repr(True) == "Set([0, 1, 2, 5])", "4-element set: %s" % `white`
# Add element to empty set
red.add(9)
assert `red` == "Set([9])", "Add to empty set: %s" % `red`
# Remove element from unit set
red.remove(9)
assert `red` == "Set([])", "Remove from unit set: %s" % `red`
# Remove element from empty set
try:
red.remove(0)
assert 0, "Remove element from empty set: %s" % `red`
except LookupError:
pass
# Length
assert len(red) == 0, "Length of empty set"
assert len(green) == 1, "Length of unit set"
assert len(blue) == 3, "Length of 3-element set"
# Compare
assert green == Set([0]), "Equality failed"
assert green != Set([1]), "Inequality failed"
# Union
assert blue | red == blue, "Union non-empty with empty"
assert red | blue == blue, "Union empty with non-empty"
assert green | blue == blue, "Union non-empty with non-empty"
assert blue | black == white, "Enclosing union"
# Intersection
assert blue & red == red, "Intersect non-empty with empty"
assert red & blue == red, "Intersect empty with non-empty"
assert green & blue == green, "Intersect non-empty with non-empty"
assert blue & black == green, "Enclosing intersection"
# Symmetric difference
assert red ^ green == green, "Empty symdiff non-empty"
assert green ^ blue == Set([1, 2]), "Non-empty symdiff"
assert white ^ white == red, "Self symdiff"
# Difference
assert red - green == red, "Empty - non-empty"
assert blue - red == blue, "Non-empty - empty"
assert white - black == Set([1, 2]), "Non-empty - non-empty"
# In-place union
orange = Set([])
orange |= Set([1])
assert orange == Set([1]), "In-place union"
# In-place intersection
orange = Set([1, 2])
orange &= Set([2])
assert orange == Set([2]), "In-place intersection"
# In-place difference
orange = Set([1, 2, 3])
orange -= Set([2, 4])
assert orange == Set([1, 3]), "In-place difference"
# In-place symmetric difference
orange = Set([1, 2, 3])
orange ^= Set([3, 4])
assert orange == Set([1, 2, 4]), "In-place symmetric difference"
print "All tests passed"
if __name__ == "__main__":
_test()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment