Commit 648e6774 authored by Julien Muchembled's avatar Julien Muchembled

Forget peers whose certificate expires

parent a7a86341
...@@ -164,7 +164,7 @@ class BaseTunnelManager(object): ...@@ -164,7 +164,7 @@ class BaseTunnelManager(object):
_forward = None _forward = None
def __init__(self, peer_db, cert, address=()): def __init__(self, peer_db, cert, cert_renew, address=()):
self.cert = cert self.cert = cert
self._network = cert.network self._network = cert.network
self._prefix = cert.prefix self._prefix = cert.prefix
...@@ -185,11 +185,28 @@ class BaseTunnelManager(object): ...@@ -185,11 +185,28 @@ class BaseTunnelManager(object):
# about binding and anycast. # about binding and anycast.
self.sock.bind(('::', PORT)) self.sock.bind(('::', PORT))
# Initialize with a dummy peer (self) so that '_peers' is never empty. p = x509.Peer(self._prefix)
self._peers = [x509.Peer(self._prefix)] self._next_invalidated = p.stop_date = cert_renew
self._peers = [p]
def select(self, r, w, t): def select(self, r, w, t):
r[self.sock] = self.handlePeerEvent r[self.sock] = self.handlePeerEvent
t.append((self._next_invalidated, self.invalidatePeers))
def invalidatePeers(self):
next = float('inf')
now = time.time()
remove = []
for i, peer in enumerate(self._peers):
if peer.stop_date < now:
if peer.prefix == self._prefix:
raise utils.ReexecException("Restart to renew certificate")
remove.append(i)
elif peer.stop_date < next:
next = peer.stop_date
for i in reversed(remove):
del self._peers[i]
self._next_invalidated = next
def sendto(self, prefix, msg): def sendto(self, prefix, msg):
to = utils.ipFromBin(self._network + prefix), PORT to = utils.ipFromBin(self._network + prefix), PORT
...@@ -266,9 +283,10 @@ class BaseTunnelManager(object): ...@@ -266,9 +283,10 @@ class BaseTunnelManager(object):
try: try:
cert = self.cert.loadVerify(msg, cert = self.cert.loadVerify(msg,
True, crypto.FILETYPE_ASN1) True, crypto.FILETYPE_ASN1)
except x509.VerifyError, e: stop_date = x509.notAfter(cert)
except (x509.VerifyError, ValueError), e:
logging.debug('ignored invalid certificate from %r (%s)', logging.debug('ignored invalid certificate from %r (%s)',
address, e.args[2]) address, e.args[-1])
return return
p = utils.binFromSubnet(x509.subnetFromCert(cert)) p = utils.binFromSubnet(x509.subnetFromCert(cert))
if p != peer.prefix: if p != peer.prefix:
...@@ -279,6 +297,9 @@ class BaseTunnelManager(object): ...@@ -279,6 +297,9 @@ class BaseTunnelManager(object):
peer = x509.Peer(p) peer = x509.Peer(p)
insort(self._peers, peer) insort(self._peers, peer)
peer.cert = cert peer.cert = cert
peer.stop_date = stop_date
if stop_date < self._next_invalidated:
self._next_invalidated = stop_date
if seqno: if seqno:
self._sendto(to, peer.hello(self.cert)) self._sendto(to, peer.hello(self.cert))
else: else:
...@@ -335,10 +356,11 @@ class BaseTunnelManager(object): ...@@ -335,10 +356,11 @@ class BaseTunnelManager(object):
class TunnelManager(BaseTunnelManager): class TunnelManager(BaseTunnelManager):
def __init__(self, control_socket, peer_db, cert, openvpn_args, timeout, def __init__(self, control_socket, peer_db, cert, cert_renew, openvpn_args,
refresh, client_count, iface_list, address, ip_changed, timeout, refresh, client_count, iface_list, address,
encrypt, remote_gateway, disable_proto, neighbour_list=()): ip_changed, encrypt, remote_gateway, disable_proto,
super(TunnelManager, self).__init__(peer_db, cert, address) neighbour_list=()):
super(TunnelManager, self).__init__(peer_db, cert, cert_renew, address)
self.ctl = ctl.Babel(control_socket, weakref.proxy(self), self._network) self.ctl = ctl.Babel(control_socket, weakref.proxy(self), self._network)
self.encrypt = encrypt self.encrypt = encrypt
self.ovpn_args = openvpn_args self.ovpn_args = openvpn_args
...@@ -403,8 +425,8 @@ class TunnelManager(BaseTunnelManager): ...@@ -403,8 +425,8 @@ class TunnelManager(BaseTunnelManager):
del self._iface_to_prefix[iface] del self._iface_to_prefix[iface]
def select(self, r, w, t): def select(self, r, w, t):
super(TunnelManager, self).select(r, w, t)
r[self._read_pipe] = self.handleTunnelEvent r[self._read_pipe] = self.handleTunnelEvent
r[self.sock] = self.handlePeerEvent
if self._next_refresh: if self._next_refresh:
t.append((self._next_refresh, self.refresh)) t.append((self._next_refresh, self.refresh))
self.ctl.select(r, w, t) self.ctl.select(r, w, t)
......
...@@ -3,6 +3,9 @@ import socket, struct, subprocess, sys, textwrap, threading, time, traceback ...@@ -3,6 +3,9 @@ import socket, struct, subprocess, sys, textwrap, threading, time, traceback
HMAC_LEN = len(hashlib.sha1('').digest()) HMAC_LEN = len(hashlib.sha1('').digest())
class ReexecException(Exception):
pass
try: try:
subprocess.CalledProcessError(0, '', '') subprocess.CalledProcessError(0, '', '')
except TypeError: # BBB: Python < 2.7 except TypeError: # BBB: Python < 2.7
......
...@@ -173,6 +173,7 @@ class Peer(object): ...@@ -173,6 +173,7 @@ class Peer(object):
""" """
_hello = _last = 0 _hello = _last = 0
_key = newHmacSecret() _key = newHmacSecret()
stop_date = float('inf')
def __init__(self, prefix): def __init__(self, prefix):
assert len(prefix) == 16 or prefix == ('0' * 14 + '1' + '0' * 65), prefix assert len(prefix) == 16 or prefix == ('0' * 14 + '1' + '0' * 65), prefix
......
...@@ -3,10 +3,7 @@ import atexit, errno, logging, os, signal, socket ...@@ -3,10 +3,7 @@ import atexit, errno, logging, os, signal, socket
import sqlite3, subprocess, sys, time, threading import sqlite3, subprocess, sys, time, threading
from collections import deque from collections import deque
from re6st import ctl, db, plib, tunnel, utils, version, x509 from re6st import ctl, db, plib, tunnel, utils, version, x509
from re6st.utils import exit from re6st.utils import exit, ReexecException
class ReexecException(Exception):
pass
def getConfig(): def getConfig():
parser = utils.ArgParser(fromfile_prefix_chars='@', parser = utils.ArgParser(fromfile_prefix_chars='@',
...@@ -260,7 +257,7 @@ def main(): ...@@ -260,7 +257,7 @@ def main():
cleanup = [lambda: peer_db.cacheMinimize(config.client_count)] cleanup = [lambda: peer_db.cacheMinimize(config.client_count)]
if config.client_count and not config.client: if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(config.control_socket, tunnel_manager = tunnel.TunnelManager(config.control_socket,
peer_db, cert, config.openvpn_args, timeout, peer_db, cert, next_renew, config.openvpn_args, timeout,
config.tunnel_refresh, config.client_count, config.iface_list, config.tunnel_refresh, config.client_count, config.iface_list,
address, ip_changed, config.encrypt, remote_gateway, address, ip_changed, config.encrypt, remote_gateway,
config.disable_proto, config.neighbour) config.disable_proto, config.neighbour)
...@@ -268,7 +265,7 @@ def main(): ...@@ -268,7 +265,7 @@ def main():
write_pipe = tunnel_manager.write_pipe write_pipe = tunnel_manager.write_pipe
else: else:
write_pipe = None write_pipe = None
tunnel_manager = tunnel.BaseTunnelManager(peer_db, cert) tunnel_manager = tunnel.BaseTunnelManager(peer_db, cert, next_renew)
cleanup.append(tunnel_manager.sock.close) cleanup.append(tunnel_manager.sock.close)
try: try:
...@@ -380,12 +377,10 @@ def main(): ...@@ -380,12 +377,10 @@ def main():
# main loop # main loop
exit.release() exit.release()
def renew():
raise ReexecException("Restart to renew certificate")
select_list = [forwarder.select] if forwarder else [] select_list = [forwarder.select] if forwarder else []
select_list += tunnel_manager.select, utils.select select_list += tunnel_manager.select, utils.select
while True: while True:
args = {}, {}, [(next_renew, renew)] args = {}, {}, []
for s in select_list: for s in select_list:
s(*args) s(*args)
finally: finally:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment