Commit c1baa601 authored by Antoine Pitrou's avatar Antoine Pitrou

Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against

the destruction of weakref'ed objects while iterating.
parent dc2a6134
......@@ -159,7 +159,7 @@ than needed.
.. method:: WeakKeyDictionary.keyrefs()
Return an :term:`iterator` that yields the weak references to the keys.
Return an iterable of the weak references to the keys.
.. class:: WeakValueDictionary([dict])
......@@ -182,7 +182,7 @@ These method have the same issues as the and :meth:`keyrefs` method of
.. method:: WeakValueDictionary.valuerefs()
Return an :term:`iterator` that yields the weak references to the values.
Return an iterable of the weak references to the values.
.. class:: WeakSet([elements])
......
......@@ -6,22 +6,61 @@ from _weakref import ref
__all__ = ['WeakSet']
class _IterationGuard:
# This context manager registers itself in the current iterators of the
# weak container, such as to delay all removals until the context manager
# exits.
# This technique should be relatively thread-safe (since sets are).
def __init__(self, weakcontainer):
# Don't create cycles
self.weakcontainer = ref(weakcontainer)
def __enter__(self):
w = self.weakcontainer()
if w is not None:
w._iterating.add(self)
return self
def __exit__(self, e, t, b):
w = self.weakcontainer()
if w is not None:
s = w._iterating
s.remove(self)
if not s:
w._commit_removals()
class WeakSet:
def __init__(self, data=None):
self.data = set()
def _remove(item, selfref=ref(self)):
self = selfref()
if self is not None:
self.data.discard(item)
if self._iterating:
self._pending_removals.append(item)
else:
self.data.discard(item)
self._remove = _remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
if data is not None:
self.update(data)
def _commit_removals(self):
l = self._pending_removals
discard = self.data.discard
while l:
discard(l.pop())
def __iter__(self):
for itemref in self.data:
item = itemref()
if item is not None:
yield item
with _IterationGuard(self):
for itemref in self.data:
item = itemref()
if item is not None:
yield item
def __len__(self):
return sum(x() is not None for x in self.data)
......@@ -34,15 +73,21 @@ class WeakSet:
getattr(self, '__dict__', None))
def add(self, item):
if self._pending_removals:
self._commit_removals()
self.data.add(ref(item, self._remove))
def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear()
def copy(self):
return self.__class__(self)
def pop(self):
if self._pending_removals:
self._commit_removals()
while True:
try:
itemref = self.data.pop()
......@@ -53,17 +98,24 @@ class WeakSet:
return item
def remove(self, item):
if self._pending_removals:
self._commit_removals()
self.data.remove(ref(item))
def discard(self, item):
if self._pending_removals:
self._commit_removals()
self.data.discard(ref(item))
def update(self, other):
if self._pending_removals:
self._commit_removals()
if isinstance(other, self.__class__):
self.data.update(other.data)
else:
for element in other:
self.add(element)
def __ior__(self, other):
self.update(other)
return self
......@@ -82,11 +134,15 @@ class WeakSet:
__sub__ = difference
def difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.difference_update(ref(item) for item in other)
def __isub__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
......@@ -98,8 +154,12 @@ class WeakSet:
__and__ = intersection
def intersection_update(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
def __iand__(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
return self
......@@ -127,11 +187,15 @@ class WeakSet:
__xor__ = symmetric_difference
def symmetric_difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
self.data.symmetric_difference_update(ref(item) for item in other)
def __ixor__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
......
......@@ -4,6 +4,8 @@ import unittest
import collections
import weakref
import operator
import contextlib
import copy
from test import support
......@@ -788,6 +790,10 @@ class Object:
self.arg = arg
def __repr__(self):
return "<Object %r>" % self.arg
def __eq__(self, other):
if isinstance(other, Object):
return self.arg == other.arg
return NotImplemented
def __lt__(self, other):
if isinstance(other, Object):
return self.arg < other.arg
......@@ -935,6 +941,87 @@ class MappingTestCase(TestBase):
self.assertFalse(values,
"itervalues() did not touch all values")
def check_weak_destroy_while_iterating(self, dict, objects, iter_name):
n = len(dict)
it = iter(getattr(dict, iter_name)())
next(it) # Trigger internal iteration
# Destroy an object
del objects[-1]
gc.collect() # just in case
# We have removed either the first consumed object, or another one
self.assertIn(len(list(it)), [len(objects), len(objects) - 1])
del it
# The removal has been committed
self.assertEqual(len(dict), n - 1)
def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext):
# Check that we can explicitly mutate the weak dict without
# interfering with delayed removal.
# `testcontext` should create an iterator, destroy one of the
# weakref'ed objects and then return a new key/value pair corresponding
# to the destroyed object.
with testcontext() as (k, v):
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.__delitem__, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.pop, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
dict[k] = v
self.assertEqual(dict[k], v)
ddict = copy.copy(dict)
with testcontext() as (k, v):
dict.update(ddict)
self.assertEqual(dict, ddict)
with testcontext() as (k, v):
dict.clear()
self.assertEqual(len(dict), 0)
def test_weak_keys_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_keyed_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
self.check_weak_destroy_while_iterating(dict, objects, 'items')
self.check_weak_destroy_while_iterating(dict, objects, 'values')
self.check_weak_destroy_while_iterating(dict, objects, 'keyrefs')
dict, objects = self.make_weak_keyed_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.items())
next(it)
# Schedule a key/value for removal and recreate it
v = objects.pop().arg
gc.collect() # just in case
yield Object(v), v
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_weak_values_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_valued_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
self.check_weak_destroy_while_iterating(dict, objects, 'items')
self.check_weak_destroy_while_iterating(dict, objects, 'values')
self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs')
self.check_weak_destroy_while_iterating(dict, objects, 'valuerefs')
dict, objects = self.make_weak_valued_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.items())
next(it)
# Schedule a key/value for removal and recreate it
k = objects.pop().arg
gc.collect() # just in case
yield k, Object(k)
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_make_weak_keyed_dict_from_dict(self):
o = Object(3)
dict = weakref.WeakKeyDictionary({o:364})
......
......@@ -10,6 +10,8 @@ import sys
import warnings
import collections
from collections import UserString as ustr
import gc
import contextlib
class Foo:
......@@ -307,6 +309,54 @@ class TestWeakSet(unittest.TestCase):
self.assertFalse(self.s == WeakSet([Foo]))
self.assertFalse(self.s == 1)
def test_weak_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
# Create new items to be sure no-one else holds a reference
items = [ustr(c) for c in ('a', 'b', 'c')]
s = WeakSet(items)
it = iter(s)
next(it) # Trigger internal iteration
# Destroy an item
del items[-1]
gc.collect() # just in case
# We have removed either the first consumed items, or another one
self.assertIn(len(list(it)), [len(items), len(items) - 1])
del it
# The removal has been committed
self.assertEqual(len(s), len(items))
def test_weak_destroy_and_mutate_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
items = [ustr(c) for c in string.ascii_letters]
s = WeakSet(items)
@contextlib.contextmanager
def testcontext():
try:
it = iter(s)
next(it)
# Schedule an item for removal and recreate it
u = ustr(str(items.pop()))
gc.collect() # just in case
yield u
finally:
it = None # should commit all removals
with testcontext() as u:
self.assertFalse(u in s)
with testcontext() as u:
self.assertRaises(KeyError, s.remove, u)
self.assertFalse(u in s)
with testcontext() as u:
s.add(u)
self.assertTrue(u in s)
t = s.copy()
with testcontext() as u:
s.update(t)
self.assertEqual(len(s), len(t))
with testcontext() as u:
s.clear()
self.assertEqual(len(s), 0)
def test_main(verbose=None):
support.run_unittest(TestWeakSet)
......
......@@ -18,7 +18,7 @@ from _weakref import (
ProxyType,
ReferenceType)
from _weakrefset import WeakSet
from _weakrefset import WeakSet, _IterationGuard
import collections # Import after _weakref to avoid circular import.
......@@ -46,11 +46,25 @@ class WeakValueDictionary(collections.MutableMapping):
def remove(wr, selfref=ref(self)):
self = selfref()
if self is not None:
del self.data[wr.key]
if self._iterating:
self._pending_removals.append(wr.key)
else:
del self.data[wr.key]
self._remove = remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
self.data = d = {}
self.update(*args, **kw)
def _commit_removals(self):
l = self._pending_removals
d = self.data
# We shouldn't encounter any KeyError, because this method should
# always be called *before* mutating the dict.
while l:
del d[l.pop()]
def __getitem__(self, key):
o = self.data[key]()
if o is None:
......@@ -59,6 +73,8 @@ class WeakValueDictionary(collections.MutableMapping):
return o
def __delitem__(self, key):
if self._pending_removals:
self._commit_removals()
del self.data[key]
def __len__(self):
......@@ -75,6 +91,8 @@ class WeakValueDictionary(collections.MutableMapping):
return "<WeakValueDictionary at %s>" % id(self)
def __setitem__(self, key, value):
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key)
def copy(self):
......@@ -110,24 +128,19 @@ class WeakValueDictionary(collections.MutableMapping):
return o
def items(self):
L = []
for key, wr in self.data.items():
o = wr()
if o is not None:
L.append((key, o))
return L
def items(self):
for wr in self.data.values():
value = wr()
if value is not None:
yield wr.key, value
with _IterationGuard(self):
for k, wr in self.data.items():
v = wr()
if v is not None:
yield k, v
def keys(self):
return iter(self.data.keys())
with _IterationGuard(self):
for k, wr in self.data.items():
if wr() is not None:
yield k
def __iter__(self):
return iter(self.data.keys())
__iter__ = keys
def itervaluerefs(self):
"""Return an iterator that yields the weak references to the values.
......@@ -139,15 +152,20 @@ class WeakValueDictionary(collections.MutableMapping):
keep the values around longer than needed.
"""
return self.data.values()
with _IterationGuard(self):
for wr in self.data.values():
yield wr
def values(self):
for wr in self.data.values():
obj = wr()
if obj is not None:
yield obj
with _IterationGuard(self):
for wr in self.data.values():
obj = wr()
if obj is not None:
yield obj
def popitem(self):
if self._pending_removals:
self._commit_removals()
while 1:
key, wr = self.data.popitem()
o = wr()
......@@ -155,6 +173,8 @@ class WeakValueDictionary(collections.MutableMapping):
return key, o
def pop(self, key, *args):
if self._pending_removals:
self._commit_removals()
try:
o = self.data.pop(key)()
except KeyError:
......@@ -170,12 +190,16 @@ class WeakValueDictionary(collections.MutableMapping):
try:
wr = self.data[key]
except KeyError:
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key)
return default
else:
return wr()
def update(self, dict=None, **kwargs):
if self._pending_removals:
self._commit_removals()
d = self.data
if dict is not None:
if not hasattr(dict, "items"):
......@@ -195,7 +219,7 @@ class WeakValueDictionary(collections.MutableMapping):
keep the values around longer than needed.
"""
return self.data.values()
return list(self.data.values())
class KeyedRef(ref):
......@@ -235,9 +259,29 @@ class WeakKeyDictionary(collections.MutableMapping):
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
del self.data[k]
if self._iterating:
self._pending_removals.append(k)
else:
del self.data[k]
self._remove = remove
if dict is not None: self.update(dict)
# A list of dead weakrefs (keys to be removed)
self._pending_removals = []
self._iterating = set()
if dict is not None:
self.update(dict)
def _commit_removals(self):
# NOTE: We don't need to call this method before mutating the dict,
# because a dead weakref never compares equal to a live weakref,
# even if they happened to refer to equal objects.
# However, it means keys may already have been removed.
l = self._pending_removals
d = self.data
while l:
try:
del d[l.pop()]
except KeyError:
pass
def __delitem__(self, key):
del self.data[ref(key)]
......@@ -284,34 +328,26 @@ class WeakKeyDictionary(collections.MutableMapping):
return wr in self.data
def items(self):
for wr, value in self.data.items():
key = wr()
if key is not None:
yield key, value
def keyrefs(self):
"""Return an iterator that yields the weak references to the keys.
The references are not guaranteed to be 'live' at the time
they are used, so the result of calling the references needs
to be checked before being used. This can be used to avoid
creating references that will cause the garbage collector to
keep the keys around longer than needed.
"""
return self.data.keys()
with _IterationGuard(self):
for wr, value in self.data.items():
key = wr()
if key is not None:
yield key, value
def keys(self):
for wr in self.data.keys():
obj = wr()
if obj is not None:
yield obj
with _IterationGuard(self):
for wr in self.data:
obj = wr()
if obj is not None:
yield obj
def __iter__(self):
return iter(self.keys())
__iter__ = keys
def values(self):
return iter(self.data.values())
with _IterationGuard(self):
for wr, value in self.data.items():
if wr() is not None:
yield value
def keyrefs(self):
"""Return a list of weak references to the keys.
......@@ -323,7 +359,7 @@ class WeakKeyDictionary(collections.MutableMapping):
keep the keys around longer than needed.
"""
return self.data.keys()
return list(self.data)
def popitem(self):
while 1:
......
......@@ -194,6 +194,9 @@ C-API
Library
-------
- Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against
the destruction of weakref'ed objects while iterating.
- Issue #7455: Fix possible crash in cPickle on invalid input. Patch by
Victor Stinner.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment