Commit 8ebdd500 authored by Julien Muchembled's avatar Julien Muchembled

Certificate revocation, with broadcast of CRL

parent f73c51ec
......@@ -20,3 +20,8 @@
- registry: add '--home PATH' command line option so that / display an HTML
page from PATH (use new str.format for templating)
- Better UI to revoke certificates, for example with a HTML form.
Currently, one have to forge the URL manually. Examples:
wget -O /dev/null http://re6st.example.com/revoke?cn_or_serial=123
wget -O /dev/null http://re6st.example.com/revoke?cn_or_serial=4/16
......@@ -4,6 +4,8 @@ from . import utils, version, x509
class Cache(object):
crl = ()
def __init__(self, db_path, registry, cert, db_size=200):
self._prefix = cert.prefix
self._db_size = db_size
......@@ -40,6 +42,7 @@ class Cache(object):
# when it tried to send us new parameters.
or self._prefix == self.registry_prefix):
self.updateConfig()
self.next_renew = cert.maybeRenew(self._registry, self.crl)
if version.protocol < self.min_protocol:
logging.critical("Your version of re6stnet is too old."
" Please update.")
......@@ -64,7 +67,11 @@ class Cache(object):
cls = self.__class__
logging.debug("Loading network parameters:")
for k, v in config:
hasattr(cls, k) or setattr(self, k, v)
if k == 'crl':
v = set(json.loads(v))
elif hasattr(cls, k):
continue
setattr(self, k, v)
logging.debug("- %s: %r", k, v)
def updateConfig(self):
......@@ -77,6 +84,7 @@ class Cache(object):
config = dict((str(k), v.decode('base64') if k in base64 else
str(v) if type(v) is unicode else v)
for k, v in config.iteritems())
config['crl'] = json.dumps(config['crl'])
except socket.error, e:
logging.warning(e)
return
......
......@@ -9,7 +9,7 @@ if script_type == 'up':
os.execlp('ip', 'ip', 'link', 'set', os.environ['dev'], 'up',
'mtu', os.environ['tun_mtu'])
# Write into pipe external ip address received
import time
os.write(int(sys.argv[1]), "%s %s %s %s\n" % (script_type,
os.environ['common_name'], time.time(), os.environ['OPENVPN_external_ip']))
if script_type == 'route-up':
import time
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])))
......@@ -2,15 +2,16 @@
import os, sys
script_type = os.environ['script_type']
external_ip = lambda: os.getenv('trusted_ip') or os.environ['trusted_ip6']
external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events
fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip))))
if script_type == 'client-connect':
if os.read(fd, 1) == '\0':
sys.exit(1)
# Send client its external ip address
with open(sys.argv[2], 'w') as f:
f.write('push "setenv-safe external_ip %s"\n' % external_ip())
# Write into pipe connect/disconnect events
arg1 = sys.argv[1]
if arg1 != 'None':
os.write(int(arg1), '%s %s %s\n' % (
script_type, os.environ['common_name'], external_ip()))
f.write('push "setenv-safe external_ip %s"\n' % external_ip)
......@@ -25,10 +25,8 @@ def openvpn(iface, encrypt, *args, **kw):
ovpn_link_mtu_dict = {'udp': 1481, 'udp6': 1450}
def server(iface, max_clients, dh_path, pipe_fd, port, proto, encrypt, *args, **kw):
client_script = '%s %s' % (ovpn_server, pipe_fd)
if pipe_fd is not None:
args = ('--client-disconnect', client_script) + args
def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
client_script = '%s %s' % (ovpn_server, fd)
try:
args = ('--link-mtu', str(ovpn_link_mtu_dict[proto]),
# mtu-disc ignored for udp6 due to a bug in OpenVPN
......@@ -39,6 +37,7 @@ def server(iface, max_clients, dh_path, pipe_fd, port, proto, encrypt, *args, **
'--tls-server',
'--mode', 'server',
'--client-connect', client_script,
'--client-disconnect', client_script,
'--dh', dh_path,
'--max-clients', str(max_clients),
'--port', str(port),
......
......@@ -25,6 +25,7 @@ from collections import defaultdict, deque
from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText
from operator import itemgetter
from OpenSSL import crypto
from urllib import splittype, splithost, splitport, urlencode
from . import ctl, tunnel, utils, version, x509
......@@ -71,6 +72,20 @@ class RegistryServer(object):
"email TEXT",
"cert TEXT"):
self.db.execute("INSERT INTO cert VALUES ('',null,null)")
if utils.sqliteCreateTable(self.db, "crl",
"serial INTEGER PRIMARY KEY NOT NULL",
# Expiration date of revoked certificate.
# TODO: purge rows with dates in the past.
"date INTEGER NOT NULL"):
# Revoke certificates produced by previous version.
# They all have serial 0.
try:
date = max(x509.notAfter(x[0]) for x in self.iterCert())
except ValueError:
pass
else:
if time.time() < date:
self.db.execute("INSERT INTO crl VALUES (0,?)", (date,))
self.cert = x509.Cert(self.config.ca, self.config.key)
# Get vpn network prefix
......@@ -97,9 +112,11 @@ class RegistryServer(object):
self.db.execute("INSERT OR REPLACE INTO config VALUES (?, ?)",
name_value)
def updateNetworkConfig(self):
def updateNetworkConfig(self, _it0=itemgetter(0)):
kw = {
'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125',
'crl': map(_it0, self.db.execute(
"SELECT serial FROM crl ORDER BY serial")),
'protocol': version.protocol,
'registry_prefix': self.prefix,
}
......@@ -220,7 +237,7 @@ class RegistryServer(object):
def handle_request(self, request, method, kw,
_localhost=('127.0.0.1', '::1')):
m = getattr(self, method)
if method in ('versions', 'topology'):
if method in ('revoke', 'versions', 'topology'):
x_forwarded_for = request.headers.get('X-Forwarded-For')
if request.client_address[0] not in _localhost or \
x_forwarded_for and x_forwarded_for not in _localhost:
......@@ -393,15 +410,16 @@ class RegistryServer(object):
@rpc
def renewCertificate(self, cn):
with self.lock:
with self.db:
with self.db as db:
pem = self.getCert(cn)
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
if x509.notAfter(cert) - RENEW_PERIOD < time.time():
not_after = None
elif cert.get_serial_number():
return pem
else:
elif db.execute("SELECT count(*) FROM crl WHERE serial=?",
(cert.get_serial_number(),)).fetchone()[0]:
not_after = cert.get_notAfter()
else:
return pem
return self.createCertificate(cn,
cert.get_subject(), cert.get_pubkey(), not_after)
......@@ -452,6 +470,29 @@ class RegistryServer(object):
logging.info("Sending bootstrap peer: %s", msg)
return x509.encrypt(cert, msg)
@rpc
def revoke(self, cn_or_serial):
with self.lock:
with self.db:
q = self.db.execute
try:
serial = int(cn_or_serial)
except ValueError:
prefix = utils.binFromSubnet(cn_or_serial)
cert = self.getCert(prefix)
q("UPDATE cert SET email=null, cert=null WHERE prefix=?",
(prefix,))
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
serial = cert.get_serial_number()
self.sessions.pop(prefix, None)
else:
cert, = (cert for cert, prefix, email in self.iterCert()
if cert.get_serial_number() == serial)
not_after = x509.notAfter(cert)
if time.time() < not_after:
q("INSERT INTO crl VALUES (?,?)", (serial, not_after))
self.updateNetworkConfig()
@rpc
def versions(self):
with self.peers_lock:
......
import errno, logging, os, random, socket, subprocess, time, weakref
import errno, logging, os, random, socket, subprocess, struct, time, weakref
from collections import defaultdict, deque
from bisect import bisect, insort
from OpenSSL import crypto
......@@ -40,6 +40,7 @@ class MultiGatewayManager(dict):
class Connection(object):
_retry = 0
serial = None
time = float('inf')
def __init__(self, tunnel_manager, address_list, iface, prefix):
......@@ -69,15 +70,19 @@ class Connection(object):
'--connect-retry-max', '3', '--tls-exit',
'--remap-usr1', 'SIGTERM',
'--ping-exit', str(tm.timeout),
'--route-up', '%s %u' % (plib.ovpn_client, tm.write_pipe),
'--route-up', '%s %u' % (plib.ovpn_client, tm.write_sock.fileno()),
*tm.ovpn_args)
tm.resetTunnelRefresh()
self._retry += 1
def connected(self):
def connected(self, serial):
cache = self.tunnel_manager.cache
if serial in cache.crl:
self.tunnel_manager._kill(self._prefix)
return
self.serial = serial
i = self._retry - 1
self._retry = None
cache = self.tunnel_manager.cache
if i:
cache.addPeer(self._prefix, ','.join(self.address_list[i]), True)
else:
......@@ -167,14 +172,14 @@ class BaseTunnelManager(object):
_forward = None
def __init__(self, cache, cert, cert_renew, address=()):
def __init__(self, cache, cert, address=()):
self.cert = cert
self._network = cert.network
self._prefix = cert.prefix
self.cache = cache
self._connecting = set()
self._connection_dict = {}
self._served = set()
self._served = defaultdict(dict)
self._version = cache.version
address_dict = defaultdict(list)
......@@ -190,9 +195,9 @@ class BaseTunnelManager(object):
self.sock.bind(('::', PORT))
p = x509.Peer(self._prefix)
p.stop_date = cert_renew
p.stop_date = cache.next_renew
self._peers = [p]
self._timeouts = [(cert_renew, self.invalidatePeers)]
self._timeouts = [(p.stop_date, self.invalidatePeers)]
def select(self, r, w, t):
r[self.sock] = self.handlePeerEvent
......@@ -307,6 +312,9 @@ class BaseTunnelManager(object):
cert = self.cert.loadVerify(msg,
True, crypto.FILETYPE_ASN1)
stop_date = x509.notAfter(cert)
serial = cert.get_serial_number()
if serial in self.cache.crl:
raise ValueError("revoked")
except (x509.VerifyError, ValueError), e:
logging.debug('ignored invalid certificate from %r (%s)',
address, e.args[-1])
......@@ -320,6 +328,7 @@ class BaseTunnelManager(object):
peer = x509.Peer(p)
insort(self._peers, peer)
peer.cert = cert
peer.serial = serial
peer.stop_date = stop_date
self.selectTimeout(stop_date, self.invalidatePeers, False)
if seqno:
......@@ -398,7 +407,7 @@ class BaseTunnelManager(object):
raise utils.ReexecException(
"Restart with new network parameters")
def broadcastVersion(self):
def _newVersion(self):
pass
def newVersion(self):
......@@ -410,32 +419,77 @@ class BaseTunnelManager(object):
logging.info("changed: %r", changed)
self.selectTimeout(None, self.newVersion)
self._version = self.cache.version
self.broadcastVersion()
self._newVersion()
self.cache.warnProtocol()
if not self.NEED_RESTART.isdisjoint(changed) or \
version.protocol < self.cache.min_protocol:
crl = self.cache.crl
for i in reversed([i for i, peer in enumerate(self._peers)
if peer.serial in crl]):
del self._peers[i]
if self.cert.cert.get_serial_number() in crl:
raise utils.ReexecException("Our certificate has just been revoked."
" Let's try to renew it.")
if (not self.NEED_RESTART.isdisjoint(changed)
or version.protocol < self.cache.min_protocol
# TODO: With --management, we could kill clients without restarting.
or not all(crl.isdisjoint(serials.itervalues())
for serials in self._served.itervalues())):
# Wait at least 1 second to broadcast new version to neighbours.
# If re6stnet is too old, don't abort now, because a new version
# may have been installed without restart.
self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
self._restart)
def handleServerEvent(self, sock):
event, args = eval(sock.recv(65536))
logging.debug("%s%r", event, args)
r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args)
if r is not None:
sock.send(chr(r))
def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip):
if serial in self.cache.crl:
return False
prefix = utils.binFromSubnet(common_name)
self._served[prefix][iface] = serial
if isinstance(self, TunnelManager): # XXX
if self._gateway_manager is not None:
self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix)
self.cache.connecting(prefix, 0)
return True
def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip):
prefix = utils.binFromSubnet(common_name)
serials = self._served.get(prefix)
try:
del serials[iface]
except (KeyError, TypeError):
logging.exception("ovpn_client_disconnect%r",
(common_name, iface, serial, trusted_ip))
return
if not serials:
del self._served[prefix]
if isinstance(self, TunnelManager): # XXX
self._abortTunnelKiller(prefix, iface)
if self._gateway_manager is not None:
self._gateway_manager.remove(trusted_ip)
class TunnelManager(BaseTunnelManager):
NEED_RESTART = BaseTunnelManager.NEED_RESTART.union((
'client_count', 'max_clients', 'tunnel_refresh'))
def __init__(self, control_socket, cache, cert, cert_renew, openvpn_args,
def __init__(self, control_socket, cache, cert, openvpn_args,
timeout, client_count, iface_list, address, ip_changed,
remote_gateway, disable_proto, neighbour_list=()):
super(TunnelManager, self).__init__(cache, cert, cert_renew, address)
super(TunnelManager, self).__init__(cache, cert, address)
self.ctl = ctl.Babel(control_socket, weakref.proxy(self), self._network)
self.ovpn_args = openvpn_args
self.timeout = timeout
# Create and open read_only pipe to get server events
r, self.write_pipe = os.pipe()
self._read_pipe = os.fdopen(r)
self._read_sock, self.write_sock = socket.socketpair(
socket.AF_UNIX, socket.SOCK_DGRAM)
self._disconnected = 0
self._distant_peers = []
self._iface_to_prefix = {}
......@@ -497,7 +551,7 @@ class TunnelManager(BaseTunnelManager):
def select(self, r, w, t):
super(TunnelManager, self).select(r, w, t)
r[self._read_pipe] = self.handleTunnelEvent
r[self._read_sock] = self.handleClientEvent
if self._next_refresh:
t.append((self._next_refresh, self.refresh))
self.ctl.select(r, w, t)
......@@ -572,11 +626,13 @@ class TunnelManager(BaseTunnelManager):
prefix = min(peer_set, key=self._tunnelScore)
self._killing[prefix] = TunnelKiller(prefix, self, True)
def _abortTunnelKiller(self, prefix):
def _abortTunnelKiller(self, prefix, iface=None):
tunnel_killer = self._killing.get(prefix)
if tunnel_killer:
if tunnel_killer.state:
tunnel_killer.abort()
if not iface or \
iface == self.ctl.interfaces[tunnel_killer.ifindex]:
tunnel_killer.abort()
else:
del self._killing[prefix]
......@@ -719,42 +775,15 @@ class TunnelManager(BaseTunnelManager):
for prefix in self._connection_dict.keys():
self._kill(prefix)
def handleTunnelEvent(self):
try:
msg = self._read_pipe.readline().rstrip()
args = msg.split()
m = getattr(self, '_ovpn_' + args.pop(0).replace('-', '_'))
except (AttributeError, ValueError):
logging.warning("Unknown message received from OpenVPN: %s", msg)
else:
logging.debug(msg)
m(*args)
def _ovpn_client_connect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name)
self._served.add(prefix)
if self._gateway_manager is not None:
self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix)
self.cache.connecting(prefix, 0)
def _ovpn_client_disconnect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name)
try:
self._served.remove(prefix)
except KeyError:
return
self._abortTunnelKiller(prefix)
if self._gateway_manager is not None:
self._gateway_manager.remove(trusted_ip)
def _ovpn_route_up(self, common_name, time, ip):
def handleClientEvent(self):
msg = self._read_sock.recv(65536)
logging.debug("route_up%s", msg)
common_name, time, serial, ip = eval(msg)
prefix = utils.binFromSubnet(common_name)
c = self._connection_dict.get(prefix)
if c and c.time < float(time):
try:
c.connected()
c.connected(serial)
except (KeyError, TypeError), e:
logging.error("%s (route_up %s)", e, common_name)
else:
......@@ -765,7 +794,7 @@ class TunnelManager(BaseTunnelManager):
if address:
self._address[family] = utils.dump_address(address)
def broadcastVersion(self):
def _newVersion(self):
for prefix in self.ctl.neighbours:
if prefix:
peer = self._getPeer(prefix)
......@@ -774,3 +803,6 @@ class TunnelManager(BaseTunnelManager):
elif (peer.version < self._version and
self.sendto(prefix, '\0' + self._version)):
peer.version = self._version
for prefix, c in self._connection_dict.items():
if c.serial in self.cache.crl:
self._kill(prefix)
......@@ -42,10 +42,12 @@ def encrypt(cert, data):
def fingerprint(cert, alg='sha1'):
return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))
def maybe_renew(path, cert, info, renew):
def maybe_renew(path, cert, info, renew, force=False):
from .registry import RENEW_PERIOD
while True:
if cert.get_serial_number():
if force:
force = False
else:
next_renew = notAfter(cert) - RENEW_PERIOD
if time.time() < next_renew:
return cert, next_renew
......@@ -110,11 +112,10 @@ class Cert(object):
'--cert', self.cert_path,
'--key', self.key_path)
def maybeRenew(self, registry):
from .registry import RegistryClient
registry = RegistryClient(registry, self)
def maybeRenew(self, registry, crl):
self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
"Certificate", lambda: registry.renewCertificate(self.prefix))
"Certificate", lambda: registry.renewCertificate(self.prefix),
self.cert.get_serial_number() in crl)
self.ca, ca_renew = maybe_renew(self.ca_path, self.ca,
"CA Certificate", registry.getCa)
return min(next_renew, ca_renew)
......@@ -181,6 +182,7 @@ class Peer(object):
"""
_hello = _last = 0
_key = newHmacSecret()
serial = None
stop_date = float('inf')
version = ''
......
......@@ -2,6 +2,7 @@
import atexit, errno, logging, os, shutil, signal
import socket, subprocess, sys, time, threading
from collections import deque
from functools import partial
from re6st import plib, tunnel, utils, version, x509
from re6st.cache import Cache
from re6st.utils import exit, ReexecException
......@@ -130,7 +131,6 @@ def main():
exit.signal(0, signal.SIGINT, signal.SIGTERM)
exit.signal(-1, signal.SIGHUP, signal.SIGUSR2)
next_renew = cert.maybeRenew(config.registry)
cache = Cache(db_path, config.registry, cert)
network = cert.network
......@@ -249,14 +249,12 @@ def main():
control_socket = os.path.join(config.run, 'babeld.sock')
if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(control_socket,
cache, cert, next_renew, config.openvpn_args, timeout,
cache, cert, config.openvpn_args, timeout,
config.client_count, config.iface_list, address, ip_changed,
remote_gateway, config.disable_proto, config.neighbour)
tunnel_interfaces += tunnel_manager.new_iface_list
write_pipe = tunnel_manager.write_pipe
else:
write_pipe = None
tunnel_manager = tunnel.BaseTunnelManager(cache, cert, next_renew)
tunnel_manager = tunnel.BaseTunnelManager(cache, cert)
cleanup.append(tunnel_manager.sock.close)
try:
......@@ -275,6 +273,7 @@ def main():
# an public IP so Babel must be changed to set a source
# address on routes it installs.
ip('addrlabel', 'prefix', my_network, 'label', '99')
R = {}
# prepare persistent interfaces
if config.client:
address_list = [x for x in utils.parse_address(config.client)
......@@ -288,9 +287,13 @@ def main():
elif server_tunnels:
required('dh')
for iface, (port, proto) in server_tunnels.iteritems():
r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
cleanup.append(plib.server(iface, config.max_clients,
config.dh, write_pipe, port, proto, cache.encrypt,
'--ping-exit', str(timeout), *config.openvpn_args).stop)
config.dh, x.fileno(), port, proto, cache.encrypt,
'--ping-exit', str(timeout), *config.openvpn_args,
preexec_fn=r.close).stop)
R[r] = partial(tunnel_manager.handleServerEvent, r)
x.close()
ip('addr', my_ip, 'dev', config.main_interface)
if_rt = ['ip', '-6', 'route', 'del',
......@@ -371,7 +374,7 @@ def main():
select_list = [forwarder.select] if forwarder else []
select_list += tunnel_manager.select, utils.select
while True:
args = {}, {}, []
args = R.copy(), {}, []
for s in select_list:
s(*args)
finally:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment