Commit 86b7ebbd authored by Julien Muchembled's avatar Julien Muchembled

storage: prevent 2 nodes from working with the same database

parent 8d42a2e6
......@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import struct, threading
import os, errno, socket, struct, sys, threading
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
......@@ -55,6 +55,10 @@ class DatabaseManager(object):
ENGINES = ()
UNSAFE = False
__lock = None
LOCK = "neostorage"
LOCKED = "error: database is locked"
_deferred = 0
_duplicating = _repairing = None
......@@ -84,6 +88,7 @@ class DatabaseManager(object):
def _duplicate(self):
cls = self.__class__
db = cls.__new__(cls)
db.LOCK = None
db._duplicating = self
try:
db._connect()
......@@ -102,6 +107,26 @@ class DatabaseManager(object):
def _connect(self):
"""Connect to the database"""
def lock(self, db_path):
if self.LOCK:
assert self.__lock is None, self.__lock
# For platforms that don't support anonymous sockets,
# we can either use zc.lockfile or an empty SQLite db
# (with BEGIN EXCLUSIVE).
try:
stat = os.stat(db_path)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return # in-memory or temporary database
s = self.__lock = socket.socket(socket.AF_UNIX)
try:
s.bind('\0%s:%s:%s' % (self.LOCK, stat.st_dev, stat.st_ino))
except socket.error as e:
if e.errno != errno.EADDRINUSE:
raise
sys.exit(self.LOCKED)
@abstract
def erase(self):
""""""
......@@ -152,6 +177,9 @@ class DatabaseManager(object):
def close(self):
self._deferredCommit()
self._close()
if self.__lock:
self.__lock.close()
del self.__lock
def _commit(self):
"""Backend-specific code to commit the pending changes"""
......
......@@ -29,6 +29,7 @@ import os
import re
import string
import struct
import sys
import time
from . import LOG_QUERIES
......@@ -102,9 +103,17 @@ class MySQLDatabaseManager(DatabaseManager):
conn.autocommit(False)
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
conn.query("SHOW VARIABLES WHERE variable_name='max_allowed_packet'")
r = conn.store_result()
(name, value), = r.fetch_row(r.num_rows())
def query(sql):
conn.query(sql)
r = conn.store_result()
return r.fetch_row(r.num_rows())
if self.LOCK:
(locked,), = query("SELECT GET_LOCK('%s.%s', 0)"
% (self.db, self.LOCK))
if not locked:
sys.exit(self.LOCKED)
(name, value), = query(
"SHOW VARIABLES WHERE variable_name='max_allowed_packet'")
if int(value) < self._max_allowed_packet:
raise DatabaseFailure("Global variable %r is too small."
" Minimal value must be %uk."
......
......@@ -78,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self):
logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False)
self.lock(self.db)
if self.UNSAFE:
q = self.query
q("PRAGMA synchronous = OFF")
......
......@@ -217,7 +217,8 @@ class NeoUnitTestBase(NeoTestBase):
temp_dir = getTempDirectory()
for i in xrange(number):
try:
os.remove(os.path.join(temp_dir, 'test_neo%s.sqlite' % i))
os.remove(os.path.join(temp_dir,
'%s%s.sqlite' % (prefix, i)))
except OSError, e:
if e.errno != errno.ENOENT:
raise
......
......@@ -37,10 +37,11 @@ from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES
from neo.lib.util import dump
from .. import ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL, \
buildUrlFromString, cluster, getTempDirectory, NeoTestBase, setupMySQLdb
from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, NeoTestBase, Patch,
setupMySQLdb)
from neo.client.Storage import Storage
from neo.storage.database import buildDatabaseManager
from neo.storage.database import manager, buildDatabaseManager
try:
coverage = sys.modules['neo.scripts.runner'].coverage
......@@ -483,7 +484,8 @@ class NEOCluster(object):
def getSQLConnection(self, db):
assert db is not None and db in self.db_list
return buildDatabaseManager(self.adapter, (self.db_template(db),))
with Patch(manager.DatabaseManager, LOCK=None):
return buildDatabaseManager(self.adapter, (self.db_template(db),))
def getMasterProcessList(self):
return self.process_dict.get(NodeTypes.MASTER)
......
......@@ -131,6 +131,15 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2))
def _test_lockDatabase_open(self):
raise NotImplementedError
def test_lockDatabase(self):
db = self._test_lockDatabase_open()
self.assertRaises(SystemExit, self._test_lockDatabase_open)
db.close()
self._test_lockDatabase_open().close()
def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2)
......
......@@ -29,11 +29,13 @@ class StorageMySQLdbTests(StorageDBTests):
engine = None
def getDB(self, reset=0):
def _test_lockDatabase_open(self):
self.prepareDatabase(number=1, prefix=DB_PREFIX)
# db manager
database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET)
db = MySQLDatabaseManager(database, self.engine)
return MySQLDatabaseManager(database, self.engine)
def getDB(self, reset=0):
db = self._test_lockDatabase_open()
self.assertEqual(db.db, DB_PREFIX + '0')
self.assertEqual(db.user, DB_USER)
try:
......@@ -129,11 +131,13 @@ class StorageMySQLdbTests(StorageDBTests):
class StorageMySQLdbRocksDBTests(StorageMySQLdbTests):
engine = "RocksDB"
test_lockDatabase = None
class StorageMySQLdbTokuDBTests(StorageMySQLdbTests):
engine = "TokuDB"
test_lockDatabase = None
del StorageDBTests
......
......@@ -14,17 +14,29 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
import os, unittest
from .. import getTempDirectory, DB_PREFIX
from .testStorageDBTests import StorageDBTests
from neo.storage.database.sqlite import SQLiteDatabaseManager
class StorageSQLiteTests(StorageDBTests):
def _test_lockDatabase_open(self):
db = os.path.join(getTempDirectory(), DB_PREFIX + '0.sqlite')
return SQLiteDatabaseManager(db)
def getDB(self, reset=0):
db = SQLiteDatabaseManager(':memory:')
db.setup(reset)
return db
def test_lockDatabase(self):
super(StorageSQLiteTests, self).test_lockDatabase()
# No lock on temporary databases.
db = self.getDB()
self.getDB().close()
db.close()
del StorageDBTests
if __name__ == "__main__":
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment