Issue #7105: weak dict iterators are fragile because of unpredictable GC runs

Backport the fix from pyton 3.x for this issue.
parent c289fa75
...@@ -163,7 +163,7 @@ than needed. ...@@ -163,7 +163,7 @@ than needed.
.. method:: WeakKeyDictionary.iterkeyrefs() .. method:: WeakKeyDictionary.iterkeyrefs()
Return an :term:`iterator` that yields the weak references to the keys. Return an iterable of the weak references to the keys.
.. versionadded:: 2.5 .. versionadded:: 2.5
...@@ -195,7 +195,7 @@ methods of :class:`WeakKeyDictionary` objects. ...@@ -195,7 +195,7 @@ methods of :class:`WeakKeyDictionary` objects.
.. method:: WeakValueDictionary.itervaluerefs() .. method:: WeakValueDictionary.itervaluerefs()
Return an :term:`iterator` that yields the weak references to the values. Return an iterable of the weak references to the values.
.. versionadded:: 2.5 .. versionadded:: 2.5
......
...@@ -4,6 +4,8 @@ import unittest ...@@ -4,6 +4,8 @@ import unittest
import UserList import UserList
import weakref import weakref
import operator import operator
import contextlib
import copy
from test import test_support from test import test_support
...@@ -903,7 +905,7 @@ class MappingTestCase(TestBase): ...@@ -903,7 +905,7 @@ class MappingTestCase(TestBase):
def check_len_cycles(self, dict_type, cons): def check_len_cycles(self, dict_type, cons):
N = 20 N = 20
items = [RefCycle() for i in range(N)] items = [RefCycle() for i in range(N)]
dct = dict_type(cons(o) for o in items) dct = dict_type(cons(i, o) for i, o in enumerate(items))
# Keep an iterator alive # Keep an iterator alive
it = dct.iteritems() it = dct.iteritems()
try: try:
...@@ -913,18 +915,23 @@ class MappingTestCase(TestBase): ...@@ -913,18 +915,23 @@ class MappingTestCase(TestBase):
del items del items
gc.collect() gc.collect()
n1 = len(dct) n1 = len(dct)
list(it)
del it del it
gc.collect() gc.collect()
n2 = len(dct) n2 = len(dct)
# one item may be kept alive inside the iterator # iteration should prevent garbage collection here
self.assertIn(n1, (0, 1)) # Note that this is a test on an implementation detail. The requirement
# is only to provide stable iteration, not that the size of the container
# stay fixed.
self.assertEqual(n1, 20)
#self.assertIn(n1, (0, 1))
self.assertEqual(n2, 0) self.assertEqual(n2, 0)
def test_weak_keyed_len_cycles(self): def test_weak_keyed_len_cycles(self):
self.check_len_cycles(weakref.WeakKeyDictionary, lambda k: (k, 1)) self.check_len_cycles(weakref.WeakKeyDictionary, lambda n, k: (k, n))
def test_weak_valued_len_cycles(self): def test_weak_valued_len_cycles(self):
self.check_len_cycles(weakref.WeakValueDictionary, lambda k: (1, k)) self.check_len_cycles(weakref.WeakValueDictionary, lambda n, k: (n, k))
def check_len_race(self, dict_type, cons): def check_len_race(self, dict_type, cons):
# Extended sanity checks for len() in the face of cyclic collection # Extended sanity checks for len() in the face of cyclic collection
...@@ -1090,6 +1097,86 @@ class MappingTestCase(TestBase): ...@@ -1090,6 +1097,86 @@ class MappingTestCase(TestBase):
self.assertEqual(len(values), 0, self.assertEqual(len(values), 0,
"itervalues() did not touch all values") "itervalues() did not touch all values")
def check_weak_destroy_while_iterating(self, dict, objects, iter_name):
n = len(dict)
it = iter(getattr(dict, iter_name)())
next(it) # Trigger internal iteration
# Destroy an object
del objects[-1]
gc.collect() # just in case
# We have removed either the first consumed object, or another one
self.assertIn(len(list(it)), [len(objects), len(objects) - 1])
del it
# The removal has been committed
self.assertEqual(len(dict), n - 1)
def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext):
# Check that we can explicitly mutate the weak dict without
# interfering with delayed removal.
# `testcontext` should create an iterator, destroy one of the
# weakref'ed objects and then return a new key/value pair corresponding
# to the destroyed object.
with testcontext() as (k, v):
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.__delitem__, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.pop, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
dict[k] = v
self.assertEqual(dict[k], v)
ddict = copy.copy(dict)
with testcontext() as (k, v):
dict.update(ddict)
self.assertEqual(dict, ddict)
with testcontext() as (k, v):
dict.clear()
self.assertEqual(len(dict), 0)
def test_weak_keys_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_keyed_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'iterkeys')
self.check_weak_destroy_while_iterating(dict, objects, 'iteritems')
self.check_weak_destroy_while_iterating(dict, objects, 'itervalues')
self.check_weak_destroy_while_iterating(dict, objects, 'iterkeyrefs')
dict, objects = self.make_weak_keyed_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.iteritems())
next(it)
# Schedule a key/value for removal and recreate it
v = objects.pop().arg
gc.collect() # just in case
yield Object(v), v
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_weak_values_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_valued_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'iterkeys')
self.check_weak_destroy_while_iterating(dict, objects, 'iteritems')
self.check_weak_destroy_while_iterating(dict, objects, 'itervalues')
self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs')
dict, objects = self.make_weak_valued_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.iteritems())
next(it)
# Schedule a key/value for removal and recreate it
k = objects.pop().arg
gc.collect() # just in case
yield k, Object(k)
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_make_weak_keyed_dict_from_dict(self): def test_make_weak_keyed_dict_from_dict(self):
o = Object(3) o = Object(3)
dict = weakref.WeakKeyDictionary({o:364}) dict = weakref.WeakKeyDictionary({o:364})
......
...@@ -11,6 +11,7 @@ import warnings ...@@ -11,6 +11,7 @@ import warnings
import collections import collections
import gc import gc
import contextlib import contextlib
from UserString import UserString as ustr
class Foo: class Foo:
...@@ -448,6 +449,54 @@ class TestWeakSet(unittest.TestCase): ...@@ -448,6 +449,54 @@ class TestWeakSet(unittest.TestCase):
self.assertGreaterEqual(n2, 0) self.assertGreaterEqual(n2, 0)
self.assertLessEqual(n2, n1) self.assertLessEqual(n2, n1)
def test_weak_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
# Create new items to be sure no-one else holds a reference
items = [ustr(c) for c in ('a', 'b', 'c')]
s = WeakSet(items)
it = iter(s)
next(it) # Trigger internal iteration
# Destroy an item
del items[-1]
gc.collect() # just in case
# We have removed either the first consumed items, or another one
self.assertIn(len(list(it)), [len(items), len(items) - 1])
del it
# The removal has been committed
self.assertEqual(len(s), len(items))
def test_weak_destroy_and_mutate_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
items = [ustr(c) for c in string.ascii_letters]
s = WeakSet(items)
@contextlib.contextmanager
def testcontext():
try:
it = iter(s)
next(it)
# Schedule an item for removal and recreate it
u = ustr(str(items.pop()))
gc.collect() # just in case
yield u
finally:
it = None # should commit all removals
with testcontext() as u:
self.assertFalse(u in s)
with testcontext() as u:
self.assertRaises(KeyError, s.remove, u)
self.assertFalse(u in s)
with testcontext() as u:
s.add(u)
self.assertTrue(u in s)
t = s.copy()
with testcontext() as u:
s.update(t)
self.assertEqual(len(s), len(t))
with testcontext() as u:
s.clear()
self.assertEqual(len(s), 0)
def test_main(verbose=None): def test_main(verbose=None):
test_support.run_unittest(TestWeakSet) test_support.run_unittest(TestWeakSet)
......
...@@ -20,7 +20,7 @@ from _weakref import ( ...@@ -20,7 +20,7 @@ from _weakref import (
ProxyType, ProxyType,
ReferenceType) ReferenceType)
from _weakrefset import WeakSet from _weakrefset import WeakSet, _IterationGuard
from exceptions import ReferenceError from exceptions import ReferenceError
...@@ -48,10 +48,24 @@ class WeakValueDictionary(UserDict.UserDict): ...@@ -48,10 +48,24 @@ class WeakValueDictionary(UserDict.UserDict):
def remove(wr, selfref=ref(self)): def remove(wr, selfref=ref(self)):
self = selfref() self = selfref()
if self is not None: if self is not None:
if self._iterating:
self._pending_removals.append(wr.key)
else:
del self.data[wr.key] del self.data[wr.key]
self._remove = remove self._remove = remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
UserDict.UserDict.__init__(self, *args, **kw) UserDict.UserDict.__init__(self, *args, **kw)
def _commit_removals(self):
l = self._pending_removals
d = self.data
# We shouldn't encounter any KeyError, because this method should
# always be called *before* mutating the dict.
while l:
del d[l.pop()]
def __getitem__(self, key): def __getitem__(self, key):
o = self.data[key]() o = self.data[key]()
if o is None: if o is None:
...@@ -59,6 +73,11 @@ class WeakValueDictionary(UserDict.UserDict): ...@@ -59,6 +73,11 @@ class WeakValueDictionary(UserDict.UserDict):
else: else:
return o return o
def __delitem__(self, key):
if self._pending_removals:
self._commit_removals()
del self.data[key]
def __contains__(self, key): def __contains__(self, key):
try: try:
o = self.data[key]() o = self.data[key]()
...@@ -77,8 +96,15 @@ class WeakValueDictionary(UserDict.UserDict): ...@@ -77,8 +96,15 @@ class WeakValueDictionary(UserDict.UserDict):
return "<WeakValueDictionary at %s>" % id(self) return "<WeakValueDictionary at %s>" % id(self)
def __setitem__(self, key, value): def __setitem__(self, key, value):
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key) self.data[key] = KeyedRef(value, self._remove, key)
def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear()
def copy(self): def copy(self):
new = WeakValueDictionary() new = WeakValueDictionary()
for key, wr in self.data.items(): for key, wr in self.data.items():
...@@ -120,16 +146,18 @@ class WeakValueDictionary(UserDict.UserDict): ...@@ -120,16 +146,18 @@ class WeakValueDictionary(UserDict.UserDict):
return L return L
def iteritems(self): def iteritems(self):
with _IterationGuard(self):
for wr in self.data.itervalues(): for wr in self.data.itervalues():
value = wr() value = wr()
if value is not None: if value is not None:
yield wr.key, value yield wr.key, value
def iterkeys(self): def iterkeys(self):
return self.data.iterkeys() with _IterationGuard(self):
for k in self.data.iterkeys():
yield k
def __iter__(self): __iter__ = iterkeys
return self.data.iterkeys()
def itervaluerefs(self): def itervaluerefs(self):
"""Return an iterator that yields the weak references to the values. """Return an iterator that yields the weak references to the values.
...@@ -141,15 +169,20 @@ class WeakValueDictionary(UserDict.UserDict): ...@@ -141,15 +169,20 @@ class WeakValueDictionary(UserDict.UserDict):
keep the values around longer than needed. keep the values around longer than needed.
""" """
return self.data.itervalues() with _IterationGuard(self):
for wr in self.data.itervalues():
yield wr
def itervalues(self): def itervalues(self):
with _IterationGuard(self):
for wr in self.data.itervalues(): for wr in self.data.itervalues():
obj = wr() obj = wr()
if obj is not None: if obj is not None:
yield obj yield obj
def popitem(self): def popitem(self):
if self._pending_removals:
self._commit_removals()
while 1: while 1:
key, wr = self.data.popitem() key, wr = self.data.popitem()
o = wr() o = wr()
...@@ -157,6 +190,8 @@ class WeakValueDictionary(UserDict.UserDict): ...@@ -157,6 +190,8 @@ class WeakValueDictionary(UserDict.UserDict):
return key, o return key, o
def pop(self, key, *args): def pop(self, key, *args):
if self._pending_removals:
self._commit_removals()
try: try:
o = self.data.pop(key)() o = self.data.pop(key)()
except KeyError: except KeyError:
...@@ -172,12 +207,16 @@ class WeakValueDictionary(UserDict.UserDict): ...@@ -172,12 +207,16 @@ class WeakValueDictionary(UserDict.UserDict):
try: try:
wr = self.data[key] wr = self.data[key]
except KeyError: except KeyError:
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key) self.data[key] = KeyedRef(default, self._remove, key)
return default return default
else: else:
return wr() return wr()
def update(self, dict=None, **kwargs): def update(self, dict=None, **kwargs):
if self._pending_removals:
self._commit_removals()
d = self.data d = self.data
if dict is not None: if dict is not None:
if not hasattr(dict, "items"): if not hasattr(dict, "items"):
...@@ -245,9 +284,29 @@ class WeakKeyDictionary(UserDict.UserDict): ...@@ -245,9 +284,29 @@ class WeakKeyDictionary(UserDict.UserDict):
def remove(k, selfref=ref(self)): def remove(k, selfref=ref(self)):
self = selfref() self = selfref()
if self is not None: if self is not None:
if self._iterating:
self._pending_removals.append(k)
else:
del self.data[k] del self.data[k]
self._remove = remove self._remove = remove
if dict is not None: self.update(dict) # A list of dead weakrefs (keys to be removed)
self._pending_removals = []
self._iterating = set()
if dict is not None:
self.update(dict)
def _commit_removals(self):
# NOTE: We don't need to call this method before mutating the dict,
# because a dead weakref never compares equal to a live weakref,
# even if they happened to refer to equal objects.
# However, it means keys may already have been removed.
l = self._pending_removals
d = self.data
while l:
try:
del d[l.pop()]
except KeyError:
pass
def __delitem__(self, key): def __delitem__(self, key):
del self.data[ref(key)] del self.data[ref(key)]
...@@ -306,6 +365,7 @@ class WeakKeyDictionary(UserDict.UserDict): ...@@ -306,6 +365,7 @@ class WeakKeyDictionary(UserDict.UserDict):
return L return L
def iteritems(self): def iteritems(self):
with _IterationGuard(self):
for wr, value in self.data.iteritems(): for wr, value in self.data.iteritems():
key = wr() key = wr()
if key is not None: if key is not None:
...@@ -321,19 +381,23 @@ class WeakKeyDictionary(UserDict.UserDict): ...@@ -321,19 +381,23 @@ class WeakKeyDictionary(UserDict.UserDict):
keep the keys around longer than needed. keep the keys around longer than needed.
""" """
return self.data.iterkeys() with _IterationGuard(self):
for wr in self.data.iterkeys():
yield wr
def iterkeys(self): def iterkeys(self):
with _IterationGuard(self):
for wr in self.data.iterkeys(): for wr in self.data.iterkeys():
obj = wr() obj = wr()
if obj is not None: if obj is not None:
yield obj yield obj
def __iter__(self): __iter__ = iterkeys
return self.iterkeys()
def itervalues(self): def itervalues(self):
return self.data.itervalues() with _IterationGuard(self):
for value in self.data.itervalues():
yield value
def keyrefs(self): def keyrefs(self):
"""Return a list of weak references to the keys. """Return a list of weak references to the keys.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment