Commit 1c70c3ed authored by Julien Muchembled's avatar Julien Muchembled

Bugfixes, cleanup and improvements

parent 43b5e5e6
...@@ -136,7 +136,8 @@ gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid -a 10.1.1.1' ...@@ -136,7 +136,8 @@ gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid -a 10.1.1.1'
' -i %s' % g1_if_0_name) ' -i %s' % g1_if_0_name)
if 1: if 1:
registry.screen('../re6stnet @registry/re6stnet.conf --ip 10.0.0.2 -v%u' % VERBOSE, registry.screen('../re6stnet @registry/re6stnet.conf --ip 10.0.0.2 -v%u' % VERBOSE,
'../re6st-registry @registry/re6st-registry.conf') '../re6st-registry @registry/re6st-registry.conf -v%u'
' --mailhost %s' % (VERBOSE, os.path.abspath('mbox')))
machine1.screen('../re6stnet @m1/re6stnet.conf -v%u' % VERBOSE) machine1.screen('../re6stnet @m1/re6stnet.conf -v%u' % VERBOSE)
machine2.screen('../re6stnet @m2/re6stnet.conf -v%u' % VERBOSE) machine2.screen('../re6stnet @m2/re6stnet.conf -v%u' % VERBOSE)
machine3.screen('../re6stnet @m3/re6stnet.conf -v%u -i%s' % (VERBOSE, m3_if_0.name)) machine3.screen('../re6stnet @m3/re6stnet.conf -v%u -i%s' % (VERBOSE, m3_if_0.name))
......
db registry/registry.db db registry/registry.db
ca ca.crt ca ca.crt
key registry/ca.key key registry/ca.key
mailhost localhost
private 2001:db8:42:8::1 private 2001:db8:42:8::1
logfile registry/registry.log
...@@ -25,14 +25,11 @@ in re6stnet. ...@@ -25,14 +25,11 @@ in re6stnet.
USAGE USAGE
===== =====
re6st-conf requires data about a distant server running re6st-registry. re6st-conf needs address of node running re6st-registry.
--server address --registry address
Ip address of the machine running the re6stnet server. Both ipv4 Public HTTP URL of the registry, which is used for bootstrapping
and ipv6 addresses are supported. and delivering certificates.
--port port
Port to connect to on the machine running the re6stnet server.
Commands Commands
-------- --------
......
#!/usr/bin/env python #!/usr/bin/env python
import argparse, os, subprocess, sqlite3, sys, xmlrpclib import argparse, os, subprocess, sqlite3, sys, xmlrpclib
from OpenSSL import crypto from OpenSSL import crypto
from re6st import utils
def create(path, text, mode=0666):
fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_EXCL, mode)
try:
os.write(fd, text)
finally:
os.close(fd)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -8,10 +16,8 @@ def main(): ...@@ -8,10 +16,8 @@ def main():
_ = parser.add_argument _ = parser.add_argument
_('--ca-only', action='store_true', _('--ca-only', action='store_true',
help='To only get CA form server') help='To only get CA form server')
_('--server', required=True, _('--registry', required=True,
help='Address of the server delivering certifiactes') help='HTTP URL of the server delivering certificates')
_('--port', required=True, type=int,
help='Port to which connect on the server')
_('-d', '--dir', default='/etc/re6stnet', _('-d', '--dir', default='/etc/re6stnet',
help='Directory where the key and certificate will be stored') help='Directory where the key and certificate will be stored')
_('-r', '--req', nargs=2, action='append', _('-r', '--req', nargs=2, action='append',
...@@ -19,20 +25,22 @@ def main(): ...@@ -19,20 +25,22 @@ def main():
_('--email', help='Your email address') _('--email', help='Your email address')
_('--token', help='The token you received') _('--token', help='The token you received')
config = parser.parse_args() config = parser.parse_args()
ca_path = os.path.join(config.dir, 'ca.pem') if config.dir:
cert_path = os.path.join(config.dir, 'cert.crt') os.chdir(config.dir)
key_path = os.path.join(config.dir, 'cert.key') ca_path = 'ca.crt'
cert_path = 'cert.crt'
key_path = 'cert.key'
dh_path = 'dh2048.pem'
# Establish connection with server # Establish connection with server
s = xmlrpclib.ServerProxy('http://%s:%u' % (config.server, config.port)) s = xmlrpclib.ServerProxy(config.registry)
# Get CA # Get CA
ca = s.getCa() ca = s.getCa()
with open(ca_path, 'w') as f: create(ca_path, ca)
f.write(ca)
if config.ca_only: if config.ca_only:
sys.exit(0) sys.exit()
# Get token # Get token
if not config.token: if not config.token:
...@@ -59,21 +67,21 @@ def main(): ...@@ -59,21 +67,21 @@ def main():
cert = s.requestCertificate(config.token, req) cert = s.requestCertificate(config.token, req)
# Store cert and key # Store cert and key
with open(key_path, 'w') as f: create(key_path, key, 0600)
f.write(key) create(cert_path, cert)
with open(cert_path, 'w') as f:
f.write(cert)
# Generating dh file # Generating dh file
if not os.access(os.path.join(config.dir, 'dh2048.pem'), os.F_OK): if not os.access(dh_path, os.F_OK):
subprocess.call(['openssl', 'dhparam', '-out', os.path.join(config.dir, 'dh2048.pem'), '2048']) r = subprocess.call(('openssl', 'dhparam', '-out', dh_path, '2048'))
if r:
sys.exit(r)
print "Certificate setup complete." print "Certificate setup complete."
network = utils.networkFromCa(ca_path) cn = utils.subnetFromCert(cert_path)
internal_ip, prefix = utils.ipFromCert(network, cert_path) subnet = utils.networkFromCa(ca_path) + utils.binFromSubnet(cn)
print "Your re6st ip : %s" % internal_ip print "Your subnet: %s/%u (CN=%s)" \
print "Your prefix : %s" % prefix % (utils.ipFromBin(subnet), len(subnet), cn)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
#!/usr/bin/env python #!/usr/bin/env python
import random, select, smtplib, sqlite3, string, socket import errno, logging, mailbox, os, random, select
import subprocess, time, threading, traceback, errno, logging, os, xmlrpclib import smtplib, socket, sqlite3, string, subprocess, sys
import threading, time, traceback, xmlrpclib
from collections import deque from collections import deque
from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler
from email.mime.text import MIMEText from email.mime.text import MIMEText
...@@ -46,7 +47,6 @@ class main(object): ...@@ -46,7 +47,6 @@ class main(object):
self.refresh_interval = 600 self.refresh_interval = 600
self.last_refresh = time.time() self.last_refresh = time.time()
utils.setupLog(3)
# Command line parsing # Command line parsing
parser = utils.ArgParser(fromfile_prefix_chars='@', parser = utils.ArgParser(fromfile_prefix_chars='@',
...@@ -60,11 +60,20 @@ class main(object): ...@@ -60,11 +60,20 @@ class main(object):
_('--key', required=True, _('--key', required=True,
help='Path to certificate key') help='Path to certificate key')
_('--mailhost', required=True, _('--mailhost', required=True,
help='SMTP server mail host') help='SMTP server mail host; for debugging purpose, it can also'
' be an absolute or existing path to a mailbox file')
_('--private', _('--private',
help='VPN IP of the node on which runs the registry') help='VPN IP of the node on which runs the registry')
_('--prefix-length', default=16,
help='Default length of allocated prefixes')
_('-l', '--logfile', default='/var/log/re6stnet/registry.log',
help='Path to logging file')
_('-v', '--verbose', default=1, type=int,
help='Log level')
self.config = parser.parse_args() self.config = parser.parse_args()
utils.setupLog(self.config.verbose, self.config.logfile)
if self.config.private: if self.config.private:
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
else: else:
...@@ -97,7 +106,9 @@ class main(object): ...@@ -97,7 +106,9 @@ class main(object):
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
# Get vpn network prefix # Get vpn network prefix
self.network = bin(self.ca.get_serial_number())[3:] self.network = bin(self.ca.get_serial_number())[3:]
logging.info("Network prefix : %s/%u" % (self.network, len(self.network))) logging.info("Network: %s/%u", utils.ipFromBin(self.network),
len(self.network))
self._email = self.ca.get_subject().emailAddress
# Starting server # Starting server
server4 = SimpleXMLRPCServer4(('0.0.0.0', self.config.port), requestHandler=RequestHandler, allow_none=True) server4 = SimpleXMLRPCServer4(('0.0.0.0', self.config.port), requestHandler=RequestHandler, allow_none=True)
...@@ -108,8 +119,8 @@ class main(object): ...@@ -108,8 +119,8 @@ class main(object):
# Main loop # Main loop
while True: while True:
try: try:
r, w, e = select.select([server4, server6], [], []) r = select.select([server4, server6], [], [])[0]
except (OSError, select.error) as e: except select.error as e:
if e.args[0] != errno.EINTR: if e.args[0] != errno.EINTR:
raise raise
else: else:
...@@ -120,22 +131,32 @@ class main(object): ...@@ -120,22 +131,32 @@ class main(object):
while True: while True:
# Generating token # Generating token
token = ''.join(random.sample(string.ascii_lowercase, 8)) token = ''.join(random.sample(string.ascii_lowercase, 8))
args = token, email, self.config.prefix_length, int(time.time())
# Updating database # Updating database
try: try:
self.db.execute("INSERT INTO token VALUES (?,?,?,?)", (token, email, 16, int(time.time()))) self.db.execute("INSERT INTO token VALUES (?,?,?,?)", args)
break break
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
pass pass
# Creating and sending email # Creating and sending email
s = smtplib.SMTP(self.config.mailhost) msg = MIMEText('Hello, your token to join re6st network is: %s\n'
me = 'postmaster@re6st.net' % token)
msg = MIMEText('Hello world !\nYour token : %s' % (token,)) # XXX
msg['Subject'] = '[re6stnet] Token Request' msg['Subject'] = '[re6stnet] Token Request'
msg['From'] = me if self._email:
msg['From'] = self._email
msg['To'] = email msg['To'] = email
s.sendmail(me, email, msg.as_string()) if os.path.isabs(self.config.mailhost) or \
s.quit() os.path.isfile(self.config.mailhost):
m = mailbox.mbox(self.config.mailhost)
try:
m.add(msg)
finally:
m.close()
else:
s = smtplib.SMTP(self.config.mailhost)
s.sendmail(self._email, email, msg.as_string())
s.quit()
def _getPrefix(self, prefix_len): def _getPrefix(self, prefix_len):
max_len = 128 - len(self.network) max_len = 128 - len(self.network)
...@@ -144,7 +165,7 @@ class main(object): ...@@ -144,7 +165,7 @@ class main(object):
prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null
ORDER BY length(prefix) DESC""", (prefix_len,)).next() ORDER BY length(prefix) DESC""", (prefix_len,)).next()
except StopIteration: except StopIteration:
logging.error('There are no more free /%s prefix available' % (prefix_len,)) logging.error('No more free /%u prefix available', prefix_len)
raise raise
while len(prefix) < prefix_len: while len(prefix) < prefix_len:
self.db.execute("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix)) self.db.execute("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + '1', prefix))
...@@ -162,7 +183,7 @@ class main(object): ...@@ -162,7 +183,7 @@ class main(object):
try: try:
token, email, prefix_len, _ = self.db.execute("SELECT * FROM token WHERE token = ?", (token,)).next() token, email, prefix_len, _ = self.db.execute("SELECT * FROM token WHERE token = ?", (token,)).next()
except StopIteration: except StopIteration:
logging.exception('Bad token (%s) in request' % (token,)) logging.exception('Bad token (%s) in request', token)
raise raise
self.db.execute("DELETE FROM token WHERE token = ?", (token,)) self.db.execute("DELETE FROM token WHERE token = ?", (token,))
...@@ -186,8 +207,9 @@ class main(object): ...@@ -186,8 +207,9 @@ class main(object):
self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix)) self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
return cert return cert
except: except Exception:
traceback.print_exc() f = traceback.format_exception(*sys.exc_info())
logging.error('%s%s', f.pop(), ''.join(f))
raise raise
def getCa(self, handler): def getCa(self, handler):
...@@ -210,7 +232,7 @@ class main(object): ...@@ -210,7 +232,7 @@ class main(object):
except IndexError: except IndexError:
peer = '' peer = ''
if peer is None: if peer is None:
raise EnvironmentError("Timeout while querying [%s]:%u", *address) raise EnvironmentError("Timeout while querying [%s]:%u" % address)
if not peer or peer.split()[0] == client_prefix: if not peer or peer.split()[0] == client_prefix:
raise LookupError("No bootstrap peer found") raise LookupError("No bootstrap peer found")
logging.info("Sending bootstrap peer: %s", peer) logging.info("Sending bootstrap peer: %s", peer)
......
...@@ -2,57 +2,24 @@ ...@@ -2,57 +2,24 @@
import os import os
import sys import sys
# example of os.environ
{'X509_0_C': 'FR',
'X509_0_CN': 'ulm',
'X509_0_O': 'Guillaume Bury',
'X509_0_OU': 'VPN',
'X509_1_C': 'FR',
'X509_1_CN': 'Guillaume Bury CA',
'X509_1_O': 'Guillaume Bury',
'X509_1_OU': 'VPN',
'common_name': 'ulm',
'daemon': '0',
'daemon_log_redirect': '0',
'daemon_pid': '11637',
'daemon_start_time': '1341568405',
'dev': 're6stnet',
'link_mtu': '1573',
'local_port_1': '1194',
'proto_1': 'udp',
'remote_port_1': '1194',
'script_context': 'init',
'script_type': 'client-connect',
'time_ascii': 'Fri Jul 6 11:53:31 2012',
'time_unix': '1341568411',
'tls_digest_0': '2d:eb:f3:05:5d:bf:17:62:dd:ef:d4:bb:30:c0:5b:b7:ef:e3:e8:a6',
'tls_digest_1': '43:1c:a1:22:ca:c0:a0:f5:b0:c6:65:6f:33:29:b2:bb:1d:04:43:9a',
'tls_id_0': '/C=FR/O=Guillaume_Bury/OU=VPN/CN=ulm',
'tls_id_1': '/C=FR/O=Guillaume_Bury/OU=VPN/CN=Guillaume_Bury_CA',
'tls_serial_0': '02',
'tls_serial_1': 'CC3019BC1CFA5141',
'trusted_ip': '192.0.2.25',
'trusted_port': '59345',
'tun_mtu': '1500',
'untrusted_ip': '192.0.2.25',
'untrusted_port': '59345',
'verb': '3'}
script_type = os.environ['script_type'] script_type = os.environ['script_type']
if script_type == 'up': if script_type == 'up':
from subprocess import call import subprocess
def call(*args):
r = subprocess.call(args)
if r:
sys.exit(r)
dev = os.environ['dev'] dev = os.environ['dev']
call('ip', 'link', 'set', dev, 'up')
if sys.argv[1] != 'none': if sys.argv[1] != 'none':
sys.exit(call(('ip', 'link', 'set', dev, 'up')) call('ip', 'addr', 'add', sys.argv[1], 'dev', dev)
or call(('ip', 'addr', 'add', sys.argv[1], 'dev', dev)))
else:
sys.exit(call(('ip', 'link', 'set', dev, 'up')))
if script_type == 'client-connect': else:
# Send client its external ip address if script_type == 'client-connect':
with open(sys.argv[2], 'w') as f: # Send client its external ip address
f.write('push "setenv-safe external_ip %s"\n' with open(sys.argv[2], 'w') as f:
% os.environ['trusted_ip']) f.write('push "setenv-safe external_ip %s"\n'
% os.environ['trusted_ip'])
# Write into pipe connect/disconnect events # Write into pipe connect/disconnect events
os.write(int(sys.argv[1]), '%(script_type)s %(common_name)s\n' % os.environ) os.write(int(sys.argv[1]), '%(script_type)s %(common_name)s\n' % os.environ)
...@@ -17,17 +17,13 @@ def openvpn(iface, hello_interval, encrypt, *args, **kw): ...@@ -17,17 +17,13 @@ def openvpn(iface, hello_interval, encrypt, *args, **kw):
'--persist-key', '--persist-key',
'--script-security', '2', '--script-security', '2',
'--ping-exit', str(4 * hello_interval), '--ping-exit', str(4 * hello_interval),
'--log-append', os.path.join(log, '%s.log' % iface),
#'--user', 'nobody', '--group', 'nogroup', #'--user', 'nobody', '--group', 'nogroup',
] + list(args) ] + list(args)
if not encrypt: if not encrypt:
args.extend(['--cipher', 'none']) args += '--cipher', 'none'
logging.debug('%r', args) logging.debug('%r', args)
fd = os.open(os.path.join(log, '%s.log' % iface), return subprocess.Popen(args, **kw)
os.O_WRONLY | os.O_CREAT | os.O_APPEND, 0666)
try:
return subprocess.Popen(args, stdout=fd, stderr=subprocess.STDOUT, **kw)
finally:
os.close(fd)
def server(iface, server_ip, ip_length, max_clients, dh_path, pipe_fd, port, proto, hello_interval, encrypt, *args, **kw): def server(iface, server_ip, ip_length, max_clients, dh_path, pipe_fd, port, proto, hello_interval, encrypt, *args, **kw):
...@@ -58,8 +54,8 @@ def client(iface, server_address, pipe_fd, hello_interval, encrypt, *args, **kw) ...@@ -58,8 +54,8 @@ def client(iface, server_address, pipe_fd, hello_interval, encrypt, *args, **kw)
remote += '--remote', ip, port, \ remote += '--remote', ip, port, \
'tcp-client' if proto == 'tcp' else proto 'tcp-client' if proto == 'tcp' else proto
except ValueError, e: except ValueError, e:
logging.warning('Error "%s" in unpacking address %s for openvpn client' logging.warning("Failed to parse node address %r (%s)",
% (e, server_address,)) server_address, e)
remote += args remote += args
return openvpn(iface, hello_interval, encrypt, *remote, **kw) return openvpn(iface, hello_interval, encrypt, *remote, **kw)
...@@ -84,6 +80,7 @@ def router(network, subnet, subnet_size, interface_list, ...@@ -84,6 +80,7 @@ def router(network, subnet, subnet_size, interface_list,
'-d', str(verbose), '-d', str(verbose),
'-h', str(hello_interval), '-h', str(hello_interval),
'-H', str(hello_interval), '-H', str(hello_interval),
'-L', os.path.join(log, 'babeld.log'),
'-S', state_path, '-S', state_path,
'-s', '-s',
] ]
......
...@@ -13,7 +13,7 @@ class Connection: ...@@ -13,7 +13,7 @@ class Connection:
ovpn_args): ovpn_args):
self.process = plib.client(iface, address, write_pipe, hello, encrypt, self.process = plib.client(iface, address, write_pipe, hello, encrypt,
'--tls-remote', '%u/%u' % (int(prefix, 2), len(prefix)), '--tls-remote', '%u/%u' % (int(prefix, 2), len(prefix)),
'--connect-retry-max', '3', *ovpn_args) '--connect-retry-max', '3', '--tls-exit', *ovpn_args)
self.iface = iface self.iface = iface
self.routes = 0 self.routes = 0
self._prefix = prefix self._prefix = prefix
...@@ -21,8 +21,8 @@ class Connection: ...@@ -21,8 +21,8 @@ class Connection:
def refresh(self): def refresh(self):
# Check that the connection is alive # Check that the connection is alive
if self.process.poll() != None: if self.process.poll() != None:
logging.info('Connection with %s has failed with return code %s' logging.info('Connection with %s has failed with return code %s',
% (self._prefix, self.process.returncode)) self._prefix, self.process.returncode)
return False return False
return True return True
...@@ -51,6 +51,8 @@ class TunnelManager(object): ...@@ -51,6 +51,8 @@ class TunnelManager(object):
self._served = set() self._served = set()
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# See also http://stackoverflow.com/questions/597225/
# about binding and anycast.
self.sock.bind(('::', PORT)) self.sock.bind(('::', PORT))
self.next_refresh = time.time() self.next_refresh = time.time()
...@@ -87,8 +89,8 @@ class TunnelManager(object): ...@@ -87,8 +89,8 @@ class TunnelManager(object):
self._kill(prefix) self._kill(prefix)
def _kill(self, prefix, kill=False): def _kill(self, prefix, kill=False):
logging.info('Killing the connection with %s/%u...' logging.info('Killing the connection with %u/%u...',
% (hex(int(prefix, 2))[2:], len(prefix))) int(prefix, 2), len(prefix))
connection = self._connection_dict.pop(prefix) connection = self._connection_dict.pop(prefix)
try: try:
getattr(connection.process, 'kill' if kill else 'terminate')() getattr(connection.process, 'kill' if kill else 'terminate')()
...@@ -97,16 +99,16 @@ class TunnelManager(object): ...@@ -97,16 +99,16 @@ class TunnelManager(object):
pass pass
self.free_interface_set.add(connection.iface) self.free_interface_set.add(connection.iface)
del self._iface_to_prefix[connection.iface] del self._iface_to_prefix[connection.iface]
logging.trace('Connection with %s/%u killed' logging.trace('Connection with %u/%u killed',
% (hex(int(prefix, 2))[2:], len(prefix))) int(prefix, 2), len(prefix))
def _makeTunnel(self, prefix, address): def _makeTunnel(self, prefix, address):
assert len(self._connection_dict) < self._client_count, (prefix, self.__dict__) assert len(self._connection_dict) < self._client_count, (prefix, self.__dict__)
if prefix in self._served or prefix in self._connection_dict: if prefix in self._served or prefix in self._connection_dict:
return False return False
assert prefix != self._prefix, self.__dict__ assert prefix != self._prefix, self.__dict__
logging.info('Establishing a connection with %s/%u', logging.info('Establishing a connection with %u/%u',
hex(int(prefix, 2))[2:], len(prefix)) int(prefix, 2), len(prefix))
iface = self.free_interface_set.pop() iface = self.free_interface_set.pop()
self._connection_dict[prefix] = Connection(address, self._write_pipe, self._connection_dict[prefix] = Connection(address, self._write_pipe,
self._hello, iface, prefix, self._encrypt, self._ovpn_args) self._hello, iface, prefix, self._encrypt, self._ovpn_args)
...@@ -239,13 +241,14 @@ class TunnelManager(object): ...@@ -239,13 +241,14 @@ class TunnelManager(object):
def handleTunnelEvent(self, msg): def handleTunnelEvent(self, msg):
try: try:
script_type, arg = msg.split(None, 1) msg = msg.rstrip()
m = getattr(self, '_ovpn_' + script_type.replace('-', '_')) args = msg.split()
m = getattr(self, '_ovpn_' + args.pop(0).replace('-', '_'))
except (AttributeError, ValueError): except (AttributeError, ValueError):
logging.warning("Unknown message received from OpenVPN: %s", msg) logging.warning("Unknown message received from OpenVPN: %s", msg)
else: else:
logging.debug('%s: %s', script_type, arg) logging.debug(msg)
m(arg) m(*args)
def _ovpn_client_connect(self, common_name): def _ovpn_client_connect(self, common_name):
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
...@@ -258,14 +261,15 @@ class TunnelManager(object): ...@@ -258,14 +261,15 @@ class TunnelManager(object):
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
self._served.remove(prefix) self._served.remove(prefix)
def _ovpn_route_up(self, arg): def _ovpn_route_up(self, common_name, ip):
common_name, ip = arg.split()
self._peer_db.connecting(utils.binFromSubnet(common_name), 0) self._peer_db.connecting(utils.binFromSubnet(common_name), 0)
if self._ip_changed: if self._ip_changed:
self._address = utils.address_str(self._ip_changed(ip)) self._address = utils.address_str(self._ip_changed(ip))
def handlePeerEvent(self): def handlePeerEvent(self):
msg, address = self.sock.recvfrom(1<<16) msg, address = self.sock.recvfrom(1<<16)
if not utils.binFromIp(address[0]).startswith(self._network):
return
code = ord(msg[0]) code = ord(msg[0])
if code == 1: # answer if code == 1: # answer
# TODO: do not fail if message contains garbage # TODO: do not fail if message contains garbage
...@@ -300,7 +304,7 @@ class TunnelManager(object): ...@@ -300,7 +304,7 @@ class TunnelManager(object):
msg = ['\xfe%s%u/%u\n%u\n' % (msg[1:], msg = ['\xfe%s%u/%u\n%u\n' % (msg[1:],
int(self._prefix, 2), len(self._prefix), int(self._prefix, 2), len(self._prefix),
len(self._connection_dict))] len(self._connection_dict))]
msg.extend('%s/%s\n' % (int(x, 2), len(x)) msg.extend('%u/%u\n' % (int(x, 2), len(x))
for x in (self._connection_dict, self._served) for x in (self._connection_dict, self._served)
for x in x) for x in x)
try: try:
......
import argparse, time, struct, socket, logging import argparse, errno, logging, os, signal, struct, socket, time
from OpenSSL import crypto from OpenSSL import crypto
logging_levels = logging.WARNING, logging.INFO, logging.DEBUG, 5 logging_levels = logging.WARNING, logging.INFO, logging.DEBUG, 5
class FileHandler(logging.FileHandler):
def setupLog(log_level, **kw): _reopen = False
def release(self):
try:
if self._reopen:
self._reopen = False
self.close()
self._open()
finally:
self.lock.release()
# In the rare case _reopen is set just before the lock was released
if self._reopen and self.lock.acquire(0):
self.release()
def async_reopen(self, *_):
self._reopen = True
if self.lock.acquire(0):
self.release()
def setupLog(log_level, filename=None, **kw):
if log_level and filename:
makedirs(os.path.dirname(filename))
handler = FileHandler(filename)
sig = handler.async_reopen
else:
handler = logging.StreamHandler()
sig = signal.SIG_IGN
handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)-9s %(message)s', '%d-%m-%Y %H:%M:%S'))
root = logging.getLogger()
root.addHandler(handler)
signal.signal(signal.SIGUSR1, sig)
if log_level: if log_level:
logging.basicConfig(level=logging_levels[log_level-1], root.setLevel(logging_levels[log_level-1])
format='%(asctime)s %(levelname)-9s %(message)s',
datefmt='%d-%m-%Y %H:%M:%S', **kw)
else: else:
logging.disable(logging.CRITICAL) logging.disable(logging.CRITICAL)
logging.addLevelName(5, 'TRACE') logging.addLevelName(5, 'TRACE')
...@@ -26,6 +56,13 @@ class ArgParser(argparse.ArgumentParser): ...@@ -26,6 +56,13 @@ class ArgParser(argparse.ArgumentParser):
if arg.strip(): if arg.strip():
yield arg yield arg
def makedirs(path):
try:
os.makedirs(path)
except OSError, e:
if e.errno != errno.EEXIST:
raise
def binFromIp(ip): def binFromIp(ip):
ip1, ip2 = struct.unpack('>QQ', socket.inet_pton(socket.AF_INET6, ip)) ip1, ip2 = struct.unpack('>QQ', socket.inet_pton(socket.AF_INET6, ip))
return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0') return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0')
...@@ -42,14 +79,11 @@ def networkFromCa(ca_path): ...@@ -42,14 +79,11 @@ def networkFromCa(ca_path):
ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
return bin(ca.get_serial_number())[3:] return bin(ca.get_serial_number())[3:]
def ipFromCert(network, cert_path): def subnetFromCert(cert_path):
# Get ip from cert.crt # Get ip from cert.crt
with open(cert_path, 'r') as f: with open(cert_path, 'r') as f:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
subject = cert.get_subject() return cert.get_subject().CN
prefix, prefix_len = subject.CN.split('/')
prefix = bin(int(prefix))[2:].rjust(int(prefix_len), '0')
return ipFromBin(network + prefix, '1'), prefix
def address_str(address): def address_str(address):
return ';'.join(map(','.join, address)) return ';'.join(map(','.join, address))
......
#!/usr/bin/env python #!/usr/bin/env python
import atexit, os, sys, select, time import argparse, atexit, errno, logging, os
import argparse, signal, subprocess, sqlite3, logging, traceback import select, signal, sqlite3, sys, time, traceback
from re6st import plib, utils, db, tunnel from re6st import plib, utils, db, tunnel
def ovpnArgs(optional_args, ca_path, cert_path, key_path): def ovpnArgs(optional_args, ca_path, cert_path, key_path):
...@@ -27,12 +27,12 @@ def getConfig(): ...@@ -27,12 +27,12 @@ def getConfig():
_('--registry', required=True, _('--registry', required=True,
help="HTTP URL of the discovery peer server," help="HTTP URL of the discovery peer server,"
" with public host (default port: 80)") " with public host (default port: 80)")
_('-l', '--log', default='/var/log', _('-l', '--log', default='/var/log/re6stnet',
help='Path to re6stnet logs directory') help='Path to re6stnet logs directory')
_('-s', '--state', default='/var/lib/re6stnet', _('-s', '--state', default='/var/lib/re6stnet',
help='Path to re6stnet state directory') help='Path to re6stnet state directory')
_('-v', '--verbose', default=1, type=int, _('-v', '--verbose', default=1, type=int,
help='Log level of re6st itself') help='Log level of re6stnet itself')
_('-i', '--interface', action='append', dest='iface_list', default=[], _('-i', '--interface', action='append', dest='iface_list', default=[],
help='Extra interface for LAN discovery') help='Extra interface for LAN discovery')
...@@ -75,18 +75,15 @@ def main(): ...@@ -75,18 +75,15 @@ def main():
# Get arguments # Get arguments
config = getConfig() config = getConfig()
network = utils.networkFromCa(config.ca) network = utils.networkFromCa(config.ca)
internal_ip, prefix = utils.ipFromCert(network, config.cert) prefix = utils.binFromSubnet(utils.subnetFromCert(config.cert))
openvpn_args = ovpnArgs(config.openvpn_args, config.ca, config.cert, openvpn_args = ovpnArgs(config.openvpn_args, config.ca, config.cert,
config.key) config.key)
db_path = os.path.join(config.state, 'peers.db')
# Set logging # Set logging
utils.setupLog(config.verbose, utils.setupLog(config.verbose, os.path.join(config.log, 're6stnet.log'))
filename=os.path.join(config.log, 're6stnet.log'))
logging.trace("Configuration :\n%s" % config)
# Set global variables logging.trace("Configuration:\n%r", config)
utils.makedirs(config.state)
db_path = os.path.join(config.state, 'peers.db')
plib.log = tunnel.log = config.log plib.log = tunnel.log = config.log
# Create and open read_only pipe to get server events # Create and open read_only pipe to get server events
...@@ -95,8 +92,8 @@ def main(): ...@@ -95,8 +92,8 @@ def main():
read_pipe = os.fdopen(r_pipe) read_pipe = os.fdopen(r_pipe)
signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1)) signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1))
signal.signal(signal.SIGTERM, lambda *args: sys.exit())
# Init db and tunnels
address = [] address = []
if config.pp: if config.pp:
pp = [(int(port), proto) for port, proto in config.pp] pp = [(int(port), proto) for port, proto in config.pp]
...@@ -123,32 +120,30 @@ def main(): ...@@ -123,32 +120,30 @@ def main():
if address: if address:
ip_changed = None ip_changed = None
peer_db = db.PeerDB(db_path, config.registry, config.key, prefix)
tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db, openvpn_args,
config.hello, config.tunnel_refresh, config.connection_count,
config.iface_list, network, prefix, address, ip_changed,
config.encrypt)
# Launch routing protocol. WARNING : you have to be root to start babeld
server_tunnels = {}
for x in pp:
server_tunnels.setdefault('re6stnet-' + x[1], x)
interface_list = list(tunnel_manager.free_interface_set) \
+ config.iface_list + server_tunnels.keys()
subnet = utils.ipFromBin((network + prefix).ljust(128, '0'))
router = plib.router(network, subnet, len(prefix) + len(network),
interface_list, config.wireless, config.hello, config.babel_verb,
config.babel_pidfile, os.path.join(config.state, 'babeld.state'),
stdout=os.open(os.path.join(config.log, 'babeld.log'),
os.O_WRONLY | os.O_CREAT | os.O_APPEND, 0666), stderr=subprocess.STDOUT)
# main loop
try: try:
# Init db and tunnels
peer_db = db.PeerDB(db_path, config.registry, config.key, prefix)
tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db, openvpn_args,
config.hello, config.tunnel_refresh, config.connection_count,
config.iface_list, network, prefix, address, ip_changed,
config.encrypt)
server_tunnels = {}
for x in pp:
server_tunnels.setdefault('re6stnet-' + x[1], x)
interface_list = list(tunnel_manager.free_interface_set) \
+ config.iface_list + server_tunnels.keys()
subnet = network + prefix
router = plib.router(network, utils.ipFromBin(subnet), len(subnet),
interface_list, config.wireless, config.hello, config.babel_verb,
config.babel_pidfile, os.path.join(config.state, 'babeld.state'))
# main loop
try: try:
server_process = [] server_process = []
for iface, (port, proto) in server_tunnels.iteritems(): for iface, (port, proto) in server_tunnels.iteritems():
server_process.append(plib.server(iface, server_process.append(plib.server(iface,
internal_ip if proto == pp[0][1] else None, utils.ipFromBin(subnet, '1') if proto == pp[0][1] else None,
len(network) + len(prefix), len(network) + len(prefix),
config.connection_count, config.dh, write_pipe, port, config.connection_count, config.dh, write_pipe, port,
proto, config.hello, config.encrypt, *openvpn_args)) proto, config.hello, config.encrypt, *openvpn_args))
...@@ -157,7 +152,12 @@ def main(): ...@@ -157,7 +152,12 @@ def main():
if forwarder: if forwarder:
next = min(next, forwarder.next_refresh) next = min(next, forwarder.next_refresh)
r = [read_pipe, tunnel_manager.sock] r = [read_pipe, tunnel_manager.sock]
r = select.select(r, [], [], max(0, next - time.time()))[0] try:
r = select.select(r, [], [], max(0, next - time.time()))[0]
except select.error as e:
if e.args[0] != errno.EINTR:
raise
continue
if read_pipe in r: if read_pipe in r:
tunnel_manager.handleTunnelEvent(read_pipe.readline()) tunnel_manager.handleTunnelEvent(read_pipe.readline())
if tunnel_manager.sock in r: if tunnel_manager.sock in r:
...@@ -167,10 +167,6 @@ def main(): ...@@ -167,10 +167,6 @@ def main():
tunnel_manager.refresh() tunnel_manager.refresh()
if forwarder and t >= forwarder.next_refresh: if forwarder and t >= forwarder.next_refresh:
forwarder.refresh() forwarder.refresh()
except Exception:
f = traceback.format_exception(*sys.exc_info())
logging.error('%s%s', f.pop(), ''.join(f))
raise
finally: finally:
router.terminate() router.terminate()
for p in server_process: for p in server_process:
...@@ -183,6 +179,7 @@ def main(): ...@@ -183,6 +179,7 @@ def main():
except: except:
pass pass
except sqlite3.Error: except sqlite3.Error:
logging.exception("Restarting with empty cache")
os.rename(db_path, db_path + '.bak') os.rename(db_path, db_path + '.bak')
try: try:
sys.exitfunc() sys.exitfunc()
...@@ -190,6 +187,10 @@ def main(): ...@@ -190,6 +187,10 @@ def main():
os.execvp(sys.argv[0], sys.argv) os.execvp(sys.argv[0], sys.argv)
except KeyboardInterrupt: except KeyboardInterrupt:
return 0 return 0
except Exception:
f = traceback.format_exception(*sys.exc_info())
logging.error('%s%s', f.pop(), ''.join(f))
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment