Commit 990d6337 authored by Julien Muchembled's avatar Julien Muchembled

Fix 'cannot VACUUM from within a transaction' error in threaded tests

parent 42b89d48
...@@ -66,7 +66,7 @@ class NEOLogger(Logger): ...@@ -66,7 +66,7 @@ class NEOLogger(Logger):
self.parent = root = getLogger() self.parent = root = getLogger()
if not root.handlers: if not root.handlers:
root.addHandler(self.default_root_handler) root.addHandler(self.default_root_handler)
self.db = None self._db = None
self._record_queue = deque() self._record_queue = deque()
self._record_size = 0 self._record_size = 0
self._async = set() self._async = set()
...@@ -82,6 +82,13 @@ class NEOLogger(Logger): ...@@ -82,6 +82,13 @@ class NEOLogger(Logger):
self._release = _release self._release = _release
self.backlog() self.backlog()
def __enter__(self):
self._acquire()
return self._db
def __exit__(self, t, v, tb):
self._release()
def __async(wrapped): def __async(wrapped):
def wrapper(self): def wrapper(self):
self._async.add(wrapped) self._async.add(wrapped)
...@@ -91,21 +98,20 @@ class NEOLogger(Logger): ...@@ -91,21 +98,20 @@ class NEOLogger(Logger):
@__async @__async
def flush(self): def flush(self):
if self.db is None: if self._db is None:
return return
try: try:
self.db.execute("BEGIN") self._db.execute("BEGIN")
for r in self._record_queue: for r in self._record_queue:
self._emit(r) self._emit(r)
finally: finally:
# Always commit, to not lose any record that we could emit. # Always commit, to not lose any record that we could emit.
self.db.commit() self._db.commit()
self._record_queue.clear() self._record_queue.clear()
self._record_size = 0 self._record_size = 0
def backlog(self, max_size=1<<24): def backlog(self, max_size=1<<24):
self._acquire() with self:
try:
self._max_size = max_size self._max_size = max_size
if max_size is None: if max_size is None:
self.flush() self.flush()
...@@ -113,26 +119,23 @@ class NEOLogger(Logger): ...@@ -113,26 +119,23 @@ class NEOLogger(Logger):
q = self._record_queue q = self._record_queue
while max_size < self._record_size: while max_size < self._record_size:
self._record_size -= RECORD_SIZE + len(q.popleft().msg) self._record_size -= RECORD_SIZE + len(q.popleft().msg)
finally:
self._release()
def setup(self, filename=None, reset=False): def setup(self, filename=None, reset=False):
self._acquire() with self:
try:
from . import protocol as p from . import protocol as p
global uuid_str global uuid_str
uuid_str = p.uuid_str uuid_str = p.uuid_str
if self.db is not None: if self._db is not None:
self.db.close() self._db.close()
if not filename: if not filename:
self.db = None self._db = None
self._record_queue.clear() self._record_queue.clear()
self._record_size = 0 self._record_size = 0
return return
if filename: if filename:
self.db = sqlite3.connect(filename, isolation_level=None, self._db = sqlite3.connect(filename, isolation_level=None,
check_same_thread=False) check_same_thread=False)
q = self.db.execute q = self._db.execute
if reset: if reset:
for t in 'log', 'packet': for t in 'log', 'packet':
q('DROP TABLE IF EXISTS ' + t) q('DROP TABLE IF EXISTS ' + t)
...@@ -167,8 +170,6 @@ class NEOLogger(Logger): ...@@ -167,8 +170,6 @@ class NEOLogger(Logger):
break break
else: else:
q("INSERT INTO protocol VALUES (?,?)", (time(), p)) q("INSERT INTO protocol VALUES (?,?)", (time(), p))
finally:
self._release()
__del__ = setup __del__ = setup
def isEnabledFor(self, level): def isEnabledFor(self, level):
...@@ -179,11 +180,11 @@ class NEOLogger(Logger): ...@@ -179,11 +180,11 @@ class NEOLogger(Logger):
ip, port = r.addr ip, port = r.addr
peer = '%s %s (%s:%u)' % ('>' if r.outgoing else '<', peer = '%s %s (%s:%u)' % ('>' if r.outgoing else '<',
uuid_str(r.uuid), ip, port) uuid_str(r.uuid), ip, port)
self.db.execute("INSERT INTO packet VALUES (NULL,?,?,?,?,?,?)", self._db.execute("INSERT INTO packet VALUES (NULL,?,?,?,?,?,?)",
(r.created, r._name, r.msg_id, r.code, peer, buffer(r.msg))) (r.created, r._name, r.msg_id, r.code, peer, buffer(r.msg)))
else: else:
pathname = os.path.relpath(r.pathname, *neo.__path__) pathname = os.path.relpath(r.pathname, *neo.__path__)
self.db.execute("INSERT INTO log VALUES (NULL,?,?,?,?,?,?)", self._db.execute("INSERT INTO log VALUES (NULL,?,?,?,?,?,?)",
(r.created, r._name, r.levelno, pathname, r.lineno, r.msg)) (r.created, r._name, r.levelno, pathname, r.lineno, r.msg))
def _queue(self, record): def _queue(self, record):
...@@ -205,7 +206,7 @@ class NEOLogger(Logger): ...@@ -205,7 +206,7 @@ class NEOLogger(Logger):
self._release() self._release()
def callHandlers(self, record): def callHandlers(self, record):
if self.db is not None: if self._db is not None:
record.msg = record.getMessage() record.msg = record.getMessage()
record.args = None record.args = None
if record.exc_info: if record.exc_info:
...@@ -218,7 +219,7 @@ class NEOLogger(Logger): ...@@ -218,7 +219,7 @@ class NEOLogger(Logger):
self.parent.callHandlers(record) self.parent.callHandlers(record)
def packet(self, connection, packet, outgoing): def packet(self, connection, packet, outgoing):
if self.db is not None: if self._db is not None:
ip, port = connection.getAddress() ip, port = connection.getAddress()
self._queue(PacketRecord( self._queue(PacketRecord(
created=time(), created=time(),
......
...@@ -752,9 +752,9 @@ class NEOThreadedTest(NeoTestBase): ...@@ -752,9 +752,9 @@ class NEOThreadedTest(NeoTestBase):
super(NEOThreadedTest, self)._tearDown(success) super(NEOThreadedTest, self)._tearDown(success)
ServerNode.resetPorts() ServerNode.resetPorts()
if success: if success:
q = logging.db.execute with logging as db:
q("UPDATE packet SET body=NULL") db.execute("UPDATE packet SET body=NULL")
q("VACUUM") db.execute("VACUUM")
def getUnpickler(self, conn): def getUnpickler(self, conn):
reader = conn._reader reader = conn._reader
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment