Commit 7cc4a19a authored by Julien Muchembled's avatar Julien Muchembled

Try not to break server connections in multi-gateway mode

parent 00fbbfe9
......@@ -10,4 +10,5 @@ if os.environ['script_type'] == 'client-connect':
# Write into pipe connect/disconnect events
arg1 = sys.argv[1]
if arg1 != 'None':
os.write(int(arg1), '%(script_type)s %(common_name)s\n' % os.environ)
os.write(int(arg1), '%(script_type)s %(common_name)s %(trusted_ip)s\n'
% os.environ)
......@@ -12,37 +12,32 @@ RTF_CACHE = 0x01000000 # cache entry
class MultiGatewayManager(dict):
def __init__(self, gateway):
if gateway:
self._gw = gateway
else:
self.add = self.remove = lambda _: None
self._gw = gateway
def _route(self, cmd, dest, gw):
cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw
logging.trace('%r', cmd)
subprocess.call(cmd)
if gw:
cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw
logging.trace('%r', cmd)
subprocess.call(cmd)
def add(self, ip_list):
for dest in ip_list:
def add(self, dest, route):
try:
self[dest][1] += 1
except KeyError:
gw = self._gw(dest) if route else None
self[dest] = [gw, 0]
self._route('add', dest, gw)
def remove(self, dest):
gw, count = self[dest]
if count:
self[dest][1] = count - 1
else:
del self[dest]
try:
self[dest][1] += 1
except KeyError:
gw = self._gw(dest)
self[dest] = [gw, 0]
self._route('add', dest, gw)
def remove(self, ip_list):
for dest in ip_list:
gw, count = self[dest]
if count:
self[dest][1] = count - 1
else:
del self[dest]
try:
self._route('del', dest, gw)
except:
pass
self._route('del', dest, gw)
except:
pass
class Connection(object):
......@@ -104,7 +99,8 @@ class TunnelManager(object):
self._address = utils.dump_address(address)
self._ip_changed = ip_changed
self._encrypt = encrypt
self._gateway_manager = MultiGatewayManager(remote_gateway)
self._gateway_manager = MultiGatewayManager(remote_gateway) \
if remote_gateway else None
self._served = set()
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
......@@ -185,7 +181,9 @@ class TunnelManager(object):
connection = self._connection_dict.pop(prefix)
self.freeInterface(connection.iface)
connection.close()
self._gateway_manager.remove(connection)
if self._gateway_manager is not None:
for ip in connection:
self._gateway_manager.remove(ip)
logging.trace('Connection with %u/%u killed',
int(prefix, 2), len(prefix))
......@@ -198,7 +196,9 @@ class TunnelManager(object):
int(prefix, 2), len(prefix))
iface = self.getFreeInterface(prefix)
self._connection_dict[prefix] = c = Connection(address, iface, prefix)
self._gateway_manager.add(c)
if self._gateway_manager is not None:
for ip in c:
self._gateway_manager.add(ip, True)
c.open(self._write_pipe, self._timeout, self._encrypt, self._ovpn_args)
self._peer_db.connecting(prefix, 1)
return True
......@@ -339,16 +339,23 @@ class TunnelManager(object):
logging.debug(msg)
m(*args)
def _ovpn_client_connect(self, common_name):
def _ovpn_client_connect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name)
self._served.add(prefix)
if self._gateway_manager is not None:
self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix)
self._peer_db.connecting(prefix, 0)
def _ovpn_client_disconnect(self, common_name):
def _ovpn_client_disconnect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name)
self._served.discard(prefix)
try:
self._served.remove(prefix)
except KeyError:
return
if self._gateway_manager is not None:
self._gateway_manager.remove(trusted_ip)
def _ovpn_route_up(self, common_name, ip):
self._peer_db.connecting(utils.binFromSubnet(common_name), 0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment