Commit 3422c99d authored by Jeremy Hylton's avatar Jeremy Hylton

Raise PicklingError when __reduce__() fails, and

add memoize() helper function to update the memo.

The first element of the tuple returned by __reduce__() must be a
callable.  If it isn't the Unpickler will raise an error.  Catch this
error in the pickler and raise the error there.

The memoize() helper also has a comment explaining how the memo
works.  So methods can't use memoize() because the write funny codes.
parent 234d9a9e
...@@ -167,6 +167,22 @@ class Pickler: ...@@ -167,6 +167,22 @@ class Pickler:
self.save(object) self.save(object)
self.write(STOP) self.write(STOP)
def memoize(self, obj):
"""Store an object in the memo."""
# The memo is a dictionary mapping object ids to 2-tuples
# that contains the memo value and the object being memoized.
# The memo value is written to the pickle and will become
# the key in the Unpickler's memo. The object is stored in the
# memo so that transient objects are kept alive during pickling.
# The use of the memo length as the memo value is just a convention.
# The only requirement is that the memo values by unique.
d = id(obj)
memo_len = len(self.memo)
self.write(self.put(memo_len))
self.memo[d] = memo_len, obj
def put(self, i): def put(self, i):
if self.bin: if self.bin:
s = mdumps(i)[1:] s = mdumps(i)[1:]
...@@ -280,11 +296,15 @@ class Pickler: ...@@ -280,11 +296,15 @@ class Pickler:
self.save(pid) self.save(pid)
self.write(BINPERSID) self.write(BINPERSID)
def save_reduce(self, callable, arg_tup, state = None): def save_reduce(self, acallable, arg_tup, state = None):
write = self.write write = self.write
save = self.save save = self.save
save(callable) if not callable(acallable):
raise PicklingError("__reduce__() must return callable as "
"first argument, not %s" % `acallable`)
save(acallable)
save(arg_tup) save(arg_tup)
write(REDUCE) write(REDUCE)
...@@ -340,9 +360,6 @@ class Pickler: ...@@ -340,9 +360,6 @@ class Pickler:
dispatch[FloatType] = save_float dispatch[FloatType] = save_float
def save_string(self, object): def save_string(self, object):
d = id(object)
memo = self.memo
if self.bin: if self.bin:
l = len(object) l = len(object)
s = mdumps(l)[1:] s = mdumps(l)[1:]
...@@ -352,16 +369,10 @@ class Pickler: ...@@ -352,16 +369,10 @@ class Pickler:
self.write(BINSTRING + s + object) self.write(BINSTRING + s + object)
else: else:
self.write(STRING + `object` + '\n') self.write(STRING + `object` + '\n')
self.memoize(object)
memo_len = len(memo)
self.write(self.put(memo_len))
memo[d] = (memo_len, object)
dispatch[StringType] = save_string dispatch[StringType] = save_string
def save_unicode(self, object): def save_unicode(self, object):
d = id(object)
memo = self.memo
if self.bin: if self.bin:
encoding = object.encode('utf-8') encoding = object.encode('utf-8')
l = len(encoding) l = len(encoding)
...@@ -371,17 +382,12 @@ class Pickler: ...@@ -371,17 +382,12 @@ class Pickler:
object = object.replace("\\", "\\u005c") object = object.replace("\\", "\\u005c")
object = object.replace("\n", "\\u000a") object = object.replace("\n", "\\u000a")
self.write(UNICODE + object.encode('raw-unicode-escape') + '\n') self.write(UNICODE + object.encode('raw-unicode-escape') + '\n')
self.memoize(object)
memo_len = len(memo)
self.write(self.put(memo_len))
memo[d] = (memo_len, object)
dispatch[UnicodeType] = save_unicode dispatch[UnicodeType] = save_unicode
if StringType == UnicodeType: if StringType == UnicodeType:
# This is true for Jython # This is true for Jython
def save_string(self, object): def save_string(self, object):
d = id(object)
memo = self.memo
unicode = object.isunicode() unicode = object.isunicode()
if self.bin: if self.bin:
...@@ -404,14 +410,10 @@ class Pickler: ...@@ -404,14 +410,10 @@ class Pickler:
self.write(UNICODE + object + '\n') self.write(UNICODE + object + '\n')
else: else:
self.write(STRING + `object` + '\n') self.write(STRING + `object` + '\n')
self.memoize(object)
memo_len = len(memo)
self.write(self.put(memo_len))
memo[d] = (memo_len, object)
dispatch[StringType] = save_string dispatch[StringType] = save_string
def save_tuple(self, object): def save_tuple(self, object):
write = self.write write = self.write
save = self.save save = self.save
memo = self.memo memo = self.memo
...@@ -434,6 +436,7 @@ class Pickler: ...@@ -434,6 +436,7 @@ class Pickler:
memo_len = len(memo) memo_len = len(memo)
self.write(TUPLE + self.put(memo_len)) self.write(TUPLE + self.put(memo_len))
memo[d] = (memo_len, object) memo[d] = (memo_len, object)
dispatch[TupleType] = save_tuple dispatch[TupleType] = save_tuple
def save_empty_tuple(self, object): def save_empty_tuple(self, object):
...@@ -451,9 +454,7 @@ class Pickler: ...@@ -451,9 +454,7 @@ class Pickler:
else: else:
write(MARK + LIST) write(MARK + LIST)
memo_len = len(memo) self.memoize(object)
write(self.put(memo_len))
memo[d] = (memo_len, object)
using_appends = (self.bin and (len(object) > 1)) using_appends = (self.bin and (len(object) > 1))
...@@ -471,20 +472,15 @@ class Pickler: ...@@ -471,20 +472,15 @@ class Pickler:
dispatch[ListType] = save_list dispatch[ListType] = save_list
def save_dict(self, object): def save_dict(self, object):
d = id(object)
write = self.write write = self.write
save = self.save save = self.save
memo = self.memo
if self.bin: if self.bin:
write(EMPTY_DICT) write(EMPTY_DICT)
else: else:
write(MARK + DICT) write(MARK + DICT)
memo_len = len(memo) self.memoize(object)
self.write(self.put(memo_len))
memo[d] = (memo_len, object)
using_setitems = (self.bin and (len(object) > 1)) using_setitems = (self.bin and (len(object) > 1))
...@@ -529,6 +525,8 @@ class Pickler: ...@@ -529,6 +525,8 @@ class Pickler:
for arg in args: for arg in args:
save(arg) save(arg)
# This method does not use memoize() so that it can handle
# the special case for non-binary mode.
memo_len = len(memo) memo_len = len(memo)
if self.bin: if self.bin:
write(OBJ + self.put(memo_len)) write(OBJ + self.put(memo_len))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment