Commit d95838e2 authored by Cédric Le Ninivin's avatar Cédric Le Ninivin

equeue: Rewrite equeue

Equeue is rewriten extending SocketServer.ThreadingUnixStreamServer (Thanks to Julien M.).
So far outputs of commands invoked by equeue are redirected to /dev/null to avoid locking the whole process.

equeue: Fix identation problem
parent bfd6ed8b
# -*- coding: utf-8 -*-
# vim: set et sts=2:
##############################################################################
#
# Copyright (c) 2010 Vifib SARL and Contributors. All Rights Reserved.
......@@ -33,79 +32,92 @@ import json
import logging
import logging.handlers
import os
import Queue
import select
import StringIO
import socket
import signal
import socket
import subprocess
import sys
import SocketServer
import StringIO
import threading
cleanup_data = {}
def cleanup(signum=None, frame=None):
cleanup_functions = dict(
sockets=lambda sock: sock.close(),
subprocesses=lambda process: process.terminate(),
paths=lambda filename: os.unlink(filename),
)
for data, function in cleanup_functions.iteritems():
for item in cleanup_data.get(data, []):
# XXX will these lists ever have more than 1 element??
# Swallow everything !
try:
function(item)
except:
pass
signal.signal(signal.SIGTERM, cleanup)
# I think this is obvious enough to not require any documentation, but I might
# be wrong.
class EqueueServer(SocketServer.ThreadingUnixStreamServer):
class TaskRunner(object):
daemon_threads = True
def __init__(self):
self._task = None
self._command = None
self._time = None
def __init__(self, *args, **kw):
self.options = kw.pop('equeue_options')
SocketServer.ThreadingUnixStreamServer.__init__(self,
RequestHandlerClass=None,
*args, **kw)
# Equeue Specific elements
self.setLogger(self.options.logfile[0], self.options.loglevel[0])
self.setDB(self.options.database[0])
# Lock to only have one command running at the time
self.lock = threading.Lock()
def has_running_task(self):
return self._task is not None and self._task.poll() is None
def setLogger(self, logfile, loglevel):
self.logger = logging.getLogger("EQueue")
handler = logging.handlers.WatchedFileHandler(logfile, mode='a')
# Natively support logrotate
level = logging._levelNames.get(loglevel, logging.INFO)
self.logger.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def had_previous_task(self):
return self._task is not None and self._task.poll() is not None
def setDB(self, database):
self.db = gdbm.open(database, 'cs', 0700)
def get_previous_command(self):
if not self.has_running_task():
return self._command
else:
return None
def process_request_thread(self, request, client_address):
# Handle request
self.logger.debug("Connection with file descriptor %d", request.fileno())
request.settimeout(self.options.timeout)
request_string = StringIO.StringIO()
segment = None
try:
while segment != '':
segment = request.recv(1024)
request_string.write(segment)
except socket.timeout:
pass
def get_previous_timestamp(self):
if not self.has_running_task():
return self._time
command = '127'
try:
request_parameters = json.loads(request_string.getvalue())
timestamp = request_parameters['timestamp']
command = str(request_parameters['command'])
self.logger.info("New command %r at %s", command, timestamp)
except (ValueError, IndexError) :
self.logger.warning("Error during the unserialization of json "
"message of %r file descriptor. The message "
"was %r", request.fileno(), request_string.getvalue())
try:
request.send(command)
except:
self.logger.warning("Couldn't respond to %r", request.fileno())
self.close_request(request)
# Run command if needed
with self.lock:
if command not in self.db or timestamp > int(self.db[command]):
self.logger.info("Running %s, %s", command, timestamp)
# XXX stdout and stderr redirected to null as they are not read
with open(os.devnull, 'r+') as f:
status = subprocess.call([command], close_fds=True,
stdin=f, stdout=f, stderr=f)
if status:
self.logger.warning("%s finished with non zero status.",
command)
else:
return None
def get_previous_returncode(self):
return self._task.poll()
def flush(self):
self._task = None
self._command = None
def run(self, command, time):
self._time = time
self._command = command
self._task = subprocess.Popen([command], stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=True)
self._task.stdin.flush()
self._task.stdin.close()
cleanup_data.update(subprocesses=[self._task])
def fd(self):
if not self.has_running_task():
raise KeyError("No task is running.")
return self._task.stdout.fileno()
self.logger.info("%s finished successfully.", command)
self.db[command] = timestamp
else:
self.logger.info("%s already runned.", command)
def main():
parser = argparse.ArgumentParser(
......@@ -127,108 +139,14 @@ def main():
args = parser.parse_args()
socketpath = args.socket
level = logging._levelNames.get(args.loglevel[0], logging.INFO)
logger = logging.getLogger("EQueue")
# Natively support logrotate
handler = logging.handlers.WatchedFileHandler(args.logfile[0], mode='a')
logger.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
unixsocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
unixsocket.bind(socketpath)
logger.debug("Bind on %r", socketpath)
unixsocket.listen(2)
logger.debug("Listen on socket")
unixsocketfd = unixsocket.fileno()
db = gdbm.open(args.database[0], 'cs', 0700)
signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1))
signal.signal(signal.SIGTERM, lambda *args: sys.exit())
logger.debug("Open timestamp database")
server = EqueueServer(socketpath, **{'equeue_options':args})
server.logger.info("Starting server on %r", socketpath)
server.serve_forever()
cleanup_data.update(sockets=[unixsocket])
cleanup_data.update(paths=[socketpath])
logger.info("Starting server on %r", socketpath)
task_queue = Queue.Queue()
task_running = TaskRunner()
try:
rlist = [unixsocketfd]
while True:
rlist_s, wlist_s, xlist_s = select.select(rlist, [], [])
if unixsocketfd in rlist_s:
conn, addr = unixsocket.accept()
logger.debug("Connection with file descriptor %d", conn.fileno())
conn.settimeout(args.timeout)
request_string = StringIO.StringIO()
segment = None
try:
while segment != '':
segment = conn.recv(1024)
request_string.write(segment)
except socket.timeout:
pass
command = '127'
try:
request = json.loads(request_string.getvalue())
timestamp = request['timestamp']
command = str(request['command'])
task_queue.put([command, timestamp])
logger.info("New command %r at %s", command, timestamp)
except (ValueError, IndexError) :
logger.warning("Error during the unserialization of json "
"message of %r file descriptor. The message "
"was %r", conn.fileno(), request_string.getvalue())
try:
conn.send(command)
conn.close()
except:
logger.warning("Couldn't respond to %r", conn.fileno())
rlist = [unixsocketfd]
if not task_running.has_running_task():
if task_running.had_previous_task():
previous_command = task_running.get_previous_command()
if task_running.get_previous_returncode() != 0:
logger.warning("%s finished with non zero status.",
previous_command)
else:
logger.info("%s finished successfully.", previous_command)
task_running.flush()
db[previous_command] = str(task_running.get_previous_timestamp())
try:
while True:
command, timestamp = task_queue.get(False)
if command not in db or timestamp > int(db[command]):
logger.info("Running %s", command)
task_running.run(command, timestamp)
break
else:
logger.info("%s already runned.", command)
except Queue.Empty:
logger.info("Task queue is empty. Nothing to do...")
try:
rlist.append(task_running.fd())
except KeyError:
pass
finally:
cleanup()
if __name__ == '__main__':
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment