Commit c7557589 authored by Guido van Rossum's avatar Guido van Rossum

Support all the new stuff supported by the new pickle code:

- subclasses of list or dict
- __reduce__ returning a 4-tuple or 5-tuple
- slots
parent 01892664
...@@ -7,7 +7,7 @@ Interface summary: ...@@ -7,7 +7,7 @@ Interface summary:
x = copy.copy(y) # make a shallow copy of y x = copy.copy(y) # make a shallow copy of y
x = copy.deepcopy(y) # make a deep copy of y x = copy.deepcopy(y) # make a deep copy of y
For module specific errors, copy.error is raised. For module specific errors, copy.Error is raised.
The difference between shallow and deep copying is only relevant for The difference between shallow and deep copying is only relevant for
compound objects (objects that contain other objects, like lists or compound objects (objects that contain other objects, like lists or
...@@ -51,6 +51,7 @@ __getstate__() and __setstate__(). See the documentation for module ...@@ -51,6 +51,7 @@ __getstate__() and __setstate__(). See the documentation for module
# XXX need to support copy_reg here too... # XXX need to support copy_reg here too...
import types import types
from pickle import _slotnames
class Error(Exception): class Error(Exception):
pass pass
...@@ -61,7 +62,7 @@ try: ...@@ -61,7 +62,7 @@ try:
except ImportError: except ImportError:
PyStringMap = None PyStringMap = None
__all__ = ["Error", "error", "copy", "deepcopy"] __all__ = ["Error", "copy", "deepcopy"]
def copy(x): def copy(x):
"""Shallow copy operation on arbitrary Python objects. """Shallow copy operation on arbitrary Python objects.
...@@ -76,18 +77,60 @@ def copy(x): ...@@ -76,18 +77,60 @@ def copy(x):
copier = x.__copy__ copier = x.__copy__
except AttributeError: except AttributeError:
try: try:
reductor = x.__reduce__ reductor = x.__class__.__reduce__
if reductor == object.__reduce__:
reductor = _better_reduce
except AttributeError: except AttributeError:
raise error, \ raise Error("un(shallow)copyable object of type %s" % type(x))
"un(shallow)copyable object of type %s" % type(x)
else: else:
y = _reconstruct(x, reductor(), 0) y = _reconstruct(x, reductor(x), 0)
else: else:
y = copier() y = copier()
else: else:
y = copierfunction(x) y = copierfunction(x)
return y return y
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
def _better_reduce(obj):
cls = obj.__class__
getnewargs = getattr(obj, "__getnewargs__", None)
if getnewargs:
args = getnewargs()
else:
args = ()
getstate = getattr(obj, "__getstate__", None)
if getstate:
try:
state = getstate()
except TypeError, err:
# XXX Catch generic exception caused by __slots__
if str(err) != ("a class that defines __slots__ "
"without defining __getstate__ "
"cannot be pickled"):
raise # Not that specific exception
getstate = None
if not getstate:
state = getattr(obj, "__dict__", None)
names = _slotnames(cls)
if names:
slots = {}
nil = []
for name in names:
value = getattr(obj, name, nil)
if value is not nil:
slots[name] = value
if slots:
state = (state, slots)
listitems = dictitems = None
if isinstance(obj, list):
listitems = iter(obj)
elif isinstance(obj, dict):
dictitems = obj.iteritems()
return __newobj__, (cls, args), state, listitems, dictitems
_copy_dispatch = d = {} _copy_dispatch = d = {}
def _copy_atomic(x): def _copy_atomic(x):
...@@ -175,12 +218,14 @@ def deepcopy(x, memo = None): ...@@ -175,12 +218,14 @@ def deepcopy(x, memo = None):
copier = x.__deepcopy__ copier = x.__deepcopy__
except AttributeError: except AttributeError:
try: try:
reductor = x.__reduce__ reductor = x.__class__.__reduce__
if reductor == object.__reduce__:
reductor = _better_reduce
except AttributeError: except AttributeError:
raise error, \ raise Error("un(shallow)copyable object of type %s" %
"un-deep-copyable object of type %s" % type(x) type(x))
else: else:
y = _reconstruct(x, reductor(), 1, memo) y = _reconstruct(x, reductor(x), 1, memo)
else: else:
y = copier(memo) y = copier(memo)
else: else:
...@@ -331,7 +376,15 @@ def _reconstruct(x, info, deep, memo=None): ...@@ -331,7 +376,15 @@ def _reconstruct(x, info, deep, memo=None):
if hasattr(y, '__setstate__'): if hasattr(y, '__setstate__'):
y.__setstate__(state) y.__setstate__(state)
else: else:
y.__dict__.update(state) if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
else:
slotstate = None
if state is not None:
y.__dict__.update(state)
if slotstate is not None:
for key, value in slotstate.iteritems():
setattr(y, key, value)
return y return y
del d del d
......
...@@ -41,11 +41,13 @@ class TestCopy(unittest.TestCase): ...@@ -41,11 +41,13 @@ class TestCopy(unittest.TestCase):
self.assert_(y is x) self.assert_(y is x)
def test_copy_cant(self): def test_copy_cant(self):
class C(object): class Meta(type):
def __getattribute__(self, name): def __getattribute__(self, name):
if name == "__reduce__": if name == "__reduce__":
raise AttributeError, name raise AttributeError, name
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
class C:
__metaclass__ = Meta
x = C() x = C()
self.assertRaises(copy.Error, copy.copy, x) self.assertRaises(copy.Error, copy.copy, x)
...@@ -189,11 +191,13 @@ class TestCopy(unittest.TestCase): ...@@ -189,11 +191,13 @@ class TestCopy(unittest.TestCase):
self.assert_(y is x) self.assert_(y is x)
def test_deepcopy_cant(self): def test_deepcopy_cant(self):
class C(object): class Meta(type):
def __getattribute__(self, name): def __getattribute__(self, name):
if name == "__reduce__": if name == "__reduce__":
raise AttributeError, name raise AttributeError, name
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
class C:
__metaclass__ = Meta
x = C() x = C()
self.assertRaises(copy.Error, copy.deepcopy, x) self.assertRaises(copy.Error, copy.deepcopy, x)
...@@ -411,6 +415,45 @@ class TestCopy(unittest.TestCase): ...@@ -411,6 +415,45 @@ class TestCopy(unittest.TestCase):
self.assert_(x is not y) self.assert_(x is not y)
self.assert_(x["foo"] is not y["foo"]) self.assert_(x["foo"] is not y["foo"])
def test_copy_slots(self):
class C(object):
__slots__ = ["foo"]
x = C()
x.foo = [42]
y = copy.copy(x)
self.assert_(x.foo is y.foo)
def test_deepcopy_slots(self):
class C(object):
__slots__ = ["foo"]
x = C()
x.foo = [42]
y = copy.deepcopy(x)
self.assertEqual(x.foo, y.foo)
self.assert_(x.foo is not y.foo)
def test_copy_list_subclass(self):
class C(list):
pass
x = C([[1, 2], 3])
x.foo = [4, 5]
y = copy.copy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
self.assert_(x[0] is y[0])
self.assert_(x.foo is y.foo)
def test_deepcopy_list_subclass(self):
class C(list):
pass
x = C([[1, 2], 3])
x.foo = [4, 5]
y = copy.deepcopy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
self.assert_(x[0] is not y[0])
self.assert_(x.foo is not y.foo)
def test_main(): def test_main():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestCopy)) suite.addTest(unittest.makeSuite(TestCopy))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment