Commit 42495e18 authored by Julien Muchembled's avatar Julien Muchembled

Implement automatic renewal of client certificate

parent f2e11d86
#!/usr/bin/python
import argparse, atexit, errno, os, subprocess, sqlite3, sys
import argparse, atexit, errno, os, subprocess, sqlite3, sys, time
from OpenSSL import crypto
from re6st import registry, utils
......@@ -139,7 +139,11 @@ def main():
os.ftruncate(cert_fd, len(cert))
os.close(cert_fd)
print "Certificate setup complete."
cert = loadCert(cert)
not_after = utils.notAfter(cert)
print("Setup complete. Certificate is valid until %s"
" and will be automatically renewed after %s" % (
time.ctime(not_after), time.ctime(not_after - registry.RENEW_PERIOD)))
if not os.path.lexists(conf_path):
create(conf_path, """\
......@@ -160,7 +164,7 @@ dh %s
""" % (config.registry, ca_path, cert_path, key_path, dh_path))
print "Sample configuration file created."
cn = utils.subnetFromCert(loadCert(cert))
cn = utils.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \
% (utils.ipFromBin(subnet), len(subnet), cn)
......
......@@ -8,6 +8,7 @@ from urllib import splittype, splithost, splitport, urlencode
from . import tunnel, utils
HMAC_HEADER = "Re6stHMAC"
RENEW_PERIOD = 30 * 86400
class getcallargs(type):
......@@ -190,28 +191,38 @@ class RegistryServer(object):
(token,)).next()
except StopIteration:
return
self.db.execute("DELETE FROM token WHERE token = ?", (token,))
# Get a new prefix
self.db.execute("DELETE FROM token WHERE token = ?",
(token,))
prefix = self._getPrefix(prefix_len)
self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?",
(email, prefix))
return self._createCertificate(prefix, req.get_subject(),
req.get_pubkey())
# Create certificate
cert = crypto.X509()
cert.set_serial_number(0) # required for libssl < 1.0
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(self.cert_duration)
cert.set_issuer(self.ca.get_subject())
subject = req.get_subject()
subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
cert.set_subject(subject)
cert.set_pubkey(req.get_pubkey())
cert.sign(self.key, 'sha1')
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
# Insert certificate into db
self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
def _createCertificate(self, client_prefix, subject, pubkey):
cert = crypto.X509()
cert.set_serial_number(0) # required for libssl < 1.0
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(self.cert_duration)
cert.set_issuer(self.ca.get_subject())
subject.CN = "%u/%u" % (int(client_prefix, 2), len(client_prefix))
cert.set_subject(subject)
cert.set_pubkey(pubkey)
cert.sign(self.key, 'sha1')
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?",
(cert, client_prefix))
return cert
return cert
def renewCertificate(self, cn):
with self.lock:
with self.db:
pem = self._getCert(cn)
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
if utils.notAfter(cert) - RENEW_PERIOD < time.time():
pem = self._createCertificate(cn, cert.get_subject(),
cert.get_pubkey())
return pem
def getCa(self):
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
......
......@@ -132,6 +132,9 @@ def networkFromCa(ca):
def subnetFromCert(cert):
return cert.get_subject().CN
def notAfter(cert):
return time.mktime(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))
def dump_address(address):
return ';'.join(map(','.join, address))
......
......@@ -4,8 +4,10 @@ import sqlite3, subprocess, sys, time, traceback
from collections import deque
from OpenSSL import crypto
from re6st import db, plib, tunnel, utils
from re6st.registry import RegistryClient
from re6st.registry import RegistryClient, RENEW_PERIOD
class ReexecException(Exception):
pass
def getConfig():
parser = utils.ArgParser(fromfile_prefix_chars='@',
......@@ -112,6 +114,8 @@ def getConfig():
return parser.parse_args()
def renew(*args):
raise ReexecException("Restart to renew certificate")
def main():
# Get arguments
......@@ -144,6 +148,24 @@ def main():
signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1))
signal.signal(signal.SIGTERM, lambda *args: sys.exit())
registry = RegistryClient(config.registry, config.key, ca)
while True:
next_renew = utils.notAfter(cert) - RENEW_PERIOD
if time.time() < next_renew:
break
pem = registry.renewCertificate(prefix)
if not pem or pem == crypto.dump_certificate(crypto.FILETYPE_PEM, cert):
logging.warning("Certificate not renewed. Will retry tomorrow.")
next_renew = time.time() + 86400
break
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
path = config.cert + '.new'
with open(path, 'w') as f:
f.write(pem)
os.rename(path, config.cert)
logging.info("Certificate renewed until %s",
time.ctime(utils.notAfter(cert)))
if config.max_clients is None:
config.max_clients = config.client_count * 2
......@@ -232,7 +254,6 @@ def main():
# Create and open read_only pipe to get server events
r_pipe, write_pipe = os.pipe()
read_pipe = os.fdopen(r_pipe)
registry = RegistryClient(config.registry, config.key, ca)
peer_db = db.PeerDB(db_path, registry, config.key, prefix)
tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db,
config.openvpn_args, timeout, config.tunnel_refresh,
......@@ -313,7 +334,12 @@ def main():
# main loop
if tunnel_manager is None:
sys.exit(os.WEXITSTATUS(os.wait()[1]))
signal.signal(signal.SIGALRM, renew)
signal.alarm(int(next_renew - time.time()))
try:
sys.exit(os.WEXITSTATUS(os.wait()[1]))
finally:
signal.alarm(0)
cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll
while True:
next = tunnel_manager.next_refresh
......@@ -333,6 +359,8 @@ def main():
t = time.time()
if t >= tunnel_manager.next_refresh:
tunnel_manager.refresh()
if t >= next_renew:
renew()
if forwarder and t >= forwarder.next_refresh:
forwarder.refresh()
finally:
......@@ -344,16 +372,18 @@ def main():
except sqlite3.Error:
logging.exception("Restarting with empty cache")
os.rename(db_path, db_path + '.bak')
try:
sys.exitfunc()
finally:
os.execvp(sys.argv[0], sys.argv)
except ReexecException, e:
logging.info(e)
except KeyboardInterrupt:
return 0
except Exception:
f = traceback.format_exception(*sys.exc_info())
logging.error('%s%s', f.pop(), ''.join(f))
sys.exit(1)
try:
sys.exitfunc()
finally:
os.execvp(sys.argv[0], sys.argv)
if __name__ == "__main__":
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment