Commit 25b90700 authored by Jeremy Hylton's avatar Jeremy Hylton

Variant of a patch from Gintautas Miliauskas.

The old code used itertools.chain(), which didn't work as expected.
If the first of the two iterator grew after the chain had started
consuming the second iterator, the new element(s) in the first
iterator would never be consumed.

Fix this bug by eliminating the need for chaining.  Process all the
_added_during_commit objects at the end, each with its own
ObjectWriter instance.
parent d9451206
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
############################################################################## ##############################################################################
"""Database connection support """Database connection support
$Id: Connection.py,v 1.151 2004/04/16 19:07:00 tim_one Exp $""" $Id: Connection.py,v 1.152 2004/04/16 19:55:04 jeremy Exp $"""
import logging import logging
import sys import sys
import threading import threading
import itertools
import warnings import warnings
from time import time from time import time
from utils import u64 from utils import u64
...@@ -568,6 +567,14 @@ class Connection(ExportImport, object): ...@@ -568,6 +567,14 @@ class Connection(ExportImport, object):
self._importDuringCommit(transaction, *self._import) self._importDuringCommit(transaction, *self._import)
self._import = None self._import = None
# Just in case an object is added as a side-effect of storing
# a modified object. If, for example, a __getstate__() method
# calls add(), the newly added objects will show up in
# _added_during_commit. This sounds insane, but has actually
# happened.
self._added_during_commit = []
for obj in self._registered_objects: for obj in self._registered_objects:
oid = obj._p_oid oid = obj._p_oid
assert oid assert oid
...@@ -592,9 +599,12 @@ class Connection(ExportImport, object): ...@@ -592,9 +599,12 @@ class Connection(ExportImport, object):
self._store_objects(ObjectWriter(obj), transaction) self._store_objects(ObjectWriter(obj), transaction)
for obj in self._added_during_commit:
self._storage_objects(ObjectWriter(obj), transaction)
self._added_during_commit = None
def _store_objects(self, writer, transaction): def _store_objects(self, writer, transaction):
self._added_during_commit = [] for obj in writer:
for obj in itertools.chain(writer, self._added_during_commit):
oid = obj._p_oid oid = obj._p_oid
serial = getattr(obj, "_p_serial", z64) serial = getattr(obj, "_p_serial", z64)
...@@ -625,7 +635,6 @@ class Connection(ExportImport, object): ...@@ -625,7 +635,6 @@ class Connection(ExportImport, object):
raise raise
self._handle_serial(s, oid) self._handle_serial(s, oid)
self._added_during_commit = None
def commit_sub(self, t): def commit_sub(self, t):
"""Commit all work done in all subtransactions for this transaction.""" """Commit all work done in all subtransactions for this transaction."""
......
...@@ -115,7 +115,9 @@ class ConnectionDotAdd(unittest.TestCase): ...@@ -115,7 +115,9 @@ class ConnectionDotAdd(unittest.TestCase):
self.assertEquals(self.db._storage._finished, [oid]) self.assertEquals(self.db._storage._finished, [oid])
def checkModifyOnGetstate(self): def checkModifyOnGetstate(self):
member = StubObject()
subobj = StubObject() subobj = StubObject()
subobj.member = member
obj = ModifyOnGetStateObject(subobj) obj = ModifyOnGetStateObject(subobj)
self.datamgr.add(obj) self.datamgr.add(obj)
self.datamgr.tpc_begin(self.transaction) self.datamgr.tpc_begin(self.transaction)
...@@ -125,6 +127,7 @@ class ConnectionDotAdd(unittest.TestCase): ...@@ -125,6 +127,7 @@ class ConnectionDotAdd(unittest.TestCase):
self.assert_(obj._p_oid in storage._stored, "object was not stored") self.assert_(obj._p_oid in storage._stored, "object was not stored")
self.assert_(subobj._p_oid in storage._stored, self.assert_(subobj._p_oid in storage._stored,
"subobject was not stored") "subobject was not stored")
self.assert_(member._p_oid in storage._stored, "member was not stored")
self.assert_(self.datamgr._added_during_commit is None) self.assert_(self.datamgr._added_during_commit is None)
def checkUnusedAddWorks(self): def checkUnusedAddWorks(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment