Issue #7105: weak dict iterators are fragile because of unpredictable GC runs

Backport the fix from pyton 3.x for this issue.
parent c289fa75
......@@ -163,7 +163,7 @@ than needed.
.. method:: WeakKeyDictionary.iterkeyrefs()
Return an :term:`iterator` that yields the weak references to the keys.
Return an iterable of the weak references to the keys.
.. versionadded:: 2.5
......@@ -195,7 +195,7 @@ methods of :class:`WeakKeyDictionary` objects.
.. method:: WeakValueDictionary.itervaluerefs()
Return an :term:`iterator` that yields the weak references to the values.
Return an iterable of the weak references to the values.
.. versionadded:: 2.5
......
......@@ -4,6 +4,8 @@ import unittest
import UserList
import weakref
import operator
import contextlib
import copy
from test import test_support
......@@ -903,7 +905,7 @@ class MappingTestCase(TestBase):
def check_len_cycles(self, dict_type, cons):
N = 20
items = [RefCycle() for i in range(N)]
dct = dict_type(cons(o) for o in items)
dct = dict_type(cons(i, o) for i, o in enumerate(items))
# Keep an iterator alive
it = dct.iteritems()
try:
......@@ -913,18 +915,23 @@ class MappingTestCase(TestBase):
del items
gc.collect()
n1 = len(dct)
list(it)
del it
gc.collect()
n2 = len(dct)
# one item may be kept alive inside the iterator
self.assertIn(n1, (0, 1))
# iteration should prevent garbage collection here
# Note that this is a test on an implementation detail. The requirement
# is only to provide stable iteration, not that the size of the container
# stay fixed.
self.assertEqual(n1, 20)
#self.assertIn(n1, (0, 1))
self.assertEqual(n2, 0)
def test_weak_keyed_len_cycles(self):
self.check_len_cycles(weakref.WeakKeyDictionary, lambda k: (k, 1))
self.check_len_cycles(weakref.WeakKeyDictionary, lambda n, k: (k, n))
def test_weak_valued_len_cycles(self):
self.check_len_cycles(weakref.WeakValueDictionary, lambda k: (1, k))
self.check_len_cycles(weakref.WeakValueDictionary, lambda n, k: (n, k))
def check_len_race(self, dict_type, cons):
# Extended sanity checks for len() in the face of cyclic collection
......@@ -1090,6 +1097,86 @@ class MappingTestCase(TestBase):
self.assertEqual(len(values), 0,
"itervalues() did not touch all values")
def check_weak_destroy_while_iterating(self, dict, objects, iter_name):
n = len(dict)
it = iter(getattr(dict, iter_name)())
next(it) # Trigger internal iteration
# Destroy an object
del objects[-1]
gc.collect() # just in case
# We have removed either the first consumed object, or another one
self.assertIn(len(list(it)), [len(objects), len(objects) - 1])
del it
# The removal has been committed
self.assertEqual(len(dict), n - 1)
def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext):
# Check that we can explicitly mutate the weak dict without
# interfering with delayed removal.
# `testcontext` should create an iterator, destroy one of the
# weakref'ed objects and then return a new key/value pair corresponding
# to the destroyed object.
with testcontext() as (k, v):
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.__delitem__, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.pop, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
dict[k] = v
self.assertEqual(dict[k], v)
ddict = copy.copy(dict)
with testcontext() as (k, v):
dict.update(ddict)
self.assertEqual(dict, ddict)
with testcontext() as (k, v):
dict.clear()
self.assertEqual(len(dict), 0)
def test_weak_keys_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_keyed_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'iterkeys')
self.check_weak_destroy_while_iterating(dict, objects, 'iteritems')
self.check_weak_destroy_while_iterating(dict, objects, 'itervalues')
self.check_weak_destroy_while_iterating(dict, objects, 'iterkeyrefs')
dict, objects = self.make_weak_keyed_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.iteritems())
next(it)
# Schedule a key/value for removal and recreate it
v = objects.pop().arg
gc.collect() # just in case
yield Object(v), v
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_weak_values_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_valued_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'iterkeys')
self.check_weak_destroy_while_iterating(dict, objects, 'iteritems')
self.check_weak_destroy_while_iterating(dict, objects, 'itervalues')
self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs')
dict, objects = self.make_weak_valued_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.iteritems())
next(it)
# Schedule a key/value for removal and recreate it
k = objects.pop().arg
gc.collect() # just in case
yield k, Object(k)
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_make_weak_keyed_dict_from_dict(self):
o = Object(3)
dict = weakref.WeakKeyDictionary({o:364})
......
......@@ -11,6 +11,7 @@ import warnings
import collections
import gc
import contextlib
from UserString import UserString as ustr
class Foo:
......@@ -448,6 +449,54 @@ class TestWeakSet(unittest.TestCase):
self.assertGreaterEqual(n2, 0)
self.assertLessEqual(n2, n1)
def test_weak_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
# Create new items to be sure no-one else holds a reference
items = [ustr(c) for c in ('a', 'b', 'c')]
s = WeakSet(items)
it = iter(s)
next(it) # Trigger internal iteration
# Destroy an item
del items[-1]
gc.collect() # just in case
# We have removed either the first consumed items, or another one
self.assertIn(len(list(it)), [len(items), len(items) - 1])
del it
# The removal has been committed
self.assertEqual(len(s), len(items))
def test_weak_destroy_and_mutate_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
items = [ustr(c) for c in string.ascii_letters]
s = WeakSet(items)
@contextlib.contextmanager
def testcontext():
try:
it = iter(s)
next(it)
# Schedule an item for removal and recreate it
u = ustr(str(items.pop()))
gc.collect() # just in case
yield u
finally:
it = None # should commit all removals
with testcontext() as u:
self.assertFalse(u in s)
with testcontext() as u:
self.assertRaises(KeyError, s.remove, u)
self.assertFalse(u in s)
with testcontext() as u:
s.add(u)
self.assertTrue(u in s)
t = s.copy()
with testcontext() as u:
s.update(t)
self.assertEqual(len(s), len(t))
with testcontext() as u:
s.clear()
self.assertEqual(len(s), 0)
def test_main(verbose=None):
test_support.run_unittest(TestWeakSet)
......
......@@ -20,7 +20,7 @@ from _weakref import (
ProxyType,
ReferenceType)
from _weakrefset import WeakSet
from _weakrefset import WeakSet, _IterationGuard
from exceptions import ReferenceError
......@@ -48,10 +48,24 @@ class WeakValueDictionary(UserDict.UserDict):
def remove(wr, selfref=ref(self)):
self = selfref()
if self is not None:
del self.data[wr.key]
if self._iterating:
self._pending_removals.append(wr.key)
else:
del self.data[wr.key]
self._remove = remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
UserDict.UserDict.__init__(self, *args, **kw)
def _commit_removals(self):
l = self._pending_removals
d = self.data
# We shouldn't encounter any KeyError, because this method should
# always be called *before* mutating the dict.
while l:
del d[l.pop()]
def __getitem__(self, key):
o = self.data[key]()
if o is None:
......@@ -59,6 +73,11 @@ class WeakValueDictionary(UserDict.UserDict):
else:
return o
def __delitem__(self, key):
if self._pending_removals:
self._commit_removals()
del self.data[key]
def __contains__(self, key):
try:
o = self.data[key]()
......@@ -77,8 +96,15 @@ class WeakValueDictionary(UserDict.UserDict):
return "<WeakValueDictionary at %s>" % id(self)
def __setitem__(self, key, value):
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key)
def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear()
def copy(self):
new = WeakValueDictionary()
for key, wr in self.data.items():
......@@ -120,16 +146,18 @@ class WeakValueDictionary(UserDict.UserDict):
return L
def iteritems(self):
for wr in self.data.itervalues():
value = wr()
if value is not None:
yield wr.key, value
with _IterationGuard(self):
for wr in self.data.itervalues():
value = wr()
if value is not None:
yield wr.key, value
def iterkeys(self):
return self.data.iterkeys()
with _IterationGuard(self):
for k in self.data.iterkeys():
yield k
def __iter__(self):
return self.data.iterkeys()
__iter__ = iterkeys
def itervaluerefs(self):
"""Return an iterator that yields the weak references to the values.
......@@ -141,15 +169,20 @@ class WeakValueDictionary(UserDict.UserDict):
keep the values around longer than needed.
"""
return self.data.itervalues()
with _IterationGuard(self):
for wr in self.data.itervalues():
yield wr
def itervalues(self):
for wr in self.data.itervalues():
obj = wr()
if obj is not None:
yield obj
with _IterationGuard(self):
for wr in self.data.itervalues():
obj = wr()
if obj is not None:
yield obj
def popitem(self):
if self._pending_removals:
self._commit_removals()
while 1:
key, wr = self.data.popitem()
o = wr()
......@@ -157,6 +190,8 @@ class WeakValueDictionary(UserDict.UserDict):
return key, o
def pop(self, key, *args):
if self._pending_removals:
self._commit_removals()
try:
o = self.data.pop(key)()
except KeyError:
......@@ -172,12 +207,16 @@ class WeakValueDictionary(UserDict.UserDict):
try:
wr = self.data[key]
except KeyError:
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key)
return default
else:
return wr()
def update(self, dict=None, **kwargs):
if self._pending_removals:
self._commit_removals()
d = self.data
if dict is not None:
if not hasattr(dict, "items"):
......@@ -245,9 +284,29 @@ class WeakKeyDictionary(UserDict.UserDict):
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
del self.data[k]
if self._iterating:
self._pending_removals.append(k)
else:
del self.data[k]
self._remove = remove
if dict is not None: self.update(dict)
# A list of dead weakrefs (keys to be removed)
self._pending_removals = []
self._iterating = set()
if dict is not None:
self.update(dict)
def _commit_removals(self):
# NOTE: We don't need to call this method before mutating the dict,
# because a dead weakref never compares equal to a live weakref,
# even if they happened to refer to equal objects.
# However, it means keys may already have been removed.
l = self._pending_removals
d = self.data
while l:
try:
del d[l.pop()]
except KeyError:
pass
def __delitem__(self, key):
del self.data[ref(key)]
......@@ -306,10 +365,11 @@ class WeakKeyDictionary(UserDict.UserDict):
return L
def iteritems(self):
for wr, value in self.data.iteritems():
key = wr()
if key is not None:
yield key, value
with _IterationGuard(self):
for wr, value in self.data.iteritems():
key = wr()
if key is not None:
yield key, value
def iterkeyrefs(self):
"""Return an iterator that yields the weak references to the keys.
......@@ -321,19 +381,23 @@ class WeakKeyDictionary(UserDict.UserDict):
keep the keys around longer than needed.
"""
return self.data.iterkeys()
with _IterationGuard(self):
for wr in self.data.iterkeys():
yield wr
def iterkeys(self):
for wr in self.data.iterkeys():
obj = wr()
if obj is not None:
yield obj
with _IterationGuard(self):
for wr in self.data.iterkeys():
obj = wr()
if obj is not None:
yield obj
def __iter__(self):
return self.iterkeys()
__iter__ = iterkeys
def itervalues(self):
return self.data.itervalues()
with _IterationGuard(self):
for value in self.data.itervalues():
yield value
def keyrefs(self):
"""Return a list of weak references to the keys.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment