Commit 30a02bdc authored by Julien Muchembled's avatar Julien Muchembled

importer: new option to write back new transactions to the source database

By doing the work with secondary connections to the underlying databases,
asynchronously and in a separate process, this should have minimal impact on
the performance of the storage node. Extra complexity comes from backends that
may lose connection to the database (here MySQL): this commit fully implements
reconnection.
parent 2fae3e54
...@@ -45,6 +45,12 @@ ...@@ -45,6 +45,12 @@
# (instead of adapter=Importer & database=/path_to_this_file). # (instead of adapter=Importer & database=/path_to_this_file).
adapter=MySQL adapter=MySQL
database=neo database=neo
# Keep writing back new transactions to the source database, provided it is
# not splitted. In case of any issue, the import can be aborted without losing
# data. Note however that it is asynchronous so don't stop the storage node
# too quickly after the last committed transaction (e.g. check with tools like
# fstail).
writeback=true
# The other sections are for source databases. # The other sections are for source databases.
[root] [root]
...@@ -52,7 +58,8 @@ database=neo ...@@ -52,7 +58,8 @@ database=neo
# ZEO is possible but less efficient: ZEO servers must be stopped # ZEO is possible but less efficient: ZEO servers must be stopped
# if NEO opens FileStorage DBs directly. # if NEO opens FileStorage DBs directly.
# Note that NEO uses 'new_oid' method to get the last OID, that's why the # Note that NEO uses 'new_oid' method to get the last OID, that's why the
# source DB can't be open read-only. NEO never modifies a FileStorage DB. # source DB can't be open read-only. Unless 'writeback' is enabled, NEO never
# modifies a FileStorage DB.
storage= storage=
<filestorage> <filestorage>
path /path/to/root.fs path /path/to/root.fs
......
...@@ -27,8 +27,7 @@ def check_signature(reference, function): ...@@ -27,8 +27,7 @@ def check_signature(reference, function):
del a[x:] del a[x:]
d = d[:x] or None d = d[:x] or None
elif x: # different signature elif x: # different signature
# We have no need yet to support methods with default parameters. return a == A[:-x] and (b or a and c) and (d or ()) == (D or ())[:-x]
return a == A[:-x] and (b or a and c) and not (d or D)
return a == A and (b or not B) and (c or not C) and d == D return a == A and (b or not B) and (c or not C) and d == D
def implements(obj, ignore=()): def implements(obj, ignore=()):
...@@ -55,7 +54,7 @@ def implements(obj, ignore=()): ...@@ -55,7 +54,7 @@ def implements(obj, ignore=()):
while 1: while 1:
name, func = base.pop() name, func = base.pop()
x = getattr(obj, name) x = getattr(obj, name)
if x.im_class is tobj: if type(getattr(x, '__self__', None)) is tobj:
x = x.__func__ x = x.__func__
if x is func: if x is func:
try: try:
......
...@@ -281,3 +281,16 @@ class NEOLogger(Logger): ...@@ -281,3 +281,16 @@ class NEOLogger(Logger):
logging = NEOLogger() logging = NEOLogger()
signal.signal(signal.SIGRTMIN, lambda signum, frame: logging.flush()) signal.signal(signal.SIGRTMIN, lambda signum, frame: logging.flush())
signal.signal(signal.SIGRTMIN+1, lambda signum, frame: logging.reopen()) signal.signal(signal.SIGRTMIN+1, lambda signum, frame: logging.reopen())
def patch():
def fork():
with logging:
pid = os_fork()
if not pid:
logging._setup()
return pid
os_fork = os.fork
os.fork = fork
patch()
del patch
...@@ -21,7 +21,8 @@ from collections import deque ...@@ -21,7 +21,8 @@ from collections import deque
from cStringIO import StringIO from cStringIO import StringIO
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from ZConfig import loadConfigFile from ZConfig import loadConfigFile
from ZODB.config import getStorageSchema from ZODB import BaseStorage
from ZODB.config import getStorageSchema, storageFromString
from ZODB.POSException import POSKeyError from ZODB.POSException import POSKeyError
from . import buildDatabaseManager, DatabaseFailure from . import buildDatabaseManager, DatabaseFailure
...@@ -295,9 +296,12 @@ class ZODBIterator(object): ...@@ -295,9 +296,12 @@ class ZODBIterator(object):
and self.zodb.shift_oid < other.zodb.shift_oid and self.zodb.shift_oid < other.zodb.shift_oid
is_true = ('false', 'true').index
class ImporterDatabaseManager(DatabaseManager): class ImporterDatabaseManager(DatabaseManager):
"""Proxy that transparently imports data from a ZODB storage """Proxy that transparently imports data from a ZODB storage
""" """
_writeback = None
_last_commit = 0 _last_commit = 0
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
...@@ -315,34 +319,58 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -315,34 +319,58 @@ class ImporterDatabaseManager(DatabaseManager):
config.read(os.path.expanduser(database)) config.read(os.path.expanduser(database))
sections = config.sections() sections = config.sections()
# XXX: defaults copy & pasted from elsewhere - refactoring needed # XXX: defaults copy & pasted from elsewhere - refactoring needed
main = {'adapter': 'MySQL', 'wait': 0} main = self._conf = {'adapter': 'MySQL', 'wait': 0}
main.update(config.items(sections.pop(0))) main.update(config.items(sections.pop(0)))
self.zodb = ((x, dict(config.items(x))) for x in sections) self.zodb = [(x, dict(config.items(x))) for x in sections]
x = main.get('compress', 'true') x = main.get('compress', 'true')
try: try:
self.compress = bool(('false', 'true').index(x)) self.compress = bool(is_true(x))
except ValueError: except ValueError:
self.compress = compress.parseOption(x) self.compress = compress.parseOption(x)
self.db = buildDatabaseManager(main['adapter'], if is_true(main.get('writeback', 'false')):
(main['database'], main.get('engine'), main['wait'])) if len(self.zodb) > 1:
raise Exception(
"Can not forward new transactions to splitted DB.")
self._writeback = self.zodb[0][1]['storage']
def _connect(self):
conf = self._conf
db = self.db = buildDatabaseManager(conf['adapter'],
(conf['database'], conf.get('engine'), conf['wait']))
for x in """getConfiguration _setConfiguration setNumPartitions for x in """getConfiguration _setConfiguration setNumPartitions
query erase getPartitionTable changePartitionTable query erase getPartitionTable
getUnfinishedTIDDict dropUnfinishedData abortTransaction getUnfinishedTIDDict dropUnfinishedData abortTransaction
storeTransaction lockTransaction unlockTransaction storeTransaction lockTransaction
loadData storeData getOrphanList _pruneData deferCommit loadData storeData getOrphanList _pruneData deferCommit
dropPartitionsTemporary dropPartitionsTemporary
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(db, x))
if self._writeback:
self._writeback = WriteBack(db, self._writeback)
db_commit = db.commit
def commit():
db_commit()
self._last_commit = time.time()
if self._writeback:
self._writeback.committed()
self.commit = db.commit = commit
def _connect(self): def _updateReadable(self):
pass raise AssertionError
def commit(self): def changePartitionTable(self, *args, **kw):
self.db.commit() self.db.changePartitionTable(*args, **kw)
# XXX: This misses commits done internally by self.db (lockTransaction). if self._writeback:
self._last_commit = time.time() self._writeback.changed()
def unlockTransaction(self, *args):
self.db.unlockTransaction(*args)
if self._writeback:
self._writeback.changed()
def close(self): def close(self):
if self._writeback:
self._writeback.close()
self.db.close() self.db.close()
if isinstance(self.zodb, list): # _setup called if isinstance(self.zodb, list): # _setup called
for zodb in self.zodb: for zodb in self.zodb:
...@@ -576,3 +604,120 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -576,3 +604,120 @@ class ImporterDatabaseManager(DatabaseManager):
def pack(self, *args, **kw): def pack(self, *args, **kw):
raise BackendNotImplemented(self.pack) raise BackendNotImplemented(self.pack)
class WriteBack(object):
_changed = False
_process = None
threading = False
def __init__(self, db, storage):
self._db = db
self._storage = storage
def close(self):
if self._process:
self._stop.set()
self._event.set()
self._process.join()
def changed(self):
self._changed = True
def committed(self):
if self._changed:
self._changed = False
if self._process:
self._event.set()
else:
if self.threading:
from threading import Thread as Process, Event
else:
from multiprocessing import Process, Event
self._event = Event()
self._idle = Event()
self._stop = Event()
self._np = self._db.getNumPartitions()
self._db = cPickle.dumps(self._db, 2)
self._process = Process(target=self._run)
self._process.daemon = True
self._process.start()
@property
def wait(self):
# For unit tests.
return self._idle.wait
def _run(self):
self._db = cPickle.loads(self._db)
try:
@self._db.autoReconnect
def _():
# Unfortunately, copyTransactionsFrom does not abort in case
# of failure, so we have to reopen.
zodb = storageFromString(self._storage)
try:
self.min_tid = util.add64(zodb.lastTransaction(), 1)
zodb.copyTransactionsFrom(self)
finally:
zodb.close()
finally:
self._idle.set()
self._db.close()
def iterator(self):
db = self._db
np = self._np
chunk_size = max(2, 1000 // np)
offset_list = xrange(np)
while 1:
with db:
# Check the partition table at the beginning of every
# transaction. Once the import is finished and at least one
# cell is replicated, it is possible that some of this node
# get outdated. In this case, wait for the next PT change.
if np == len(db._readable_set):
while 1:
tid_list = []
loop = False
for offset in offset_list:
x = db.getReplicationTIDList(
self.min_tid, MAX_TID, chunk_size, offset)
tid_list += x
if len(x) == chunk_size:
loop = True
if tid_list:
tid_list.sort()
for tid in tid_list:
if self._stop.is_set():
return
yield TransactionRecord(db, tid)
self.min_tid = util.add64(tid, 1)
if loop:
continue
break
if not self._event.is_set():
self._idle.set()
self._event.wait()
self._idle.clear()
self._event.clear()
if self._stop.is_set():
break
class TransactionRecord(BaseStorage.TransactionRecord):
def __init__(self, db, tid):
self._oid_list, user, desc, ext, _, _ = db.getTransaction(tid)
super(TransactionRecord, self).__init__(tid, ' ', user, desc,
cPickle.loads(ext) if ext else {})
self._db = db
def __iter__(self):
tid = self.tid
for oid in self._oid_list:
_, compression, _, data, data_tid = self._db.fetchObject(oid, tid)
if data is not None:
data = compress.decompress_list[compression](data)
yield BaseStorage.DataRecord(oid, tid, data, data_tid)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import os, errno, socket, struct, sys, threading import os, errno, socket, struct, sys, threading
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy
from functools import wraps from functools import wraps
from neo.lib import logging, util from neo.lib import logging, util
from neo.lib.interfaces import abstract, requires from neo.lib.interfaces import abstract, requires
...@@ -60,7 +61,7 @@ class DatabaseManager(object): ...@@ -60,7 +61,7 @@ class DatabaseManager(object):
LOCKED = "error: database is locked" LOCKED = "error: database is locked"
_deferred = 0 _deferred = 0
_duplicating = _repairing = None _repairing = None
def __init__(self, database, engine=None, wait=None): def __init__(self, database, engine=None, wait=None):
""" """
...@@ -75,30 +76,56 @@ class DatabaseManager(object): ...@@ -75,30 +76,56 @@ class DatabaseManager(object):
# But for unit tests, we really want to never retry. # But for unit tests, we really want to never retry.
self._wait = wait or 0 self._wait = wait or 0
self._parse(database) self._parse(database)
self._init_attrs = tuple(self.__dict__)
self._connect() self._connect()
def __getattr__(self, attr): def __getstate__(self):
if self._duplicating is None: state = {x: getattr(self, x) for x in self._init_attrs}
return self.__getattribute__(attr) assert state # otherwise, __setstate__ is not called
value = getattr(self._duplicating, attr) return state
setattr(self, attr, value)
return value def __setstate__(self, state):
self.__dict__.update(state)
# For the moment, no need to duplicate secondary connections.
#self._init_attrs = tuple(self.__dict__)
# Secondary connections don't lock.
self.LOCK = None
self._connect()
@contextmanager @contextmanager
def _duplicate(self): def _duplicate(self):
cls = self.__class__ db = copy(self)
db = cls.__new__(cls)
db.LOCK = None
db._duplicating = self
try:
db._connect()
finally:
del db._duplicating
try: try:
yield db yield db
finally: finally:
db.close() db.close()
def __getattr__(self, attr):
if attr in ('_readable_set', '_getPartition', '_getReadablePartition'):
self._updateReadable()
return self.__getattribute__(attr)
def _partitionTableChanged(self):
try:
del (self._readable_set,
self._getPartition,
self._getReadablePartition)
except AttributeError:
pass
def __enter__(self):
assert not self.LOCK, "not a secondary connection"
# XXX: All config caching should be done in this class,
# rather than in backend classes.
self._config.clear()
self._partitionTableChanged()
def __exit__(self, t, v, tb):
if v is None:
# Deferring commits make no sense for secondary connections.
assert not self._deferred
self._commit()
@abstract @abstract
def _parse(self, database): def _parse(self, database):
"""Called during instantiation, to process database parameter.""" """Called during instantiation, to process database parameter."""
...@@ -107,6 +134,17 @@ class DatabaseManager(object): ...@@ -107,6 +134,17 @@ class DatabaseManager(object):
def _connect(self): def _connect(self):
"""Connect to the database""" """Connect to the database"""
def autoReconnect(self, f):
"""
Placeholder for backends that may lose connection to the underlying
database: although a primary connection is reestablished transparently
when possible, secondary connections use transactions and they must
restart from the beginning.
For other backends, there's no expected transient failure so the
default implementation is to execute the given task exactly once.
"""
f()
def lock(self, db_path): def lock(self, db_path):
if self.LOCK: if self.LOCK:
assert self.__lock is None, self.__lock assert self.__lock is None, self.__lock
...@@ -147,7 +185,6 @@ class DatabaseManager(object): ...@@ -147,7 +185,6 @@ class DatabaseManager(object):
""" """
if reset: if reset:
self.erase() self.erase()
self._readable_set = set()
self._uncommitted_data = defaultdict(int) self._uncommitted_data = defaultdict(int)
self._setup(dedup) self._setup(dedup)
...@@ -250,10 +287,7 @@ class DatabaseManager(object): ...@@ -250,10 +287,7 @@ class DatabaseManager(object):
Store the number of partitions into a database. Store the number of partitions into a database.
""" """
self.setConfiguration('partitions', num_partitions) self.setConfiguration('partitions', num_partitions)
try: self._partitionTableChanged()
del self._getPartition, self._getReadablePartition
except AttributeError:
pass
def getNumReplicas(self): def getNumReplicas(self):
""" """
...@@ -320,6 +354,15 @@ class DatabaseManager(object): ...@@ -320,6 +354,15 @@ class DatabaseManager(object):
is again a tuple of an offset (row ID), the NID of a storage is again a tuple of an offset (row ID), the NID of a storage
node, and a cell state.""" node, and a cell state."""
def _getAssignedPartitionList(self, *states):
nid = self.getUUID()
if nid is None:
return ()
if states:
return [nid for nid, state in self.getPartitionTable(nid)
if state in states]
return [x[0] for x in self.getPartitionTable(nid)]
@abstract @abstract
def getLastTID(self, max_tid): def getLastTID(self, max_tid):
"""Return greatest tid in trans table that is <= given 'max_tid' """Return greatest tid in trans table that is <= given 'max_tid'
...@@ -492,11 +535,12 @@ class DatabaseManager(object): ...@@ -492,11 +535,12 @@ class DatabaseManager(object):
""" """
""" """
@requires(_changePartitionTable, _getDataLastId) @requires(_getDataLastId)
def changePartitionTable(self, ptid, cell_list, reset=False): def _updateReadable(self):
readable_set = self._readable_set try:
if reset: readable_set = self.__dict__['_readable_set']
readable_set.clear() except KeyError:
readable_set = self._readable_set = set()
np = self.getNumPartitions() np = self.getNumPartitions()
def _getPartition(x, np=np): def _getPartition(x, np=np):
return x % np return x % np
...@@ -511,14 +555,15 @@ class DatabaseManager(object): ...@@ -511,14 +555,15 @@ class DatabaseManager(object):
for p in xrange(np): for p in xrange(np):
i = self._getDataLastId(p) i = self._getDataLastId(p)
d.append(p << 48 if i is None else i + 1) d.append(p << 48 if i is None else i + 1)
me = self.getUUID()
for offset, nid, state in cell_list:
if nid == me:
if CellStates.UP_TO_DATE != state != CellStates.FEEDING:
readable_set.discard(offset)
else: else:
readable_set.add(offset) readable_set.clear()
readable_set.update(self._getAssignedPartitionList(
CellStates.UP_TO_DATE, CellStates.FEEDING))
@requires(_changePartitionTable)
def changePartitionTable(self, ptid, cell_list, reset=False):
self._changePartitionTable(cell_list, reset) self._changePartitionTable(cell_list, reset)
self._updateReadable()
assert isinstance(ptid, (int, long)), ptid assert isinstance(ptid, (int, long)), ptid
self._setConfiguration('ptid', str(ptid)) self._setConfiguration('ptid', str(ptid))
......
...@@ -76,7 +76,9 @@ def auto_reconnect(wrapped): ...@@ -76,7 +76,9 @@ def auto_reconnect(wrapped):
if (self._active if (self._active
or SERVER_GONE_ERROR != m.args[0] != SERVER_LOST or SERVER_GONE_ERROR != m.args[0] != SERVER_LOST
or not retry): or not retry):
if self.LOCK:
raise MysqlError(m, *args) raise MysqlError(m, *args)
raise # caught upper for secondary connections
logging.info('the MySQL server is gone; reconnecting') logging.info('the MySQL server is gone; reconnecting')
assert not self._deferred assert not self._deferred
self.close() self.close()
...@@ -112,7 +114,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -112,7 +114,7 @@ class MySQLDatabaseManager(DatabaseManager):
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == 'conn': if attr == 'conn':
self._tryConnect() self._tryConnect()
return DatabaseManager.__getattr__(self, attr) return super(MySQLDatabaseManager, self).__getattr__(attr)
def _tryConnect(self): def _tryConnect(self):
kwd = {'db' : self.db, 'user' : self.user} kwd = {'db' : self.db, 'user' : self.user}
...@@ -171,9 +173,30 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -171,9 +173,30 @@ class MySQLDatabaseManager(DatabaseManager):
if e.args[0] != NO_SUCH_TABLE: if e.args[0] != NO_SUCH_TABLE:
raise raise
self._dedup = None self._dedup = None
if not self.LOCK:
# Prevent automatic reconnection for secondary connections.
self._active = 1
self._commit = self.conn.commit
_connect = auto_reconnect(_tryConnect) _connect = auto_reconnect(_tryConnect)
def autoReconnect(self, f):
assert self._active and not self.LOCK
@auto_reconnect
def try_once(self):
if self._active:
try:
f()
finally:
self._active = 0
return True
while not try_once(self):
# Avoid reconnecting too often.
# Since this is used to wrap an arbitrary long process and
# not just a single query, we can't limit the number of retries.
time.sleep(5)
self._connect()
def _commit(self): def _commit(self):
self.conn.commit() self.conn.commit()
self._active = 0 self._active = 0
...@@ -371,12 +394,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -371,12 +394,6 @@ class MySQLDatabaseManager(DatabaseManager):
return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid) return self.query("SELECT rid, state FROM pt WHERE nid=%u" % nid)
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
def _getAssignedPartitionList(self):
nid = self.getUUID()
if nid is None:
return ()
return [p for p, in self.query("SELECT rid FROM pt WHERE nid=%s" % nid)]
def _sqlmax(self, sql, arg_list): def _sqlmax(self, sql, arg_list):
q = self.query q = self.query
x = [x for x in arg_list for x, in q(sql % x) if x is not None] x = [x for x in arg_list for x, in q(sql % x) if x is not None]
......
...@@ -28,6 +28,7 @@ import weakref ...@@ -28,6 +28,7 @@ import weakref
import MySQLdb import MySQLdb
import transaction import transaction
from ConfigParser import SafeConfigParser
from cStringIO import StringIO from cStringIO import StringIO
try: try:
from ZODB._compat import Unpickler from ZODB._compat import Unpickler
...@@ -155,6 +156,18 @@ def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True): ...@@ -155,6 +156,18 @@ def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True):
conn.commit() conn.commit()
conn.close() conn.close()
def ImporterConfigParser(adapter, zodb, **kw):
cfg = SafeConfigParser()
cfg.add_section("neo")
cfg.set("neo", "adapter", adapter)
for x in kw.iteritems():
cfg.set("neo", *x)
for name, zodb in zodb:
cfg.add_section(name)
for x in zodb.iteritems():
cfg.set(name, *x)
return cfg
class NeoTestBase(unittest.TestCase): class NeoTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -29,7 +29,6 @@ import tempfile ...@@ -29,7 +29,6 @@ import tempfile
import traceback import traceback
import threading import threading
import psutil import psutil
from ConfigParser import SafeConfigParser
import neo.scripts import neo.scripts
from neo.neoctl.neoctl import NeoCTL, NotReadyException from neo.neoctl.neoctl import NeoCTL, NotReadyException
...@@ -38,8 +37,8 @@ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \ ...@@ -38,8 +37,8 @@ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES UUID_NAMESPACES
from neo.lib.util import dump from neo.lib.util import dump
from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL, from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, NeoTestBase, Patch, buildUrlFromString, cluster, getTempDirectory, setupMySQLdb,
setupMySQLdb) ImporterConfigParser, NeoTestBase, Patch)
from neo.client.Storage import Storage from neo.client.Storage import Storage
from neo.storage.database import manager, buildDatabaseManager from neo.storage.database import manager, buildDatabaseManager
...@@ -154,7 +153,7 @@ class Process(object): ...@@ -154,7 +153,7 @@ class Process(object):
if args: if args:
os.close(w) os.close(w)
os.kill(os.getpid(), signal.SIGSTOP) os.kill(os.getpid(), signal.SIGSTOP)
self.pid = logging.fork() self.pid = os.fork()
if self.pid: if self.pid:
# Wait that the signal to kill the child is set up. # Wait that the signal to kill the child is set up.
os.close(w) os.close(w)
...@@ -317,14 +316,8 @@ class NEOCluster(object): ...@@ -317,14 +316,8 @@ class NEOCluster(object):
IP_VERSION_FORMAT_DICT[self.address_type] IP_VERSION_FORMAT_DICT[self.address_type]
self.setupDB(clear_databases) self.setupDB(clear_databases)
if importer: if importer:
cfg = SafeConfigParser() cfg = ImporterConfigParser(adapter, **importer)
cfg.add_section("neo")
cfg.set("neo", "adapter", adapter)
cfg.set("neo", "database", self.db_template(*db_list)) cfg.set("neo", "database", self.db_template(*db_list))
for name, zodb in importer:
cfg.add_section(name)
for x in zodb.iteritems():
cfg.set(name, *x)
importer_conf = os.path.join(temp_dir, 'importer.cfg') importer_conf = os.path.join(temp_dir, 'importer.cfg')
with open(importer_conf, 'w') as f: with open(importer_conf, 'w') as f:
cfg.write(f) cfg.write(f)
......
...@@ -202,9 +202,9 @@ class ClientTests(NEOFunctionalTest): ...@@ -202,9 +202,9 @@ class ClientTests(NEOFunctionalTest):
self.neo.stop() self.neo.stop()
self.neo = NEOCluster(db_list=['test_neo1'], partitions=3, self.neo = NEOCluster(db_list=['test_neo1'], partitions=3,
importer=[("root", { importer={"zodb": [("root", {
"storage": "<filestorage>\npath %s\n</filestorage>" "storage": "<filestorage>\npath %s\n</filestorage>"
% dfs_storage.getName()})], % dfs_storage.getName()})]},
temp_dir=self.getTempDirectory()) temp_dir=self.getTempDirectory())
self.neo.start() self.neo.start()
neo_db, neo_conn = self.neo.getZODBConnection() neo_db, neo_conn = self.neo.getZODBConnection()
......
...@@ -227,6 +227,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -227,6 +227,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_changePartitionTable(self): def test_changePartitionTable(self):
db = self.getDB() db = self.getDB()
db.setNumPartitions(3)
ptid = 1 ptid = 1
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE cell1 = 0, uuid, CellStates.OUT_OF_DATE
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from contextlib import contextmanager
from MySQLdb import NotSupportedError, OperationalError, ProgrammingError from MySQLdb import NotSupportedError, OperationalError, ProgrammingError
from MySQLdb.constants.CR import SERVER_GONE_ERROR
from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE
from ..mock import Mock from ..mock import Mock
from neo.lib.protocol import ZERO_OID from neo.lib.protocol import ZERO_OID
...@@ -26,6 +28,19 @@ from neo.storage.database import DatabaseFailure ...@@ -26,6 +28,19 @@ from neo.storage.database import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager from neo.storage.database.mysqldb import MySQLDatabaseManager
class ServerGone(object):
@contextmanager
def __new__(cls, db):
self = object.__new__(cls)
with Patch(db, conn=self) as self._p:
yield self._p
def query(self, *args):
self._p.revert()
raise OperationalError(SERVER_GONE_ERROR, 'this is a test')
class StorageMySQLdbTests(StorageDBTests): class StorageMySQLdbTests(StorageDBTests):
engine = None engine = None
...@@ -67,14 +82,7 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -67,14 +82,7 @@ class StorageMySQLdbTests(StorageDBTests):
calls[0].checkArgs('SELECT ') calls[0].checkArgs('SELECT ')
def test_query2(self): def test_query2(self):
# test the OperationalError exception with ServerGone(self.db) as p:
# fake object, raise exception during the first call
from MySQLdb.constants.CR import SERVER_GONE_ERROR
class FakeConn(object):
def query(*args):
p.revert()
raise OperationalError(SERVER_GONE_ERROR, 'this is a test')
with Patch(self.db, conn=FakeConn()) as p:
self.assertRaises(ProgrammingError, self.db.query, 'QUERY') self.assertRaises(ProgrammingError, self.db.query, 'QUERY')
self.assertFalse(p.applied) self.assertFalse(p.applied)
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
import os, random, select, socket, sys, tempfile import os, random, select, socket, sys, tempfile
import thread, threading, time, traceback, weakref import thread, threading, time, traceback, weakref
from collections import deque from collections import deque
from ConfigParser import SafeConfigParser
from contextlib import contextmanager from contextlib import contextmanager
from itertools import count from itertools import count
from functools import partial, wraps from functools import partial, wraps
...@@ -37,8 +36,9 @@ from neo.lib.handler import EventHandler ...@@ -37,8 +36,9 @@ from neo.lib.handler import EventHandler
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import ClusterStates, Enum, NodeStates, NodeTypes, Packets from neo.lib.protocol import ClusterStates, Enum, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import (getTempDirectory, setupMySQLdb,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER ImporterConfigParser, NeoTestBase, Patch,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER)
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]) LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
...@@ -685,14 +685,8 @@ class NEOCluster(object): ...@@ -685,14 +685,8 @@ class NEOCluster(object):
else: else:
assert False, adapter assert False, adapter
if importer: if importer:
cfg = SafeConfigParser() cfg = ImporterConfigParser(adapter, **importer)
cfg.add_section("neo")
cfg.set("neo", "adapter", adapter)
cfg.set("neo", "database", db % tuple(db_list)) cfg.set("neo", "database", db % tuple(db_list))
for name, zodb in importer:
cfg.add_section(name)
for x in zodb.iteritems():
cfg.set(name, *x)
db = os.path.join(getTempDirectory(), '%s.conf') db = os.path.join(getTempDirectory(), '%s.conf')
with open(db % tuple(db_list), "w") as f: with open(db % tuple(db_list), "w") as f:
cfg.write(f) cfg.write(f)
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
from cPickle import Pickler, Unpickler from cPickle import Pickler, Unpickler
from cStringIO import StringIO from cStringIO import StringIO
from itertools import izip_longest from itertools import izip_longest
import os, random, shutil, unittest import os, random, shutil, time, unittest
import transaction, ZODB import transaction, ZODB
from neo.client.exception import NEOPrimaryMasterLost from neo.client.exception import NEOPrimaryMasterLost
from neo.lib import logging from neo.lib import logging
from neo.lib.util import u64 from neo.lib.util import u64
from neo.storage.database.importer import Repickler from neo.storage.database import getAdapterKlass, manager
from .. import expectedFailure, getTempDirectory, random_tree from neo.storage.database.importer import \
Repickler, TransactionRecord, WriteBack
from .. import expectedFailure, getTempDirectory, random_tree, Patch
from . import NEOCluster, NEOThreadedTest from . import NEOCluster, NEOThreadedTest
from ZODB import serialize from ZODB import serialize
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
...@@ -159,7 +161,8 @@ class ImporterTests(NEOThreadedTest): ...@@ -159,7 +161,8 @@ class ImporterTests(NEOThreadedTest):
if r: if r:
transaction.commit() transaction.commit()
# Get oids of mount points and close. # Get oids of mount points and close.
importer = [] zodb = []
importer = {'zodb': zodb}
for db, r, cfg in db_list: for db, r, cfg in db_list:
if db == 'root': if db == 'root':
if multi: if multi:
...@@ -169,13 +172,14 @@ class ImporterTests(NEOThreadedTest): ...@@ -169,13 +172,14 @@ class ImporterTests(NEOThreadedTest):
h = random_tree.hashTree(r) h = random_tree.hashTree(r)
h() h()
self.assertEqual(import_hash, h.hexdigest()) self.assertEqual(import_hash, h.hexdigest())
importer['writeback'] = 'true'
else: else:
cfg["oid"] = str(u64(r[db]._p_oid)) cfg["oid"] = str(u64(r[db]._p_oid))
db = '_%s' % db db = '_%s' % db
r._p_jar.db().close() r._p_jar.db().close()
importer.append((db, cfg)) zodb.append((db, cfg))
del db_list, iter_list del db_list, iter_list
#del importer[0][1][importer.pop()[0]] #del zodb[0][1][zodb.pop()[0]]
# Start NEO cluster with transparent import. # Start NEO cluster with transparent import.
with NEOCluster(importer=importer) as cluster: with NEOCluster(importer=importer) as cluster:
# Suspend import for a while, so that import # Suspend import for a while, so that import
...@@ -226,13 +230,51 @@ class ImporterTests(NEOThreadedTest): ...@@ -226,13 +230,51 @@ class ImporterTests(NEOThreadedTest):
assert i < last_import * 3 < 2 * i, (last_import, i) assert i < last_import * 3 < 2 * i, (last_import, i)
self.assertFalse(cluster.storage.dm._import) self.assertFalse(cluster.storage.dm._import)
storage._cache.clear() storage._cache.clear()
def finalCheck(r):
h = random_tree.hashTree(r) h = random_tree.hashTree(r)
self.assertEqual(93, h()) self.assertEqual(93, h())
self.assertEqual('6bf0f0cb2d6c1aae9e52c412ef0e25b6', h.hexdigest()) self.assertEqual('6bf0f0cb2d6c1aae9e52c412ef0e25b6',
h.hexdigest())
finalCheck(r)
if dm._writeback:
dm.commit()
dm._writeback.wait()
if dm._writeback:
db = ZODB.DB(FileStorage(fs_path, read_only=True))
finalCheck(db.open().root()['tree'])
db.close()
def test1(self): def test1(self):
self._importFromFileStorage() self._importFromFileStorage()
def testThreadedWriteback(self):
# Also check reconnection to the underlying DB for relevant backends.
tid_list = []
def __init__(orig, tr, db, tid):
orig(tr, db, tid)
tid_list.append(tid)
def fetchObject(orig, db, *args):
if len(tid_list) == 5:
if isinstance(db, getAdapterKlass('MySQL')):
from neo.tests.storage.testStorageMySQL import ServerGone
with ServerGone(db):
orig(db, *args)
self.fail()
else:
tid_list.append(None)
p.revert()
return orig(db, *args)
def sleep(orig, seconds):
self.assertEqual(len(tid_list), 5)
p.revert()
with Patch(WriteBack, threading=True), \
Patch(TransactionRecord, __init__=__init__), \
Patch(manager.DatabaseManager, fetchObject=fetchObject), \
Patch(time, sleep=sleep) as p:
self._importFromFileStorage()
self.assertFalse(p.applied)
self.assertEqual(len(tid_list), 11)
def testMerge(self): def testMerge(self):
multi = 1, 2, 3 multi = 1, 2, 3
self._importFromFileStorage(multi, self._importFromFileStorage(multi,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment