Commit 6153a752 authored by Julien Muchembled's avatar Julien Muchembled

Add support for PyPy & PyMySQL

parent c29b8b2d
......@@ -48,9 +48,10 @@ Requirements
- Python 2.7.x (2.7.9 or later for SSL support)
- For storage nodes using MySQL backend:
- For storage nodes using MySQL, one of the following backends:
- MySQLdb: https://github.com/PyMySQL/mysqlclient-python
- MySQLdb: https://github.com/PyMySQL/mysqlclient
- PyMySQL: https://github.com/PyMySQL/PyMySQL
- For client nodes: ZODB 4.4.5 or later
......
......@@ -33,6 +33,7 @@ def patch():
assert H(Connection.afterCompletion) in (
'cd3a080b80fd957190ff3bb867149448', # Python 2.7
'b1d9685c13967d4b6d74c7ef86f68f17', # PyPy 2.7
)
def afterCompletion(self, *ignored):
......
......@@ -35,7 +35,7 @@ if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]):
coverage.start()
from neo.lib import logging
from neo.tests import getTempDirectory, NeoTestBase, Patch, \
from neo.tests import adapter, getTempDirectory, NeoTestBase, Patch, \
__dict__ as neo_tests__dict__
from neo.tests.benchmark import BenchmarkRunner
......@@ -216,9 +216,11 @@ class NeoTestRunner(unittest.TextTestResult):
add_status('Directory', self.temp_directory)
if self.testsRun:
add_status('Status', '%.3f%%' % (success * 100.0 / self.testsRun))
for var in os.environ:
if var.startswith('NEO_TEST'):
add_status(var, os.environ[var])
for k, v in os.environ.iteritems():
if k.startswith('NEO_TEST'):
if k == 'NEO_TESTS_ADAPTER' and v == 'MySQL':
from neo.storage.database.mysql import binding_name as v
add_status(k, v)
# visual
header = "%25s | run | unexpected | expected | skipped | time \n" % 'Test Module'
separator = "%25s-+-------+------------+----------+---------+----------\n" % ('-' * 25)
......@@ -318,7 +320,7 @@ class TestRunner(BenchmarkRunner):
" passed.")
parser.epilog = """
Environment Variables:
NEO_PYPY PyPy executable to run master nodes in functional
NEOMASTER_PYPY PyPy executable to run master nodes in functional
tests (and also in zodb tests depending on
NEO_TEST_ZODB_FUNCTIONAL).
NEO_TESTS_ADAPTER Default is SQLite for threaded clusters,
......
......@@ -26,7 +26,7 @@ from neo.lib.pt import PartitionTable
from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager
from .checker import Checker
from .database import buildDatabaseManager, DATABASE_MANAGER_DICT
from .database import buildDatabaseManager, DATABASE_MANAGERS
from .handlers import identification, initialization, master
from .replicator import Replicator
from .transactions import TransactionManager
......@@ -37,7 +37,7 @@ option_defaults = {
'adapter': 'MySQL',
'wait': 0,
}
assert option_defaults['adapter'] in DATABASE_MANAGER_DICT
assert option_defaults['adapter'] in DATABASE_MANAGERS
@buildOptionParser
class Application(BaseApplication):
......@@ -52,7 +52,7 @@ class Application(BaseApplication):
cls.addCommonServerOptions('storage', '127.0.0.1')
_ = parser.group('storage')
_('a', 'adapter', choices=sorted(DATABASE_MANAGER_DICT),
_('a', 'adapter', choices=DATABASE_MANAGERS,
help="database adapter to use")
_('d', 'database', required=True,
help="database connections string")
......
......@@ -16,18 +16,51 @@
LOG_QUERIES = False
DATABASE_MANAGER_DICT = {
'Importer': 'importer.ImporterDatabaseManager',
'MySQL': 'mysql.MySQLDatabaseManager',
'SQLite': 'sqlite.SQLiteDatabaseManager',
}
def getAdapterKlass(name):
def useMySQLdb():
import platform
py = platform.python_implementation() == 'PyPy'
try:
module, name = DATABASE_MANAGER_DICT[name or 'MySQL'].split('.')
except KeyError:
raise DatabaseFailure('Cannot find a database adapter <%s>' % name)
return getattr(__import__(module, globals(), level=1), name)
if py:
import pymysql
else:
import MySQLdb
except ImportError:
return py
return not py
class getAdapterKlass(object):
def __new__(cls, name):
try:
m = getattr(cls, name or 'MySQL')
except AttributeError:
raise DatabaseFailure('Cannot find a database adapter <%s>' % name)
return m()
@staticmethod
def Importer():
from .importer import ImporterDatabaseManager as DM
return DM
@classmethod
def MySQL(cls, MySQLdb=None):
if MySQLdb is not None:
global useMySQLdb
useMySQLdb = lambda: MySQLdb
from .mysql import binding_name, MySQLDatabaseManager as DM
assert hasattr(cls, binding_name)
return DM
MySQLdb = classmethod(lambda cls: cls.MySQL(True))
PyMySQL = classmethod(lambda cls: cls.MySQL(False))
@staticmethod
def SQLite():
from .sqlite import SQLiteDatabaseManager as DM
return DM
DATABASE_MANAGERS = tuple(sorted(
x for x in dir(getAdapterKlass) if not x.startswith('_')))
def buildDatabaseManager(name, args=(), kw={}):
return getAdapterKlass(name)(*args, **kw)
......
......@@ -22,7 +22,7 @@ from cStringIO import StringIO
from ConfigParser import SafeConfigParser
from ZConfig import loadConfigFile
from ZODB import BaseStorage
from ZODB._compat import dumps, loads, _protocol
from ZODB._compat import dumps, loads, _protocol, PersistentPickler
from ZODB.config import getStorageSchema, storageFromString
from ZODB.POSException import POSKeyError
from ZODB.FileStorage import FileStorage
......@@ -44,6 +44,35 @@ def transactionAsTuple(txn):
dumps(ext, _protocol) if ext else '',
txn.status == 'p', txn.tid)
@apply
def patch_save_reduce(): # for _noload.__reduce__
Pickler = PersistentPickler(None, StringIO()).__class__
try:
orig_save_reduce = Pickler.save_reduce.__func__
except AttributeError: # both cPickle and C zodbpickle accept
return # that first reduce argument is None
BUILD = pickle.BUILD
REDUCE = pickle.REDUCE
def save_reduce(self, func, args, state=None,
listitems=None, dictitems=None, obj=None):
if func is not None:
return orig_save_reduce(self,
func, args, state, listitems, dictitems, obj)
assert args is ()
save = self.save
write = self.write
save(func)
save(args)
self.write(REDUCE)
if obj is not None:
self.memoize(obj)
self._batch_appends(listitems)
self._batch_setitems(dictitems)
if state is not None:
save(state)
write(BUILD)
Pickler.save_reduce = save_reduce
class Reference(object):
......@@ -59,17 +88,15 @@ class Repickler(pickle.Unpickler):
# Use python implementation for unpickling because loading can not
# be customized enough with cPickle.
pickle.Unpickler.__init__(self, self._f)
# For pickling, it is possible to use the fastest implementation,
# which also generates fewer useless PUT opcodes.
self._p = cPickle.Pickler(self._f, 1)
self.memo = self._p.memo # just a tiny optimization
def persistent_id(obj):
if isinstance(obj, Reference):
r = obj.value
del obj.value # minimize refcnt like for deque+popleft
return r
self._p.inst_persistent_id = persistent_id
# For pickling, it is possible to use the fastest implementation,
# which also generates fewer useless PUT opcodes.
self._p = PersistentPickler(persistent_id, self._f, 1)
self.memo = self._p.memo # just a tiny optimization
def persistent_load(obj):
new_obj = persistent_map(obj)
......@@ -96,8 +123,10 @@ class Repickler(pickle.Unpickler):
self.memo.clear()
if self._changed:
f.truncate(0)
dump = self._p.dump
try:
self._p.dump(classmeta).dump(state)
dump(classmeta)
dump(state)
finally:
self.memo.clear()
return f.getvalue()
......
......@@ -800,10 +800,9 @@ class DatabaseManager(object):
if found_undone_tid is None:
return
if transaction_object:
try:
current_tid = current_data_tid = u64(transaction_object[2])
except struct.error:
current_tid = current_data_tid = tid
transaction_tid = transaction_object[2]
current_tid = current_data_tid = \
tid if transaction_tid is None else u64(transaction_tid)
else:
current_tid, current_data_tid = getDataTID(before_tid=ltid)
if current_tid is None:
......
......@@ -14,25 +14,45 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os, re, string, struct, sys, time
from binascii import a2b_hex
from collections import OrderedDict
from functools import wraps
import MySQLdb
from MySQLdb import DataError, IntegrityError, \
OperationalError, ProgrammingError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
from . import useMySQLdb
if useMySQLdb():
binding_name = 'MySQLdb'
from MySQLdb.connections import Connection
from MySQLdb import __version__ as binding_version, DataError, \
IntegrityError, OperationalError, ProgrammingError
InternalOrOperationalError = OperationalError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
def fetch_all(conn):
r = conn.store_result()
return r.fetch_row(r.num_rows())
# for tests
from MySQLdb import NotSupportedError
from MySQLdb.constants.ER import BAD_DB_ERROR, UNKNOWN_STORAGE_ENGINE
else:
binding_name = 'PyMySQL'
from pymysql.connections import Connection
from pymysql import __version__ as binding_version, DataError, \
IntegrityError, InternalError, OperationalError, ProgrammingError
InternalOrOperationalError = InternalError, OperationalError
from pymysql.constants.CR import (
CR_SERVER_GONE_ERROR as SERVER_GONE_ERROR,
CR_SERVER_LOST as SERVER_LOST)
from pymysql.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
def fetch_all(conn):
return conn._result.rows
# for tests
from pymysql import NotSupportedError
from pymysql.constants.ER import BAD_DB_ERROR, UNKNOWN_STORAGE_ENGINE
# BBB: the following 2 constants were added to mysqlclient 1.3.8
DROP_LAST_PARTITION = 1508
SAME_NAME_PARTITION = 1517
from array import array
from hashlib import sha1
import os
import re
import string
import struct
import sys
import time
from . import LOG_QUERIES, DatabaseFailure
from .manager import DatabaseManager, splitOIDField
......@@ -68,7 +88,7 @@ def auto_reconnect(wrapped):
while 1:
try:
return wrapped(self, *args)
except OperationalError as m:
except InternalOrOperationalError as m:
# IDEA: Is it safe to retry in case of DISK_FULL ?
# XXX: However, this would another case of failure that would
# be unnoticed by other nodes (ADMIN & MASTER). When
......@@ -121,6 +141,7 @@ class MySQLDatabaseManager(DatabaseManager):
return super(MySQLDatabaseManager, self).__getattr__(attr)
def _tryConnect(self):
# BBB: db/passwd are deprecated favour of database/password since 1.3.8
kwd = {'db' : self.db}
if self.user:
kwd['user'] = self.user
......@@ -128,8 +149,8 @@ class MySQLDatabaseManager(DatabaseManager):
kwd['passwd'] = self.passwd
if self.socket:
kwd['unix_socket'] = os.path.expanduser(self.socket)
logging.info('connecting to MySQL on the database %s with user %s',
self.db, self.user)
logging.info('Using %s %s to connect to the database %s with user %s',
binding_name, binding_version, self.db, self.user)
self._active = 0
if self._wait < 0:
timeout_at = None
......@@ -138,7 +159,7 @@ class MySQLDatabaseManager(DatabaseManager):
last = None
while True:
try:
self.conn = MySQLdb.connect(**kwd)
self.conn = Connection(**kwd)
break
except Exception as e:
if None is not timeout_at <= time.time():
......@@ -154,15 +175,15 @@ class MySQLDatabaseManager(DatabaseManager):
self._config = {}
conn = self.conn
conn.autocommit(False)
conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
conn.query("SET"
" SESSION sql_mode = 'TRADITIONAL,NO_ENGINE_SUBSTITUTION',"
" SESSION group_concat_max_len = %u" % (2**32-1))
if self._engine == 'RocksDB':
# Maximum value for _deleteRange.
conn.query("SET SESSION rocksdb_max_row_locks = %u" % 2**30)
def query(sql):
conn.query(sql)
r = conn.store_result()
return r.fetch_row(r.num_rows())
return fetch_all(conn)
if self.LOCK:
(locked,), = query("SELECT GET_LOCK('%s.%s', 0)"
% (self.db, self.LOCK))
......@@ -220,8 +241,7 @@ class MySQLDatabaseManager(DatabaseManager):
conn = self.conn
conn.query(query)
if query.startswith("SELECT "):
r = conn.store_result()
return r.fetch_row(r.num_rows())
return fetch_all(conn)
r = query.split(None, 1)[0]
if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"):
self._active = 1
......
......@@ -83,8 +83,8 @@ class SQLiteDatabaseManager(DatabaseManager):
self.lock(self.db)
if self.UNSAFE:
q = self.query
q("PRAGMA synchronous = OFF")
q("PRAGMA journal_mode = MEMORY")
q("PRAGMA synchronous = OFF").fetchall()
q("PRAGMA journal_mode = MEMORY").fetchall()
self._config = {}
def _getDevPath(self):
......
......@@ -28,7 +28,7 @@ import unittest
import weakref
import transaction
from contextlib import contextmanager
from contextlib import closing, contextmanager
from ConfigParser import SafeConfigParser
from cStringIO import StringIO
try:
......@@ -76,6 +76,12 @@ DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db')
DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld')
DB_MYCNF = os.getenv('NEO_DB_MYCNF')
adapter = os.getenv('NEO_TESTS_ADAPTER')
if adapter:
from neo.storage.database import getAdapterKlass
if getAdapterKlass(adapter).__name__ == 'MySQLDatabaseManager':
os.environ['NEO_TESTS_ADAPTER'] = 'MySQL'
IP_VERSION_FORMAT_DICT = {
socket.AF_INET: '127.0.0.1',
socket.AF_INET6: '::1',
......@@ -137,31 +143,28 @@ def getTempDirectory():
print 'Using temp directory %r.' % temp_dir
return temp_dir
def setupMySQLdb(db_list, clear_databases=True):
def setupMySQL(db_list, clear_databases=True):
if mysql_pool:
return mysql_pool.setup(db_list, clear_databases)
import MySQLdb
from MySQLdb.constants.ER import BAD_DB_ERROR
from neo.storage.database.mysql import \
Connection, OperationalError, BAD_DB_ERROR
user = DB_USER
password = ''
kw = {'unix_socket': os.path.expanduser(DB_SOCKET)} if DB_SOCKET else {}
conn = MySQLdb.connect(user=DB_ADMIN, passwd=DB_PASSWD, **kw)
cursor = conn.cursor()
for database in db_list:
try:
conn.select_db(database)
if not clear_databases:
continue
cursor.execute('DROP DATABASE `%s`' % database)
except MySQLdb.OperationalError, (code, _):
if code != BAD_DB_ERROR:
raise
cursor.execute('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED'
# BBB: passwd is deprecated favour of password since 1.3.8
with closing(Connection(user=DB_ADMIN, passwd=DB_PASSWD, **kw)) as conn:
for database in db_list:
try:
conn.select_db(database)
if not clear_databases:
continue
conn.query('DROP DATABASE `%s`' % database)
except OperationalError, (code, _):
if code != BAD_DB_ERROR:
raise
conn.query('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED'
' BY "%s"' % (database, user, password))
cursor.execute('CREATE DATABASE `%s`' % database)
cursor.close()
conn.commit()
conn.close()
conn.query('CREATE DATABASE `%s`' % database)
return '{}:{}@%s{}'.format(user, password, DB_SOCKET).__mod__
class MySQLPool(object):
......@@ -178,7 +181,7 @@ class MySQLPool(object):
self.kill(*self._mysqld_dict)
def setup(self, db_list, clear_databases):
import MySQLdb
from neo.storage.database.mysql import Connection
start_list = set(db_list).difference(self._mysqld_dict)
if start_list:
start_list = sorted(start_list)
......@@ -221,12 +224,11 @@ class MySQLPool(object):
if x is not None:
raise subprocess.CalledProcessError(x, DB_MYSQLD)
for db in db_list:
db = MySQLdb.connect(unix_socket=self._sock_template % db,
user='root')
if clear_databases:
db.query('DROP DATABASE IF EXISTS neo')
db.query('CREATE DATABASE IF NOT EXISTS neo')
db.close()
with closing(Connection(unix_socket=self._sock_template % db,
user='root')) as db:
if clear_databases:
db.query('DROP DATABASE IF EXISTS neo')
db.query('CREATE DATABASE IF NOT EXISTS neo')
return ('root@neo' + self._sock_template).__mod__
def start(self, *db, **kw):
......@@ -274,6 +276,8 @@ class NeoTestBase(unittest.TestCase):
assert self.tearDown.im_func is NeoTestBase.tearDown.im_func
self._tearDown(sys._getframe(1).f_locals['success'])
assert not gc.garbage, gc.garbage
# XXX: I tried the following line to avoid random freezes on PyPy...
gc.collect()
def _tearDown(self, success):
# Kill all unfinished transactions for next test.
......@@ -335,7 +339,7 @@ class NeoUnitTestBase(NeoTestBase):
""" create empty databases """
adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
if adapter == 'MySQL':
db_template = setupMySQLdb(
db_template = setupMySQL(
[prefix + str(i) for i in xrange(number)])
self.db_template = lambda i: db_template(prefix + str(i))
elif adapter == 'SQLite':
......
......@@ -51,13 +51,20 @@ class BenchmarkRunner(object):
def build_report(self, content):
fmt = "%-25s : %s"
py_impl = platform.python_implementation()
if py_impl == 'PyPy':
info = sys.pypy_version_info
py_impl += ' %s.%s.%s' % info[:3]
kind = info.releaselevel
if kind != 'final':
py_impl += kind[0] + str(info.serial)
status = "\n".join([fmt % item for item in [
('Title', self._config.title),
('Date', datetime.date.today().isoformat()),
('Node', platform.node()),
('Machine', platform.machine()),
('System', platform.system()),
('Python', platform.python_version()),
('Python', '%s [%s]' % (platform.python_version(), py_impl)),
]])
status += '\n\n'
status += "\n".join([fmt % item for item in self._status])
......
......@@ -36,7 +36,7 @@ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES
from neo.lib.util import dump, setproctitle
from .. import (ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, setupMySQLdb,
buildUrlFromString, cluster, getTempDirectory, setupMySQL,
ImporterConfigParser, NeoTestBase, Patch)
from neo.client.Storage import Storage
from neo.storage.database import manager, buildDatabaseManager
......@@ -55,8 +55,8 @@ command_dict = {
DELAY_SAFETY_MARGIN = 10
MAX_START_TIME = 30
PYPY_EXECUTABLE = os.getenv('NEO_PYPY')
if PYPY_EXECUTABLE:
NEOMASTER_PYPY = os.getenv('NEOMASTER_PYPY')
if NEOMASTER_PYPY:
import neo, msgpack
PYPY_TEMPLATE = """\
import os, signal, sys
......@@ -194,8 +194,8 @@ class Process(object):
from coverage import Coverage
coverage = Coverage(coverage_data_path)
coverage.start()
elif PYPY_EXECUTABLE and command == 'neomaster':
os.execlp(PYPY_EXECUTABLE, PYPY_EXECUTABLE, '-c',
elif NEOMASTER_PYPY and command == 'neomaster':
os.execlp(NEOMASTER_PYPY, NEOMASTER_PYPY, '-c',
PYPY_TEMPLATE % (
w, self._coverage_fd, w,
logging._max_size, logging._max_packet,
......@@ -348,7 +348,7 @@ class NEOCluster(object):
temp_dir = tempfile.mkdtemp(prefix='neo_')
print 'Using temp directory ' + temp_dir
if adapter == 'MySQL':
self.db_template = setupMySQLdb(db_list, clear_databases)
self.db_template = setupMySQL(db_list, clear_databases)
elif adapter == 'SQLite':
self.db_template = (lambda t: lambda db:
':memory:' if db is None else db if os.sep in db else t % db
......
......@@ -47,10 +47,15 @@ class StorageClientHandlerTests(NeoUnitTestBase):
def _getConnection(self, uuid=None):
return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000))
def fakeDM(self, **kw):
self.app.dm.close()
self.app.dm = dm = Mock(kw)
return dm
def test_18_askTransactionInformation1(self):
# transaction does not exists
conn = self._getConnection()
self.app.dm = Mock({'getNumPartitions': 1})
self.fakeDM(getNumPartitions=1)
self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn)
......@@ -58,7 +63,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
# invalid offsets => error
app = self.app
app.pt = Mock()
app.dm = Mock()
self.fakeDM()
conn = self._getConnection()
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None)
self.assertEqual(len(app.pt.mockGetNamedCalls('getCellList')), 0)
......@@ -68,7 +73,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
# well case => answer
conn = self._getConnection()
self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.fakeDM(getTIDList=(INVALID_TID,))
self.operation.askTIDs(conn, 1, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1)
......@@ -77,12 +82,11 @@ class StorageClientHandlerTests(NeoUnitTestBase):
def test_26_askObjectHistory1(self):
# invalid offsets => error
app = self.app
app.dm = Mock()
dm = self.fakeDM()
conn = self._getConnection()
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn,
1, 1, None)
self.assertEqual(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
self.assertEqual(len(dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_askObjectUndoSerial(self):
conn = self._getConnection(uuid=self.getClientUUID())
......@@ -94,9 +98,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.app.tm = Mock({
'getObjectFromTransaction': None,
})
self.app.dm = Mock({
'findUndoTID': ReturnValues((None, None, False), )
})
self.fakeDM(findUndoTID=ReturnValues((None, None, False),))
self.operation.askObjectUndoSerial(conn, tid, ltid, undone_tid, oid_list)
self.checkErrorPacket(conn)
......
......@@ -82,8 +82,9 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
app.pt = PartitionTable(3, 1)
app.pt._id = 1
ptid = 2
app.dm = Mock({ })
app.replicator = Mock({})
app.dm.close()
app.dm = Mock()
app.replicator = Mock()
self.operation.notifyPartitionChanges(conn, ptid, 1, cells)
# ptid set
self.assertEqual(app.pt.getID(), ptid)
......
......@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from binascii import a2b_hex
from contextlib import contextmanager
from contextlib import closing, contextmanager
import unittest
from neo.lib.util import add64, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
......@@ -34,22 +34,18 @@ class StorageDBTests(NeoUnitTestBase):
try:
return self._db
except AttributeError:
self.setNumPartitions(1)
self.setupDB(1)
return self._db
def _tearDown(self, success):
try:
self.__dict__.pop('_db', None).close()
except AttributeError:
pass
NeoUnitTestBase._tearDown(self, success)
def getDB(self, reset=0):
def _getDB(self, reset):
raise NotImplementedError
def setNumPartitions(self, num_partitions, reset=0):
def setupDB(self, num_partitions=None, reset=False):
assert not hasattr(self, '_db')
self._db = db = self.getDB(reset)
self._db = db = self._getDB(reset)
self.addCleanup(db.close)
if num_partitions is None:
return
uuid = self.getStorageUUID()
db.setUUID(uuid)
self.assertEqual(uuid, db.getUUID())
......@@ -80,12 +76,12 @@ class StorageDBTests(NeoUnitTestBase):
self.db.abortTransaction(ttid)
def test_UUID(self):
db = self.getDB()
self.checkConfigEntry(db.getUUID, db.setUUID, 123)
self.setupDB()
self.checkConfigEntry(self.db.getUUID, self.db.setUUID, 123)
def test_Name(self):
db = self.getDB()
self.checkConfigEntry(db.getName, db.setName, 'TEST_NAME')
self.setupDB()
self.checkConfigEntry(self.db.getName, self.db.setName, 'TEST_NAME')
def getOIDs(self, count):
return map(p64, xrange(count))
......@@ -111,9 +107,8 @@ class StorageDBTests(NeoUnitTestBase):
raise NotImplementedError
def test_lockDatabase(self):
db = self._test_lockDatabase_open()
self.assertRaises(SystemExit, self._test_lockDatabase_open)
db.close()
with closing(self._test_lockDatabase_open()) as db:
self.assertRaises(SystemExit, self._test_lockDatabase_open)
self._test_lockDatabase_open().close()
def test_getUnfinishedTIDDict(self):
......@@ -237,7 +232,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_deleteRange(self):
np = 4
self.setNumPartitions(np)
self.setupDB(np)
t1, t2, t3 = map(p64, (1, 2, 3))
oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3:
......@@ -310,7 +305,7 @@ class StorageDBTests(NeoUnitTestBase):
return tid_list
def test_getTIDList(self):
self.setNumPartitions(2, True)
self.setupDB(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
# - all partitions
......@@ -330,7 +325,7 @@ class StorageDBTests(NeoUnitTestBase):
self.checkSet(result, [])
def test_getReplicationTIDList(self):
self.setNumPartitions(2, True)
self.setupDB(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# - one partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
......@@ -352,7 +347,7 @@ class StorageDBTests(NeoUnitTestBase):
def check(trans, obj, *args):
self.assertEqual(trans, self.db.checkTIDRange(*args))
self.assertEqual(obj, self.db.checkSerialRange(*(args+(ZERO_OID,))))
self.setNumPartitions(2, True)
self.setupDB(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
z = 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
# - one partition
......@@ -380,7 +375,7 @@ class StorageDBTests(NeoUnitTestBase):
check(y, x + y[1:], 1, 1, ZERO_TID, MAX_TID)
def test_findUndoTID(self):
self.setNumPartitions(4, True)
self.setupDB(4, True)
db = self.db
tid1 = self.getNextTID()
tid2 = self.getNextTID()
......
......@@ -15,17 +15,16 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from contextlib import contextmanager
from MySQLdb import NotSupportedError, OperationalError, ProgrammingError
from MySQLdb.constants.CR import SERVER_GONE_ERROR
from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE
from contextlib import closing, contextmanager
from ..mock import Mock
from neo.lib.protocol import ZERO_OID
from neo.lib.util import p64
from .. import DB_PREFIX, DB_USER, Patch, setupMySQLdb
from .. import DB_PREFIX, DB_USER, Patch, setupMySQL
from .testStorageDBTests import StorageDBTests
from neo.storage.database import DatabaseFailure
from neo.storage.database.mysql import MySQLDatabaseManager
from neo.storage.database.mysql import (MySQLDatabaseManager,
NotSupportedError, OperationalError, ProgrammingError,
SERVER_GONE_ERROR, UNKNOWN_STORAGE_ENGINE)
class ServerGone(object):
......@@ -50,17 +49,21 @@ class StorageMySQLdbTests(StorageDBTests):
database = self.db_template(0)
return MySQLDatabaseManager(database, self.engine)
def getDB(self, reset=0):
def _getDB(self, reset):
db = self._test_lockDatabase_open()
self.assertEqual(db.db, DB_PREFIX + '0')
self.assertEqual(db.user, DB_USER)
try:
db.setup(reset, True)
except NotSupportedError as m:
code, m = m.args
if code != UNKNOWN_STORAGE_ENGINE:
raise
raise unittest.SkipTest(m)
self.assertEqual(db.db, DB_PREFIX + '0')
self.assertEqual(db.user, DB_USER)
try:
db.setup(reset, True)
except NotSupportedError as m:
code, m = m.args
if code != UNKNOWN_STORAGE_ENGINE:
raise
raise unittest.SkipTest(m)
except:
db.close()
raise
return db
def test_ServerGone(self):
......@@ -75,8 +78,9 @@ class StorageMySQLdbTests(StorageDBTests):
pass
def query(*args):
raise OperationalError(-1, 'this is a test')
self.db.conn = FakeConn()
self.assertRaises(DatabaseFailure, self.db.query, 'QUERY')
with closing(self.db.conn):
self.db.conn = FakeConn()
self.assertRaises(DatabaseFailure, self.db.query, 'QUERY')
def test_escape(self):
self.assertEqual(self.db.escape('a"b'), 'a\\"b')
......
......@@ -25,7 +25,7 @@ class StorageSQLiteTests(StorageDBTests):
db = os.path.join(getTempDirectory(), DB_PREFIX + '0.sqlite')
return SQLiteDatabaseManager(db)
def getDB(self, reset=0):
def _getDB(self, reset=False):
db = SQLiteDatabaseManager(':memory:')
db.setup(reset, True)
return db
......@@ -33,8 +33,8 @@ class StorageSQLiteTests(StorageDBTests):
def test_lockDatabase(self):
super(StorageSQLiteTests, self).test_lockDatabase()
# No lock on temporary databases.
db = self.getDB()
self.getDB().close()
db = self._getDB()
self._getDB().close()
db.close()
del StorageDBTests
......
......@@ -40,7 +40,7 @@ from neo.lib.protocol import ZERO_OID, ZERO_TID, MAX_TID, uuid_str, \
ClusterStates, Enum, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64
from neo.master.recovery import RecoveryManager
from .. import (getTempDirectory, setupMySQLdb,
from .. import (getTempDirectory, setupMySQL,
ImporterConfigParser, NeoTestBase, Patch,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX)
......@@ -787,7 +787,7 @@ class NEOCluster(object):
db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
for _ in xrange(storage_count)]
if adapter == 'MySQL':
db = setupMySQLdb(db_list, clear_databases)
db = setupMySQL(db_list, clear_databases)
elif adapter == 'SQLite':
db = os.path.join(getTempDirectory(), '%s.sqlite').__mod__
else:
......
......@@ -1663,6 +1663,9 @@ class Test(NEOThreadedTest):
m2c, = cluster.master.getConnectionList(cluster.client)
cluster.client._cache.clear()
c.cacheMinimize()
if not hasattr(sys, 'getrefcount'): # PyPy
# See persistent commit ff64867cca3179b1a6379c93b6ef90db565da36c
import gc; gc.collect()
# Make the master disconnects the client when the latter is about
# to send a AskObject packet to the storage node.
with cluster.client.filterConnection(cluster.storage) as c2s:
......
......@@ -128,7 +128,9 @@ class ImporterTests(NEOThreadedTest):
r5["foo"] = "bar"
state = {r2: r3, r4: r5}
p = StringIO()
Pickler(p, 1).dump(Obj).dump(state)
pickler = Pickler(p, 1)
pickler.dump(Obj)
pickler.dump(state)
p = p.getvalue()
r = DummyRepickler()(p)
load = Unpickler(StringIO(r)).load
......
......@@ -10,6 +10,8 @@ Intended Audience :: Developers
License :: OSI Approved :: GNU General Public License (GPL)
Operating System :: POSIX :: Linux
Programming Language :: Python :: 2.7
Programming Language :: Python :: Implementation :: CPython
Programming Language :: Python :: Implementation :: PyPy
Topic :: Database
Topic :: Software Development :: Libraries :: Python Modules
"""
......@@ -53,6 +55,7 @@ extras_require = {
'master': [],
'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'],
'storage-pymysql': ['PyMySQL'],
'storage-importer': zodb_require + ['setproctitle'],
}
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
......
......@@ -18,7 +18,7 @@ from neo.lib.debug import PdbSocket
from neo.lib.node import Node
from neo.lib.protocol import NodeTypes
from neo.lib.util import datetimeFromTID, p64, u64
from neo.storage.app import DATABASE_MANAGER_DICT, \
from neo.storage.app import DATABASE_MANAGERS, \
Application as StorageApplication
from neo.tests import getTempDirectory, mysql_pool
from neo.tests.ConflictFree import ConflictFreeLog
......@@ -580,7 +580,7 @@ class ArgumentDefaultsHelpFormatter(argparse.HelpFormatter):
def main():
adapters = sorted(DATABASE_MANAGER_DICT)
adapters = list(DATABASE_MANAGERS)
adapters.remove('Importer')
default_adapter = 'SQLite'
assert default_adapter in adapters
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment