Commit bff42749 authored by Tres Seaver's avatar Tres Seaver

Test cPersistence handling of derived classes w/ slots.

Match its behavior in pyPersistence.
parent a60a8426
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# #
############################################################################## ##############################################################################
from copy_reg import __newobj__ from copy_reg import __newobj__
from copy_reg import _slotnames
import sys import sys
from zope.interface import implements from zope.interface import implements
...@@ -252,34 +253,51 @@ class Persistent(object): ...@@ -252,34 +253,51 @@ class Persistent(object):
_OGA(self, '_p_register')() _OGA(self, '_p_register')()
object.__delattr__(self, name) object.__delattr__(self, name)
def _slotnames(self):
slotnames = _slotnames(type(self))
return [x for x in slotnames
if not x.startswith('_p_') and
not x.startswith('_v_') and
not x.startswith('_Persistent__') and
x not in Persistent.__slots__]
def __getstate__(self): def __getstate__(self):
""" See IPersistent. """ See IPersistent.
""" """
idict = getattr(self, '__dict__', None) idict = getattr(self, '__dict__', None)
slotnames = self._slotnames()
if idict is not None: if idict is not None:
return dict([x for x in idict.items() d = dict([x for x in idict.items()
if not x[0].startswith('_p_') and if not x[0].startswith('_p_') and
not x[0].startswith('_v_')]) not x[0].startswith('_v_')])
slots = getattr(type(self), '__slots__', None) else:
if slots is not None: d = None
slots = [x for x in slots if slotnames:
if not x.startswith('_p_') and s = {}
not x.startswith('_v_') and for slotname in slotnames:
x not in Persistent.__slots__] value = getattr(self, slotname, self)
if slots: if value is not self:
return None, dict([(x, getattr(self, x)) for x in slots]) s[slotname] = value
return None return d, s
return d
def __setstate__(self, state): def __setstate__(self, state):
""" See IPersistent. """ See IPersistent.
""" """
try:
inst_dict, slots = state
except:
inst_dict, slots = state, ()
idict = getattr(self, '__dict__', None) idict = getattr(self, '__dict__', None)
if idict is not None: if inst_dict is not None:
if idict is None:
raise TypeError('No instance dict')
idict.clear() idict.clear()
idict.update(state) idict.update(inst_dict)
else: slotnames = self._slotnames()
if state != None: if slotnames:
raise ValueError('No state allowed on base Persistent class') for k, v in slots.items():
setattr(self, k, v)
def __reduce__(self): def __reduce__(self):
""" See IPersistent. """ See IPersistent.
......
...@@ -704,6 +704,30 @@ class _Persistent_Base(object): ...@@ -704,6 +704,30 @@ class _Persistent_Base(object):
inst._v_qux = 'spam' inst._v_qux = 'spam'
self.assertEqual(inst.__getstate__(), (None, {'foo': 'bar'})) self.assertEqual(inst.__getstate__(), (None, {'foo': 'bar'}))
def test___getstate___derived_w_slots_in_base_and_derived(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
__slots__ = ('baz', 'qux',)
inst = Derived()
inst.foo = 'bar'
inst.baz = 'bam'
inst.qux = 'spam'
self.assertEqual(inst.__getstate__(),
(None, {'foo': 'bar', 'baz': 'bam', 'qux': 'spam'}))
def test___getstate___derived_w_slots_in_base_but_not_derived(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
pass
inst = Derived()
inst.foo = 'bar'
inst.baz = 'bam'
inst.qux = 'spam'
self.assertEqual(inst.__getstate__(),
({'baz': 'bam', 'qux': 'spam'}, {'foo': 'bar'}))
def test___setstate___empty(self): def test___setstate___empty(self):
inst = self._makeOne() inst = self._makeOne()
inst.__setstate__(None) # doesn't raise, but doesn't change anything inst.__setstate__(None) # doesn't raise, but doesn't change anything
...@@ -727,6 +751,35 @@ class _Persistent_Base(object): ...@@ -727,6 +751,35 @@ class _Persistent_Base(object):
inst.__setstate__({'baz': 'bam'}) inst.__setstate__({'baz': 'bam'})
self.assertEqual(inst.__dict__, {'baz': 'bam'}) self.assertEqual(inst.__dict__, {'baz': 'bam'})
def test___setstate___derived_w_slots(self):
class Derived(self._getTargetClass()):
__slots__ = ('foo', '_p_baz', '_v_qux')
inst = Derived()
inst.__setstate__((None, {'foo': 'bar'}))
self.assertEqual(inst.foo, 'bar')
def test___setstate___derived_w_slots_in_base_classes(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
__slots__ = ('baz', 'qux',)
inst = Derived()
inst.__setstate__((None, {'foo': 'bar', 'baz': 'bam', 'qux': 'spam'}))
self.assertEqual(inst.foo, 'bar')
self.assertEqual(inst.baz, 'bam')
self.assertEqual(inst.qux, 'spam')
def test___setstate___derived_w_slots_in_base_but_not_derived(self):
class Base(self._getTargetClass()):
__slots__ = ('foo',)
class Derived(Base):
pass
inst = Derived()
inst.__setstate__(({'baz': 'bam', 'qux': 'spam'}, {'foo': 'bar'}))
self.assertEqual(inst.foo, 'bar')
self.assertEqual(inst.baz, 'bam')
self.assertEqual(inst.qux, 'spam')
def test___reduce__(self): def test___reduce__(self):
from copy_reg import __newobj__ from copy_reg import __newobj__
inst = self._makeOne() inst = self._makeOne()
...@@ -1067,8 +1120,11 @@ class PyPersistentTests(unittest.TestCase, _Persistent_Base): ...@@ -1067,8 +1120,11 @@ class PyPersistentTests(unittest.TestCase, _Persistent_Base):
jar._cache._mru[:] = [] jar._cache._mru[:] = []
import os try:
if os.environ.get('run_C_tests'): from persistent import cPersistence
except ImportError:
pass
else:
class CPersistentTests(unittest.TestCase, _Persistent_Base): class CPersistentTests(unittest.TestCase, _Persistent_Base):
def _getTargetClass(self): def _getTargetClass(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment