Commit 39fb497b authored by Jim Fulton's avatar Jim Fulton

Fixed a bug in file pool: it didn't properly handle multiple write

locks.

In fixing, also made it work with with.
parent 37f26139
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
"""Storage implementation using a log written to a single file. """Storage implementation using a log written to a single file.
""" """
from __future__ import with_statement
from cPickle import Pickler, Unpickler, loads from cPickle import Pickler, Unpickler, loads
from persistent.TimeStamp import TimeStamp from persistent.TimeStamp import TimeStamp
from struct import pack, unpack from struct import pack, unpack
...@@ -32,6 +34,7 @@ from ZODB.utils import p64, u64, z64 ...@@ -32,6 +34,7 @@ from ZODB.utils import p64, u64, z64
import base64 import base64
import BTrees.OOBTree import BTrees.OOBTree
import contextlib
import errno import errno
import logging import logging
import os import os
...@@ -409,8 +412,7 @@ class FileStorage( ...@@ -409,8 +412,7 @@ class FileStorage(
"""Return pickle data and serial number.""" """Return pickle data and serial number."""
assert not version assert not version
_file = self._files.get() with self._files.get() as _file:
try:
pos = self._lookup_pos(oid) pos = self._lookup_pos(oid)
h = self._read_data_header(pos, oid, _file) h = self._read_data_header(pos, oid, _file)
if h.plen: if h.plen:
...@@ -423,8 +425,6 @@ class FileStorage( ...@@ -423,8 +425,6 @@ class FileStorage(
return data, h.tid return data, h.tid
else: else:
raise POSKeyError(oid) raise POSKeyError(oid)
finally:
self._files.put(_file)
def loadSerial(self, oid, serial): def loadSerial(self, oid, serial):
self._lock_acquire() self._lock_acquire()
...@@ -445,8 +445,7 @@ class FileStorage( ...@@ -445,8 +445,7 @@ class FileStorage(
self._lock_release() self._lock_release()
def loadBefore(self, oid, tid): def loadBefore(self, oid, tid):
_file = self._files.get() with self._files.get() as _file:
try:
pos = self._lookup_pos(oid) pos = self._lookup_pos(oid)
end_tid = None end_tid = None
while True: while True:
...@@ -464,8 +463,6 @@ class FileStorage( ...@@ -464,8 +463,6 @@ class FileStorage(
return data, h.tid, end_tid return data, h.tid, end_tid
else: else:
return _file.read(h.plen), h.tid, end_tid return _file.read(h.plen), h.tid, end_tid
finally:
self._files.put(_file)
def store(self, oid, oldserial, data, version, transaction): def store(self, oid, oldserial, data, version, transaction):
if self._is_read_only: if self._is_read_only:
...@@ -718,12 +715,8 @@ class FileStorage( ...@@ -718,12 +715,8 @@ class FileStorage(
self._lock_release() self._lock_release()
def tpc_finish(self, transaction, f=None): def tpc_finish(self, transaction, f=None):
with self._files.write_lock():
# Get write lock with self._lock:
self._files.write_lock()
try:
self._lock_acquire()
try:
if transaction is not self._transaction: if transaction is not self._transaction:
raise POSException.StorageTransactionError( raise POSException.StorageTransactionError(
"tpc_finish called with wrong transaction") "tpc_finish called with wrong transaction")
...@@ -737,11 +730,6 @@ class FileStorage( ...@@ -737,11 +730,6 @@ class FileStorage(
self._ude = None self._ude = None
self._transaction = None self._transaction = None
self._commit_lock_release() self._commit_lock_release()
finally:
self._lock_release()
finally:
self._files.write_unlock()
def _finish(self, tid, u, d, e): def _finish(self, tid, u, d, e):
# If self._nextpos is 0, then the transaction didn't write any # If self._nextpos is 0, then the transaction didn't write any
...@@ -1139,25 +1127,21 @@ class FileStorage( ...@@ -1139,25 +1127,21 @@ class FileStorage(
return return
have_commit_lock = True have_commit_lock = True
opos, index = pack_result opos, index = pack_result
self._files.write_lock() with self._files.write_lock():
self._lock_acquire() with self._lock:
try: self._files.empty()
self._files.empty() self._file.close()
self._file.close() try:
try: os.rename(self._file_name, oldpath)
os.rename(self._file_name, oldpath) except Exception:
except Exception: self._file = open(self._file_name, 'r+b')
self._file = open(self._file_name, 'r+b') raise
raise
# OK, we're beyond the point of no return # OK, we're beyond the point of no return
os.rename(self._file_name + '.pack', self._file_name) os.rename(self._file_name + '.pack', self._file_name)
self._file = open(self._file_name, 'r+b') self._file = open(self._file_name, 'r+b')
self._initIndex(index, self._tindex) self._initIndex(index, self._tindex)
self._pos = opos self._pos = opos
finally:
self._files.write_unlock()
self._lock_release()
# We're basically done. Now we need to deal with removed # We're basically done. Now we need to deal with removed
# blobs and removing the .old file (see further down). # blobs and removing the .old file (see further down).
...@@ -2053,6 +2037,7 @@ class FilePool: ...@@ -2053,6 +2037,7 @@ class FilePool:
closed = False closed = False
writing = False writing = False
writers = 0
def __init__(self, file_name): def __init__(self, file_name):
self.name = file_name self.name = file_name
...@@ -2060,26 +2045,31 @@ class FilePool: ...@@ -2060,26 +2045,31 @@ class FilePool:
self._out = [] self._out = []
self._cond = threading.Condition() self._cond = threading.Condition()
@contextlib.contextmanager
def write_lock(self): def write_lock(self):
self._cond.acquire() with self._cond:
try: self.writers += 1
self.writing = True while self.writing or self._out:
while self._out:
self._cond.wait() self._cond.wait()
finally: if self.closed:
self._cond.release() raise ValueError('closed')
self.writing = True
def write_unlock(self): try:
self._cond.acquire() yield None
self.writing = False finally:
self._cond.notifyAll() with self._cond:
self._cond.release() self.writing = False
if self.writers > 0:
self.writers -= 1
self._cond.notifyAll()
@contextlib.contextmanager
def get(self): def get(self):
self._cond.acquire() with self._cond:
try: while self.writers:
while self.writing:
self._cond.wait() self._cond.wait()
assert not self.writing
if self.closed: if self.closed:
raise ValueError('closed') raise ValueError('closed')
...@@ -2088,32 +2078,25 @@ class FilePool: ...@@ -2088,32 +2078,25 @@ class FilePool:
except IndexError: except IndexError:
f = open(self.name, 'rb') f = open(self.name, 'rb')
self._out.append(f) self._out.append(f)
return f
finally:
self._cond.release()
def put(self, f): try:
self._out.remove(f) yield f
self._files.append(f) finally:
if not self._out: self._out.remove(f)
self._cond.acquire() self._files.append(f)
try: if not self._out:
if self.writing and not self._out: with self._cond:
self._cond.notifyAll() if self.writers and not self._out:
finally: self._cond.notifyAll()
self._cond.release()
def empty(self): def empty(self):
while self._files: while self._files:
self._files.pop().close() self._files.pop().close()
def close(self): def close(self):
self._cond.acquire() with self._cond:
self.closed = True self.closed = True
self._cond.release() while self._out:
self._out.pop().close()
self.write_lock()
try:
self.empty() self.empty()
finally: self.writing = self.writers = 0
self.write_unlock()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment