Commit c7557589 authored by Guido van Rossum's avatar Guido van Rossum

Support all the new stuff supported by the new pickle code:

- subclasses of list or dict
- __reduce__ returning a 4-tuple or 5-tuple
- slots
parent 01892664
......@@ -7,7 +7,7 @@ Interface summary:
x = copy.copy(y) # make a shallow copy of y
x = copy.deepcopy(y) # make a deep copy of y
For module specific errors, copy.error is raised.
For module specific errors, copy.Error is raised.
The difference between shallow and deep copying is only relevant for
compound objects (objects that contain other objects, like lists or
......@@ -51,6 +51,7 @@ __getstate__() and __setstate__(). See the documentation for module
# XXX need to support copy_reg here too...
import types
from pickle import _slotnames
class Error(Exception):
pass
......@@ -61,7 +62,7 @@ try:
except ImportError:
PyStringMap = None
__all__ = ["Error", "error", "copy", "deepcopy"]
__all__ = ["Error", "copy", "deepcopy"]
def copy(x):
"""Shallow copy operation on arbitrary Python objects.
......@@ -76,18 +77,60 @@ def copy(x):
copier = x.__copy__
except AttributeError:
try:
reductor = x.__reduce__
reductor = x.__class__.__reduce__
if reductor == object.__reduce__:
reductor = _better_reduce
except AttributeError:
raise error, \
"un(shallow)copyable object of type %s" % type(x)
raise Error("un(shallow)copyable object of type %s" % type(x))
else:
y = _reconstruct(x, reductor(), 0)
y = _reconstruct(x, reductor(x), 0)
else:
y = copier()
else:
y = copierfunction(x)
return y
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
def _better_reduce(obj):
cls = obj.__class__
getnewargs = getattr(obj, "__getnewargs__", None)
if getnewargs:
args = getnewargs()
else:
args = ()
getstate = getattr(obj, "__getstate__", None)
if getstate:
try:
state = getstate()
except TypeError, err:
# XXX Catch generic exception caused by __slots__
if str(err) != ("a class that defines __slots__ "
"without defining __getstate__ "
"cannot be pickled"):
raise # Not that specific exception
getstate = None
if not getstate:
state = getattr(obj, "__dict__", None)
names = _slotnames(cls)
if names:
slots = {}
nil = []
for name in names:
value = getattr(obj, name, nil)
if value is not nil:
slots[name] = value
if slots:
state = (state, slots)
listitems = dictitems = None
if isinstance(obj, list):
listitems = iter(obj)
elif isinstance(obj, dict):
dictitems = obj.iteritems()
return __newobj__, (cls, args), state, listitems, dictitems
_copy_dispatch = d = {}
def _copy_atomic(x):
......@@ -175,12 +218,14 @@ def deepcopy(x, memo = None):
copier = x.__deepcopy__
except AttributeError:
try:
reductor = x.__reduce__
reductor = x.__class__.__reduce__
if reductor == object.__reduce__:
reductor = _better_reduce
except AttributeError:
raise error, \
"un-deep-copyable object of type %s" % type(x)
raise Error("un(shallow)copyable object of type %s" %
type(x))
else:
y = _reconstruct(x, reductor(), 1, memo)
y = _reconstruct(x, reductor(x), 1, memo)
else:
y = copier(memo)
else:
......@@ -331,7 +376,15 @@ def _reconstruct(x, info, deep, memo=None):
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
else:
slotstate = None
if state is not None:
y.__dict__.update(state)
if slotstate is not None:
for key, value in slotstate.iteritems():
setattr(y, key, value)
return y
del d
......
......@@ -41,11 +41,13 @@ class TestCopy(unittest.TestCase):
self.assert_(y is x)
def test_copy_cant(self):
class C(object):
class Meta(type):
def __getattribute__(self, name):
if name == "__reduce__":
raise AttributeError, name
return object.__getattribute__(self, name)
class C:
__metaclass__ = Meta
x = C()
self.assertRaises(copy.Error, copy.copy, x)
......@@ -189,11 +191,13 @@ class TestCopy(unittest.TestCase):
self.assert_(y is x)
def test_deepcopy_cant(self):
class C(object):
class Meta(type):
def __getattribute__(self, name):
if name == "__reduce__":
raise AttributeError, name
return object.__getattribute__(self, name)
class C:
__metaclass__ = Meta
x = C()
self.assertRaises(copy.Error, copy.deepcopy, x)
......@@ -411,6 +415,45 @@ class TestCopy(unittest.TestCase):
self.assert_(x is not y)
self.assert_(x["foo"] is not y["foo"])
def test_copy_slots(self):
class C(object):
__slots__ = ["foo"]
x = C()
x.foo = [42]
y = copy.copy(x)
self.assert_(x.foo is y.foo)
def test_deepcopy_slots(self):
class C(object):
__slots__ = ["foo"]
x = C()
x.foo = [42]
y = copy.deepcopy(x)
self.assertEqual(x.foo, y.foo)
self.assert_(x.foo is not y.foo)
def test_copy_list_subclass(self):
class C(list):
pass
x = C([[1, 2], 3])
x.foo = [4, 5]
y = copy.copy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
self.assert_(x[0] is y[0])
self.assert_(x.foo is y.foo)
def test_deepcopy_list_subclass(self):
class C(list):
pass
x = C([[1, 2], 3])
x.foo = [4, 5]
y = copy.deepcopy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
self.assert_(x[0] is not y[0])
self.assert_(x.foo is not y.foo)
def test_main():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestCopy))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment