Commit 4536ca7e authored by Guido van Rossum's avatar Guido van Rossum

The default __reduce__ on the base object type obscured any

possibility of calling save_reduce().  Add a special hack for this.
The tests for this are much simpler now (no __getstate__ or
__getnewargs__ needed).
parent 9b4365ef
...@@ -27,7 +27,7 @@ Misc variables: ...@@ -27,7 +27,7 @@ Misc variables:
__version__ = "$Revision$" # Code version __version__ = "$Revision$" # Code version
from types import * from types import *
from copy_reg import dispatch_table, safe_constructors from copy_reg import dispatch_table, safe_constructors, _reconstructor
import marshal import marshal
import sys import sys
import struct import struct
...@@ -320,6 +320,13 @@ class Pickler: ...@@ -320,6 +320,13 @@ class Pickler:
raise PicklingError("Tuple returned by %s must have " raise PicklingError("Tuple returned by %s must have "
"exactly two or three elements" % reduce) "exactly two or three elements" % reduce)
# XXX Temporary hack XXX
# Override the default __reduce__ for new-style class instances
if self.proto >= 2:
if func is _reconstructor:
self.save_newobj(obj)
return
# Save the reduce() output and finally memoize the object # Save the reduce() output and finally memoize the object
self.save_reduce(func, args, state) self.save_reduce(func, args, state)
self.memoize(obj) self.memoize(obj)
...@@ -369,14 +376,37 @@ class Pickler: ...@@ -369,14 +376,37 @@ class Pickler:
# Save a new-style class instance, using protocol 2. # Save a new-style class instance, using protocol 2.
# XXX Much of this is still experimental. # XXX Much of this is still experimental.
t = type(obj) t = type(obj)
args = ()
getnewargs = getattr(obj, "__getnewargs__", None) getnewargs = getattr(obj, "__getnewargs__", None)
if getnewargs: if getnewargs:
args = getnewargs() # This better not reference obj args = getnewargs() # This better not reference obj
else:
for cls in int, long, float, complex, str, unicode, tuple:
if isinstance(obj, cls):
args = (cls(obj),)
break
else:
args = ()
save = self.save
write = self.write
self.save_global(t) self.save_global(t)
self.save(args) save(args)
self.write(NEWOBJ) write(NEWOBJ)
self.memoize(obj) self.memoize(obj)
if isinstance(obj, list):
write(MARK)
for x in obj:
save(x)
write(APPENDS)
elif isinstance(obj, dict):
write(MARK)
for k, v in obj.iteritems():
save(k)
save(v)
write(SETITEMS)
getstate = getattr(obj, "__getstate__", None) getstate = getattr(obj, "__getstate__", None)
if getstate: if getstate:
state = getstate() state = getstate()
...@@ -384,9 +414,8 @@ class Pickler: ...@@ -384,9 +414,8 @@ class Pickler:
state = getattr(obj, "__dict__", None) state = getattr(obj, "__dict__", None)
# XXX What about __slots__? # XXX What about __slots__?
if state is not None: if state is not None:
self.save(state) save(state)
self.write(BUILD) write(BUILD)
return
# Methods below this point are dispatched through the dispatch table # Methods below this point are dispatched through the dispatch table
...@@ -1173,6 +1202,8 @@ def encode_long(x): ...@@ -1173,6 +1202,8 @@ def encode_long(x):
'\x7f' '\x7f'
>>> >>>
""" """
# XXX This is still a quadratic algorithm.
# Should use hex() to get started.
digits = [] digits = []
while not -128 <= x < 128: while not -128 <= x < 128:
digits.append(x & 0xff) digits.append(x & 0xff)
...@@ -1195,6 +1226,7 @@ def decode_long(data): ...@@ -1195,6 +1226,7 @@ def decode_long(data):
>>> decode_long("\x7f") >>> decode_long("\x7f")
127L 127L
""" """
# XXX This is quadratic too.
x = 0L x = 0L
i = 0L i = 0L
for c in data: for c in data:
......
...@@ -301,43 +301,34 @@ class AbstractPickleTests(unittest.TestCase): ...@@ -301,43 +301,34 @@ class AbstractPickleTests(unittest.TestCase):
self.assert_(x is y, (proto, x, s, y)) self.assert_(x is y, (proto, x, s, y))
def test_newobj_tuple(self): def test_newobj_tuple(self):
x = MyTuple([1, 2, 3], foo=42, bar="hello") x = MyTuple([1, 2, 3])
x.foo = 42
x.bar = "hello"
s = self.dumps(x, 2) s = self.dumps(x, 2)
y = self.loads(s) y = self.loads(s)
self.assertEqual(tuple(x), tuple(y)) self.assertEqual(tuple(x), tuple(y))
self.assertEqual(x.__dict__, y.__dict__) self.assertEqual(x.__dict__, y.__dict__)
## import pickletools
## print
## pickletools.dis(s)
def test_newobj_list(self): def test_newobj_list(self):
x = MyList([1, 2, 3], foo=42, bar="hello") x = MyList([1, 2, 3])
x.foo = 42
x.bar = "hello"
s = self.dumps(x, 2) s = self.dumps(x, 2)
y = self.loads(s) y = self.loads(s)
self.assertEqual(list(x), list(y)) self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__) self.assertEqual(x.__dict__, y.__dict__)
## import pickletools
## print
## pickletools.dis(s)
class MyTuple(tuple): class MyTuple(tuple):
def __new__(cls, *args, **kwds): pass
# Ignore **kwds
return tuple.__new__(cls, *args)
def __getnewargs__(self):
return (tuple(self),)
def __init__(self, *args, **kwds):
for k, v in kwds.items():
setattr(self, k, v)
class MyList(list): class MyList(list):
def __new__(cls, *args, **kwds): pass
# Ignore **kwds
return list.__new__(cls, *args)
def __init__(self, *args, **kwds):
for k, v in kwds.items():
setattr(self, k, v)
def __getstate__(self):
return list(self), self.__dict__
def __setstate__(self, arg):
lst, dct = arg
for x in lst:
self.append(x)
self.__init__(**dct)
class AbstractPickleModuleTests(unittest.TestCase): class AbstractPickleModuleTests(unittest.TestCase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment