Commit c2cc0c11 authored by Julien Muchembled's avatar Julien Muchembled

mysql: fail instead of silently reconnect if there is any pending change

parent 069b95e5
...@@ -82,13 +82,13 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -82,13 +82,13 @@ class MySQLDatabaseManager(DatabaseManager):
while True: while True:
try: try:
self.conn = MySQLdb.connect(**kwd) self.conn = MySQLdb.connect(**kwd)
break
except Exception: except Exception:
if timeout_at is not None and time.time() >= timeout_at: if timeout_at is not None and time.time() >= timeout_at:
raise raise
logging.exception('Connection to MySQL failed, retrying.') logging.exception('Connection to MySQL failed, retrying.')
time.sleep(1) time.sleep(1)
else: self._active = 0
break
self.conn.autocommit(False) self.conn.autocommit(False)
self.conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1)) self.conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
...@@ -96,6 +96,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -96,6 +96,7 @@ class MySQLDatabaseManager(DatabaseManager):
def commit(self): def commit(self):
logging.debug('committing...') logging.debug('committing...')
self.conn.commit() self.conn.commit()
self._active = 0
def query(self, query): def query(self, query):
"""Query data from a database.""" """Query data from a database."""
...@@ -111,25 +112,23 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -111,25 +112,23 @@ class MySQLDatabaseManager(DatabaseManager):
logging.debug('querying %s...', query_part) logging.debug('querying %s...', query_part)
conn.query(query) conn.query(query)
if query.startswith("SELECT "):
r = conn.store_result() r = conn.store_result()
if r is not None: return tuple([
new_r = [] tuple([d.tostring() if isinstance(d, array) else d
for row in r.fetch_row(r.num_rows()): for d in row])
new_row = [] for row in r.fetch_row(r.num_rows())])
for d in row: r = query.split(None, 1)[0]
if isinstance(d, array): if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"):
d = d.tostring() self._active = 1
new_row.append(d) else:
new_r.append(tuple(new_row)) assert r in ("ALTER", "CREATE", "DROP", "TRUNCATE"), query
r = tuple(new_r)
except OperationalError, m: except OperationalError, m:
if m[0] in (SERVER_GONE_ERROR, SERVER_LOST): if m[0] in (SERVER_GONE_ERROR, SERVER_LOST) and not self._active:
logging.info('the MySQL server is gone; reconnecting') logging.info('the MySQL server is gone; reconnecting')
self._connect() self._connect()
return self.query(query) return self.query(query)
raise DatabaseFailure('MySQL error %d: %s' % (m[0], m[1])) raise DatabaseFailure('MySQL error %d: %s' % (m[0], m[1]))
return r
@property @property
def escape(self): def escape(self):
......
...@@ -75,14 +75,14 @@ class StorageTests(NEOFunctionalTest): ...@@ -75,14 +75,14 @@ class StorageTests(NEOFunctionalTest):
# wait for the sql transaction to be commited # wait for the sql transaction to be commited
def callback(last_try): def callback(last_try):
# One revision per object and two for the root, before and after # One revision per object and two for the root, before and after
(object_number,), = db.query('select count(*) from obj') (object_number,), = db.query('SELECT count(*) FROM obj')
return object_number == OBJECT_NUMBER + 2, object_number return object_number == OBJECT_NUMBER + 2, object_number
self.neo.expectCondition(callback) self.neo.expectCondition(callback)
# no more temporarily objects # no more temporarily objects
(t_objects,), = db.query('select count(*) from tobj') (t_objects,), = db.query('SELECT count(*) FROM tobj')
self.assertEqual(t_objects, 0) self.assertEqual(t_objects, 0)
# One object more for the root # One object more for the root
query = 'select count(*) from (select * from obj group by oid) as t' query = 'SELECT count(*) FROM (SELECT * FROM obj GROUP BY oid) AS t'
(objects,), = db.query(query) (objects,), = db.query(query)
self.assertEqual(objects, OBJECT_NUMBER + 1) self.assertEqual(objects, OBJECT_NUMBER + 1)
# Check object content # Check object content
......
...@@ -55,11 +55,11 @@ class StorageMySQSLdbTests(StorageDBTests): ...@@ -55,11 +55,11 @@ class StorageMySQSLdbTests(StorageDBTests):
(1, 2, '\x01\x02', ), (1, 2, '\x01\x02', ),
) )
self.db.conn = Mock({ 'store_result': result_object }) self.db.conn = Mock({ 'store_result': result_object })
result = self.db.query('QUERY') result = self.db.query('SELECT ')
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
calls = self.db.conn.mockGetNamedCalls('query') calls = self.db.conn.mockGetNamedCalls('query')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs('QUERY') calls[0].checkArgs('SELECT ')
def test_query2(self): def test_query2(self):
# test the OperationalError exception # test the OperationalError exception
...@@ -73,12 +73,12 @@ class StorageMySQSLdbTests(StorageDBTests): ...@@ -73,12 +73,12 @@ class StorageMySQSLdbTests(StorageDBTests):
self.connect_called = False self.connect_called = False
def connect_hook(): def connect_hook():
# mock object, break raise/connect loop # mock object, break raise/connect loop
self.db.conn = Mock({'num_rows': 0}) self.db.conn = Mock()
self.connect_called = True self.connect_called = True
self.db._connect = connect_hook self.db._connect = connect_hook
# make a query, exception will be raised then connect() will be # make a query, exception will be raised then connect() will be
# called and the second query will use the mock object # called and the second query will use the mock object
self.db.query('QUERY') self.db.query('INSERT')
self.assertTrue(self.connect_called) self.assertTrue(self.connect_called)
def test_query3(self): def test_query3(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment