Commit 9fc779e7 authored by Tres Seaver's avatar Tres Seaver

Coverage for Transacdtion.savepoint and helpers.

parent 88dd3373
......@@ -59,6 +59,7 @@ class TransactionTests(unittest.TestCase):
self.assertEqual(len(t._synchronizers), 0)
self.assertTrue(t._manager is None)
self.assertTrue(t._savepoint2index is None)
self.assertEqual(t._savepoint_index, 0)
self.assertEqual(t._resources, [])
self.assertEqual(t._adapters, {})
self.assertEqual(t._voted, {})
......@@ -182,6 +183,108 @@ class TransactionTests(unittest.TestCase):
t._unjoin(resource)
self.assertEqual(t._resources, [])
def test_savepoint_COMMITFAILED(self):
from transaction.interfaces import TransactionFailedError
from transaction._transaction import Status
class _Traceback(object):
def getvalue(self):
return 'TRACEBACK'
t = self._makeOne()
t.status = Status.COMMITFAILED
t._failure_traceback = _Traceback()
self.assertRaises(TransactionFailedError, t.savepoint)
def test_savepoint_empty(self):
from weakref import WeakKeyDictionary
from transaction import _transaction
from transaction._transaction import Savepoint
from transaction.tests.common import Monkey
logger = DummyLogger()
with Monkey(_transaction, _LOGGER=logger):
t = self._makeOne()
sp = t.savepoint()
self.assertTrue(isinstance(sp, Savepoint))
self.assertTrue(sp.transaction is t)
self.assertEqual(sp._savepoints, [])
self.assertEqual(t._savepoint_index, 1)
self.assertTrue(isinstance(t._savepoint2index, WeakKeyDictionary))
self.assertEqual(t._savepoint2index[sp], 1)
def test_savepoint_non_optimistc_resource_wo_support(self):
from transaction import _transaction
from transaction._transaction import Status
from transaction._compat import StringIO
from transaction.tests.common import Monkey
logger = DummyLogger()
with Monkey(_transaction, _LOGGER=logger):
t = self._makeOne()
logger._clear()
resource = object()
t._resources.append(resource)
self.assertRaises(TypeError, t.savepoint)
self.assertEqual(t.status, Status.COMMITFAILED)
self.assertTrue(isinstance(t._failure_traceback, StringIO))
self.assertTrue('TypeError' in t._failure_traceback.getvalue())
self.assertEqual(len(logger._log), 2)
self.assertEqual(logger._log[0][0], 'error')
self.assertTrue(logger._log[0][1].startswith('Error in abort'))
self.assertEqual(logger._log[1][0], 'error')
self.assertTrue(logger._log[1][1].startswith('Error in tpc_abort'))
def test__remove_and_invalidate_after_miss(self):
from weakref import WeakKeyDictionary
t = self._makeOne()
t._savepoint2index = WeakKeyDictionary()
class _SP(object):
def __init__(self, t):
self.transaction = t
holdme = []
for i in range(10):
sp = _SP(t)
holdme.append(sp) #prevent gc
t._savepoint2index[sp] = i
self.assertEqual(len(t._savepoint2index), 10)
self.assertRaises(KeyError, t._remove_and_invalidate_after, _SP(t))
self.assertEqual(len(t._savepoint2index), 10)
def test__remove_and_invalidate_after_hit(self):
from weakref import WeakKeyDictionary
t = self._makeOne()
t._savepoint2index = WeakKeyDictionary()
class _SP(object):
def __init__(self, t, index):
self.transaction = t
self._index = index
def __repr__(self):
return '_SP: %d' % self._index
holdme = []
for i in range(10):
sp = _SP(t, i)
holdme.append(sp) #prevent gc
t._savepoint2index[sp] = i
self.assertEqual(len(t._savepoint2index), 10)
t._remove_and_invalidate_after(holdme[1])
self.assertEqual(sorted(t._savepoint2index), sorted(holdme[:2]))
def test__invalidate_all_savepoints(self):
from weakref import WeakKeyDictionary
t = self._makeOne()
t._savepoint2index = WeakKeyDictionary()
class _SP(object):
def __init__(self, t, index):
self.transaction = t
self._index = index
def __repr__(self):
return '_SP: %d' % self._index
holdme = []
for i in range(10):
sp = _SP(t, i)
holdme.append(sp) #prevent gc
t._savepoint2index[sp] = i
self.assertEqual(len(t._savepoint2index), 10)
t._invalidate_all_savepoints()
self.assertEqual(list(t._savepoint2index), [])
def test_note(self):
t = self._makeOne()
try:
......@@ -285,6 +388,8 @@ class MiscellaneousTests(unittest.TestCase):
class DummyLogger(object):
def __init__(self):
self._clear()
def _clear(self):
self._log = []
def log(self, level, msg, *args, **kw):
if args:
......@@ -295,6 +400,8 @@ class DummyLogger(object):
self._log.append((level, msg))
def debug(self, msg, *args, **kw):
self.log('DEBUG', msg, *args, **kw)
def error(self, msg, *args, **kw):
self.log('error', msg, *args, **kw)
def test_suite():
return unittest.TestSuite((
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment