Commit c61cab22 authored by Julien Muchembled's avatar Julien Muchembled

Implement automatic renewal of client certificate

parent e24eb3f5
#!/usr/bin/python
import argparse, atexit, errno, os, subprocess, sqlite3, sys
import argparse, atexit, errno, os, subprocess, sqlite3, sys, time
from OpenSSL import crypto
from re6st import registry, utils
......@@ -139,7 +139,12 @@ def main():
os.ftruncate(cert_fd, len(cert))
os.close(cert_fd)
print "Certificate setup complete."
cert = loadCert(cert)
not_after = utils.notAfter(cert)
print("Setup complete. Certificate is valid until %s UTC"
" and will be automatically renewed after %s UTC" % (
time.asctime(time.gmtime(not_after)),
time.asctime(time.gmtime(not_after - registry.RENEW_PERIOD))))
if not os.path.lexists(conf_path):
create(conf_path, """\
......@@ -160,7 +165,7 @@ dh %s
""" % (config.registry, ca_path, cert_path, key_path, dh_path))
print "Sample configuration file created."
cn = utils.subnetFromCert(loadCert(cert))
cn = utils.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \
% (utils.ipFromBin(subnet), len(subnet), cn)
......
......@@ -8,6 +8,7 @@ from urllib import splittype, splithost, splitport, urlencode
from . import tunnel, utils
HMAC_HEADER = "Re6stHMAC"
RENEW_PERIOD = 30 * 86400
class getcallargs(type):
......@@ -190,29 +191,39 @@ class RegistryServer(object):
(token,)).next()
except StopIteration:
return
self.db.execute("DELETE FROM token WHERE token = ?", (token,))
# Get a new prefix
self.db.execute("DELETE FROM token WHERE token = ?",
(token,))
prefix = self._getPrefix(prefix_len)
self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?",
(email, prefix))
return self._createCertificate(prefix, req.get_subject(),
req.get_pubkey())
# Create certificate
def _createCertificate(self, client_prefix, subject, pubkey):
cert = crypto.X509()
cert.set_serial_number(0) # required for libssl < 1.0
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(self.cert_duration)
cert.set_issuer(self.ca.get_subject())
subject = req.get_subject()
subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
subject.CN = "%u/%u" % (int(client_prefix, 2), len(client_prefix))
cert.set_subject(subject)
cert.set_pubkey(req.get_pubkey())
cert.set_pubkey(pubkey)
cert.sign(self.key, 'sha1')
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
# Insert certificate into db
self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?",
(cert, client_prefix))
return cert
def renewCertificate(self, cn):
with self.lock:
with self.db:
pem = self._getCert(cn)
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
if utils.notAfter(cert) - RENEW_PERIOD < time.time():
pem = self._createCertificate(cn, cert.get_subject(),
cert.get_pubkey())
return pem
def getCa(self):
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
......
import argparse, errno, logging, os, shlex, signal, socket
import argparse, calendar, errno, logging, os, shlex, signal, socket
import struct, subprocess, textwrap, threading, time
logging_levels = logging.WARNING, logging.INFO, logging.DEBUG, 5
......@@ -132,6 +132,9 @@ def networkFromCa(ca):
def subnetFromCert(cert):
return cert.get_subject().CN
def notAfter(cert):
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))
def dump_address(address):
return ';'.join(map(','.join, address))
......
......@@ -4,8 +4,10 @@ import sqlite3, subprocess, sys, time, traceback
from collections import deque
from OpenSSL import crypto
from re6st import db, plib, tunnel, utils
from re6st.registry import RegistryClient
from re6st.registry import RegistryClient, RENEW_PERIOD
class ReexecException(Exception):
pass
def getConfig():
parser = utils.ArgParser(fromfile_prefix_chars='@',
......@@ -112,6 +114,34 @@ def getConfig():
return parser.parse_args()
def maybe_renew(path, cert, info, renew):
while True:
next_renew = utils.notAfter(cert) - RENEW_PERIOD
if time.time() < next_renew:
return cert, next_renew
try:
pem = renew()
if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert):
exc_info = 0
break
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
except Exception:
exc_info = 1
break
new_path = path + '.new'
with open(new_path, 'w') as f:
f.write(pem)
os.rename(new_path, path)
logging.info("%s renewed until %s UTC",
info, time.asctime(time.gmtime(utils.notAfter(cert))))
logging.error("%s not renewed. Will retry tomorrow.",
info, exc_info=exc_info)
return cert, time.time() + 86400
def exit(status):
exit.status = status
os.kill(os.getpid(), signal.SIGTERM)
def main():
# Get arguments
......@@ -142,7 +172,15 @@ def main():
plib.ovpn_log = config.log
signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1))
signal.signal(signal.SIGTERM, lambda *args: sys.exit())
signal.signal(signal.SIGTERM, lambda *args:
sys.exit(getattr(exit, 'status', None)))
registry = RegistryClient(config.registry, config.key, ca)
cert, next_renew = maybe_renew(config.cert, cert, "Certificate",
lambda: registry.renewCertificate(prefix))
ca, ca_renew = maybe_renew(config.ca, ca, "CA Certificate", registry.getCa)
if next_renew > ca_renew:
next_renew = ca_renew
if config.max_clients is None:
config.max_clients = config.client_count * 2
......@@ -232,7 +270,6 @@ def main():
# Create and open read_only pipe to get server events
r_pipe, write_pipe = os.pipe()
read_pipe = os.fdopen(r_pipe)
registry = RegistryClient(config.registry, config.key, ca)
peer_db = db.PeerDB(db_path, registry, config.key, prefix)
tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db,
config.openvpn_args, timeout, config.tunnel_refresh,
......@@ -313,7 +350,12 @@ def main():
# main loop
if tunnel_manager is None:
sys.exit(os.WEXITSTATUS(os.wait()[1]))
t = threading.Thread(target=lambda:
exit(os.WEXITSTATUS(os.wait()[1])))
t.daemon = True
t.start()
time.sleep(max(0, next_renew - time.time()))
raise ReexecException("Restart to renew certificate")
cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll
while True:
next = tunnel_manager.next_refresh
......@@ -333,6 +375,8 @@ def main():
t = time.time()
if t >= tunnel_manager.next_refresh:
tunnel_manager.refresh()
if t >= next_renew:
raise ReexecException("Restart to renew certificate")
if forwarder and t >= forwarder.next_refresh:
forwarder.refresh()
finally:
......@@ -344,16 +388,18 @@ def main():
except sqlite3.Error:
logging.exception("Restarting with empty cache")
os.rename(db_path, db_path + '.bak')
try:
sys.exitfunc()
finally:
os.execvp(sys.argv[0], sys.argv)
except ReexecException, e:
logging.info(e)
except KeyboardInterrupt:
return 0
except Exception:
f = traceback.format_exception(*sys.exc_info())
logging.error('%s%s', f.pop(), ''.join(f))
sys.exit(1)
try:
sys.exitfunc()
finally:
os.execvp(sys.argv[0], sys.argv)
if __name__ == "__main__":
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment