Commit 28dc1184 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Fix WeakValueDict/WeakKeyDict for a gc

The CPython implementation relies on immediate destruction.
parent f118d43a
......@@ -49,8 +49,17 @@ class WeakValueDictionary(UserDict.UserDict):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(wr.key)
self._pending_removals.append(wr)
else:
# Pyston change: adopted this pypy fix:
#
# Changed this for PyPy: made more resistent. The
# issue is that in some corner cases, self.data
# might already be changed or removed by the time
# this weakref's callback is called. If that is
# the case, we don't want to randomly kill an
# unrelated entry.
if self.data.get(wr.key) is wr:
del self.data[wr.key]
self._remove = remove
# A list of keys to be removed
......@@ -64,7 +73,9 @@ class WeakValueDictionary(UserDict.UserDict):
# We shouldn't encounter any KeyError, because this method should
# always be called *before* mutating the dict.
while l:
del d[l.pop()]
wr = l.pop()
if d.get(wr.key) is wr:
del d[wr.key]
def __getitem__(self, key):
o = self.data[key]()
......@@ -280,14 +291,32 @@ class WeakKeyDictionary(UserDict.UserDict):
"""
def __init__(self, dict=None):
# Pyston change:
# This implementation of WeakKeyDictionary originally relied on quick destruction
# of the weakref key objects and the immediate calling of their callback. With a gc,
# there can be multiple key removals before a collection happens, at which point we
# call remove() with keys that are not the most recent version.
#
# The approach here is to check the key in the dict to make sure it is still the same.
# This is a little bit complicated since 1) if the weakref.ref's referent gets freed,
# the ref object is no longer usable as a hash key, and 2) setting a value in a dict
# when the key already exists will not update the key.
#
# So in __setitem__, remove the existing key and replace it with the new one.
# Since there's no way to query for the current key inside a dict, given a lookup key,
# we keep a separate "refs" dict to look it up.
self.data = {}
self.refs = {}
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
assert len(self.data) == len(self.refs)
if self._iterating:
self._pending_removals.append(k)
else:
if self.refs.get(k) is k:
del self.data[k]
del self.refs[k]
self._remove = remove
# A list of dead weakrefs (keys to be removed)
self._pending_removals = []
......@@ -302,14 +331,20 @@ class WeakKeyDictionary(UserDict.UserDict):
# However, it means keys may already have been removed.
l = self._pending_removals
d = self.data
r = self.refs
while l:
try:
del d[l.pop()]
k = l.pop()
if self.refs.get(k) is k:
del d[k]
del r[k]
except KeyError:
pass
def __delitem__(self, key):
del self.data[ref(key)]
r = ref(key)
del self.data[r]
del self.refs[r]
def __getitem__(self, key):
return self.data[ref(key)]
......@@ -318,7 +353,11 @@ class WeakKeyDictionary(UserDict.UserDict):
return "<WeakKeyDictionary at %s>" % id(self)
def __setitem__(self, key, value):
self.data[ref(key, self._remove)] = value
r = ref(key, self._remove)
self.data.pop(r, None)
self.refs.pop(r, None)
self.data[r] = value
self.refs[r] = r
def copy(self):
new = WeakKeyDictionary()
......@@ -421,23 +460,28 @@ class WeakKeyDictionary(UserDict.UserDict):
def popitem(self):
while 1:
key, value = self.data.popitem()
_, key = self.refs.popitem()
value = self.data.pop(key)
o = key()
if o is not None:
return o, value
def pop(self, key, *args):
return self.data.pop(ref(key), *args)
r = ref(key)
self.keys.pop(r, None)
return self.data.pop(r, *args)
def setdefault(self, key, default=None):
return self.data.setdefault(ref(key, self._remove),default)
if key not in self:
self[key] = default
return default
return self[key]
def update(self, dict=None, **kwargs):
d = self.data
if dict is not None:
if not hasattr(dict, "items"):
dict = type({})(dict)
for key, value in dict.items():
d[ref(key, self._remove)] = value
self[key] = value
if len(kwargs):
self.update(kwargs)
from weakref import WeakKeyDictionary, WeakValueDictionary
class S(object):
def __init__(self, n):
self.n = n
def __hash__(self):
return hash(self.n)
def __eq__(self, rhs):
return self.n == rhs.n
def test(d):
print "Testing on", d.__class__
k, v = None, None
for i in xrange(10):
print i
for j in xrange(100):
k, v = S(0), S(0)
d[k] = v
import gc
gc.collect()
print len(d.keys()), k in d
test(WeakKeyDictionary())
test(WeakValueDictionary())
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment