#!/usr/bin/python
import math, nemu, os, signal, socket, subprocess, sys, time, weakref
from collections import defaultdict
IPTABLES = 'iptables'
SCREEN = 'screen'
VERBOSE = 4
REGISTRY='10.0.0.2'
CA_DAYS = 1000

#                       registry
#                           |.2
#                           |10.0.0
#                           |.1
#        ---------------Internet----------------
#        |.1                |.1                |.1
#        |10.1.0            |10.2.0            |
#        |.2                |.2                |
#    gateway1           gateway2           s3:10.0.1
#        |.1                |.1            |.2 |.3 |.4
#    s1:10.1.1          s2:10.2.1          m6  m7  m8
#    |.2     |.3        |.2 |.3 |.4        |
#    m1      m2         m3  m4  m5         m9
#

def disable_signal_on_children(sig):
    pid = os.getpid()
    sigint = signal.signal(sig, lambda *x: os.getpid() == pid and sigint(*x))
disable_signal_on_children(signal.SIGINT)

Node__add_interface = nemu.Node._add_interface
def _add_interface(node, iface):
    iface.__dict__['node'] = weakref.proxy(node)
    return Node__add_interface(node, iface)
nemu.Node._add_interface = _add_interface

# create nodes
for name in """internet=I registry=R
                          gateway1=g1 machine1=1 machine2=2
                          gateway2=g2 machine3=3 machine4=4 machine5=5
                          machine6=6 machine7=7 machine8=8 machine9=9
            """.split():
    name, short = name.split('=')
    globals()[name] = node = nemu.Node()
    node.name = name
    node.short = short
    node._screen = node.Popen((SCREEN, '-DmS', name))
    node.screen = (lambda name: lambda *cmd:
        subprocess.call([SCREEN, '-r', name, '-X', 'eval'] + map(
            """screen sh -c 'set %s; "\$@"; echo "\$@"; exec $SHELL'"""
            .__mod__, cmd)))(name)

# create switch
switch1 = nemu.Switch()
switch2 = nemu.Switch()
switch3 = nemu.Switch()

#create interfaces
re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet)
in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1)
in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2)
m6_if_1, m9_if_0 = nemu.P2PInterface.create_pair(machine6, machine9)

g1_if_0_name = g1_if_0.name
gateway1.Popen((IPTABLES, '-t', 'nat', '-A', 'POSTROUTING', '-o', g1_if_0_name, '-j', 'MASQUERADE')).wait()
gateway1.Popen((IPTABLES, '-t', 'nat', '-N', 'MINIUPNPD')).wait()
gateway1.Popen((IPTABLES, '-t', 'nat', '-A', 'PREROUTING', '-i', g1_if_0_name, '-j', 'MINIUPNPD')).wait()
gateway1.Popen((IPTABLES, '-N', 'MINIUPNPD')).wait()
machine9.Popen(('sysctl', 'net.ipv6.conf.%s.accept_ra=2' % m9_if_0.name)).wait()

in_if_3 = nemu.NodeInterface(internet)
g1_if_1 = nemu.NodeInterface(gateway1)
g2_if_1 = nemu.NodeInterface(gateway2)
m1_if_0 = nemu.NodeInterface(machine1)
m2_if_0 = nemu.NodeInterface(machine2)
m3_if_0 = nemu.NodeInterface(machine3)
m4_if_0 = nemu.NodeInterface(machine4)
m5_if_0 = nemu.NodeInterface(machine5)
m6_if_0 = nemu.NodeInterface(machine6)
m7_if_0 = nemu.NodeInterface(machine7)
m8_if_0 = nemu.NodeInterface(machine8)

# connect to switch
switch1.connect(g1_if_1)
switch1.connect(m1_if_0)
switch1.connect(m2_if_0)

switch2.connect(g2_if_1)
switch2.connect(m3_if_0)
switch2.connect(m4_if_0)
switch2.connect(m5_if_0)

switch3.connect(in_if_3)
switch3.connect(m6_if_0)
switch3.connect(m7_if_0)
switch3.connect(m8_if_0)

# setting everything up
switch1.up = switch2.up = switch3.up = True
re_if_0.up = in_if_0.up = in_if_1.up = g1_if_0.up = in_if_2.up = g2_if_0.up = True
in_if_3.up = g1_if_1.up = g2_if_1.up = m1_if_0.up = m2_if_0.up = m3_if_0.up = m4_if_0.up = m5_if_0.up = m6_if_0.up = m6_if_1.up = m7_if_0.up = m8_if_0.up = m9_if_0.up = True

# Add IPv4 addresses
re_if_0.add_v4_address(address=REGISTRY, prefix_len=24)
in_if_0.add_v4_address(address='10.0.0.1', prefix_len=24)
in_if_1.add_v4_address(address='10.1.0.1', prefix_len=24)
in_if_2.add_v4_address(address='10.2.0.1', prefix_len=24)
in_if_3.add_v4_address(address='10.0.1.1', prefix_len=24)
in_if_3.add_v6_address(address='2001:db8::1', prefix_len=48)
g1_if_0.add_v4_address(address='10.1.0.2', prefix_len=24)
g1_if_1.add_v4_address(address='10.1.1.1', prefix_len=24)
g2_if_0.add_v4_address(address='10.2.0.2', prefix_len=24)
g2_if_1.add_v4_address(address='10.2.1.1', prefix_len=24)
m1_if_0.add_v4_address(address='10.1.1.2', prefix_len=24)
m2_if_0.add_v4_address(address='10.1.1.3', prefix_len=24)
m3_if_0.add_v4_address(address='10.2.1.2', prefix_len=24)
m4_if_0.add_v4_address(address='10.2.1.3', prefix_len=24)
m5_if_0.add_v4_address(address='10.2.1.4', prefix_len=24)
m6_if_0.add_v4_address(address='10.0.1.2', prefix_len=24)
m7_if_0.add_v4_address(address='10.0.1.3', prefix_len=24)
m8_if_0.add_v4_address(address='10.0.1.4', prefix_len=24)
m6_if_1.add_v4_address(address='192.168.241.1', prefix_len=24)

def add_llrtr(iface, peer, dst='default'):
    for a in peer.get_addresses():
        a = a['address']
        if a.startswith('fe80:'):
            return iface.node.Popen(('ip', 'route', 'add', dst, 'via', a,
                'proto', 'static', 'dev', iface.name)).wait()

# setup routes
add_llrtr(re_if_0, in_if_0)
add_llrtr(in_if_0, re_if_0, '2001:db8:42::/48')
registry.add_route(prefix='10.0.0.0', prefix_len=8, nexthop='10.0.0.1')
internet.add_route(prefix='10.2.0.0', prefix_len=16, nexthop='10.2.0.2')
gateway1.add_route(prefix='10.0.0.0', prefix_len=8, nexthop='10.1.0.1')
gateway2.add_route(prefix='10.0.0.0', prefix_len=8, nexthop='10.2.0.1')
for m in machine1, machine2:
    m.add_route(nexthop='10.1.1.1')
for m in machine3, machine4, machine5:
    m.add_route(prefix='10.0.0.0', prefix_len=8, nexthop='10.2.1.1')
for m in machine6, machine7, machine8:
    m.add_route(prefix='10.0.0.0', prefix_len=8, nexthop='10.0.1.1')

# Test connectivity first. Run process, hide output and check
# return code
null = file(os.devnull, "r+")
for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3':
    if machine1.Popen(('ping', '-c1', ip), stdout=null).wait():
        print 'Failed to ping %s' % ip
        break
else:
    print "Connectivity IPv4 OK!"

nodes = []
gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid'
                ' -a %s -i %s' % (g1_if_1.name, g1_if_0_name))
if 1:
    import sqlite3
    os.path.exists('ca.crt') or subprocess.check_call(
        "openssl req -nodes -new -x509 -key registry/ca.key -out ca.crt"
        " -subj /CN=re6st.example.com/emailAddress=re6st@example.com"
        " -set_serial 0x120010db80042 -days %u" % CA_DAYS, shell=True)
    db_path = 'registry/registry.db'
    registry.screen('./py re6st-registry @registry/re6st-registry.conf'
        ' --db %s --mailhost %s -v%u --control-socket registry/babeld.socket'
        % (db_path, os.path.abspath('mbox'), VERBOSE))
    registry_url = 'http://%s/' % REGISTRY
    registry.Popen(('python', '-c', """if 1:
        import socket, time
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        while True:
            try:
                s.connect(('localhost', 80))
                break
            except socket.error:
                time.sleep(.1)
        """)).wait()
    db = sqlite3.connect(db_path, isolation_level=None)
    def re6stnet(node, folder, args='', prefix_len=None, registry=registry_url):
        nodes.append(node)
        if not os.path.exists(folder + '/cert.crt'):
            dh_path = folder + '/dh2048.pem'
            if not os.path.exists(dh_path):
                os.symlink('../dh2048.pem', dh_path)
            email = node.name + '@example.com'
            p = node.Popen(('../py', 're6st-conf', '--registry', registry,
                '--email', email), stdin=subprocess.PIPE, cwd=folder)
            token = None
            while not token:
                time.sleep(.1)
                token = db.execute("SELECT token FROM token WHERE email=?",
                                   (email,)).fetchone()
            if prefix_len:
                db.execute("UPDATE token SET prefix_len=? WHERE token=?",
                           (prefix_len, token[0]))
            p.communicate(str(token[0]))
            os.remove(dh_path)
            os.remove(folder + '/ca.crt')
        node.screen('./py re6stnet @%s/re6stnet.conf -v%u --registry %s'
                    ' --control-socket %s/babeld.socket'
                    ' %s' % (folder, VERBOSE, registry, folder, args))
    re6stnet(registry, 'registry', '--ip ' + REGISTRY, registry='http://localhost/')
    re6stnet(machine1, 'm1', '-I%s' % m1_if_0.name)
    re6stnet(machine2, 'm2', '--remote-gateway 10.1.1.1', prefix_len=80)
    re6stnet(machine3, 'm3', '-i%s' % m3_if_0.name)
    re6stnet(machine4, 'm4', '-i%s' % m4_if_0.name)
    re6stnet(machine5, 'm5', '-i%s' % m5_if_0.name)
    re6stnet(machine6, 'm6', '-I%s' % m6_if_1.name)
    re6stnet(machine7, 'm7')
    re6stnet(machine8, 'm8')
    db.close()

_ll = {}
def node_by_ll(addr):
    try:
        return _ll[addr]
    except KeyError:
        for n in nodes:
            for i in n.get_interfaces():
                t = isinstance(i, nemu.interface.ImportedNodeInterface)
                try:
                    a = i.get_addresses()
                except KeyError:
                    break
                for a in a:
                    p = a['prefix_len']
                    a = a['address']
                    if a.startswith('2001:db8:'):
                        assert not p % 8
                        a = socket.inet_ntop(socket.AF_INET6,
                            socket.inet_pton(socket.AF_INET6,
                            a)[:p/8].ljust(16, '\0'))
                    elif not a.startswith('fe80::'):
                        continue
                    _ll[a] = n, t
    return _ll[addr]

def route_svg(z=4):
    graph = {}
    for n in nodes:
        g = graph[n] = defaultdict(list)
        for r in n.get_routes():
            if r.prefix and r.prefix.startswith('2001:db8:'):
                try:
                    g[node_by_ll(r.nexthop)].append(node_by_ll(r.prefix)[0])
                except KeyError:
                    pass
    gv = ["digraph { splines = true; edge[color=grey, labelangle=0, arrowhead=dot];"]
    N = len(nodes)
    a = 2 * math.pi / N
    edges = set()
    for i, n in enumerate(nodes):
        gv.append('%s[pos="%s,%s!"];'
            % (n.name, z * math.cos(a * i), z * math.sin(a * i)))
        l = []
        for p, r in graph[n].iteritems():
            j = abs(nodes.index(p[0]) - i)
            l.append((min(j, N - j), p, r))
        for j, (l, (p, t), r) in enumerate(sorted(l)):
            l = []
            for r in sorted(r.short for r in r):
                if r == p.short:
                    r = '<font color="grey">%s</font>' % r
                l.append(r)
            if (n.name, p.name) in edges:
                r = 'penwidth=0'
            else:
                edges.add((p.name, n.name))
                r = 'style=solid' if t else 'style=dashed'
            gv.append('%s -> %s [labeldistance=%u, headlabel=<%s>, %s];'
                % (p.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l), r))
    gv.append('}\n')
    return subprocess.Popen(('neato', '-Tsvg'),
        stdin=subprocess.PIPE, stdout=subprocess.PIPE,
        ).communicate('\n'.join(gv))[0]

if len(sys.argv) > 1:
    import SimpleHTTPServer, SocketServer

    class Handler(SimpleHTTPServer.SimpleHTTPRequestHandler):
        def do_GET(self):
            svg = None
            if self.path == '/route.html':
                other = 'tunnel'
                svg = route_svg()
            elif self.path == '/tunnel.html':
                other = 'route'
                gv = registry.Popen(('python', '-c', r"""if 1:
                    import math
                    from re6st.registry import RegistryClient
                    g = eval(RegistryClient('http://localhost/').topology())
                    print 'digraph {'
                    a = 2 * math.pi / len(g)
                    z = 4
                    m2 = '%u/80' % (2 << 64)
                    title = lambda n: '2|80' if n == m2 else n
                    g = sorted((title(k), v) for k, v in g.iteritems())
                    for i, (n, p) in enumerate(g):
                        print '"%s"[pos="%s,%s!"%s];' % (title(n),
                            z * math.cos(a * i), z * math.sin(a * i),
                            ', style=dashed' if p is None else '')
                        for p in p or ():
                            print '"%s" -> "%s";' % (n, title(p))
                    print '}'
                """), stdout=subprocess.PIPE, cwd="..").communicate()[0]
                if gv:
                    svg = subprocess.Popen(('neato', '-Tsvg'),
                        stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                        ).communicate(gv)[0]
                if not svg:
                    self.send_error(500)
                    return
            else:
                if self.path == '/':
                    self.send_response(302)
                    self.send_header('Location', 'route.html')
                    self.end_headers()
                else:
                    self.send_error(404)
                return
            mt = 'text/html'
            body = """<html>
<head><meta http-equiv="refresh" content="10"/></head>
<body><a style="position: absolute" href="%s.html">%ss</a>
%s
</body>
</html>""" % (other, other, svg[svg.find('<svg'):])
            self.send_response(200)
            self.send_header('Content-Length', len(body))
            self.send_header('Content-type', mt + '; charset=utf-8')
            self.end_headers()
            self.wfile.write(body)

    class TCPServer(SocketServer.TCPServer):
        allow_reuse_address = True

    TCPServer(('', int(sys.argv[1])), Handler).serve_forever()

import pdb; pdb.set_trace()