Commit 7cc4a19a authored by Julien Muchembled's avatar Julien Muchembled

Try not to break server connections in multi-gateway mode

parent 00fbbfe9
...@@ -10,4 +10,5 @@ if os.environ['script_type'] == 'client-connect': ...@@ -10,4 +10,5 @@ if os.environ['script_type'] == 'client-connect':
# Write into pipe connect/disconnect events # Write into pipe connect/disconnect events
arg1 = sys.argv[1] arg1 = sys.argv[1]
if arg1 != 'None': if arg1 != 'None':
os.write(int(arg1), '%(script_type)s %(common_name)s\n' % os.environ) os.write(int(arg1), '%(script_type)s %(common_name)s %(trusted_ip)s\n'
% os.environ)
...@@ -12,37 +12,32 @@ RTF_CACHE = 0x01000000 # cache entry ...@@ -12,37 +12,32 @@ RTF_CACHE = 0x01000000 # cache entry
class MultiGatewayManager(dict): class MultiGatewayManager(dict):
def __init__(self, gateway): def __init__(self, gateway):
if gateway: self._gw = gateway
self._gw = gateway
else:
self.add = self.remove = lambda _: None
def _route(self, cmd, dest, gw): def _route(self, cmd, dest, gw):
cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw if gw:
logging.trace('%r', cmd) cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw
subprocess.call(cmd) logging.trace('%r', cmd)
subprocess.call(cmd)
def add(self, ip_list): def add(self, dest, route):
for dest in ip_list: try:
self[dest][1] += 1
except KeyError:
gw = self._gw(dest) if route else None
self[dest] = [gw, 0]
self._route('add', dest, gw)
def remove(self, dest):
gw, count = self[dest]
if count:
self[dest][1] = count - 1
else:
del self[dest]
try: try:
self[dest][1] += 1 self._route('del', dest, gw)
except KeyError: except:
gw = self._gw(dest) pass
self[dest] = [gw, 0]
self._route('add', dest, gw)
def remove(self, ip_list):
for dest in ip_list:
gw, count = self[dest]
if count:
self[dest][1] = count - 1
else:
del self[dest]
try:
self._route('del', dest, gw)
except:
pass
class Connection(object): class Connection(object):
...@@ -104,7 +99,8 @@ class TunnelManager(object): ...@@ -104,7 +99,8 @@ class TunnelManager(object):
self._address = utils.dump_address(address) self._address = utils.dump_address(address)
self._ip_changed = ip_changed self._ip_changed = ip_changed
self._encrypt = encrypt self._encrypt = encrypt
self._gateway_manager = MultiGatewayManager(remote_gateway) self._gateway_manager = MultiGatewayManager(remote_gateway) \
if remote_gateway else None
self._served = set() self._served = set()
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
...@@ -185,7 +181,9 @@ class TunnelManager(object): ...@@ -185,7 +181,9 @@ class TunnelManager(object):
connection = self._connection_dict.pop(prefix) connection = self._connection_dict.pop(prefix)
self.freeInterface(connection.iface) self.freeInterface(connection.iface)
connection.close() connection.close()
self._gateway_manager.remove(connection) if self._gateway_manager is not None:
for ip in connection:
self._gateway_manager.remove(ip)
logging.trace('Connection with %u/%u killed', logging.trace('Connection with %u/%u killed',
int(prefix, 2), len(prefix)) int(prefix, 2), len(prefix))
...@@ -198,7 +196,9 @@ class TunnelManager(object): ...@@ -198,7 +196,9 @@ class TunnelManager(object):
int(prefix, 2), len(prefix)) int(prefix, 2), len(prefix))
iface = self.getFreeInterface(prefix) iface = self.getFreeInterface(prefix)
self._connection_dict[prefix] = c = Connection(address, iface, prefix) self._connection_dict[prefix] = c = Connection(address, iface, prefix)
self._gateway_manager.add(c) if self._gateway_manager is not None:
for ip in c:
self._gateway_manager.add(ip, True)
c.open(self._write_pipe, self._timeout, self._encrypt, self._ovpn_args) c.open(self._write_pipe, self._timeout, self._encrypt, self._ovpn_args)
self._peer_db.connecting(prefix, 1) self._peer_db.connecting(prefix, 1)
return True return True
...@@ -339,16 +339,23 @@ class TunnelManager(object): ...@@ -339,16 +339,23 @@ class TunnelManager(object):
logging.debug(msg) logging.debug(msg)
m(*args) m(*args)
def _ovpn_client_connect(self, common_name): def _ovpn_client_connect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
self._served.add(prefix) self._served.add(prefix)
if self._gateway_manager is not None:
self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix: if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix) self._kill(prefix)
self._peer_db.connecting(prefix, 0) self._peer_db.connecting(prefix, 0)
def _ovpn_client_disconnect(self, common_name): def _ovpn_client_disconnect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
self._served.discard(prefix) try:
self._served.remove(prefix)
except KeyError:
return
if self._gateway_manager is not None:
self._gateway_manager.remove(trusted_ip)
def _ovpn_route_up(self, common_name, ip): def _ovpn_route_up(self, common_name, ip):
self._peer_db.connecting(utils.binFromSubnet(common_name), 0) self._peer_db.connecting(utils.binFromSubnet(common_name), 0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment