Commit bff42749 authored by Tres Seaver's avatar Tres Seaver

Test cPersistence handling of derived classes w/ slots.

Match its behavior in pyPersistence.
parent a60a8426
......@@ -12,6 +12,7 @@
#
##############################################################################
from copy_reg import __newobj__
from copy_reg import _slotnames
import sys
from zope.interface import implements
......@@ -252,34 +253,51 @@ class Persistent(object):
_OGA(self, '_p_register')()
object.__delattr__(self, name)
def _slotnames(self):
slotnames = _slotnames(type(self))
return [x for x in slotnames
if not x.startswith('_p_') and
not x.startswith('_v_') and
not x.startswith('_Persistent__') and
x not in Persistent.__slots__]
def __getstate__(self):
""" See IPersistent.
"""
idict = getattr(self, '__dict__', None)
slotnames = self._slotnames()
if idict is not None:
return dict([x for x in idict.items()
if not x[0].startswith('_p_') and
not x[0].startswith('_v_')])
slots = getattr(type(self), '__slots__', None)
if slots is not None:
slots = [x for x in slots
if not x.startswith('_p_') and
not x.startswith('_v_') and
x not in Persistent.__slots__]
if slots:
return None, dict([(x, getattr(self, x)) for x in slots])
return None
d = dict([x for x in idict.items()
if not x[0].startswith('_p_') and
not x[0].startswith('_v_')])
else:
d = None
if slotnames:
s = {}
for slotname in slotnames:
value = getattr(self, slotname, self)
if value is not self:
s[slotname] = value
return d, s
return d
def __setstate__(self, state):
""" See IPersistent.
"""
try:
inst_dict, slots = state
except:
inst_dict, slots = state, ()
idict = getattr(self, '__dict__', None)
if idict is not None:
if inst_dict is not None:
if idict is None:
raise TypeError('No instance dict')
idict.clear()
idict.update(state)
else:
if state != None:
raise ValueError('No state allowed on base Persistent class')
idict.update(inst_dict)
slotnames = self._slotnames()
if slotnames:
for k, v in slots.items():
setattr(self, k, v)
def __reduce__(self):
""" See IPersistent.
......
......@@ -704,6 +704,30 @@ class _Persistent_Base(object):
inst._v_qux = 'spam'
self.assertEqual(inst.__getstate__(), (None, {'foo': 'bar'}))
def test___getstate___derived_w_slots_in_base_and_derived(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
__slots__ = ('baz', 'qux',)
inst = Derived()
inst.foo = 'bar'
inst.baz = 'bam'
inst.qux = 'spam'
self.assertEqual(inst.__getstate__(),
(None, {'foo': 'bar', 'baz': 'bam', 'qux': 'spam'}))
def test___getstate___derived_w_slots_in_base_but_not_derived(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
pass
inst = Derived()
inst.foo = 'bar'
inst.baz = 'bam'
inst.qux = 'spam'
self.assertEqual(inst.__getstate__(),
({'baz': 'bam', 'qux': 'spam'}, {'foo': 'bar'}))
def test___setstate___empty(self):
inst = self._makeOne()
inst.__setstate__(None) # doesn't raise, but doesn't change anything
......@@ -727,6 +751,35 @@ class _Persistent_Base(object):
inst.__setstate__({'baz': 'bam'})
self.assertEqual(inst.__dict__, {'baz': 'bam'})
def test___setstate___derived_w_slots(self):
class Derived(self._getTargetClass()):
__slots__ = ('foo', '_p_baz', '_v_qux')
inst = Derived()
inst.__setstate__((None, {'foo': 'bar'}))
self.assertEqual(inst.foo, 'bar')
def test___setstate___derived_w_slots_in_base_classes(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
__slots__ = ('baz', 'qux',)
inst = Derived()
inst.__setstate__((None, {'foo': 'bar', 'baz': 'bam', 'qux': 'spam'}))
self.assertEqual(inst.foo, 'bar')
self.assertEqual(inst.baz, 'bam')
self.assertEqual(inst.qux, 'spam')
def test___setstate___derived_w_slots_in_base_but_not_derived(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
pass
inst = Derived()
inst.__setstate__(({'baz': 'bam', 'qux': 'spam'}, {'foo': 'bar'}))
self.assertEqual(inst.foo, 'bar')
self.assertEqual(inst.baz, 'bam')
self.assertEqual(inst.qux, 'spam')
def test___reduce__(self):
from copy_reg import __newobj__
inst = self._makeOne()
......@@ -1067,8 +1120,11 @@ class PyPersistentTests(unittest.TestCase, _Persistent_Base):
jar._cache._mru[:] = []
import os
if os.environ.get('run_C_tests'):
try:
from persistent import cPersistence
except ImportError:
pass
else:
class CPersistentTests(unittest.TestCase, _Persistent_Base):
def _getTargetClass(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment