Commit 1d879f68 authored by Raymond Hettinger's avatar Raymond Hettinger

Backport r87613 to make OrderedDict subclassing match dict subclassing.

parent db0ef2b5
...@@ -21,7 +21,7 @@ from itertools import repeat as _repeat, chain as _chain, starmap as _starmap ...@@ -21,7 +21,7 @@ from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
class _Link(object): class _Link(object):
__slots__ = 'prev', 'next', 'key', '__weakref__' __slots__ = 'prev', 'next', 'key', '__weakref__'
class OrderedDict(dict, MutableMapping): class OrderedDict(dict):
'Dictionary that remembers insertion order' 'Dictionary that remembers insertion order'
# An inherited dict maps keys to values. # An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get. # The inherited dict provides __getitem__, __len__, __contains__, and get.
...@@ -50,7 +50,7 @@ class OrderedDict(dict, MutableMapping): ...@@ -50,7 +50,7 @@ class OrderedDict(dict, MutableMapping):
self.__root = root = _Link() # sentinel node for the doubly linked list self.__root = root = _Link() # sentinel node for the doubly linked list
root.prev = root.next = root root.prev = root.next = root
self.__map = {} self.__map = {}
self.update(*args, **kwds) self.__update(*args, **kwds)
def clear(self): def clear(self):
'od.clear() -> None. Remove all items from od.' 'od.clear() -> None. Remove all items from od.'
...@@ -109,13 +109,29 @@ class OrderedDict(dict, MutableMapping): ...@@ -109,13 +109,29 @@ class OrderedDict(dict, MutableMapping):
return (self.__class__, (items,), inst_dict) return (self.__class__, (items,), inst_dict)
return self.__class__, (items,) return self.__class__, (items,)
setdefault = MutableMapping.setdefault update = __update = MutableMapping.update
update = MutableMapping.update
pop = MutableMapping.pop
keys = MutableMapping.keys keys = MutableMapping.keys
values = MutableMapping.values values = MutableMapping.values
items = MutableMapping.items items = MutableMapping.items
__marker = object()
def pop(self, key, default=__marker):
if key in self:
result = self[key]
del self[key]
return result
if default is self.__marker:
raise KeyError(key)
return default
def setdefault(self, key, default=None):
'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
if key in self:
return self[key]
self[key] = default
return default
def popitem(self, last=True): def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair. '''od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false. Pairs are returned in LIFO order if last is true or FIFO order if false.
......
...@@ -792,6 +792,10 @@ class TestOrderedDict(unittest.TestCase): ...@@ -792,6 +792,10 @@ class TestOrderedDict(unittest.TestCase):
self.assertEqual(list(d.items()), self.assertEqual(list(d.items()),
[('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)])
def test_abc(self):
self.assertTrue(isinstance(OrderedDict(), MutableMapping))
self.assertTrue(issubclass(OrderedDict, MutableMapping))
def test_clear(self): def test_clear(self):
pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
shuffle(pairs) shuffle(pairs)
...@@ -850,6 +854,17 @@ class TestOrderedDict(unittest.TestCase): ...@@ -850,6 +854,17 @@ class TestOrderedDict(unittest.TestCase):
self.assertEqual(len(od), 0) self.assertEqual(len(od), 0)
self.assertEqual(od.pop(k, 12345), 12345) self.assertEqual(od.pop(k, 12345), 12345)
# make sure pop still works when __missing__ is defined
class Missing(OrderedDict):
def __missing__(self, key):
return 0
m = Missing(a=1)
self.assertEqual(m.pop('b', 5), 5)
self.assertEqual(m.pop('a', 6), 1)
self.assertEqual(m.pop('a', 6), 6)
with self.assertRaises(KeyError):
m.pop('a')
def test_equality(self): def test_equality(self):
pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
shuffle(pairs) shuffle(pairs)
...@@ -934,6 +949,12 @@ class TestOrderedDict(unittest.TestCase): ...@@ -934,6 +949,12 @@ class TestOrderedDict(unittest.TestCase):
# make sure 'x' is added to the end # make sure 'x' is added to the end
self.assertEqual(list(od.items())[-1], ('x', 10)) self.assertEqual(list(od.items())[-1], ('x', 10))
# make sure setdefault still works when __missing__ is defined
class Missing(OrderedDict):
def __missing__(self, key):
return 0
self.assertEqual(Missing().setdefault(5, 9), 9)
def test_reinsert(self): def test_reinsert(self):
# Given insert a, insert b, delete a, re-insert a, # Given insert a, insert b, delete a, re-insert a,
# verify that a is now later than b. # verify that a is now later than b.
...@@ -945,6 +966,13 @@ class TestOrderedDict(unittest.TestCase): ...@@ -945,6 +966,13 @@ class TestOrderedDict(unittest.TestCase):
self.assertEqual(list(od.items()), [('b', 2), ('a', 1)]) self.assertEqual(list(od.items()), [('b', 2), ('a', 1)])
def test_override_update(self):
# Verify that subclasses can override update() without breaking __init__()
class MyOD(OrderedDict):
def update(self, *args, **kwds):
raise Exception()
items = [('a', 1), ('c', 3), ('b', 2)]
self.assertEqual(list(MyOD(items).items()), items)
class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
type2test = OrderedDict type2test = OrderedDict
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment