Commit 5aac4e63 authored by Guido van Rossum's avatar Guido van Rossum

Move _better_reduce from copy.py to copy_reg.py, and also use it in

pickle.py, where it makes save_newobj() unnecessary.  Tests pass.
parent e7ee17c5
...@@ -51,7 +51,7 @@ __getstate__() and __setstate__(). See the documentation for module ...@@ -51,7 +51,7 @@ __getstate__() and __setstate__(). See the documentation for module
# XXX need to support copy_reg here too... # XXX need to support copy_reg here too...
import types import types
from pickle import _slotnames from copy_reg import _better_reduce
class Error(Exception): class Error(Exception):
pass pass
...@@ -89,46 +89,6 @@ def copy(x): ...@@ -89,46 +89,6 @@ def copy(x):
else: else:
y = copierfunction(x) y = copierfunction(x)
return y return y
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
def _better_reduce(obj):
cls = obj.__class__
getnewargs = getattr(obj, "__getnewargs__", None)
if getnewargs:
args = getnewargs()
else:
args = ()
getstate = getattr(obj, "__getstate__", None)
if getstate:
try:
state = getstate()
except TypeError, err:
# XXX Catch generic exception caused by __slots__
if str(err) != ("a class that defines __slots__ "
"without defining __getstate__ "
"cannot be pickled"):
raise # Not that specific exception
getstate = None
if not getstate:
state = getattr(obj, "__dict__", None)
names = _slotnames(cls)
if names:
slots = {}
nil = []
for name in names:
value = getattr(obj, name, nil)
if value is not nil:
slots[name] = value
if slots:
state = (state, slots)
listitems = dictitems = None
if isinstance(obj, list):
listitems = iter(obj)
elif isinstance(obj, dict):
dictitems = obj.iteritems()
return __newobj__, (cls,) + args, state, listitems, dictitems
_copy_dispatch = d = {} _copy_dispatch = d = {}
......
...@@ -69,6 +69,84 @@ def _reduce(self): ...@@ -69,6 +69,84 @@ def _reduce(self):
else: else:
return _reconstructor, args return _reconstructor, args
# A better version of _reduce, used by copy and pickle protocol 2
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
def _better_reduce(obj):
cls = obj.__class__
getnewargs = getattr(obj, "__getnewargs__", None)
if getnewargs:
args = getnewargs()
else:
args = ()
getstate = getattr(obj, "__getstate__", None)
if getstate:
try:
state = getstate()
except TypeError, err:
# XXX Catch generic exception caused by __slots__
if str(err) != ("a class that defines __slots__ "
"without defining __getstate__ "
"cannot be pickled"):
raise # Not that specific exception
getstate = None
if not getstate:
state = getattr(obj, "__dict__", None)
names = _slotnames(cls)
if names:
slots = {}
nil = []
for name in names:
value = getattr(obj, name, nil)
if value is not nil:
slots[name] = value
if slots:
state = (state, slots)
listitems = dictitems = None
if isinstance(obj, list):
listitems = iter(obj)
elif isinstance(obj, dict):
dictitems = obj.iteritems()
return __newobj__, (cls,) + args, state, listitems, dictitems
def _slotnames(cls):
"""Return a list of slot names for a given class.
This needs to find slots defined by the class and its bases, so we
can't simply return the __slots__ attribute. We must walk down
the Method Resolution Order and concatenate the __slots__ of each
class found there. (This assumes classes don't modify their
__slots__ attribute to misrepresent their slots after the class is
defined.)
"""
# Get the value from a cache in the class if possible
names = cls.__dict__.get("__slotnames__")
if names is not None:
return names
# Not cached -- calculate the value
names = []
if not hasattr(cls, "__slots__"):
# This class has no slots
pass
else:
# Slots found -- gather slot names from all base classes
for c in cls.__mro__:
if "__slots__" in c.__dict__:
names += [name for name in c.__dict__["__slots__"]
if name not in ("__dict__", "__weakref__")]
# Cache the outcome in the class if at all possible
try:
cls.__slotnames__ = names
except:
pass # But don't die if we can't
return names
# A registry of extension codes. This is an ad-hoc compression # A registry of extension codes. This is an ad-hoc compression
# mechanism. Whenever a global reference to <module>, <name> is about # mechanism. Whenever a global reference to <module>, <name> is about
# to be pickled, the (<module>, <name>) tuple is looked up here to see # to be pickled, the (<module>, <name>) tuple is looked up here to see
......
...@@ -27,7 +27,7 @@ Misc variables: ...@@ -27,7 +27,7 @@ Misc variables:
__version__ = "$Revision$" # Code version __version__ = "$Revision$" # Code version
from types import * from types import *
from copy_reg import dispatch_table, _reconstructor from copy_reg import dispatch_table, _reconstructor, _better_reduce
from copy_reg import _extension_registry, _inverted_registry, _extension_cache from copy_reg import _extension_registry, _inverted_registry, _extension_cache
import marshal import marshal
import sys import sys
...@@ -310,10 +310,7 @@ class Pickler: ...@@ -310,10 +310,7 @@ class Pickler:
if self.proto >= 2: if self.proto >= 2:
# Protocol 2 can do better than the default __reduce__ # Protocol 2 can do better than the default __reduce__
if reduce is object.__reduce__: if reduce is object.__reduce__:
reduce = None reduce = _better_reduce
if not reduce:
self.save_newobj(obj)
return
if not reduce: if not reduce:
raise PicklingError("Can't pickle %r object: %r" % raise PicklingError("Can't pickle %r object: %r" %
(t.__name__, obj)) (t.__name__, obj))
...@@ -433,86 +430,6 @@ class Pickler: ...@@ -433,86 +430,6 @@ class Pickler:
save(state) save(state)
write(BUILD) write(BUILD)
def save_newobj(self, obj):
# Save a new-style class instance, using protocol 2.
assert self.proto >= 2 # This only works for protocol 2
t = type(obj)
getnewargs = getattr(obj, "__getnewargs__", None)
if getnewargs:
args = getnewargs() # This better not reference obj
else:
args = ()
save = self.save
write = self.write
self.save(t)
save(args)
write(NEWOBJ)
self.memoize(obj)
if isinstance(obj, list):
self._batch_appends(iter(obj))
elif isinstance(obj, dict):
self._batch_setitems(obj.iteritems())
getstate = getattr(obj, "__getstate__", None)
if getstate:
# A class may define both __getstate__ and __getnewargs__.
# If they are the same function, we ignore __getstate__.
# This is for the benefit of protocols 0 and 1, which don't
# use __getnewargs__. Note that the only way to make them
# the same function is something like this:
#
# class C(object):
# def __getstate__(self):
# return ...
# __getnewargs__ = __getstate__
#
# No tricks are needed to ignore __setstate__; it simply
# won't be called when we don't generate BUILD.
# Also note that when __getnewargs__ and __getstate__ are
# the same function, we don't do the default thing of
# looking for __dict__ and slots either -- it is assumed
# that __getnewargs__ returns all the state there is
# (which should be a safe assumption since __getstate__
# returns the *same* state).
if getstate == getnewargs:
return
try:
state = getstate()
except TypeError, err:
# XXX Catch generic exception caused by __slots__
if str(err) != ("a class that defines __slots__ "
"without defining __getstate__ "
"cannot be pickled"):
raise # Not that specific exception
getstate = None
if not getstate:
state = getattr(obj, "__dict__", None)
if not state:
state = None
# If there are slots, the state becomes a tuple of two
# items: the first item the regular __dict__ or None, and
# the second a dict mapping slot names to slot values
names = _slotnames(t)
if names:
slots = {}
nil = []
for name in names:
value = getattr(obj, name, nil)
if value is not nil:
slots[name] = value
if slots:
state = (state, slots)
if state is not None:
save(state)
write(BUILD)
# Methods below this point are dispatched through the dispatch table # Methods below this point are dispatched through the dispatch table
dispatch = {} dispatch = {}
...@@ -713,7 +630,8 @@ class Pickler: ...@@ -713,7 +630,8 @@ class Pickler:
tmp = [] tmp = []
for i in r: for i in r:
try: try:
tmp.append(items.next()) x = items.next()
tmp.append(x)
except StopIteration: except StopIteration:
items = None items = None
break break
...@@ -865,42 +783,6 @@ class Pickler: ...@@ -865,42 +783,6 @@ class Pickler:
# Pickling helpers # Pickling helpers
def _slotnames(cls):
"""Return a list of slot names for a given class.
This needs to find slots defined by the class and its bases, so we
can't simply return the __slots__ attribute. We must walk down
the Method Resolution Order and concatenate the __slots__ of each
class found there. (This assumes classes don't modify their
__slots__ attribute to misrepresent their slots after the class is
defined.)
"""
# Get the value from a cache in the class if possible
names = cls.__dict__.get("__slotnames__")
if names is not None:
return names
# Not cached -- calculate the value
names = []
if not hasattr(cls, "__slots__"):
# This class has no slots
pass
else:
# Slots found -- gather slot names from all base classes
for c in cls.__mro__:
if "__slots__" in c.__dict__:
names += [name for name in c.__dict__["__slots__"]
if name not in ("__dict__", "__weakref__")]
# Cache the outcome in the class if at all possible
try:
cls.__slotnames__ = names
except:
pass # But don't die if we can't
return names
def _keep_alive(x, memo): def _keep_alive(x, memo):
"""Keeps a reference to the object x in the memo. """Keeps a reference to the object x in the memo.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment