Commit e80ca878 authored by Thomas Gambier's avatar Thomas Gambier 🚴🏼

Port re6stnet to python3

See merge request !46
parents f2fd7247 227d63d4
...@@ -5,3 +5,11 @@ ...@@ -5,3 +5,11 @@
/build/ /build/
/dist/ /dist/
/re6stnet.egg-info/ /re6stnet.egg-info/
*.log
*.pid
*.db
*.state
*.crt
*.pem
demo/mbox
*.sock
...@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes. ...@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes.
Requirements Requirements
============ ============
- Python 2.7 - Python 3.11
- OpenSSL binary and development libraries - OpenSSL binary and development libraries
- OpenVPN 2.4.* - OpenVPN 2.4.*
- Babel_ (with Nexedi patches) - Babel_ (with Nexedi patches)
......
Demo
====
Usage
-----
To run the demo, make sure all the dependencies are installed
and run ``./demo 8000`` (or any port).
Troubleshooting
---------------
If the demo crashes and fails to clean up its resources properly,
run the following commands::
for b in $(sudo ip l | grep -Po 'NETNS\w\w[\d\-a-f]+'); do sudo ip l del $b; done
pkill screen
killall python
killall python3
find . -name '*.crt' -delete; find . -name '*.db' -delete; find . -name '*.log' -delete
.. warning::
This will kill all Python processes. These commands assume you're running
the demo on a dedicated machine with nothing else on it.
#!/usr/bin/python2 #!/usr/bin/env python3
import argparse, math, nemu, os, re, signal import argparse, math, nemu, os, re, shlex, signal
import socket, sqlite3, subprocess, sys, time, weakref import socket, sqlite3, subprocess, sys, time, weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from threading import Thread from threading import Thread
from typing import Optional
IPTABLES = 'iptables' IPTABLES = 'iptables'
SCREEN = 'screen' SCREEN = 'screen'
VERBOSE = 4 VERBOSE = 4
...@@ -14,9 +16,9 @@ REGISTRY2_SERIAL = '0x120010db80043' ...@@ -14,9 +16,9 @@ REGISTRY2_SERIAL = '0x120010db80043'
CA_DAYS = 1000 CA_DAYS = 1000
# Quick check to avoid wasting time if there is an error. # Quick check to avoid wasting time if there is an error.
with open(os.devnull, "wb") as f: for x in 're6stnet', 're6st-conf', 're6st-registry':
for x in 're6stnet', 're6st-conf', 're6st-registry': subprocess.check_call(('./py', x, '--help'), stdout=subprocess.DEVNULL)
subprocess.check_call(('./py', x, '--help'), stdout=f)
# #
# Underlying network: # Underlying network:
# #
...@@ -55,14 +57,14 @@ def _add_interface(node, iface): ...@@ -55,14 +57,14 @@ def _add_interface(node, iface):
nemu.Node._add_interface = _add_interface nemu.Node._add_interface = _add_interface
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('port', type = int, parser.add_argument('port', type=int,
help = 'port used to display tunnels') help='port used to display tunnels')
parser.add_argument('-d', '--duration', type = int, parser.add_argument('-d', '--duration', type=int,
help = 'time of the demo execution in seconds') help='time of the demo execution in seconds')
parser.add_argument('-p', '--ping', action = 'store_true', parser.add_argument('-p', '--ping', action='store_true',
help = 'execute ping utility') help='execute ping utility')
parser.add_argument('-m', '--hmac', action = 'store_true', parser.add_argument('-m', '--hmac', action='store_true',
help = 'execute HMAC test') help='execute HMAC test')
args = parser.parse_args() args = parser.parse_args()
def handler(signum, frame): def handler(signum, frame):
...@@ -72,33 +74,55 @@ if args.duration: ...@@ -72,33 +74,55 @@ if args.duration:
signal.signal(signal.SIGALRM, handler) signal.signal(signal.SIGALRM, handler)
signal.alarm(args.duration) signal.alarm(args.duration)
execfile("fixnemu.py")
# create nodes class Re6stNode(nemu.Node):
for name in """internet=I registry=R name: str
gateway1=g1 machine1=1 machine2=2 short: str
gateway2=g2 machine3=3 machine4=4 machine5=5 re6st_cmdline: Optional[list[str]]
machine6=6 machine7=7 machine8=8 machine9=9
registry2=R2 machine10=10 def __init__(self, name, short):
""".split(): super().__init__()
name, short = name.split('=') self.name = name
globals()[name] = node = nemu.Node() self.short = short
node.name = name self.Popen(('sysctl', '-q',
node.short = short
node.Popen(('sysctl', '-q',
'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait() 'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait()
node._screen = node.Popen((SCREEN, '-DmS', name)) self._screen = self.Popen((SCREEN, '-DmS', name))
node.screen = (lambda name: lambda *cmd: self.re6st_cmdline = None
subprocess.call([SCREEN, '-r', name, '-X', 'eval'] + map(
"""screen sh -c 'set %s; "\$@"; echo "\$@"; exec $SHELL'""" def screen(self, command: list[str]):
.__mod__, cmd)))(name) runner_cmd = ('set -- %s; "\\$@"; echo "\\$@"; exec $SHELL' %
' '.join(map(shlex.quote, command)))
inner_cmd = [
'screen', 'sh', '-c', runner_cmd
]
cmd = [
SCREEN, '-r', self.name, '-X', 'eval', shlex.join(inner_cmd)
]
return subprocess.call(cmd)
# create nodes
internet = Re6stNode('internet', 'I')
registry = Re6stNode('registry', 'R')
gateway1 = Re6stNode('gateway1', 'g1')
machine1 = Re6stNode('machine1', '1')
machine2 = Re6stNode('machine2', '2')
gateway2 = Re6stNode('gateway2', 'g2')
machine3 = Re6stNode('machine3', '3')
machine4 = Re6stNode('machine4', '4')
machine5 = Re6stNode('machine5', '5')
machine6 = Re6stNode('machine6', '6')
machine7 = Re6stNode('machine7', '7')
machine8 = Re6stNode('machine8', '8')
machine9 = Re6stNode('machine9', '9')
registry2 = Re6stNode('registry2', 'R2')
machine10 = Re6stNode('machine10', '10')
# create switch # create switch
switch1 = nemu.Switch() switch1 = nemu.Switch()
switch2 = nemu.Switch() switch2 = nemu.Switch()
switch3 = nemu.Switch() switch3 = nemu.Switch()
#create interfaces # create interfaces
re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet) re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet)
in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1) in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1)
in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2) in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2)
...@@ -205,19 +229,19 @@ for m in machine6, machine7, machine8: ...@@ -205,19 +229,19 @@ for m in machine6, machine7, machine8:
# Test connectivity first. Run process, hide output and check # Test connectivity first. Run process, hide output and check
# return code # return code
null = file(os.devnull, "r+")
for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3': for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3':
if machine1.Popen(('ping', '-c1', ip), stdout=null).wait(): if machine1.Popen(('ping', '-c1', ip), stdout=subprocess.DEVNULL).wait():
print 'Failed to ping %s' % ip print('Failed to ping', ip)
break break
else: else:
print "Connectivity IPv4 OK!" print("Connectivity IPv4 OK!")
nodes: list[Re6stNode] = []
gateway1.screen(['miniupnpd', '-d', '-f', 'miniupnpd.conf', '-P',
'miniupnpd.pid', '-a', g1_if_1.name, '-i', g1_if_0_name])
nodes = []
gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid'
' -a %s -i %s' % (g1_if_1.name, g1_if_0_name))
@contextmanager @contextmanager
def new_network(registry, reg_addr, serial, ca): def new_network(registry: Re6stNode, reg_addr: str, serial: str, ca: str):
from OpenSSL import crypto from OpenSSL import crypto
import hashlib, sqlite3 import hashlib, sqlite3
os.path.exists(ca) or subprocess.check_call( os.path.exists(ca) or subprocess.check_call(
...@@ -225,16 +249,18 @@ def new_network(registry, reg_addr, serial, ca): ...@@ -225,16 +249,18 @@ def new_network(registry, reg_addr, serial, ca):
" -subj /CN=re6st.example.com/emailAddress=re6st@example.com" " -subj /CN=re6st.example.com/emailAddress=re6st@example.com"
" -set_serial %s -days %u" " -set_serial %s -days %u"
% (registry.name, ca, serial, CA_DAYS), shell=True) % (registry.name, ca, serial, CA_DAYS), shell=True)
with open(ca) as f: with open(ca, "rb") as f:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
fingerprint = "sha256:" + hashlib.sha256( fingerprint = "sha256:" + hashlib.sha256(
crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)).hexdigest() crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)).hexdigest()
db_path = "%s/registry.db" % registry.name db_path = "%s/registry.db" % registry.name
registry.screen("./py re6st-registry @%s/re6st-registry.conf" registry.screen([
" --db %s --mailhost %s -v%u" sys.executable, './py', 're6st-registry',
% (registry.name, db_path, os.path.abspath('mbox'), VERBOSE)) '@%s/re6st-registry.conf' % registry.name, '--db', db_path,
'--mailhost', os.path.abspath('mbox'), '-v%u' % VERBOSE,
])
registry_url = 'http://%s/' % reg_addr registry_url = 'http://%s/' % reg_addr
registry.Popen(('python', '-c', """if 1: registry.Popen((sys.executable, '-c', """if 1:
import socket, time import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True: while True:
...@@ -245,16 +271,21 @@ def new_network(registry, reg_addr, serial, ca): ...@@ -245,16 +271,21 @@ def new_network(registry, reg_addr, serial, ca):
time.sleep(.1) time.sleep(.1)
""")).wait() """)).wait()
db = sqlite3.connect(db_path, isolation_level=None) db = sqlite3.connect(db_path, isolation_level=None)
def new_node(node, folder, args='', prefix_len=None, registry=registry_url):
def new_node(node: Re6stNode, folder: str, args: list[str]=[],
prefix_len: Optional[int] = None, registry=registry_url):
nodes.append(node) nodes.append(node)
if not os.path.exists(folder + '/cert.crt'): if not os.path.exists(folder + '/cert.crt'):
dh_path = folder + '/dh2048.pem' dh_path = folder + '/dh2048.pem'
if not os.path.exists(dh_path): if not os.path.exists(dh_path):
os.symlink('../dh2048.pem', dh_path) os.symlink('../dh2048.pem', dh_path)
email = node.name + '@example.com' email = node.name + '@example.com'
p = node.Popen(('../py', 're6st-conf', '--registry', registry, p = node.Popen((
'--email', email, '--fingerprint', fingerprint), sys.executable, '../py', 're6st-conf',
stdin=subprocess.PIPE, cwd=folder) '--registry', registry,
'--email', email,
'--fingerprint', fingerprint,
), stdin=subprocess.PIPE, cwd=folder)
token = None token = None
while not token: while not token:
time.sleep(.1) time.sleep(.1)
...@@ -266,27 +297,30 @@ def new_network(registry, reg_addr, serial, ca): ...@@ -266,27 +297,30 @@ def new_network(registry, reg_addr, serial, ca):
p.communicate(str(token[0])) p.communicate(str(token[0]))
os.remove(dh_path) os.remove(dh_path)
os.remove(folder + '/ca.crt') os.remove(folder + '/ca.crt')
node.re6st_cmdline = ( node.re6st_cmdline = [
'./py re6stnet @%s/re6stnet.conf -v%u --registry %s' sys.executable, './py', 're6stnet', '@%s/re6stnet.conf' % folder,
' --console %s/run/console.sock %s' '-v%u' % VERBOSE, '--registry', registry, '--console',
) % (folder, VERBOSE, registry, folder, args) '%s/run/console.sock' % folder, *args,
]
node.screen(node.re6st_cmdline) node.screen(node.re6st_cmdline)
new_node(registry, registry.name, '--ip ' + reg_addr, registry='http://localhost/')
new_node(registry, registry.name, ['--ip', reg_addr],
registry='http://localhost/')
yield new_node yield new_node
db.close() db.close()
with new_network(registry, REGISTRY, REGISTRY_SERIAL, 'ca.crt') as new_node: with new_network(registry, REGISTRY, REGISTRY_SERIAL, 'ca.crt') as new_node:
new_node(machine1, 'm1', '-I%s' % m1_if_0.name) new_node(machine1, 'm1', ['-I%s' % m1_if_0.name])
new_node(machine2, 'm2', '--remote-gateway 10.1.1.1', prefix_len=77) new_node(machine2, 'm2', ['--remote-gateway', '10.1.1.1'], prefix_len=77)
new_node(machine3, 'm3', '-i%s' % m3_if_0.name) new_node(machine3, 'm3', ['-i%s' % m3_if_0.name])
new_node(machine4, 'm4', '-i%s' % m4_if_0.name) new_node(machine4, 'm4', ['-i%s' % m4_if_0.name])
new_node(machine5, 'm5', '-i%s' % m5_if_0.name) new_node(machine5, 'm5', ['-i%s' % m5_if_0.name])
new_node(machine6, 'm6', '-I%s' % m6_if_1.name) new_node(machine6, 'm6', ['-I%s' % m6_if_1.name])
new_node(machine7, 'm7') new_node(machine7, 'm7')
new_node(machine8, 'm8') new_node(machine8, 'm8')
with new_network(registry2, REGISTRY2, REGISTRY2_SERIAL, 'ca2.crt') as new_node: with new_network(registry2, REGISTRY2, REGISTRY2_SERIAL, 'ca2.crt') as new_node:
new_node(machine10, 'm10', '-i%s' % m10_if_0.name) new_node(machine10, 'm10', ['-i%s' % m10_if_0.name])
if args.ping: if args.ping:
for j, machine in enumerate(nodes): for j, machine in enumerate(nodes):
...@@ -297,10 +331,10 @@ if args.ping: ...@@ -297,10 +331,10 @@ if args.ping:
'2001:db8:43:1::1' if i == 10 else '2001:db8:43:1::1' if i == 10 else
# Only 1 address for machine2 because prefix_len = 80,+48 = 128 # Only 1 address for machine2 because prefix_len = 80,+48 = 128
'2001:db8:42:%s::1' % i '2001:db8:42:%s::1' % i
for i in xrange(11) for i in range(11)
if i != j] if i != j]
name = machine.name if machine.short[0] == 'R' else 'm' + machine.short name = machine.name if machine.short[0] == 'R' else 'm' + machine.short
machine.screen('python ping.py {} {}'.format(name, ' '.join(ips))) machine.screen(['python', 'ping.py', name] + ips)
class testHMAC(Thread): class testHMAC(Thread):
...@@ -314,30 +348,30 @@ class testHMAC(Thread): ...@@ -314,30 +348,30 @@ class testHMAC(Thread):
reg1_db.text_factory = reg2_db.text_factory = str reg1_db.text_factory = reg2_db.text_factory = str
m_net1 = 'registry', 'm1', 'm2', 'm3', 'm4', 'm5', 'm6', 'm7', 'm8' m_net1 = 'registry', 'm1', 'm2', 'm3', 'm4', 'm5', 'm6', 'm7', 'm8'
m_net2 = 'registry2', 'm10' m_net2 = 'registry2', 'm10'
print 'Testing HMAC, letting the time to machines to create tunnels...' print('Testing HMAC, letting the time to machines to create tunnels...')
time.sleep(45) time.sleep(45)
print 'Check that the initial HMAC config is deployed on network 1' print('Check that the initial HMAC config is deployed on network 1')
test_hmac.checkHMAC(reg1_db, m_net1) test_hmac.checkHMAC(reg1_db, m_net1)
print 'Test that a HMAC update works with nodes that are up' print('Test that a HMAC update works with nodes that are up')
registry.backticks_raise(updateHMAC) registry.backticks_raise(updateHMAC)
print 'Updated HMAC (config = hmac0 & hmac1), waiting...' print('Updated HMAC (config = hmac0 & hmac1), waiting...')
time.sleep(60) time.sleep(60)
print 'Checking HMAC on machines connected to registry 1...' print('Checking HMAC on machines connected to registry 1...')
test_hmac.checkHMAC(reg1_db, m_net1) test_hmac.checkHMAC(reg1_db, m_net1)
print ('Test that machines can update upon reboot ' + print('Test that machines can update upon reboot '
'when they were off during a HMAC update.') 'when they were off during a HMAC update.')
test_hmac.killRe6st(machine1) test_hmac.killRe6st(machine1)
print 'Re6st on machine 1 is stopped' print('Re6st on machine 1 is stopped')
time.sleep(5) time.sleep(5)
registry.backticks_raise(updateHMAC) registry.backticks_raise(updateHMAC)
print 'Updated HMAC on registry (config = hmac1 & hmac2), waiting...' print('Updated HMAC on registry (config = hmac1 & hmac2), waiting...')
time.sleep(60) time.sleep(60)
machine1.screen(machine1.re6st_cmdline) machine1.screen(machine1.re6st_cmdline)
print 'Started re6st on machine 1, waiting for it to get new conf' print('Started re6st on machine 1, waiting for it to get new conf')
time.sleep(60) time.sleep(60)
print 'Checking HMAC on machines connected to registry 1...' print('Checking HMAC on machines connected to registry 1...')
test_hmac.checkHMAC(reg1_db, m_net1) test_hmac.checkHMAC(reg1_db, m_net1)
print 'Testing of HMAC done!' print('Testing of HMAC done!')
# TODO: missing last step # TODO: missing last step
reg1_db.close() reg1_db.close()
reg2_db.close() reg2_db.close()
...@@ -349,8 +383,9 @@ if args.hmac: ...@@ -349,8 +383,9 @@ if args.hmac:
t.start() t.start()
del t del t
_ll = {} _ll: dict[str, tuple[Re6stNode, bool]] = {}
def node_by_ll(addr):
def node_by_ll(addr: str) -> tuple[Re6stNode, bool]:
try: try:
return _ll[addr] return _ll[addr]
except KeyError: except KeyError:
...@@ -368,27 +403,28 @@ def node_by_ll(addr): ...@@ -368,27 +403,28 @@ def node_by_ll(addr):
if a.startswith('10.42.'): if a.startswith('10.42.'):
assert not p % 8 assert not p % 8
_ll[socket.inet_ntoa(socket.inet_aton( _ll[socket.inet_ntoa(socket.inet_aton(
a)[:p/8].ljust(4, '\0'))] = n, t a)[:p//8].ljust(4, b'\0'))] = n, t
elif a.startswith('2001:db8:'): elif a.startswith('2001:db8:'):
assert not p % 8 assert not p % 8
a = socket.inet_ntop(socket.AF_INET6, a = socket.inet_ntop(socket.AF_INET6,
socket.inet_pton(socket.AF_INET6, socket.inet_pton(socket.AF_INET6,
a)[:p/8].ljust(16, '\0')) a)[:p // 8].ljust(16, b'\0'))
elif not a.startswith('fe80::'): elif not a.startswith('fe80::'):
continue continue
_ll[a] = n, t _ll[a] = n, t
return _ll[addr] return _ll[addr]
def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
graph = {} def route_svg(ipv4, z=4):
graph: dict[Re6stNode, dict[tuple[Re6stNode, bool], list[Re6stNode]]] = {}
for n in nodes: for n in nodes:
g = graph[n] = defaultdict(list) g = graph[n] = defaultdict(list)
for r in n.get_routes(): for r in n.get_routes():
if (r.prefix and r.prefix.startswith('10.42.') if ipv4 else if (r.prefix and r.prefix.startswith('10.42.') if ipv4 else
r.prefix is None or r.prefix.startswith('2001:db8:')): r.prefix is None or r.prefix.startswith('2001:db8:')):
try: try:
g[node_by_ll(r.nexthop)].append( if r.prefix:
node_by_ll(r.prefix)[0] if r.prefix else default) g[node_by_ll(r.nexthop)].append(node_by_ll(r.prefix)[0])
except KeyError: except KeyError:
pass pass
gv = ["digraph { splines = true; edge[color=grey, labelangle=0];"] gv = ["digraph { splines = true; edge[color=grey, labelangle=0];"]
...@@ -399,37 +435,37 @@ def route_svg(ipv4, z = 4, default = type('', (), {'short': None})): ...@@ -399,37 +435,37 @@ def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
gv.append('%s[pos="%s,%s!"];' gv.append('%s[pos="%s,%s!"];'
% (n.name, z * math.cos(a * i), z * math.sin(a * i))) % (n.name, z * math.cos(a * i), z * math.sin(a * i)))
l = [] l = []
for p, r in graph[n].iteritems(): for p, r in graph[n].items():
j = abs(nodes.index(p[0]) - i) j = abs(nodes.index(p[0]) - i)
l.append((min(j, N - j), p, r)) l.append((min(j, N - j), p, r))
for j, (l, (p, t), r) in enumerate(sorted(l)): for j, (_, (p2, t), r) in enumerate(sorted(l, key=lambda x: x[0])):
l = [] l2 = []
arrowhead = 'none' arrowhead = 'none'
for r in sorted(r.short for r in r): for r2 in sorted(r2.short or '' for r2 in r):
if r: if r2:
if r == p.short: if r2 == p2.short:
r = '<font color="grey">%s</font>' % r r2 = '<font color="grey">%s</font>' % r2
l.append(r) l2.append(r2)
else: else:
arrowhead = 'dot' arrowhead = 'dot'
if (n.name, p.name) in edges: if (n.name, p2.name) in edges:
r = 'penwidth=0' r3 = 'penwidth=0'
else: else:
edges.add((p.name, n.name)) edges.add((p2.name, n.name))
r = 'style=solid' if t else 'style=dashed' r3 = 'style=solid' if t else 'style=dashed'
gv.append( gv.append(
'%s -> %s [labeldistance=%u, headlabel=<%s>, arrowhead=%s, %s];' '%s -> %s [labeldistance=%u, headlabel=<%s>, arrowhead=%s, %s];'
% (p.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l), % (p2.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l2),
arrowhead, r)) arrowhead, r3))
gv.append('}\n') gv.append('}\n')
return subprocess.Popen(('neato', '-Tsvg'), return subprocess.run(
stdin=subprocess.PIPE, stdout=subprocess.PIPE, ('neato', '-Tsvg'), check=True, text=True, capture_output=True,
).communicate('\n'.join(gv))[0] input='\n'.join(gv)).stdout
if args.port: if args.port:
import SimpleHTTPServer, SocketServer import http.server, socketserver
class Handler(SimpleHTTPServer.SimpleHTTPRequestHandler): class Handler(http.server.SimpleHTTPRequestHandler):
_path_match = re.compile('/(.+)\.(html|svg)$').match _path_match = re.compile('/(.+)\.(html|svg)$').match
pages = 'ipv6', 'ipv4', 'tunnels' pages = 'ipv6', 'ipv4', 'tunnels'
...@@ -439,7 +475,7 @@ if args.port: ...@@ -439,7 +475,7 @@ if args.port:
try: try:
name, ext = self._path_match(self.path).groups() name, ext = self._path_match(self.path).groups()
page = self.pages.index(name) page = self.pages.index(name)
except AttributeError, ValueError: except (AttributeError, ValueError):
if self.path == '/': if self.path == '/':
self.send_response(302) self.send_response(302)
self.send_header('Location', self.pages[0] + '.html') self.send_header('Location', self.pages[0] + '.html')
...@@ -450,36 +486,46 @@ if args.port: ...@@ -450,36 +486,46 @@ if args.port:
if page < 2: if page < 2:
body = route_svg(page) body = route_svg(page)
else: else:
body = registry.Popen(('python', '-c', r"""if 1: out, err = (registry.Popen(('python3', '-c', r"""if 1:
import math, json import math, json
from re6st.registry import RegistryClient from re6st.registry import RegistryClient
g = json.loads(RegistryClient( topo = RegistryClient('http://localhost/').topology()
'http://localhost/').topology()) g = json.loads(topo)
if not g:
print('digraph { "empty topology" [shape="none"] }')
exit()
r = set(g.pop('', ())) r = set(g.pop('', ()))
a = set() a = set()
for v in g.itervalues(): for v in g.values():
a.update(v) a.update(v)
g.update(dict.fromkeys(a.difference(g), ())) g.update(dict.fromkeys(a.difference(g), ()))
print 'digraph {' print('digraph {')
a = 2 * math.pi / len(g) a = 2 * math.pi / len(g)
z = 4 z = 4
m2 = '%u/80' % (2 << 64) m2 = '%u/80' % (2 << 64)
title = lambda n: '2|80' if n == m2 else n title = lambda n: '2|80' if n == m2 else n
g = sorted((title(k), k in r, v) for k, v in g.iteritems()) g = sorted((title(k), k in r, v) for k, v in g.items())
for i, (n, r, v) in enumerate(g): for i, (n, r, v) in enumerate(g):
print '"%s"[pos="%s,%s!"%s];' % (title(n), print('"%s"[pos="%s,%s!"%s];' % (title(n),
z * math.cos(a * i), z * math.sin(a * i), z * math.cos(a * i), z * math.sin(a * i),
'' if r else ', style=dashed') '' if r else ', style=dashed'))
for v in v: for v in v:
print '"%s" -> "%s";' % (n, title(v)) print('"%s" -> "%s";' % (n, title(v)))
print '}' print('}')
"""), stdout=subprocess.PIPE, cwd="..").communicate()[0] """), stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd="..")
if body: .communicate())
body = subprocess.Popen(('neato', '-Tsvg'), if err:
stdin=subprocess.PIPE, stdout=subprocess.PIPE, self.send_error(500, explain='SVG generation failed: '
).communicate(body)[0] + err.decode(errors='replace'))
if not body: return
self.send_error(500) graph_body = out.decode("utf-8")
try:
body = subprocess.run(
('neato', '-Tsvg'), check=True, text=True,
capture_output=True,
input=graph_body).stdout
except subprocess.CalledProcessError as e:
self.send_error(500, explain='neato failed: ' + e.stderr)
return return
if ext == 'svg': if ext == 'svg':
mt = 'image/svg+xml' mt = 'image/svg+xml'
...@@ -508,14 +554,15 @@ if args.port: ...@@ -508,14 +554,15 @@ if args.port:
for i, x in enumerate(self.pages)), for i, x in enumerate(self.pages)),
body[body.find('<svg'):]) body[body.find('<svg'):])
self.send_response(200) self.send_response(200)
self.send_header('Content-Length', len(body)) body = body.encode("utf-8")
self.send_header('Content-Length', str(len(body)))
self.send_header('Content-type', mt + '; charset=utf-8') self.send_header('Content-type', mt + '; charset=utf-8')
self.end_headers() self.end_headers()
self.wfile.write(body) self.wfile.write(body)
class TCPServer(SocketServer.TCPServer): class TCPServer(socketserver.TCPServer):
allow_reuse_address = True allow_reuse_address = True
TCPServer(('', args.port), Handler).serve_forever() TCPServer(('', args.port), Handler).serve_forever()
import pdb; pdb.set_trace() breakpoint()
# -*- coding: utf-8 -*-
# Copyright 2010, 2011 INRIA
# Copyright 2011 Martín Ferrari <martin.ferrari@gmail.com>
#
# This file is contains patches to Nemu.
#
# Nemu is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License version 2, as published by the Free
# Software Foundation.
#
# Nemu is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>.
import re
import os
from new import function
from nemu.iproute import backticks, get_if_data, route, \
get_addr_data, get_all_route_data, interface
from nemu.interface import Switch, Interface
def _get_all_route_data():
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
ipdata += backticks([IP_PATH, "-o", "-f", "inet6", "route", "list"])
ifdata = get_if_data()[1]
ret = []
for line in ipdata.split("\n"):
if line == "":
continue
# PATCH: parse 'from'
# PATCH: 'dev' is missing on 'unreachable' ipv4 routes
match = re.match('(?:(unicast|local|broadcast|multicast|throw|'
r'unreachable|prohibit|blackhole|nat) )?(\S+)(?: from (\S+))?'
r'(?: via (\S+))?(?: dev (\S+))?.*(?: metric (\d+))?', line)
if not match:
raise RuntimeError("Invalid output from `ip route': `%s'" % line)
tipe = match.group(1) or "unicast"
prefix = match.group(2)
#src = match.group(3)
nexthop = match.group(4)
interface = ifdata[match.group(5) or "lo"]
metric = match.group(6)
if prefix == "default" or re.search(r'/0$', prefix):
prefix = None
prefix_len = 0
else:
match = re.match(r'([0-9a-f:.]+)(?:/(\d+))?$', prefix)
prefix = match.group(1)
prefix_len = int(match.group(2) or 32)
ret.append(route(tipe, prefix, prefix_len, nexthop, interface.index,
metric))
return ret
get_all_route_data.func_code = _get_all_route_data.func_code
interface__init__ = interface.__init__
def __init__(self, *args, **kw):
interface__init__(self, *args, **kw)
if self.name:
self.name = self.name.split('@',1)[0]
interface.__init__ = __init__
get_addr_data.orig = function(get_addr_data.func_code,
get_addr_data.func_globals)
def _get_addr_data():
byidx, bynam = get_addr_data.orig()
return byidx, {name.split('@',1)[0]: a for name, a in bynam.iteritems()}
get_addr_data.func_code = _get_addr_data.func_code
@staticmethod
def _gen_if_name():
n = Interface._gen_next_id()
# Max 15 chars
# XXX: We truncate pid to not exceed IFNAMSIZ on systems with 32-bits pids
# but we should find something better to avoid possible collision.
return "NETNSif-%.4x%.3x" % (os.getpid() % 0xffff, n)
Interface._gen_if_name = _gen_if_name
@staticmethod
def _gen_br_name():
n = Switch._gen_next_id()
# XXX: same as for _gen_if_name
return "NETNSbr-%.4x%.3x" % (os.getpid() % 0xffff, n)
Switch._gen_br_name = _gen_br_name
#!/usr/bin/env python #!/usr/bin/env python3
def __file__(): def __file__():
import argparse, os, sys import argparse, os, sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
...@@ -30,4 +30,5 @@ def __file__(): ...@@ -30,4 +30,5 @@ def __file__():
return os.path.join(sys.path[0], sys.argv[1]) return os.path.join(sys.path[0], sys.argv[1])
__file__ = __file__() __file__ = __file__()
execfile(__file__) with open(__file__) as f:
exec(compile(f.read(), __file__, 'exec'))
...@@ -34,7 +34,7 @@ def checkHMAC(db, machines): ...@@ -34,7 +34,7 @@ def checkHMAC(db, machines):
else: else:
i = 0 if hmac[0] else 1 i = 0 if hmac[0] else 1
if hmac[i] != sign or hmac[i+1] != accept: if hmac[i] != sign or hmac[i+1] != accept:
print 'HMAC config wrong for in %s' % args print('HMAC config wrong for in %s' % args)
rc = False rc = False
if rc: if rc:
print('All nodes use Babel with the correct HMAC configuration') print('All nodes use Babel with the correct HMAC configuration')
......
...@@ -5,7 +5,7 @@ if 're6st' not in sys.modules: ...@@ -5,7 +5,7 @@ if 're6st' not in sys.modules:
from re6st import utils, x509 from re6st import utils, x509
from OpenSSL import crypto from OpenSSL import crypto
with open("/etc/re6stnet/ca.crt") as f: with open("/etc/re6stnet/ca.crt", "rb") as f:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
network = x509.networkFromCa(ca) network = x509.networkFromCa(ca)
......
import json, logging, os, sqlite3, socket, subprocess, sys, time, zlib import base64, json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
from itertools import chain from itertools import chain
from .registry import RegistryClient from .registry import RegistryClient
from . import utils, version, x509 from . import utils, version, x509
class Cache(object): class Cache:
def __init__(self, db_path, registry, cert, db_size=200): def __init__(self, db_path: str, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix self._prefix = cert.prefix
self._db_size = db_size self._db_size = db_size
self._decrypt = cert.decrypt self._decrypt = cert.decrypt
...@@ -50,7 +50,7 @@ class Cache(object): ...@@ -50,7 +50,7 @@ class Cache(object):
self.warnProtocol() self.warnProtocol()
logging.info("Cache initialized.") logging.info("Cache initialized.")
def _open(self, path): def _open(self, path: str) -> sqlite3.Connection:
db = sqlite3.connect(path, isolation_level=None) db = sqlite3.connect(path, isolation_level=None)
db.text_factory = str db.text_factory = str
db.execute("PRAGMA synchronous = OFF") db.execute("PRAGMA synchronous = OFF")
...@@ -64,9 +64,8 @@ class Cache(object): ...@@ -64,9 +64,8 @@ class Cache(object):
return db return db
@staticmethod @staticmethod
def _selectConfig(execute): # BBB: blob def _selectConfig(execute):
return ((k, str(v) if type(v) is buffer else v) return execute("SELECT * FROM config")
for k, v in execute("SELECT * FROM config"))
def _loadConfig(self, config): def _loadConfig(self, config):
cls = self.__class__ cls = self.__class__
...@@ -89,24 +88,23 @@ class Cache(object): ...@@ -89,24 +88,23 @@ class Cache(object):
logging.info("Getting new network parameters from registry...") logging.info("Getting new network parameters from registry...")
try: try:
# TODO: When possible, the registry should be queried via the re6st. # TODO: When possible, the registry should be queried via the re6st.
x = json.loads(zlib.decompress( network_config = self._registry.getNetworkConfig(self._prefix)
self._registry.getNetworkConfig(self._prefix))) logging.debug('getNetworkConfig result: %r', network_config)
base64 = x.pop('', ()) x = json.loads(zlib.decompress(network_config))
base64_list = x.pop('', ())
config = {} config = {}
for k, v in x.iteritems(): for k, v in x.items():
k = str(k) k = str(k)
if k.startswith('babel_hmac'): if k.startswith('babel_hmac'):
if v: if v:
v = self._decrypt(v.decode('base64')) v = self._decrypt(base64.b64decode(v))
elif k in base64: elif k in base64_list:
v = v.decode('base64') v = base64.b64decode(v)
elif type(v) is unicode:
v = str(v)
elif isinstance(v, (list, dict)): elif isinstance(v, (list, dict)):
k += ':json' k += ':json'
v = json.dumps(v) v = json.dumps(v)
config[k] = v config[k] = v
except socket.error, e: except socket.error as e:
logging.warning(e) logging.warning(e)
return return
except Exception: except Exception:
...@@ -130,16 +128,12 @@ class Cache(object): ...@@ -130,16 +128,12 @@ class Cache(object):
remove.append(k) remove.append(k)
db.execute("DELETE FROM config WHERE name in ('%s')" db.execute("DELETE FROM config WHERE name in ('%s')"
% "','".join(remove)) % "','".join(remove))
# BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6
db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)", db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)",
((k, buffer(v) if k in base64 or config.items())
k.startswith('babel_hmac') else v) self._loadConfig(config.items())
for k, v in config.iteritems()))
self._loadConfig(config.iteritems())
return [k[:-5] if k.endswith(':json') else k return [k[:-5] if k.endswith(':json') else k
for k in chain(remove, (k for k in chain(remove, (k
for k, v in config.iteritems() for k, v in config.items()
if k not in old or old[k] != v))] if k not in old or old[k] != v))]
def warnProtocol(self): def warnProtocol(self):
...@@ -147,7 +141,7 @@ class Cache(object): ...@@ -147,7 +141,7 @@ class Cache(object):
logging.warning("There's a new version of re6stnet:" logging.warning("There's a new version of re6stnet:"
" you should update.") " you should update.")
def getDh(self, path): def getDh(self, path: str):
# We'd like to do a full check here but # We'd like to do a full check here but
# from OpenSSL import SSL # from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path) # SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
...@@ -179,11 +173,11 @@ class Cache(object): ...@@ -179,11 +173,11 @@ class Cache(object):
logging.trace("- %s: %s%s", prefix, address, logging.trace("- %s: %s%s", prefix, address,
' (blacklisted)' if _try else '') ' (blacklisted)' if _try else '')
def cacheMinimize(self, size): def cacheMinimize(self, size: int):
with self._db: with self._db:
self._cacheMinimize(size) self._cacheMinimize(size)
def _cacheMinimize(self, size): def _cacheMinimize(self, size: int):
a = self._db.execute( a = self._db.execute(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1", "SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1",
(size,)).fetchall() (size,)).fetchall()
...@@ -192,26 +186,26 @@ class Cache(object): ...@@ -192,26 +186,26 @@ class Cache(object):
q("DELETE FROM peer WHERE prefix IN (?)", a) q("DELETE FROM peer WHERE prefix IN (?)", a)
q("DELETE FROM volatile.stat WHERE peer IN (?)", a) q("DELETE FROM volatile.stat WHERE peer IN (?)", a)
def connecting(self, prefix, connecting): def connecting(self, prefix: str, connecting: bool):
self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?", self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?",
(connecting, prefix)) (connecting, prefix))
def resetConnecting(self): def resetConnecting(self):
self._db.execute("UPDATE volatile.stat SET try=0") self._db.execute("UPDATE volatile.stat SET try=0")
def getAddress(self, prefix): def getAddress(self, prefix: str) -> bool:
r = self._db.execute("SELECT address FROM peer, volatile.stat" r = self._db.execute("SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0", " WHERE prefix=? AND prefix=peer AND try=0",
(prefix,)).fetchone() (prefix,)).fetchone()
return r and r[0] return r and r[0]
@property @property
def my_address(self): def my_address(self) -> str:
for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"): for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"):
return x return x
@my_address.setter @my_address.setter
def my_address(self, value): def my_address(self, value: str):
if value: if value:
with self._db as db: with self._db as db:
db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)", db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)",
...@@ -229,18 +223,20 @@ class Cache(object): ...@@ -229,18 +223,20 @@ class Cache(object):
# IOW, one should probably always put our own address there. # IOW, one should probably always put our own address there.
_get_peer_sql = "SELECT %s FROM peer, volatile.stat" \ _get_peer_sql = "SELECT %s FROM peer, volatile.stat" \
" WHERE prefix=peer AND prefix!=? AND try=?" " WHERE prefix=peer AND prefix!=? AND try=?"
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address" def getPeerList(self, failed=False, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"): + " ORDER BY RANDOM()"):
return self._db.execute(__sql, (self._prefix, failed)) return self._db.execute(__sql, (self._prefix, failed))
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)"):
def getPeerCount(self, failed=False, __sql=_get_peer_sql % "COUNT(*)") \
-> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0] return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self): def getBootstrapPeer(self) -> tuple[str, str]:
logging.info('Getting Boot peer...') logging.info('Getting Boot peer...')
try: try:
bootpeer = self._registry.getBootstrapPeer(self._prefix) bootpeer = self._registry.getBootstrapPeer(self._prefix)
prefix, address = self._decrypt(bootpeer).split() prefix, address = self._decrypt(bootpeer).decode().split()
except (socket.error, subprocess.CalledProcessError, ValueError), e: except (socket.error, subprocess.CalledProcessError, ValueError) as e:
logging.warning('Failed to bootstrap (%s)', logging.warning('Failed to bootstrap (%s)',
e if bootpeer else 'no peer returned') e if bootpeer else 'no peer returned')
else: else:
...@@ -249,7 +245,7 @@ class Cache(object): ...@@ -249,7 +245,7 @@ class Cache(object):
return prefix, address return prefix, address
logging.warning('Buggy registry sent us our own address') logging.warning('Buggy registry sent us our own address')
def addPeer(self, prefix, address, set_preferred=False): def addPeer(self, prefix: str, address: str, set_preferred=False):
logging.debug('Adding peer %s: %s', prefix, address) logging.debug('Adding peer %s: %s', prefix, address)
with self._db: with self._db:
q = self._db.execute q = self._db.execute
...@@ -273,8 +269,8 @@ class Cache(object): ...@@ -273,8 +269,8 @@ class Cache(object):
q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address)) q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,)) q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))
def getCountry(self, ip): def getCountry(self, ip: str) -> str:
try: try:
return self._registry.getCountry(self._prefix, ip) return self._registry.getCountry(self._prefix, ip).decode()
except socket.error, e: except socket.error as e:
logging.warning('Failed to get country (%s)', ip) logging.warning('Failed to get country (%s)', ip)
#!/usr/bin/python2 #!/usr/bin/env python3
import argparse, atexit, binascii, errno, hashlib import argparse, atexit, binascii, errno, hashlib
import os, subprocess, sqlite3, sys, time import os, subprocess, sqlite3, sys, time
from OpenSSL import crypto from OpenSSL import crypto
...@@ -6,14 +6,14 @@ if 're6st' not in sys.modules: ...@@ -6,14 +6,14 @@ if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0])) sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, x509 from re6st import registry, utils, x509
def create(path, text=None, mode=0666): def create(path, text=None, mode=0o666):
fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, mode) fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, mode)
try: try:
os.write(fd, text) os.write(fd, text)
finally: finally:
os.close(fd) os.close(fd)
def loadCert(pem): def loadCert(pem: bytes):
return crypto.load_certificate(crypto.FILETYPE_PEM, pem) return crypto.load_certificate(crypto.FILETYPE_PEM, pem)
def main(): def main():
...@@ -68,17 +68,18 @@ def main(): ...@@ -68,17 +68,18 @@ def main():
fingerprint = binascii.a2b_hex(fingerprint) fingerprint = binascii.a2b_hex(fingerprint)
if hashlib.new(alg).digest_size != len(fingerprint): if hashlib.new(alg).digest_size != len(fingerprint):
raise ValueError("wrong size") raise ValueError("wrong size")
except StandardError, e: except Exception as e:
parser.error("invalid fingerprint: %s" % e) parser.error("invalid fingerprint: %s" % e)
if x509.fingerprint(ca, alg).digest() != fingerprint: if x509.fingerprint(ca, alg).digest() != fingerprint:
sys.exit("CA fingerprint doesn't match") sys.exit("CA fingerprint doesn't match")
else: else:
print "WARNING: it is strongly recommended to use --fingerprint option." print("WARNING: it is strongly recommended to use --fingerprint option.")
network = x509.networkFromCa(ca) network = x509.networkFromCa(ca)
if config.is_needed: if config.is_needed:
route, err = subprocess.Popen(('ip', '-6', '-o', 'route', 'get', with subprocess.Popen(('ip', '-6', '-o', 'route', 'get',
utils.ipFromBin(network)), utils.ipFromBin(network)),
stdout=subprocess.PIPE).communicate() stdout=subprocess.PIPE) as proc:
route, err = proc.communicate()
sys.exit(err or route and sys.exit(err or route and
utils.binFromIp(route.split()[8]).startswith(network)) utils.binFromIp(route.split()[8]).startswith(network))
...@@ -89,19 +90,20 @@ def main(): ...@@ -89,19 +90,20 @@ def main():
reserved = 'CN', 'serial' reserved = 'CN', 'serial'
req = crypto.X509Req() req = crypto.X509Req()
try: try:
with open(cert_path) as f: with open(cert_path, "rb") as f:
cert = loadCert(f.read()) cert = loadCert(f.read())
components = dict(cert.get_subject().get_components()) components = \
{k.decode(): v for k, v in cert.get_subject().get_components()}
for k in reserved: for k in reserved:
components.pop(k, None) components.pop(k, None)
except IOError, e: except IOError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
components = {} components = {}
if config.req: if config.req:
components.update(config.req) components.update(config.req)
subj = req.get_subject() subj = req.get_subject()
for k, v in components.iteritems(): for k, v in components.items():
if k in reserved: if k in reserved:
sys.exit(k + " field is reserved.") sys.exit(k + " field is reserved.")
if v: if v:
...@@ -116,35 +118,35 @@ def main(): ...@@ -116,35 +118,35 @@ def main():
token = '' token = ''
elif not token: elif not token:
if not config.email: if not config.email:
config.email = raw_input('Please enter your email address: ') config.email = input('Please enter your email address: ')
s.requestToken(config.email) s.requestToken(config.email)
token_advice = "Use --token to retry without asking a new token\n" token_advice = "Use --token to retry without asking a new token\n"
while not token: while not token:
token = raw_input('Please enter your token: ') token = input('Please enter your token: ')
try: try:
with open(key_path) as f: with open(key_path) as f:
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
key = None key = None
print "Reusing existing key." print("Reusing existing key.")
except IOError, e: except IOError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
bits = ca.get_pubkey().bits() bits = ca.get_pubkey().bits()
print "Generating %s-bit key ..." % bits print("Generating %s-bit key ..." % bits)
pkey = crypto.PKey() pkey = crypto.PKey()
pkey.generate_key(crypto.TYPE_RSA, bits) pkey.generate_key(crypto.TYPE_RSA, bits)
key = crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey) key = crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)
create(key_path, key, 0600) create(key_path, key, 0o600)
req.set_pubkey(pkey) req.set_pubkey(pkey)
req.sign(pkey, 'sha512') req.sign(pkey, 'sha512')
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req) req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req).decode()
# First make sure we can open certificate file for writing, # First make sure we can open certificate file for writing,
# to avoid using our token for nothing. # to avoid using our token for nothing.
cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0666) cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0o666)
print "Requesting certificate ..." print("Requesting certificate ...")
if config.location: if config.location:
cert = s.requestCertificate(token, req, location=config.location) cert = s.requestCertificate(token, req, location=config.location)
else: else:
...@@ -173,7 +175,7 @@ def main(): ...@@ -173,7 +175,7 @@ def main():
key_path)) key_path))
if not os.path.lexists(conf_path): if not os.path.lexists(conf_path):
create(conf_path, """\ create(conf_path, ("""\
registry %s registry %s
ca %s ca %s
cert %s cert %s
...@@ -187,14 +189,14 @@ key %s ...@@ -187,14 +189,14 @@ key %s
#O--verb #O--verb
#O3 #O3
""" % (config.registry, ca_path, cert_path, key_path, """ % (config.registry, ca_path, cert_path, key_path,
('country ' + config.location.split(',', 1)[0]) \ ('country ' + config.location.split(',', 1)[0])
if config.location else '')) if config.location else '')).encode())
print "Sample configuration file created." print("Sample configuration file created.")
cn = x509.subnetFromCert(cert) cn = x509.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn) subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \ print("Your subnet: %s/%u (CN=%s)"
% (utils.ipFromBin(subnet), len(subnet), cn) % (utils.ipFromBin(subnet), len(subnet), cn))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
#!/usr/bin/python2 #!/usr/bin/env python3
import atexit, errno, logging, os, shutil, signal import atexit, errno, logging, os, shutil, signal
import socket, struct, subprocess, sys import socket, struct, subprocess, sys
from collections import deque from collections import deque
...@@ -246,7 +246,7 @@ def main(): ...@@ -246,7 +246,7 @@ def main():
try: try:
from re6st.upnpigd import Forwarder from re6st.upnpigd import Forwarder
forwarder = Forwarder('re6stnet openvpn server') forwarder = Forwarder('re6stnet openvpn server')
except Exception, e: except Exception as e:
if ipv4: if ipv4:
raise raise
logging.info("%s: assume we are not NATed", e) logging.info("%s: assume we are not NATed", e)
...@@ -266,19 +266,13 @@ def main(): ...@@ -266,19 +266,13 @@ def main():
def call(cmd): def call(cmd):
logging.debug('%r', cmd) logging.debug('%r', cmd)
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, return subprocess.run(cmd, capture_output=True, check=True).stdout
stderr=subprocess.PIPE) def ip4(object: str, *args):
stdout, stderr = p.communicate()
if p.returncode:
raise EnvironmentError("%r failed with error %u\n%s"
% (' '.join(cmd), p.returncode, stderr))
return stdout
def ip4(object, *args):
args = ['ip', '-4', object, 'add'] + list(args) args = ['ip', '-4', object, 'add'] + list(args)
call(args) call(args)
args[3] = 'del' args[3] = 'del'
cleanup.append(lambda: subprocess.call(args)) cleanup.append(lambda: subprocess.call(args))
def ip(object, *args): def ip(object: str, *args):
args = ['ip', '-6', object, 'add'] + list(args) args = ['ip', '-6', object, 'add'] + list(args)
call(args) call(args)
args[3] = 'del' args[3] = 'del'
...@@ -299,7 +293,7 @@ def main(): ...@@ -299,7 +293,7 @@ def main():
timeout = 4 * cache.hello timeout = 4 * cache.hello
cleanup = [lambda: cache.cacheMinimize(config.client_count), cleanup = [lambda: cache.cacheMinimize(config.client_count),
lambda: shutil.rmtree(config.run, True)] lambda: shutil.rmtree(config.run, True)]
utils.makedirs(config.run, 0700) utils.makedirs(config.run, 0o700)
control_socket = os.path.join(config.run, 'babeld.sock') control_socket = os.path.join(config.run, 'babeld.sock')
if config.client_count and not config.client: if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(control_socket, tunnel_manager = tunnel.TunnelManager(control_socket,
...@@ -362,7 +356,7 @@ def main(): ...@@ -362,7 +356,7 @@ def main():
if not dh: if not dh:
dh = os.path.join(config.state, "dh.pem") dh = os.path.join(config.state, "dh.pem")
cache.getDh(dh) cache.getDh(dh)
for iface, (port, proto) in server_tunnels.iteritems(): for iface, (port, proto) in server_tunnels.items():
r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM) r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
utils.setCloexec(r) utils.setCloexec(r)
cleanup.append(plib.server(iface, config.max_clients, cleanup.append(plib.server(iface, config.max_clients,
...@@ -442,7 +436,7 @@ def main(): ...@@ -442,7 +436,7 @@ def main():
except: except:
pass pass
exit.release() exit.release()
except ReexecException, e: except ReexecException as e:
logging.info(e) logging.info(e)
except Exception: except Exception:
utils.log_exception() utils.log_exception()
...@@ -455,7 +449,7 @@ def main(): ...@@ -455,7 +449,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
try: try:
main() main()
except SystemExit, e: except SystemExit as e:
if type(e.code) is str: if type(e.code) is str:
if hasattr(logging, 'trace'): # utils.setupLog called if hasattr(logging, 'trace'): # utils.setupLog called
logging.critical(e.code) logging.critical(e.code)
......
#!/usr/bin/python2 #!/usr/bin/env python3
import httplib, logging, os, socket, sys import http.client, logging, os, socket, sys
from BaseHTTPServer import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
from SocketServer import ThreadingTCPServer from socketserver import ThreadingTCPServer
from urlparse import parse_qsl from urllib.parse import parse_qsl
if 're6st' not in sys.modules: if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0])) sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, version from re6st import registry, utils, version
...@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler): ...@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler):
path = self.path path = self.path
query = {} query = {}
else: else:
query = dict(parse_qsl(query, keep_blank_values=1, query = dict(parse_qsl(query, keep_blank_values=True,
strict_parsing=1)) strict_parsing=True))
_, path = path.split('/') _, path = path.split('/')
if not _: if not _:
return self.server.handle_request(self, path, query) return self.server.handle_request(self, path, query)
except Exception: except Exception:
logging.info(self.requestline, exc_info=1) logging.info(self.requestline, exc_info=True)
self.send_error(httplib.BAD_REQUEST) self.send_error(http.client.BAD_REQUEST)
def log_error(*args): def log_error(*args):
pass pass
......
...@@ -5,7 +5,7 @@ from . import utils ...@@ -5,7 +5,7 @@ from . import utils
uint16 = struct.Struct("!H") uint16 = struct.Struct("!H")
header = struct.Struct("!HI") header = struct.Struct("!HI")
class Struct(object): class Struct:
def __init__(self, format, *args): def __init__(self, format, *args):
if args: if args:
...@@ -29,39 +29,39 @@ class Struct(object): ...@@ -29,39 +29,39 @@ class Struct(object):
self.encode = encode self.encode = encode
self.decode = decode self.decode = decode
class Array(object): class Array:
def __init__(self, item): def __init__(self, item):
self._item = item self._item = item
def encode(self, buffer, value): def encode(self, buffer: bytes, value: list):
buffer += uint16.pack(len(value)) buffer += uint16.pack(len(value))
encode = self._item.encode encode = self._item.encode
for value in value: for value in value:
encode(buffer, value) encode(buffer, value)
def decode(self, buffer, offset=0): def decode(self, buffer: bytes, offset=0) -> tuple[int, list]:
r = [] r = []
o = offset + 2 o = offset + 2
decode = self._item.decode decode = self._item.decode
for i in xrange(*uint16.unpack_from(buffer, offset)): for i in range(*uint16.unpack_from(buffer, offset)):
o, x = decode(buffer, o) o, x = decode(buffer, o)
r.append(x) r.append(x)
return o, r return o, r
class String(object): class String:
@staticmethod @staticmethod
def encode(buffer, value): def encode(buffer: bytes, value: str):
buffer += value + "\0" buffer += value.encode("utf-8") + b'\0'
@staticmethod @staticmethod
def decode(buffer, offset=0): def decode(buffer: bytes, offset=0) -> tuple[int, str]:
i = buffer.index("\0", offset) i = buffer.index(0, offset)
return i + 1, buffer[offset:i] return i + 1, buffer[offset:i].decode("utf-8")
class Buffer(object): class Buffer:
def __init__(self): def __init__(self):
self._buf = bytearray() self._buf = bytearray()
...@@ -104,21 +104,6 @@ class Buffer(object): ...@@ -104,21 +104,6 @@ class Buffer(object):
self._seek(r) self._seek(r)
return value return value
try: # BBB: Python < 2.7.4 (http://bugs.python.org/issue10212)
uint16.unpack_from(bytearray(uint16.size))
except TypeError:
def unpack_from(self, struct):
r = self._r
x = r + struct.size
value = struct.unpack(buffer(self._buf)[r:x])
self._seek(x)
return value
def decode(self, decode):
r = self._r
size, value = decode(buffer(self._buf)[r:])
self._seek(r + size)
return value
# writing # writing
def send(self, socket, *args): def send(self, socket, *args):
...@@ -129,7 +114,7 @@ class Buffer(object): ...@@ -129,7 +114,7 @@ class Buffer(object):
struct.pack_into(self._buf, offset, *args) struct.pack_into(self._buf, offset, *args)
class Packet(object): class Packet:
response_dict = {} response_dict = {}
...@@ -149,7 +134,7 @@ class Packet(object): ...@@ -149,7 +134,7 @@ class Packet(object):
logging.trace('send %s%r', self.__class__.__name__, logging.trace('send %s%r', self.__class__.__name__,
(self.id,) + self.args) (self.id,) + self.args)
offset = len(buffer) offset = len(buffer)
buffer += '\0' * header.size buffer += bytes(header.size)
r = self.request r = self.request
if isinstance(r, Struct): if isinstance(r, Struct):
r.encode(buffer, self.args) r.encode(buffer, self.args)
...@@ -182,11 +167,11 @@ class ConnectionClosed(BabelException): ...@@ -182,11 +167,11 @@ class ConnectionClosed(BabelException):
return "connection to babeld closed (%s)" % self.args return "connection to babeld closed (%s)" % self.args
class Babel(object): class Babel:
_decode = None _decode = None
def __init__(self, socket_path, handler, network): def __init__(self, socket_path: str, handler, network: str):
self.socket_path = socket_path self.socket_path = socket_path
self.handler = handler self.handler = handler
self.network = network self.network = network
...@@ -206,11 +191,11 @@ class Babel(object): ...@@ -206,11 +191,11 @@ class Babel(object):
def select(*args): def select(*args):
try: try:
s.connect(self.socket_path) s.connect(self.socket_path)
except socket.error, e: except socket.error as e:
logging.debug("Can't connect to %r (%r)", self.socket_path, e) logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e return e
s.send("\1") s.send(b'\1')
s.setblocking(0) s.setblocking(False)
del self.select del self.select
self.socket = s self.socket = s
return self.select(*args) return self.select(*args)
...@@ -267,15 +252,18 @@ class Babel(object): ...@@ -267,15 +252,18 @@ class Babel(object):
unidentified = set(n) unidentified = set(n)
self.neighbours = neighbours = {} self.neighbours = neighbours = {}
a = len(self.network) a = len(self.network)
logging.info("Routes: %r", routes)
for route in routes: for route in routes:
assert route.flags & 1, route # installed assert route.flags & 1, route # installed
if route.prefix.startswith('\0\0\0\0\0\0\0\0\0\0\xff\xff'): if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'):
logging.warning("Ignoring IPv4 route: %r", route)
continue continue
assert route.neigh_address == route.nexthop, route assert route.neigh_address == route.nexthop, route
address = route.neigh_address, route.ifindex address = route.neigh_address, route.ifindex
neigh_routes = n[address] neigh_routes = n[address]
ip = utils.binFromRawIp(route.prefix) ip = utils.binFromRawIp(route.prefix)
if ip[:a] == self.network: if ip[:a] == self.network:
logging.debug("Route is on the network: %r", route)
prefix = ip[a:route.plen] prefix = ip[a:route.plen]
if prefix and not route.refmetric: if prefix and not route.refmetric:
neighbours[prefix] = neigh_routes neighbours[prefix] = neigh_routes
...@@ -290,7 +278,9 @@ class Babel(object): ...@@ -290,7 +278,9 @@ class Babel(object):
socket.inet_ntop(socket.AF_INET6, route.prefix), socket.inet_ntop(socket.AF_INET6, route.prefix),
route.plen) route.plen)
else: else:
logging.debug("Route is not on the network: %r", route)
prefix = None prefix = None
logging.debug("Adding route %r to %r", route, neigh_routes)
neigh_routes[1][prefix] = route neigh_routes[1][prefix] = route
self.locked.clear() self.locked.clear()
if unidentified: if unidentified:
...@@ -310,11 +300,11 @@ class Babel(object): ...@@ -310,11 +300,11 @@ class Babel(object):
pass pass
class iterRoutes(object): class iterRoutes:
_waiting = True _waiting = True
def __new__(cls, control_socket, network): def __new__(cls, control_socket: str, network: str):
self = object.__new__(cls) self = object.__new__(cls)
c = Babel(control_socket, self, network) c = Babel(control_socket, self, network)
c.request_dump() c.request_dump()
...@@ -323,7 +313,7 @@ class iterRoutes(object): ...@@ -323,7 +313,7 @@ class iterRoutes(object):
c.select(*args) c.select(*args)
utils.select(*args) utils.select(*args)
return (prefix return (prefix
for neigh_routes in c.neighbours.itervalues() for neigh_routes in c.neighbours.values()
for prefix in neigh_routes[1] for prefix in neigh_routes[1]
if prefix) if prefix)
......
import errno, os, socket, stat, threading import errno, os, socket, stat, threading
class Socket(object): class Socket:
def __init__(self, socket): def __init__(self, socket: socket.socket):
# In case that the default timeout is not None. # In case that the default timeout is not None.
socket.settimeout(None) socket.settimeout(None)
self._socket = socket self._socket = socket
self._buf = '' self._buf = b''
def close(self): def close(self):
self._socket.close() self._socket.close()
def write(self, data): def write(self, data: bytes):
self._socket.send(data) self._socket.send(data)
def readline(self): def readline(self) -> bytes:
recv = self._socket.recv recv = self._socket.recv
data = self._buf data = self._buf
while True: while True:
i = 1 + data.find('\n') i = 1 + data.find(b'\n')
if i: if i:
self._buf = data[i:] self._buf = data[i:]
return data[:i] return data[:i]
d = recv(4096) d = recv(4096)
data += d data += d
if not d: if not d:
self._buf = '' self._buf = b''
return data return data
def flush(self): def flush(self):
...@@ -37,14 +37,14 @@ class Socket(object): ...@@ -37,14 +37,14 @@ class Socket(object):
try: try:
self._socket.recv(0) self._socket.recv(0)
return True return True
except socket.error, (err, _): except socket.error as e:
if err != errno.EAGAIN: if e.errno != errno.EAGAIN:
raise raise
self._socket.setblocking(1) self._socket.setblocking(1)
return False return False
class Console(object): class Console:
def __init__(self, path, pdb): def __init__(self, path, pdb):
self.path = path self.path = path
...@@ -52,7 +52,7 @@ class Console(object): ...@@ -52,7 +52,7 @@ class Console(object):
socket.SOCK_STREAM | socket.SOCK_CLOEXEC) socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
try: try:
self._removeSocket() self._removeSocket()
except OSError, e: except OSError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
s.bind(path) s.bind(path)
......
...@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs ...@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs
freeifaddrs.restype = None freeifaddrs.restype = None
freeifaddrs.argtypes = [POINTER(struct_ifaddrs)] freeifaddrs.argtypes = [POINTER(struct_ifaddrs)]
class unpacker(object): class unpacker:
def __init__(self, buf): def __init__(self, buf):
self._buf = buf self._buf = buf
...@@ -55,7 +55,7 @@ class unpacker(object): ...@@ -55,7 +55,7 @@ class unpacker(object):
self._offset += s.size self._offset += s.size
return result return result
class PimDm(object): class PimDm:
def __init__(self): def __init__(self):
s_netlink = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE) s_netlink = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)
......
#!/usr/bin/python2 -S #!/usr/bin/env -S python3 -S
import os, sys import os, sys
script_type = os.environ['script_type'] script_type = os.environ['script_type']
...@@ -14,4 +14,5 @@ if script_type == 'up': ...@@ -14,4 +14,5 @@ if script_type == 'up':
if script_type == 'route-up': if script_type == 'route-up':
import time import time
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(), os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip']))) int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip']))
.encode())
#!/usr/bin/python2 -S #!/usr/bin/env -S python3 -S
import os, sys import os, sys
script_type = os.environ['script_type'] script_type = os.environ['script_type']
...@@ -7,10 +7,11 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6'] ...@@ -7,10 +7,11 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events # Write into pipe connect/disconnect events
fd = int(sys.argv[1]) fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'], os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip)))) int(os.environ['tls_serial_0']), external_ip)))
.encode("utf-8"))
if script_type == 'client-connect': if script_type == 'client-connect':
if os.read(fd, 1) == '\0': if os.read(fd, 1) == b'\0':
sys.exit(1) sys.exit(1)
# Send client its external ip address # Send client its external ip address
with open(sys.argv[2], 'w') as f: with open(sys.argv[2], 'w') as f:
......
import binascii
import logging, errno, os import logging, errno, os
from typing import Optional
from . import utils from . import utils
here = os.path.realpath(os.path.dirname(__file__)) here = os.path.realpath(os.path.dirname(__file__))
ovpn_server = os.path.join(here, 'ovpn-server') ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client') ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log = None ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw): def openvpn(iface: str, encrypt, *args, **kw) -> utils.Popen:
args = ['openvpn', args = ['openvpn',
'--dev-type', 'tap', '--dev-type', 'tap',
'--dev', iface, '--dev', iface,
...@@ -19,13 +21,16 @@ def openvpn(iface, encrypt, *args, **kw): ...@@ -19,13 +21,16 @@ def openvpn(iface, encrypt, *args, **kw):
if ovpn_log: if ovpn_log:
args += '--log-append', os.path.join(ovpn_log, '%s.log' % iface), args += '--log-append', os.path.join(ovpn_log, '%s.log' % iface),
if not encrypt: if not encrypt:
# TODO: --ncp-disable was deprecated in OpenVPN 2.5 and removed in 2.6
# and is no longer necessary in those versions.
args += '--cipher', 'none', '--ncp-disable' args += '--cipher', 'none', '--ncp-disable'
logging.debug('%r', args) logging.debug('%r', args)
return utils.Popen(args, **kw) return utils.Popen(args, **kw)
ovpn_link_mtu_dict = {'udp4': 1432, 'udp6': 1450} ovpn_link_mtu_dict = {'udp4': 1432, 'udp6': 1450}
def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw): def server(iface: str, max_clients: int, dh_path: str, fd: int,
port: int, proto: str, encrypt: bool, *args, **kw) -> utils.Popen:
if proto == 'udp': if proto == 'udp':
proto = 'udp4' proto = 'udp4'
client_script = '%s %s' % (ovpn_server, fd) client_script = '%s %s' % (ovpn_server, fd)
...@@ -43,10 +48,11 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw): ...@@ -43,10 +48,11 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
'--max-clients', str(max_clients), '--max-clients', str(max_clients),
'--port', str(port), '--port', str(port),
'--proto', proto, '--proto', proto,
*args, **kw) *args, pass_fds=[fd], **kw)
def client(iface, address_list, encrypt, *args, **kw): def client(iface: str, address_list: list[tuple[str, int, str]],
encrypt: bool, *args, **kw) -> utils.Popen:
remote = ['--nobind', '--client'] remote = ['--nobind', '--client']
# XXX: We'd like to pass <connection> sections at command-line. # XXX: We'd like to pass <connection> sections at command-line.
link_mtu = set() link_mtu = set()
...@@ -62,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw): ...@@ -62,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw):
return openvpn(iface, encrypt, *remote, **kw) return openvpn(iface, encrypt, *remote, **kw)
def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, def router(ip: tuple[str, int], ip4, rt6: tuple[str, bool, bool],
control_socket, default, hmac, *args, **kw): hello_interval: int, log_path: str, state_path: str, pidfile: str,
control_socket: str, default: str,
hmac: tuple[bytes | None, bytes | None], *args, **kw) -> utils.Popen:
network, gateway, has_ipv6_subtrees = rt6 network, gateway, has_ipv6_subtrees = rt6
network_mask = int(network[network.index('/')+1:]) network_mask = int(network[network.index('/')+1:])
ip, n = ip ip, n = ip
...@@ -80,9 +88,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, ...@@ -80,9 +88,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny', '-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)] '-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign: if hmac_sign:
def key(cmd, id, value): def key(cmd: list[str], id: str, value: bytes):
cmd += '-C', ('key type blake2s128 id %s value %s' % cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, value.encode('hex'))) (id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign) key(cmd, 'sign', hmac_sign)
default += ' key sign' default += ' key sign'
if hmac_accept is not None: if hmac_accept is not None:
...@@ -132,7 +140,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, ...@@ -132,7 +140,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
# WKRD: babeld fails to start if pidfile already exists # WKRD: babeld fails to start if pidfile already exists
try: try:
os.remove(pidfile) os.remove(pidfile)
except OSError, e: except OSError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
logging.info('%r', cmd) logging.info('%r', cmd)
......
...@@ -18,16 +18,19 @@ Authenticated communication: ...@@ -18,16 +18,19 @@ Authenticated communication:
- the last one that was really used by the client (!hello) - the last one that was really used by the client (!hello)
- the one of the last handshake (hello) - the one of the last handshake (hello)
""" """
import base64, hmac, hashlib, httplib, inspect, json, logging import base64, hmac, hashlib, http.client, inspect, json, logging
import mailbox, os, platform, random, select, smtplib, socket, sqlite3 import mailbox, os, platform, random, select, smtplib, socket, sqlite3
import string, sys, threading, time, weakref, zlib import string, sys, threading, time, weakref, zlib
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Iterator
from datetime import datetime from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from http.server import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText from email.mime.text import MIMEText
from operator import itemgetter from operator import itemgetter
from typing import Tuple
from OpenSSL import crypto from OpenSSL import crypto
from urllib import splittype, splithost, unquote, urlencode from urllib.parse import urlparse, unquote, urlencode
from . import ctl, tunnel, utils, version, x509 from . import ctl, tunnel, utils, version, x509
HMAC_HEADER = "Re6stHMAC" HMAC_HEADER = "Re6stHMAC"
...@@ -35,13 +38,13 @@ RENEW_PERIOD = 30 * 86400 ...@@ -35,13 +38,13 @@ RENEW_PERIOD = 30 * 86400
BABEL_HMAC = 'babel_hmac0', 'babel_hmac1', 'babel_hmac2' BABEL_HMAC = 'babel_hmac0', 'babel_hmac1', 'babel_hmac2'
def rpc(f): def rpc(f):
args, varargs, varkw, defaults = inspect.getargspec(f) argspec = inspect.getfullargspec(f)
assert not (varargs or varkw), f assert not (argspec.varargs or argspec.varkw), f
if not defaults: sig = inspect.signature(f)
defaults = () sig = sig.replace(parameters=[v.replace(annotation=inspect.Parameter.empty)
i = len(args) - len(defaults) for v in sig.parameters.values()][1:],
f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:i] return_annotation=inspect.Signature.empty)
+ map("%s=%r".__mod__, zip(args[i:], defaults)))) f.getcallargs = eval("lambda %s: locals()" % str(sig)[1:-1])
return f return f
def rpc_private(f): def rpc_private(f):
...@@ -53,13 +56,15 @@ class HTTPError(Exception): ...@@ -53,13 +56,15 @@ class HTTPError(Exception):
pass pass
class RegistryServer(object): class RegistryServer:
peers = 0, () peers = 0, ()
cert_duration = 365 * 86400 cert_duration = 365 * 86400
sessions: dict[str, list[tuple[bytes, int]]]
def _geoiplookup(self, ip): def _geoiplookup(self, ip):
raise HTTPError(httplib.BAD_REQUEST) raise HTTPError(http.client.BAD_REQUEST)
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
...@@ -76,7 +81,7 @@ class RegistryServer(object): ...@@ -76,7 +81,7 @@ class RegistryServer(object):
if x and not x.startswith('#'): if x and not x.startswith('#'):
x = x.split() x = x.split()
self.community_map[x.pop(0)] = x self.community_map[x.pop(0)] = x
if sum('*' in x for x in self.community_map.itervalues()) != 1: if sum('*' in x for x in self.community_map.values()) != 1:
sys.exit("Invalid community configuration: missing or multiple default location ('*')") sys.exit("Invalid community configuration: missing or multiple default location ('*')")
else: else:
self.community_map[''] = '*' self.community_map[''] = '*'
...@@ -91,7 +96,7 @@ class RegistryServer(object): ...@@ -91,7 +96,7 @@ class RegistryServer(object):
"name TEXT PRIMARY KEY NOT NULL", "name TEXT PRIMARY KEY NOT NULL",
"value") "value")
self.prefix = self.getConfig("prefix", None) self.prefix = self.getConfig("prefix", None)
self.version = str(self.getConfig("version", "\0")) # BBB: blob self.version = self.getConfig("version", b'\0')
utils.sqliteCreateTable(self.db, "token", utils.sqliteCreateTable(self.db, "token",
"token TEXT PRIMARY KEY NOT NULL", "token TEXT PRIMARY KEY NOT NULL",
"email TEXT NOT NULL", "email TEXT NOT NULL",
...@@ -102,6 +107,7 @@ class RegistryServer(object): ...@@ -102,6 +107,7 @@ class RegistryServer(object):
"email TEXT", "email TEXT",
"cert TEXT") "cert TEXT")
if not self.db.execute("SELECT 1 FROM cert LIMIT 1").fetchone(): if not self.db.execute("SELECT 1 FROM cert LIMIT 1").fetchone():
logging.debug("No existing certs found; creating an unallocated cert")
self.db.execute("INSERT INTO cert VALUES ('',null,null)") self.db.execute("INSERT INTO cert VALUES ('',null,null)")
prev = '-' prev = '-'
...@@ -139,10 +145,10 @@ class RegistryServer(object): ...@@ -139,10 +145,10 @@ class RegistryServer(object):
if self.geoip_db: if self.geoip_db:
from geoip2 import database, errors from geoip2 import database, errors
country = database.Reader(self.geoip_db).country country = database.Reader(self.geoip_db).country
def geoiplookup(ip): def geoiplookup(ip: str) -> Tuple[str, str]:
try: try:
req = country(ip) req = country(ip)
return req.country.iso_code.encode(), req.continent.code.encode() return req.country.iso_code, req.continent.code
except (errors.AddressNotFoundError, ValueError): except (errors.AddressNotFoundError, ValueError):
return '*', '*' return '*', '*'
self._geoiplookup = geoiplookup self._geoiplookup = geoiplookup
...@@ -157,6 +163,11 @@ class RegistryServer(object): ...@@ -157,6 +163,11 @@ class RegistryServer(object):
else: else:
self.newHMAC(0) self.newHMAC(0)
def close(self):
self.sock.close()
self.db.close()
self.ctl.close()
def getConfig(self, name, *default): def getConfig(self, name, *default):
r, = next(self.db.execute( r, = next(self.db.execute(
"SELECT value FROM config WHERE name=?", (name,)), default) "SELECT value FROM config WHERE name=?", (name,)), default)
...@@ -169,8 +180,8 @@ class RegistryServer(object): ...@@ -169,8 +180,8 @@ class RegistryServer(object):
def updateNetworkConfig(self, _it0=itemgetter(0)): def updateNetworkConfig(self, _it0=itemgetter(0)):
kw = { kw = {
'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125', 'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125',
'crl': map(_it0, self.db.execute( 'crl': list(map(_it0, self.db.execute(
"SELECT serial FROM crl ORDER BY serial")), "SELECT serial FROM crl ORDER BY serial"))),
'protocol': version.protocol, 'protocol': version.protocol,
'registry_prefix': self.prefix, 'registry_prefix': self.prefix,
} }
...@@ -184,32 +195,32 @@ class RegistryServer(object): ...@@ -184,32 +195,32 @@ class RegistryServer(object):
config = json.dumps(kw, sort_keys=True) config = json.dumps(kw, sort_keys=True)
if config != self.getConfig('last_config', None): if config != self.getConfig('last_config', None):
self.increaseVersion() self.increaseVersion()
# BBB: Use buffer because of http://bugs.python.org/issue13676 self.setConfig('version', self.version)
# on Python 2.6
self.setConfig('version', buffer(self.version))
self.setConfig('last_config', config) self.setConfig('last_config', config)
self.sendto(self.prefix, 0) self.sendto(self.prefix, 0)
# The following entry lists values that are base64-encoded. # The following entry lists values that are base64-encoded.
kw[''] = 'version', kw[''] = 'version',
kw['version'] = self.version.encode('base64') kw['version'] = base64.b64encode(self.version).decode()
self.network_config = kw self.network_config = kw
def increaseVersion(self): def increaseVersion(self):
x = utils.packInteger(1 + utils.unpackInteger(self.version)[0]) x = utils.packInteger(1 + utils.unpackInteger(self.version)[0])
self.version = x + self.cert.sign(x) self.version = x + self.cert.sign(x)
def sendto(self, prefix, code): def sendto(self, prefix: str, code: int):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT)) self.sock.sendto(prefix.encode() + bytes((0, code)),
('::1', tunnel.PORT))
def recv(self, code): def recv(self, code: int) -> tuple[str, str] | tuple[None, None]:
try: try:
prefix, msg = self.sock.recv(1<<16).split('\0', 1) prefix, msg = self.sock.recv(1 << 16).split(b'\0', 1)
int(prefix, 2) int(prefix, 2)
except ValueError: except ValueError:
pass pass
else: else:
if msg and ord(msg[0]) == code: if len(msg) >= 1 and msg[0] == code:
return prefix, msg[1:] return prefix.decode(), msg[1:].decode()
logging.error("Invalid message or unexpected code: %r", msg)
return None, None return None, None
def select(self, r, w, t): def select(self, r, w, t):
...@@ -235,7 +246,7 @@ class RegistryServer(object): ...@@ -235,7 +246,7 @@ class RegistryServer(object):
def babel_dump(self): def babel_dump(self):
self._wait_dump = False self._wait_dump = False
def iterCert(self): def iterCert(self) -> Iterator[Tuple[crypto.X509, str, str]]:
for prefix, email, cert in self.db.execute( for prefix, email, cert in self.db.execute(
"SELECT * FROM cert WHERE cert IS NOT NULL"): "SELECT * FROM cert WHERE cert IS NOT NULL"):
try: try:
...@@ -263,8 +274,10 @@ class RegistryServer(object): ...@@ -263,8 +274,10 @@ class RegistryServer(object):
x = x509.notAfter(cert) x = x509.notAfter(cert)
if x <= old: if x <= old:
if prefix == self.prefix: if prefix == self.prefix:
logging.critical("Refuse to delete certificate" logging.critical(
" of main node: wrong clock ?") "Refuse to delete certificate of main node:"
" wrong clock ? Alternatively, the database"
" might be in an inconsistent state.")
sys.exit(1) sys.exit(1)
logging.info("Delete %s: %s (invalid since %s)", logging.info("Delete %s: %s (invalid since %s)",
"certificate requested by '%s'" % email "certificate requested by '%s'" % email
...@@ -286,14 +299,15 @@ class RegistryServer(object): ...@@ -286,14 +299,15 @@ class RegistryServer(object):
x_forwarded_for = request.headers.get('X-Forwarded-For') x_forwarded_for = request.headers.get('X-Forwarded-For')
if request.client_address[0] not in authorized_origin or \ if request.client_address[0] not in authorized_origin or \
x_forwarded_for and x_forwarded_for not in authorized_origin: x_forwarded_for and x_forwarded_for not in authorized_origin:
return request.send_error(httplib.FORBIDDEN) return request.send_error(http.client.FORBIDDEN)
key = m.getcallargs(**kw).get('cn') key = m.getcallargs(**kw).get('cn')
if key: if key:
h = base64.b64decode(request.headers[HMAC_HEADER]) h = base64.b64decode(request.headers[HMAC_HEADER])
with self.lock: with self.lock:
session = self.sessions[key] session = self.sessions[key]
for key, protocol in session: for key, protocol in session:
if h == hmac.HMAC(key, request.path, hashlib.sha1).digest(): if h == hmac.HMAC(key, request.path.encode(), hashlib.sha1
).digest():
break break
else: else:
raise Exception("Wrong HMAC") raise Exception("Wrong HMAC")
...@@ -313,29 +327,31 @@ class RegistryServer(object): ...@@ -313,29 +327,31 @@ class RegistryServer(object):
request.headers.get("host")) request.headers.get("host"))
try: try:
result = m(**kw) result = m(**kw)
except HTTPError, e: except HTTPError as e:
return request.send_error(*e.args) return request.send_error(*e.args)
except: except:
logging.warning(request.requestline, exc_info=1) logging.warning(request.requestline, exc_info=True)
return request.send_error(httplib.INTERNAL_SERVER_ERROR) return request.send_error(http.client.INTERNAL_SERVER_ERROR)
if result: if result:
request.send_response(httplib.OK) if type(result) is str:
result = result.encode("utf-8")
request.send_response(http.client.OK)
request.send_header("Content-Length", str(len(result))) request.send_header("Content-Length", str(len(result)))
else: else:
request.send_response(httplib.NO_CONTENT) request.send_response(http.client.NO_CONTENT)
if key: if key:
request.send_header(HMAC_HEADER, base64.b64encode( request.send_header(HMAC_HEADER, base64.b64encode(
hmac.HMAC(key, result, hashlib.sha1).digest())) hmac.HMAC(key, result, hashlib.sha1).digest()).decode("ascii"))
request.end_headers() request.end_headers()
if result: if result:
request.wfile.write(result) request.wfile.write(result)
def getPeerProtocol(self, cn): def getPeerProtocol(self, cn: str) -> int:
session, = self.sessions[cn] session, = self.sessions[cn]
return session[1] return session[1]
@rpc @rpc
def hello(self, client_prefix, protocol='1'): def hello(self, client_prefix: str, protocol='1') -> bytes:
with self.lock: with self.lock:
cert = self.getCert(client_prefix) cert = self.getCert(client_prefix)
key = utils.newHmacSecret() key = utils.newHmacSecret()
...@@ -345,29 +361,32 @@ class RegistryServer(object): ...@@ -345,29 +361,32 @@ class RegistryServer(object):
assert len(key) == len(sign) assert len(key) == len(sign)
return key + sign return key + sign
def getCert(self, client_prefix): def getCert(self, client_prefix: str) -> bytes:
assert self.lock.locked() assert self.lock.locked()
return self.db.execute("SELECT cert FROM cert" cert = self.db.execute("SELECT cert FROM cert"
" WHERE prefix=? AND cert IS NOT NULL", " WHERE prefix=? AND cert IS NOT NULL",
(client_prefix,)).next()[0] (client_prefix,)).fetchone()
assert cert, (f"No cert result for prefix '{client_prefix}';"
f" this should not happen, DB is inconsistent")
return cert[0]
@rpc_private @rpc_private
def isToken(self, token): def isToken(self, token: str):
with self.lock: with self.lock:
if self.db.execute("SELECT 1 FROM token WHERE token = ?", if self.db.execute("SELECT 1 FROM token WHERE token = ?",
(token,)).fetchone(): (token,)).fetchone():
return "1" return b"1"
@rpc_private @rpc_private
def deleteToken(self, token): def deleteToken(self, token: str):
with self.lock: with self.lock:
self.db.execute("DELETE FROM token WHERE token = ?", (token,)) self.db.execute("DELETE FROM token WHERE token = ?", (token,))
@rpc_private @rpc_private
def addToken(self, email, token): def addToken(self, email: str, token: str | None) -> str:
prefix_len = self.config.prefix_length prefix_len = self.config.prefix_length
if not prefix_len: if not prefix_len:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(http.client.FORBIDDEN)
request = token is None request = token is None
with self.lock: with self.lock:
while True: while True:
...@@ -381,7 +400,7 @@ class RegistryServer(object): ...@@ -381,7 +400,7 @@ class RegistryServer(object):
break break
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
if not request: if not request:
raise HTTPError(httplib.CONFLICT) raise HTTPError(http.client.CONFLICT)
self.timeout = 1 self.timeout = 1
if request: if request:
return token return token
...@@ -389,7 +408,7 @@ class RegistryServer(object): ...@@ -389,7 +408,7 @@ class RegistryServer(object):
@rpc @rpc
def requestToken(self, email): def requestToken(self, email):
if not self.config.mailhost: if not self.config.mailhost:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(http.client.FORBIDDEN)
token = self.addToken(email, None) token = self.addToken(email, None)
...@@ -418,11 +437,11 @@ class RegistryServer(object): ...@@ -418,11 +437,11 @@ class RegistryServer(object):
s.quit() s.quit()
def getCommunity(self, country, continent): def getCommunity(self, country, continent):
for prefix, location_list in self.community_map.iteritems(): for prefix, location_list in self.community_map.items():
if country in location_list: if country in location_list:
return prefix return prefix
default = '' default = ''
for prefix, location_list in self.community_map.iteritems(): for prefix, location_list in self.community_map.items():
if continent in location_list: if continent in location_list:
return prefix return prefix
if '*' in location_list: if '*' in location_list:
...@@ -430,13 +449,14 @@ class RegistryServer(object): ...@@ -430,13 +449,14 @@ class RegistryServer(object):
return default return default
def mergePrefixes(self): def mergePrefixes(self):
logging.debug("Merging prefixes")
q = self.db.execute q = self.db.execute
prev_prefix = None prev_prefix = None
max_len = 128, max_len = 128,
while True: while True:
max_len = q("SELECT max(length(prefix)) FROM cert" max_len = q("SELECT max(length(prefix)) FROM cert"
" WHERE cert is null AND length(prefix) < ?", " WHERE cert is null AND length(prefix) < ?",
max_len).next() max_len).fetchone()
if not max_len[0]: if not max_len[0]:
break break
for prefix, in q("SELECT prefix FROM cert" for prefix, in q("SELECT prefix FROM cert"
...@@ -452,6 +472,7 @@ class RegistryServer(object): ...@@ -452,6 +472,7 @@ class RegistryServer(object):
prev_prefix = prefix prev_prefix = prefix
def newPrefix(self, prefix_len, community): def newPrefix(self, prefix_len, community):
logging.info("Allocating /%u prefix for %s", prefix_len, community)
community_len = len(community) community_len = len(community)
prefix_len += community_len prefix_len += community_len
max_len = 128 - len(self.network) max_len = 128 - len(self.network)
...@@ -460,25 +481,25 @@ class RegistryServer(object): ...@@ -460,25 +481,25 @@ class RegistryServer(object):
while True: while True:
try: try:
# Find longest free prefix whithin community. # Find longest free prefix whithin community.
prefix, = q( prefix, = next(q(
"SELECT prefix FROM cert" "SELECT prefix FROM cert"
" WHERE prefix LIKE ?" " WHERE prefix LIKE ?"
" AND length(prefix) <= ? AND cert is null" " AND length(prefix) <= ? AND cert is null"
" ORDER BY length(prefix) DESC", " ORDER BY length(prefix) DESC",
(community + '%', prefix_len)).next() (community + '%', prefix_len)))
except StopIteration: except StopIteration:
# Community not yet allocated? # Community not yet allocated?
# There should be exactly 1 row whose # There should be exactly 1 row whose
# prefix is the beginning of community. # prefix is the beginning of community.
prefix, x = q("SELECT prefix, cert FROM cert" prefix, x = next(q("SELECT prefix, cert FROM cert"
" WHERE substr(?,1,length(prefix)) = prefix", " WHERE substr(?,1,length(prefix)) = prefix",
(community,)).next() (community,)))
if x is not None: if x is not None:
logging.error('No more free /%u prefix available', logging.error('No more free /%u prefix available',
prefix_len) prefix_len)
raise raise
# Split the tree until prefix has wanted length. # Split the tree until prefix has wanted length.
for x in xrange(len(prefix), prefix_len): for x in range(len(prefix), prefix_len):
# Prefix starts with community, then we complete with 0. # Prefix starts with community, then we complete with 0.
x = community[x] if x < community_len else '0' x = community[x] if x < community_len else '0'
q("UPDATE cert SET prefix = ? WHERE prefix = ?", q("UPDATE cert SET prefix = ? WHERE prefix = ?",
...@@ -490,17 +511,19 @@ class RegistryServer(object): ...@@ -490,17 +511,19 @@ class RegistryServer(object):
q("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,)) q("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
@rpc @rpc
def requestCertificate(self, token, req, location='', ip=''): def requestCertificate(self, token: str | None, req: bytes,
location: str='', ip: str=''):
logging.debug("Requesting certificate with token %s", token)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req) req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req)
with self.lock: with self.lock:
with self.db: with self.db:
if token: if token:
if not self.config.prefix_length: if not self.config.prefix_length:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(http.client.FORBIDDEN)
try: try:
token, email, prefix_len, _ = self.db.execute( token, email, prefix_len, _ = next(self.db.execute(
"SELECT * FROM token WHERE token = ?", "SELECT * FROM token WHERE token = ?",
(token,)).next() (token,)))
except StopIteration: except StopIteration:
return return
self.db.execute("DELETE FROM token WHERE token = ?", self.db.execute("DELETE FROM token WHERE token = ?",
...@@ -508,7 +531,7 @@ class RegistryServer(object): ...@@ -508,7 +531,7 @@ class RegistryServer(object):
else: else:
prefix_len = self.config.anonymous_prefix_length prefix_len = self.config.anonymous_prefix_length
if not prefix_len: if not prefix_len:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(http.client.FORBIDDEN)
email = None email = None
country, continent = '*', '*' country, continent = '*', '*'
if self.geoip_db: if self.geoip_db:
...@@ -563,7 +586,7 @@ class RegistryServer(object): ...@@ -563,7 +586,7 @@ class RegistryServer(object):
return cert return cert
@rpc @rpc
def renewCertificate(self, cn): def renewCertificate(self, cn: str) -> bytes:
with self.lock: with self.lock:
with self.db as db: with self.db as db:
pem = self.getCert(cn) pem = self.getCert(cn)
...@@ -579,26 +602,28 @@ class RegistryServer(object): ...@@ -579,26 +602,28 @@ class RegistryServer(object):
cert.get_subject(), cert.get_pubkey(), not_after) cert.get_subject(), cert.get_pubkey(), not_after)
@rpc @rpc
def getCa(self): def getCa(self) -> bytes:
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca) return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca)
@rpc @rpc
def getDh(self, cn): def getDh(self, cn: str) -> bytes:
with open(self.config.dh) as f: with open(self.config.dh, "rb") as f:
return f.read() return f.read()
@rpc @rpc
def getNetworkConfig(self, cn): def getNetworkConfig(self, cn: str) -> bytes:
with self.lock: with self.lock:
cert = self.getCert(cn) cert = self.getCert(cn)
config = self.network_config.copy() config = self.network_config.copy()
hmac = [self.getConfig(k, None) for k in BABEL_HMAC] hmac = [self.getConfig(k, None) for k in BABEL_HMAC]
for i, v in enumerate(v for v in hmac if v is not None): for i, v in enumerate(v for v in hmac if v is not None):
config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \ config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \
v and x509.encrypt(cert, v).encode('base64') v and base64.b64encode(x509.encrypt(cert, v)).decode()
return zlib.compress(json.dumps(config)) return zlib.compress(json.dumps(config).encode("utf-8"))
def _queryAddress(self, peer): def _queryAddress(self, peer: str) -> str:
logging.info("Querying address for %s/%s %r",
int(peer, 2), len(peer), peer)
self.sendto(peer, 1) self.sendto(peer, 1)
s = self.sock, s = self.sock,
timeout = 3 timeout = 3
...@@ -606,6 +631,7 @@ class RegistryServer(object): ...@@ -606,6 +631,7 @@ class RegistryServer(object):
# Loop because there may be answers from previous requests. # Loop because there may be answers from previous requests.
while select.select(s, (), (), timeout)[0]: while select.select(s, (), (), timeout)[0]:
prefix, msg = self.recv(1) prefix, msg = self.recv(1)
logging.info("* received: %r - %r", prefix, msg)
if prefix == peer: if prefix == peer:
return msg return msg
timeout = max(0, end - time.time()) timeout = max(0, end - time.time())
...@@ -613,23 +639,25 @@ class RegistryServer(object): ...@@ -613,23 +639,25 @@ class RegistryServer(object):
int(peer, 2), len(peer)) int(peer, 2), len(peer))
@rpc @rpc
def getCountry(self, cn, address): def getCountry(self, cn: str, address: str) -> str | None:
country = self._geoiplookup(address)[0] country = self._geoiplookup(address)[0]
return None if country == '*' else country return None if country == '*' else country
@rpc @rpc
def getBootstrapPeer(self, cn): def getBootstrapPeer(self, cn: str) -> bytes | None:
logging.info("Answering bootstrap peer for %s", cn)
with self.peers_lock: with self.peers_lock:
age, peers = self.peers age, peers = self.peers
if age < time.time() or not peers: if age < time.time() or not peers:
self.request_dump() self.request_dump()
peers = [prefix peers = [prefix
for neigh_routes in self.ctl.neighbours.itervalues() for neigh_routes in self.ctl.neighbours.values()
for prefix in neigh_routes[1] for prefix in neigh_routes[1]
if prefix] if prefix]
peers.append(self.prefix) peers.append(self.prefix)
random.shuffle(peers) random.shuffle(peers)
self.peers = time.time() + 60, peers self.peers = time.time() + 60, peers
logging.debug("peers: %r", peers)
peer = peers.pop() peer = peers.pop()
if peer == cn: if peer == cn:
# Very unlikely (e.g. peer restarted with empty cache), # Very unlikely (e.g. peer restarted with empty cache),
...@@ -639,6 +667,7 @@ class RegistryServer(object): ...@@ -639,6 +667,7 @@ class RegistryServer(object):
with self.lock: with self.lock:
msg = self._queryAddress(peer) msg = self._queryAddress(peer)
if msg is None: if msg is None:
logging.info("No address for %s, returning None", peer)
return return
# Remove country for old nodes # Remove country for old nodes
if self.getPeerProtocol(cn) < 7: if self.getPeerProtocol(cn) < 7:
...@@ -647,10 +676,10 @@ class RegistryServer(object): ...@@ -647,10 +676,10 @@ class RegistryServer(object):
cert = self.getCert(cn) cert = self.getCert(cn)
msg = "%s %s" % (peer, msg) msg = "%s %s" % (peer, msg)
logging.info("Sending bootstrap peer: %s", msg) logging.info("Sending bootstrap peer: %s", msg)
return x509.encrypt(cert, msg) return x509.encrypt(cert, msg.encode())
@rpc_private @rpc_private
def revoke(self, cn_or_serial): def revoke(self, cn_or_serial: int | str):
with self.lock, self.db: with self.lock, self.db:
q = self.db.execute q = self.db.execute
try: try:
...@@ -671,12 +700,12 @@ class RegistryServer(object): ...@@ -671,12 +700,12 @@ class RegistryServer(object):
q("INSERT INTO crl VALUES (?,?)", (serial, not_after)) q("INSERT INTO crl VALUES (?,?)", (serial, not_after))
self.updateNetworkConfig() self.updateNetworkConfig()
def newHMAC(self, i, key=None): def newHMAC(self, i: int, key: bytes=None):
if key is None: if key is None:
key = buffer(os.urandom(16)) key = os.urandom(16)
self.setConfig(BABEL_HMAC[i], key) self.setConfig(BABEL_HMAC[i], key)
def delHMAC(self, i): def delHMAC(self, i: int):
self.db.execute("DELETE FROM config WHERE name=?", (BABEL_HMAC[i],)) self.db.execute("DELETE FROM config WHERE name=?", (BABEL_HMAC[i],))
@rpc_private @rpc_private
...@@ -696,25 +725,26 @@ class RegistryServer(object): ...@@ -696,25 +725,26 @@ class RegistryServer(object):
else: else:
# Initialization of HMAC on the network # Initialization of HMAC on the network
self.newHMAC(1) self.newHMAC(1)
self.newHMAC(2, '') self.newHMAC(2, b'')
self.increaseVersion() self.increaseVersion()
self.setConfig('version', buffer(self.version)) self.setConfig('version', self.version)
self.network_config['version'] = self.version.encode('base64') self.network_config['version'] = base64.b64encode(self.version)
self.sendto(self.prefix, 0) self.sendto(self.prefix, 0)
@rpc_private @rpc_private
def getNodePrefix(self, email): def getNodePrefix(self, email: str) -> str | None:
with self.lock, self.db: with self.lock, self.db:
try: try:
cert, = self.db.execute("SELECT cert FROM cert WHERE email = ?", cert, = next(
(email,)).next() self.db.execute("SELECT cert FROM cert WHERE email = ?",
(email,)))
except StopIteration: except StopIteration:
return return
certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert) certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
return x509.subnetFromCert(certificate) return x509.subnetFromCert(certificate)
@rpc_private @rpc_private
def getIPv6Address(self, email): def getIPv6Address(self, email: str) -> str:
cn = self.getNodePrefix(email) cn = self.getNodePrefix(email)
if cn: if cn:
return utils.ipFromBin( return utils.ipFromBin(
...@@ -722,13 +752,13 @@ class RegistryServer(object): ...@@ -722,13 +752,13 @@ class RegistryServer(object):
+ utils.binFromSubnet(cn)) + utils.binFromSubnet(cn))
@rpc_private @rpc_private
def getIPv4Information(self, email): def getIPv4Information(self, email: str) -> str | None:
peer = self.getNodePrefix(email) peer = self.getNodePrefix(email)
if peer: if peer:
peer = utils.binFromSubnet(peer) peer = utils.binFromSubnet(peer)
with self.peers_lock: with self.peers_lock:
self.request_dump() self.request_dump()
for neigh_routes in self.ctl.neighbours.itervalues(): for neigh_routes in self.ctl.neighbours.values():
for prefix in neigh_routes[1]: for prefix in neigh_routes[1]:
if prefix == peer: if prefix == peer:
break break
...@@ -736,16 +766,16 @@ class RegistryServer(object): ...@@ -736,16 +766,16 @@ class RegistryServer(object):
return return
logging.info("%s %s", email, peer) logging.info("%s %s", email, peer)
with self.lock: with self.lock:
msg = self._queryAddress(peer) msg = self._queryAddress(peer).decode()
if msg: if msg:
return msg.split(',')[0] return msg.split(',')[0]
@rpc_private @rpc_private
def versions(self): def versions(self) -> str:
with self.peers_lock: with self.peers_lock:
self.request_dump() self.request_dump()
peers = {prefix peers = {prefix
for neigh_routes in self.ctl.neighbours.itervalues() for neigh_routes in self.ctl.neighbours.values()
for prefix in neigh_routes[1] for prefix in neigh_routes[1]
if prefix} if prefix}
peers.add(self.prefix) peers.add(self.prefix)
...@@ -767,7 +797,8 @@ class RegistryServer(object): ...@@ -767,7 +797,8 @@ class RegistryServer(object):
return json.dumps(peer_dict) return json.dumps(peer_dict)
@rpc_private @rpc_private
def topology(self): def topology(self) -> str:
logging.debug("Computing topology")
p = lambda p: '%s/%s' % (int(p, 2), len(p)) p = lambda p: '%s/%s' % (int(p, 2), len(p))
peers = deque((p(self.prefix),)) peers = deque((p(self.prefix),))
graph = defaultdict(set) graph = defaultdict(set)
...@@ -777,7 +808,8 @@ class RegistryServer(object): ...@@ -777,7 +808,8 @@ class RegistryServer(object):
r, w, _ = select.select(s, s if peers else (), (), 3) r, w, _ = select.select(s, s if peers else (), (), 3)
if r: if r:
prefix, x = self.recv(5) prefix, x = self.recv(5)
if prefix and x: logging.debug("Received %s %s", prefix, x)
if prefix:
prefix = p(prefix) prefix = p(prefix)
x = x.split() x = x.split()
try: try:
...@@ -791,35 +823,46 @@ class RegistryServer(object): ...@@ -791,35 +823,46 @@ class RegistryServer(object):
graph[x].add(prefix) graph[x].add(prefix)
graph[''].add(prefix) graph[''].add(prefix)
if w: if w:
self.sendto(utils.binFromSubnet(peers.popleft()), 5) first = peers.popleft()
logging.debug("Sending %s", first)
self.sendto(utils.binFromSubnet(first), 5)
elif not r: elif not r:
logging.debug("No more sockets, stopping")
break break
return json.dumps({k: list(v) for k, v in graph.iteritems()}) return json.dumps({k: list(v) for k, v in graph.items()})
class RegistryClient:
"""
Client for the re6st registry.
class RegistryClient(object): Method calls are forwarded to the registry server.
String results are always returned as bytes.
"""
_hmac = None _hmac = None
user_agent = "re6stnet/%s, %s" % (version.version, platform.platform()) user_agent = "re6stnet/%s, %s" % (version.version, platform.platform())
def __init__(self, url, cert=None, auto_close=True): def __init__(self, url: str, cert: x509.Cert=None, auto_close=True):
self.cert = cert self.cert = cert
self.auto_close = auto_close self.auto_close = auto_close
scheme, host = splittype(url) url_parsed = urlparse(url)
host, path = splithost(host) scheme = url_parsed.scheme
self._conn = dict(http=httplib.HTTPConnection, host = url_parsed.netloc
https=httplib.HTTPSConnection, path = url_parsed.path
self._conn = dict(http=http.client.HTTPConnection,
https=http.client.HTTPSConnection,
)[scheme](unquote(host), timeout=60) )[scheme](unquote(host), timeout=60)
self._path = path.rstrip('/') self._path = path.rstrip('/')
def __getattr__(self, name): def __getattr__(self, name: str):
getcallargs = getattr(RegistryServer, name).getcallargs getcallargs = getattr(RegistryServer, name).getcallargs
def rpc(*args, **kw): def rpc(*args, **kw) -> bytes:
kw = getcallargs(*args, **kw) kw = getcallargs(*args, **kw)
query = '/' + name query = '/' + name
if kw: if kw:
if any(type(v) is not str for v in kw.itervalues()): if any(not isinstance(v, (str, bytes)) for v in kw.values()):
raise TypeError raise TypeError(kw)
query += '?' + urlencode(kw) query += '?' + urlencode(kw)
url = self._path + query url = self._path + query
client_prefix = kw.get('cn') client_prefix = kw.get('cn')
...@@ -834,7 +877,8 @@ class RegistryClient(object): ...@@ -834,7 +877,8 @@ class RegistryClient(object):
n = len(h) // 2 n = len(h) // 2
self.cert.verify(h[n:], h[:n]) self.cert.verify(h[n:], h[:n])
key = self.cert.decrypt(h[:n]) key = self.cert.decrypt(h[:n])
h = hmac.HMAC(key, query, hashlib.sha1).digest() h = (hmac.HMAC(key, query.encode(), hashlib.sha1)
.digest())
key = hashlib.sha1(key).digest() key = hashlib.sha1(key).digest()
self._hmac = hashlib.sha1(key).digest() self._hmac = hashlib.sha1(key).digest()
else: else:
...@@ -846,14 +890,15 @@ class RegistryClient(object): ...@@ -846,14 +890,15 @@ class RegistryClient(object):
self._conn.endheaders() self._conn.endheaders()
response = self._conn.getresponse() response = self._conn.getresponse()
body = response.read() body = response.read()
if response.status in (httplib.OK, httplib.NO_CONTENT): if response.status in (http.client.OK,
http.client.NO_CONTENT):
if (not client_prefix or if (not client_prefix or
hmac.HMAC(key, body, hashlib.sha1).digest() == hmac.HMAC(key, body, hashlib.sha1).digest() ==
base64.b64decode(response.msg[HMAC_HEADER])): base64.b64decode(response.msg[HMAC_HEADER])):
if self.auto_close and name != 'hello': if self.auto_close and name != 'hello':
self._conn.close() self._conn.close()
return body return body
elif response.status == httplib.FORBIDDEN: elif response.status == http.client.FORBIDDEN:
# XXX: We should improve error handling, while making # XXX: We should improve error handling, while making
# sure re6st nodes don't crash on temporary errors. # sure re6st nodes don't crash on temporary errors.
# This is currently good enough for re6st-conf, to # This is currently good enough for re6st-conf, to
...@@ -864,7 +909,7 @@ class RegistryClient(object): ...@@ -864,7 +909,7 @@ class RegistryClient(object):
except HTTPError: except HTTPError:
raise raise
except Exception: except Exception:
logging.info(url, exc_info=1) logging.info(url, exc_info=True)
else: else:
logging.info('%s\nUnexpected response %s %s', logging.info('%s\nUnexpected response %s %s',
url, response.status, response.reason) url, response.status, response.reason)
......
from pathlib2 import Path from pathlib import Path
DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo" DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo"
...@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH ...@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH
DH_FILE = DEMO_PATH / "dh2048.pem" DH_FILE = DEMO_PATH / "dh2048.pem"
class DummyNode(object): class DummyNode:
"""fake node to reuse Re6stRegistry """fake node to reuse Re6stRegistry
error: node.Popen has destory method which not in subprocess.Popen error: node.Popen has destory method which not in subprocess.Popen
...@@ -29,19 +29,20 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -29,19 +29,20 @@ class TestRegistryClientInteract(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
re6st_wrap.initial()
# if running in net ns, set lo up # if running in net ns, set lo up
subprocess.check_call(("ip", "link", "set", "lo", "up")) subprocess.check_call(("ip", "link", "set", "lo", "up"))
def setUp(self): def setUp(self):
re6st_wrap.initial()
self.port = 18080 self.port = 18080
self.url = "http://localhost:{}/".format(self.port) self.url = "http://localhost:{}/".format(self.port)
# not important, used in network_config check # not important, used in network_config check
self.max_clients = 10 self.max_clients = 10
def tearDown(self): def tearDown(self):
self.server.proc.terminate() with self.server.proc as p:
p.terminate()
def test_1_main(self): def test_1_main(self):
""" a client interact a server, no re6stnet node test basic function""" """ a client interact a server, no re6stnet node test basic function"""
...@@ -60,7 +61,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -60,7 +61,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# read token from db # read token from db
db = sqlite3.connect(str(self.server.db), isolation_level=None) db = sqlite3.connect(str(self.server.db), isolation_level=None)
token = None token = None
for _ in xrange(100): for _ in range(100):
time.sleep(.1) time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?", token = db.execute("SELECT token FROM token WHERE email=?",
(email,)).fetchone() (email,)).fetchone()
...@@ -70,7 +71,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -70,7 +71,7 @@ class TestRegistryClientInteract(unittest.TestCase):
self.fail("Request token failed, no token in database") self.fail("Request token failed, no token in database")
# token: tuple[unicode,] # token: tuple[unicode,]
token = str(token[0]) token = str(token[0])
self.assertEqual(client.isToken(token), "1") self.assertEqual(client.isToken(token), b"1")
# request ca # request ca
ca = client.getCa() ca = client.getCa()
...@@ -78,7 +79,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -78,7 +79,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# request a cert and get cn # request a cert and get cn
key, csr = tools.generate_csr() key, csr = tools.generate_csr()
cert = client.requestCertificate(token, csr) cert = client.requestCertificate(token, csr)
self.assertEqual(client.isToken(token), '', "token should be deleted") self.assertEqual(client.isToken(token), b'', "token should be deleted")
# creat x509.cert object # creat x509.cert object
def write_to_temp(text): def write_to_temp(text):
...@@ -97,7 +98,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -97,7 +98,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# verfiy cn and prefix # verfiy cn and prefix
prefix = client.cert.prefix prefix = client.cert.prefix
cn = client.getNodePrefix(email) cn = client.getNodePrefix(email).decode()
self.assertEqual(tools.prefix2cn(prefix), cn) self.assertEqual(tools.prefix2cn(prefix), cn)
# simulate the process in cache # simulate the process in cache
...@@ -108,7 +109,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -108,7 +109,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# no re6stnet, empty result # no re6stnet, empty result
bootpeer = client.getBootstrapPeer(prefix) bootpeer = client.getBootstrapPeer(prefix)
self.assertEqual(bootpeer, "") self.assertEqual(bootpeer, b"")
# server should not die # server should not die
self.assertIsNone(self.server.proc.poll()) self.assertIsNone(self.server.proc.poll())
......
...@@ -3,14 +3,9 @@ import logging ...@@ -3,14 +3,9 @@ import logging
import nemu import nemu
import time import time
import weakref import weakref
from subprocess import PIPE from subprocess import DEVNULL, PIPE
from pathlib2 import Path from pathlib import Path
from re6st.tests import DEMO_PATH
fix_file = DEMO_PATH / "fixnemu.py"
# execfile(str(fix_file)) Removed in python3
exec(open(str(fix_file)).read())
IPTABLES = 'iptables-nft' IPTABLES = 'iptables-nft'
class ConnectableError(Exception): class ConnectableError(Exception):
...@@ -50,7 +45,7 @@ class Node(nemu.Node): ...@@ -50,7 +45,7 @@ class Node(nemu.Node):
if_s.add_v4_address(ip, prefix_len=prefix_len) if_s.add_v4_address(ip, prefix_len=prefix_len)
return if_s return if_s
class NetManager(object): class NetManager:
"""contain all the nemu object created, so they can live more time""" """contain all the nemu object created, so they can live more time"""
def __init__(self): def __init__(self):
self.object = [] self.object = []
...@@ -60,9 +55,10 @@ class NetManager(object): ...@@ -60,9 +55,10 @@ class NetManager(object):
Raise: Raise:
AssertionError AssertionError
""" """
for reg, nodes in self.registries.iteritems(): for reg, nodes in self.registries.items():
for node in nodes: for node in nodes:
app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE) with node.Popen(["ping", "-c", "1", reg.ip],
stdout=DEVNULL) as app0:
ret = app0.wait() ret = app0.wait()
if ret: if ret:
raise ConnectableError( raise ConnectableError(
......
...@@ -6,13 +6,15 @@ import ipaddress ...@@ -6,13 +6,15 @@ import ipaddress
import json import json
import logging import logging
import re import re
import shlex
import shutil import shutil
import sqlite3 import sqlite3
import sys
import tempfile import tempfile
import time import time
import weakref import weakref
from subprocess import PIPE from subprocess import PIPE
from pathlib2 import Path from pathlib import Path
from re6st.tests import tools from re6st.tests import tools
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
...@@ -20,13 +22,15 @@ from re6st.tests import DEMO_PATH ...@@ -20,13 +22,15 @@ from re6st.tests import DEMO_PATH
WORK_DIR = Path(__file__).parent / "temp_net_test" WORK_DIR = Path(__file__).parent / "temp_net_test"
DH_FILE = DEMO_PATH / "dh2048.pem" DH_FILE = DEMO_PATH / "dh2048.pem"
RE6STNET = "python -m re6st.cli.node" PYTHON = shlex.quote(sys.executable)
RE6ST_REGISTRY = "python -m re6st.cli.registry" RE6STNET = PYTHON + " -m re6st.cli.node"
RE6ST_CONF = "python -m re6st.cli.conf" RE6ST_REGISTRY = PYTHON + " -m re6st.cli.registry"
RE6ST_CONF = PYTHON + " -m re6st.cli.conf"
def initial(): def initial():
"""create the workplace""" """create the workplace"""
if not WORK_DIR.exists(): if WORK_DIR.exists():
shutil.rmtree(str(WORK_DIR))
WORK_DIR.mkdir() WORK_DIR.mkdir()
def ip_to_serial(ip6): def ip_to_serial(ip6):
...@@ -36,7 +40,7 @@ def ip_to_serial(ip6): ...@@ -36,7 +40,7 @@ def ip_to_serial(ip6):
return int(ip6, 16) return int(ip6, 16)
class Re6stRegistry(object): class Re6stRegistry:
"""class run a re6st-registry service on a namespace""" """class run a re6st-registry service on a namespace"""
registry_seq = 0 registry_seq = 0
...@@ -72,7 +76,7 @@ class Re6stRegistry(object): ...@@ -72,7 +76,7 @@ class Re6stRegistry(object):
self.run() self.run()
# wait the servcice started # wait the servcice started
p = self.node.Popen(['python', '-c', """if 1: p = self.node.Popen([sys.executable, '-c', """if 1:
import socket, time import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True: while True:
...@@ -115,7 +119,7 @@ class Re6stRegistry(object): ...@@ -115,7 +119,7 @@ class Re6stRegistry(object):
'--client-count', (self.client_number+1)//2, '--port', self.port] '--client-count', (self.client_number+1)//2, '--port', self.port]
#PY3: convert PosixPath to str, can be remove in Python 3 #PY3: convert PosixPath to str, can be remove in Python 3
cmd = map(str, cmd) cmd = list(map(str, cmd))
cmd[:0] = RE6ST_REGISTRY.split() cmd[:0] = RE6ST_REGISTRY.split()
...@@ -131,15 +135,19 @@ class Re6stRegistry(object): ...@@ -131,15 +135,19 @@ class Re6stRegistry(object):
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
def __del__(self): def terminate(self):
try: try:
logging.debug("teminate process %s", self.proc.pid) logging.debug("teminate process %s", self.proc.pid)
self.proc.destroy() with self.proc as p:
p.destroy()
except: except:
pass pass
def __del__(self):
self.terminate()
class Re6stNode(object):
class Re6stNode:
"""class run a re6stnet service on a namespace""" """class run a re6stnet service on a namespace"""
node_seq = 0 node_seq = 0
...@@ -210,7 +218,7 @@ class Re6stNode(object): ...@@ -210,7 +218,7 @@ class Re6stNode(object):
# read token # read token
db = sqlite3.connect(str(self.registry.db), isolation_level=None) db = sqlite3.connect(str(self.registry.db), isolation_level=None)
token = None token = None
for _ in xrange(100): for _ in range(100):
time.sleep(.1) time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?", token = db.execute("SELECT token FROM token WHERE email=?",
(self.email,)).fetchone() (self.email,)).fetchone()
...@@ -223,7 +231,8 @@ class Re6stNode(object): ...@@ -223,7 +231,8 @@ class Re6stNode(object):
out, _ = p.communicate(str(token[0])) out, _ = p.communicate(str(token[0]))
# logging.debug("re6st-conf output: {}".format(out)) # logging.debug("re6st-conf output: {}".format(out))
# find the ipv6 subnet of node # find the ipv6 subnet of node
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out).group(0) self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+',
out.decode("utf-8")).group(0)
data = {'ip6': self.ip6, 'hash': self.registry.ident} data = {'ip6': self.ip6, 'hash': self.registry.ident}
with open(str(self.data_file), 'w') as f: with open(str(self.data_file), 'w') as f:
json.dump(data, f) json.dump(data, f)
...@@ -236,7 +245,7 @@ class Re6stNode(object): ...@@ -236,7 +245,7 @@ class Re6stNode(object):
'--key', self.key, '-v4', '--registry', self.registry.url, '--key', self.key, '-v4', '--registry', self.registry.url,
'--console', self.console] '--console', self.console]
#PY3: same as for Re6stRegistry.run #PY3: same as for Re6stRegistry.run
cmd = map(str, cmd) cmd = list(map(str, cmd))
cmd[:0] = RE6STNET.split() cmd[:0] = RE6STNET.split()
cmd += args cmd += args
...@@ -260,7 +269,8 @@ class Re6stNode(object): ...@@ -260,7 +269,8 @@ class Re6stNode(object):
def stop(self): def stop(self):
"""stop running re6stnet process""" """stop running re6stnet process"""
logging.debug("%s teminate process %s", self.name, self.proc.pid) logging.debug("%s teminate process %s", self.name, self.proc.pid)
self.proc.destroy() with self.proc as p:
p.destroy()
def __del__(self): def __del__(self):
"""teminate process and rm temp dir""" """teminate process and rm temp dir"""
......
"""contain ping-test for re6set net""" """contain ping-test for re6set net"""
import os import os
import sys
import unittest import unittest
import time import time
import psutil import psutil
import logging import logging
import random import random
from pathlib2 import Path from pathlib import Path
import network_build from . import network_build, re6st_wrap
import re6st_wrap
PING_PATH = str(Path(__file__).parent.resolve() / "ping.py") PING_PATH = str(Path(__file__).parent.resolve() / "ping.py")
def deploy_re6st(nm, recreate=False):
net = nm.registries
nodes = []
registries = []
re6st_wrap.Re6stRegistry.registry_seq = 0
re6st_wrap.Re6stNode.node_seq = 0
for registry in net:
reg = re6st_wrap.Re6stRegistry(registry, "2001:db8:42::", len(net[registry]),
recreate=recreate)
reg_node = re6st_wrap.Re6stNode(registry, reg, name=reg.name)
registries.append(reg)
reg_node.run("--gateway", "--disable-proto", "none", "--ip", registry.ip)
nodes.append(reg_node)
for m in net[registry]:
node = re6st_wrap.Re6stNode(m, reg)
node.run("-i" + m.iface.name)
nodes.append(node)
return nodes, registries
def wait_stable(nodes, timeout=240): def wait_stable(nodes, timeout=240):
"""try use ping6 from each node to the other until ping success to all the """try use ping6 from each node to the other until ping success to all the
other nodes other nodes
...@@ -47,12 +28,13 @@ def wait_stable(nodes, timeout=240): ...@@ -47,12 +28,13 @@ def wait_stable(nodes, timeout=240):
for node in nodes: for node in nodes:
sub_ips = set(ips) - {node.ip6} sub_ips = set(ips) - {node.ip6}
node.ping_proc = node.node.Popen( node.ping_proc = node.node.Popen(
["python", PING_PATH, '--retry', '-a'] + list(sub_ips)) [sys.executable, PING_PATH, '--retry', '-a'] + list(sub_ips),
env=os.environ)
# check all the node network can ping each other, in order reverse # check all the node network can ping each other, in order reverse
unfinished = list(nodes) unfinished = list(nodes)
while unfinished: while unfinished:
for i in xrange(len(unfinished)-1, -1, -1): for i in range(len(unfinished)-1, -1, -1):
node = unfinished[i] node = unfinished[i]
if node.ping_proc.poll() is not None: if node.ping_proc.poll() is not None:
logging.debug("%s 's network is stable", node.name) logging.debug("%s 's network is stable", node.name)
...@@ -75,8 +57,42 @@ class TestNet(unittest.TestCase): ...@@ -75,8 +57,42 @@ class TestNet(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
"""create work dir""" """create work dir"""
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def setUp(self):
re6st_wrap.initial() re6st_wrap.initial()
def deploy_re6st(self, nm, recreate=False):
net = nm.registries
nodes = []
registries = []
re6st_wrap.Re6stRegistry.registry_seq = 0
re6st_wrap.Re6stNode.node_seq = 0
for registry in net:
reg = re6st_wrap.Re6stRegistry(registry, "2001:db8:42::",
len(net[registry]),
recreate=recreate)
reg_node = re6st_wrap.Re6stNode(registry, reg, name=reg.name)
registries.append(reg)
reg_node.run("--gateway", "--disable-proto", "none",
"--ip", registry.ip)
nodes.append(reg_node)
for m in net[registry]:
node = re6st_wrap.Re6stNode(m, reg)
node.run("-i" + m.iface.name)
nodes.append(node)
def clean_re6st():
for node in nodes:
node.node.destroy()
node.stop()
for reg in registries:
reg.terminate()
self.addCleanup(clean_re6st)
return nodes, registries
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
"""watch any process leaked after tests""" """watch any process leaked after tests"""
...@@ -94,7 +110,7 @@ class TestNet(unittest.TestCase): ...@@ -94,7 +110,7 @@ class TestNet(unittest.TestCase):
"""create a network in a net segment, test the connectivity by ping """create a network in a net segment, test the connectivity by ping
""" """
nm = network_build.net_route() nm = network_build.net_route()
nodes, _ = deploy_re6st(nm) nodes, _ = self.deploy_re6st(nm)
wait_stable(nodes, 40) wait_stable(nodes, 40)
time.sleep(10) time.sleep(10)
...@@ -107,7 +123,7 @@ class TestNet(unittest.TestCase): ...@@ -107,7 +123,7 @@ class TestNet(unittest.TestCase):
then test if network recover, this test seems always failed then test if network recover, this test seems always failed
""" """
nm = network_build.net_demo() nm = network_build.net_demo()
nodes, _ = deploy_re6st(nm) nodes, _ = self.deploy_re6st(nm)
wait_stable(nodes, 100) wait_stable(nodes, 100)
...@@ -126,7 +142,7 @@ class TestNet(unittest.TestCase): ...@@ -126,7 +142,7 @@ class TestNet(unittest.TestCase):
then test if network recover, then test if network recover,
""" """
nm = network_build.net_route() nm = network_build.net_route()
nodes, _ = deploy_re6st(nm) nodes, _ = self.deploy_re6st(nm)
wait_stable(nodes, 40) wait_stable(nodes, 40)
......
#!/usr/bin/python2 #!/usr/bin/env python3
""" unit test for re6st-conf """ unit test for re6st-conf
""" """
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import sys import sys
import unittest import unittest
from shutil import rmtree from shutil import rmtree
from StringIO import StringIO from io import StringIO
from mock import patch from mock import patch
from OpenSSL import crypto from OpenSSL import crypto
...@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
# mocked server cert and pkey # mocked server cert and pkey
cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull) cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull)
cls.fingerprint = "".join( cls.cert.digest("sha1").split(":")) cls.fingerprint = "".join(cls.cert.digest("sha1").decode().split(":"))
# client.getCa should return a string form cert # client.getCa should return a string form cert
cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert) cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert)
...@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase): ...@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase):
# go back to original dir # go back to original dir
os.chdir(self.origin_dir) os.chdir(self.origin_dir)
@patch("__builtin__.raw_input") @patch("builtins.input")
def test_basic(self, mock_raw_input): def test_basic(self, mock_raw_input):
""" go through all the step """ go through all the step
getCa, requestToken, requestCertificate getCa, requestToken, requestCertificate
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import random import random
import string import string
import json import json
import httplib import http.client
import base64 import base64
import unittest import unittest
import hmac import hmac
...@@ -11,18 +11,21 @@ import hashlib ...@@ -11,18 +11,21 @@ import hashlib
import time import time
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from sqlite3 import Cursor
from OpenSSL import crypto from OpenSSL import crypto
from mock import Mock, patch from mock import Mock, patch
from pathlib2 import Path from pathlib import Path
from re6st import registry from re6st import registry, x509
from re6st.tests.tools import * from re6st.tests.tools import *
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer # TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions # getIPV4Information, versions
def load_config(filename="registry.json"): def load_config(filename: str="registry.json") -> Namespace:
with open(filename) as f: with open(filename) as f:
config = json.load(f) config = json.load(f)
config["dh"] = DEMO_PATH / "dh2048.pem" config["dh"] = DEMO_PATH / "dh2048.pem"
...@@ -36,23 +39,25 @@ def load_config(filename="registry.json"): ...@@ -36,23 +39,25 @@ def load_config(filename="registry.json"):
return Namespace(**config) return Namespace(**config)
def get_cert(cur, prefix): def get_cert(cur: Cursor, prefix: str):
res = cur.execute( res = cur.execute(
"SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone() "SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone()
return res[0] return res[0]
def insert_cert(cur, ca, prefix, not_after=None, email=None): def insert_cert(cur: Cursor, ca: x509.Cert, prefix: str,
not_after=None, email=None):
key, csr = generate_csr() key, csr = generate_csr()
cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after) cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after)
cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert)) cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert))
insert_cert.serial += 1 insert_cert.serial += 1
return key, cert return key, cert
insert_cert.serial = 0 insert_cert.serial = 0
def delete_cert(cur, prefix): def delete_cert(cur: Cursor, prefix: str):
cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,)) cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,))
...@@ -68,6 +73,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -68,6 +73,7 @@ class TestRegistryServer(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
cls.server.close()
# remove database # remove database
for file in [cls.config.db, cls.config.ca, cls.config.key]: for file in [cls.config.db, cls.config.ca, cls.config.key]:
try: try:
...@@ -80,14 +86,23 @@ class TestRegistryServer(unittest.TestCase): ...@@ -80,14 +86,23 @@ class TestRegistryServer(unittest.TestCase):
+ "@mail.com" + "@mail.com"
def test_recv(self): def test_recv(self):
recv = self.server.sock.recv = Mock() side_effect = iter([
recv.side_effect = [
"0001001001001a_msg", "0001001001001a_msg",
"0001001001002\0001dqdq", "0001001001002\0001dqdq",
"0001001001001\000a_msg", "0001001001001\000a_msg",
"0001001001001\000\4a_msg", "0001001001001\000\4a_msg",
"0000000000000\0" # ERROR, IndexError: msg is null "0000000000000\0" # ERROR, IndexError: msg is null
] ])
class SocketProxy:
def __init__(self, wrappee):
self.wrappee = wrappee
self.recv = lambda _: next(side_effect)
def __getattr__(self, attr):
return getattr(self.wrappee, attr)
self.server.sock = SocketProxy(self.server.sock)
try: try:
res1 = self.server.recv(4) res1 = self.server.recv(4)
...@@ -115,7 +130,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -115,7 +130,7 @@ class TestRegistryServer(unittest.TestCase):
now = int(time.time()) - self.config.grace_period + 20 now = int(time.time()) - self.config.grace_period + 20
# makeup data # makeup data
insert_cert(cur, self.server.cert, prefix_old, 1) insert_cert(cur, self.server.cert, prefix_old, 1)
insert_cert(cur, self.server.cert, prefix, now -1) insert_cert(cur, self.server.cert, prefix, now - 1)
cur.execute("INSERT INTO token VALUES (?,?,?,?)", cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token_old, self.email, 4, 2)) (token_old, self.email, 4, 2))
cur.execute("INSERT INTO token VALUES (?,?,?,?)", cur.execute("INSERT INTO token VALUES (?,?,?,?)",
...@@ -143,16 +158,16 @@ class TestRegistryServer(unittest.TestCase): ...@@ -143,16 +158,16 @@ class TestRegistryServer(unittest.TestCase):
prefix = "0000000011111111" prefix = "0000000011111111"
method = "func" method = "func"
protocol = 7 protocol = 7
params = {"cn" : prefix, "a" : 1, "b" : 2} params = {"cn": prefix, "a": 1, "b": 2}
func.getcallargs.return_value = params func.getcallargs.return_value = params
del func._private del func._private
func.return_value = result = "this_is_a_result" func.return_value = result = b"this_is_a_result"
key = "this_is_a_key" key = b"this_is_a_key"
self.server.sessions[prefix] = [(key, protocol)] self.server.sessions[prefix] = [(key, protocol)]
request = Mock() request = Mock()
request.path = "/func?a=1&b=2&cn=0000000011111111" request.path = "/func?a=1&b=2&cn=0000000011111111"
request.headers = {registry.HMAC_HEADER: base64.b64encode( request.headers = {registry.HMAC_HEADER: base64.b64encode(
hmac.HMAC(key, request.path, hashlib.sha1).digest())} hmac.HMAC(key, request.path.encode(), hashlib.sha1).digest())}
self.server.handle_request(request, method, params) self.server.handle_request(request, method, params)
...@@ -162,11 +177,12 @@ class TestRegistryServer(unittest.TestCase): ...@@ -162,11 +177,12 @@ class TestRegistryServer(unittest.TestCase):
[(hashlib.sha1(key).digest(), protocol)]) [(hashlib.sha1(key).digest(), protocol)])
func.assert_called_once_with(**params) func.assert_called_once_with(**params)
# http response check # http response check
request.send_response.assert_called_once_with(httplib.OK) request.send_response.assert_called_once_with(http.client.OK)
request.send_header.assert_any_call("Content-Length", str(len(result))) request.send_header.assert_any_call("Content-Length", str(len(result)))
request.send_header.assert_any_call( request.send_header.assert_any_call(
registry.HMAC_HEADER, registry.HMAC_HEADER,
base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest())) base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest())
.decode("ascii"))
request.wfile.write.assert_called_once_with(result) request.wfile.write.assert_called_once_with(result)
# remove the create session \n # remove the create session \n
...@@ -176,7 +192,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -176,7 +192,7 @@ class TestRegistryServer(unittest.TestCase):
def test_handle_request_private(self, func): def test_handle_request_private(self, func):
"""case request with _private attr""" """case request with _private attr"""
method = "func" method = "func"
params = {"a" : 1, "b" : 2} params = {"a": 1, "b": 2}
func.getcallargs.return_value = params func.getcallargs.return_value = params
func.return_value = None func.return_value = None
request_good = Mock() request_good = Mock()
...@@ -189,8 +205,8 @@ class TestRegistryServer(unittest.TestCase): ...@@ -189,8 +205,8 @@ class TestRegistryServer(unittest.TestCase):
self.server.handle_request(request_bad, method, params) self.server.handle_request(request_bad, method, params)
func.assert_called_once_with(**params) func.assert_called_once_with(**params)
request_bad.send_error.assert_called_once_with(httplib.FORBIDDEN) request_bad.send_error.assert_called_once_with(http.client.FORBIDDEN)
request_good.send_response.assert_called_once_with(httplib.NO_CONTENT) request_good.send_response.assert_called_once_with(http.client.NO_CONTENT)
# will cause valueError, if a node send hello twice to a registry # will cause valueError, if a node send hello twice to a registry
def test_getPeerProtocol(self): def test_getPeerProtocol(self):
...@@ -213,7 +229,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -213,7 +229,7 @@ class TestRegistryServer(unittest.TestCase):
res = self.server.hello(prefix, protocol=protocol) res = self.server.hello(prefix, protocol=protocol)
# decrypt # decrypt
length = len(res)/2 length = len(res) // 2
key, sign = res[:length], res[length:] key, sign = res[:length], res[length:]
key = decrypt(pkey, key) key = decrypt(pkey, key)
self.assertEqual(self.server.sessions[prefix][-1][0], key, self.assertEqual(self.server.sessions[prefix][-1][0], key,
...@@ -282,7 +298,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -282,7 +298,7 @@ class TestRegistryServer(unittest.TestCase):
nb_less = 0 nb_less = 0
for cert in self.server.iterCert(): for cert in self.server.iterCert():
s = cert[0].get_subject().serialNumber s = cert[0].get_subject().serialNumber
if(s and int(s) <= serial): if s and int(s) <= serial:
nb_less += 1 nb_less += 1
self.assertEqual(nb_less, serial) self.assertEqual(nb_less, serial)
...@@ -378,7 +394,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -378,7 +394,7 @@ class TestRegistryServer(unittest.TestCase):
hmacs = get_hmac() hmacs = get_hmac()
key_1 = hmacs[1] key_1 = hmacs[1]
self.assertEqual(hmacs, [None, key_1, '']) self.assertEqual(hmacs, [None, key_1, b''])
# step 2 # step 2
self.server.updateHMAC() self.server.updateHMAC()
...@@ -397,12 +413,11 @@ class TestRegistryServer(unittest.TestCase): ...@@ -397,12 +413,11 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(get_hmac(), [None, key_2, key_1]) self.assertEqual(get_hmac(), [None, key_2, key_1])
#setp 5 # step 5
self.server.updateHMAC() self.server.updateHMAC()
self.assertEqual(get_hmac(), [key_2, None, None]) self.assertEqual(get_hmac(), [key_2, None, None])
def test_getNodePrefix(self): def test_getNodePrefix(self):
# prefix in short format # prefix in short format
prefix = "0000000101" prefix = "0000000101"
...@@ -426,20 +441,38 @@ class TestRegistryServer(unittest.TestCase): ...@@ -426,20 +441,38 @@ class TestRegistryServer(unittest.TestCase):
('0000000000000001', '2 0/16 6/16') ('0000000000000001', '2 0/16 6/16')
] ]
recv.side_effect = recv_case recv.side_effect = recv_case
def side_effct(rlist, wlist, elist, timeout): def side_effct(rlist, wlist, elist, timeout):
# rlist is true until the len(recv_case)th call # rlist is true until the len(recv_case)th call
side_effct.i -= side_effct.i > 0 side_effct.i -= side_effct.i > 0
return [side_effct.i, wlist, None] return [side_effct.i, wlist, None]
side_effct.i = len(recv_case) + 1 side_effct.i = len(recv_case) + 1
select.side_effect = side_effct select.side_effect = side_effct
res = self.server.topology() res = self.server.topology()
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \ class CustomDecoder(json.JSONDecoder):
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \ def __init__(self, **kwargs):
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\ super().__init__(**kwargs)
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}''' self.parse_array = self.JSONArray
self.assertEqual(res, expect_res) self.scan_once = json.scanner.py_make_scanner(self)
def JSONArray(self, *args, **kw):
values, end = json.decoder.JSONArray(*args, **kw)
return set(values), end
res = json.loads(res, cls=CustomDecoder)
self.assertEqual(res, {
"": {"36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"},
"0/16": {"6/16", "7/16"},
"1/16": {"6/16", "0/16"},
"36893488147419103232/80": {"0/16", "7/16"},
"3/16": {"0/16", "7/16"},
"4/16": {"0/16"},
"7/16": {"6/16", "4/16"},
})
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -2,7 +2,7 @@ import sys ...@@ -2,7 +2,7 @@ import sys
import os import os
import unittest import unittest
import hmac import hmac
import httplib import http.client
import base64 import base64
import hashlib import hashlib
from mock import Mock, patch from mock import Mock, patch
...@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase): ...@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase):
self.assertEqual(client1._path, "/example") self.assertEqual(client1._path, "/example")
self.assertEqual(client1._conn.host, "localhost") self.assertEqual(client1._conn.host, "localhost")
self.assertIsInstance(client1._conn, httplib.HTTPSConnection) self.assertIsInstance(client1._conn, http.client.HTTPSConnection)
self.assertIsInstance(client2._conn, httplib.HTTPConnection) self.assertIsInstance(client2._conn, http.client.HTTPConnection)
def test_rpc_hello(self): def test_rpc_hello(self):
prefix = "0000000011111111" prefix = "0000000011111111"
protocol = "7" protocol = "7"
body = "a_hmac_key" body = "a_hmac_key"
query = "/hello?client_prefix=0000000011111111&protocol=7" query = "/hello?client_prefix=0000000011111111&protocol=7"
response = fakeResponse(body, httplib.OK) response = fakeResponse(body, http.client.OK)
self.client._conn.getresponse.return_value = response self.client._conn.getresponse.return_value = response
res = self.client.hello(prefix, protocol) res = self.client.hello(prefix, protocol)
...@@ -52,14 +52,15 @@ class TestRegistryClient(unittest.TestCase): ...@@ -52,14 +52,15 @@ class TestRegistryClient(unittest.TestCase):
self.client._hmac = None self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb") self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock() self.client.cert = Mock()
key = "this_is_a_key" key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest() h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest() key = hashlib.sha1(key).digest()
# response part # response part
body = None body = b'this is a body'
response = fakeResponse(body, httplib.NO_CONTENT) response = fakeResponse(body, http.client.NO_CONTENT)
response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest()) response.msg = dict(Re6stHMAC=base64.b64encode(
hmac.HMAC(key, body, hashlib.sha1).digest()))
self.client._conn.getresponse.return_value = response self.client._conn.getresponse.return_value = response
res = self.client.getNetworkConfig(cn) res = self.client.getNetworkConfig(cn)
......
#!/usr/bin/python2 #!/usr/bin/env python3
import os import os
import sys import sys
import unittest import unittest
...@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True) # @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel): # def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """ # """code is 1, peer and msg not none """
# c = chr(1) # c = b"\x01"
# msg = "address" # msg = "address"
# peer = x509.Peer("000001") # peer = x509.Peer("000001")
# self.tunnel._connecting = {peer} # self.tunnel._connecting = {peer}
...@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
def test_processPacket_address(self): def test_processPacket_address(self):
"""code is 1, for address. And peer or msg are none""" """code is 1, for address. And peer or msg are none"""
c = chr(1) c = b"\x01"
self.tunnel._address = {1: "1,1", 2: "2,2"} self.tunnel._address = {1: "1,1", 2: "2,2"}
res = self.tunnel._processPacket(c) res = self.tunnel._processPacket(c)
...@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
and each address join by ; and each address join by ;
it will truncate address which has more than 3 element it will truncate address which has more than 3 element
""" """
c = chr(1) c = b"\x01"
peer = x509.Peer("000001") peer = x509.Peer("000001")
peer.protocol = 1 peer.protocol = 1
self.tunnel._peers.append(peer) self.tunnel._peers.append(peer)
...@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
"""code is 0, for network version, peer is not none """code is 0, for network version, peer is not none
2 case, one modify the version, one not 2 case, one modify the version, one not
""" """
c = chr(0) c = b"\x00"
peer = x509.Peer("000001") peer = x509.Peer("000001")
version1 = "00003" version1 = b"00003"
version2 = "00007" version2 = b"00007"
self.tunnel._version = version3 = "00005" self.tunnel._version = version3 = b"00005"
self.tunnel._peers.append(peer) self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer) res = self.tunnel._processPacket(c + version1, peer)
......
#!/usr/bin/python2 #!/usr/bin/env python3
import os import os
import sys import sys
import unittest import unittest
......
...@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
return return
crypto.X509Cert in pem format crypto.X509Cert in pem format
""" """
if type(ca) is str: if type(ca) is bytes:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca) ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca)
if type(ca_key) is str: if type(ca_key) is bytes:
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key) ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr) req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
...@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
if not_after: if not_after:
cert.set_notAfter( cert.set_notAfter(
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after))) time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)).encode())
else: else:
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration) cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject = req.get_subject() subject = req.get_subject()
...@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial): def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr() pkey, csr = generate_csr()
cert = generate_cert(ca, ca_key, csr, prefix, serial) cert = generate_cert(ca, ca_key, csr, prefix, serial)
with open(pkey_file, 'w') as f: with open(pkey_file, 'wb') as f:
f.write(pkey) f.write(pkey)
with open(cert_file, 'w') as f: with open(cert_file, 'wb') as f:
f.write(cert) f.write(cert)
return pkey, cert return pkey, cert
...@@ -84,26 +84,23 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042): ...@@ -84,26 +84,23 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
cert.set_pubkey(key) cert.set_pubkey(key)
cert.sign(key, "sha512") cert.sign(key, "sha512")
with open(pkey_file, 'w') as pkey_file: with open(pkey_file, 'wb') as pkey_file:
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file: with open(cert_file, 'wb') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert return key, cert
def prefix2cn(prefix): def prefix2cn(prefix: str) -> str:
return "%u/%u" % (int(prefix, 2), len(prefix)) return "%u/%u" % (int(prefix, 2), len(prefix))
def serial2prefix(serial): def serial2prefix(serial: int) -> str:
return bin(serial)[2:].rjust(16, '0') return bin(serial)[2:].rjust(16, '0')
# pkey: private key # pkey: private key
def decrypt(pkey, incontent): def decrypt(pkey: bytes, incontent: bytes) -> bytes:
with open("node.key", 'w') as f: with open("node.key", 'wb') as f:
f.write(pkey) f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split() args = "openssl rsautl -decrypt -inkey node.key".split()
p = subprocess.Popen( return subprocess.run(args, input=incontent, stdout=subprocess.PIPE).stdout
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
outcontent, err = p.communicate(incontent)
return outcontent
...@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket ...@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
import subprocess, struct, sys, time, weakref import subprocess, struct, sys, time, weakref
from collections import defaultdict, deque from collections import defaultdict, deque
from bisect import bisect, insort from bisect import bisect, insort
from collections.abc import Iterator, Sequence
from typing import Callable, TYPE_CHECKING
from OpenSSL import crypto from OpenSSL import crypto
from . import ctl, plib, utils, version, x509 from . import ctl, plib, utils, version, x509
if TYPE_CHECKING:
from . import cache
PORT = 326 PORT = 326
...@@ -21,7 +26,8 @@ proto_dict = { ...@@ -21,7 +26,8 @@ proto_dict = {
proto_dict['tcp'] = proto_dict['tcp4'] proto_dict['tcp'] = proto_dict['tcp4']
proto_dict['udp'] = proto_dict['udp4'] proto_dict['udp'] = proto_dict['udp4']
def resolve(ip, port, proto): def resolve(ip, port, proto: str) \
-> tuple[socket.AddressFamily | None, Iterator[str]]:
try: try:
family, proto = proto_dict[proto] family, proto = proto_dict[proto]
except KeyError: except KeyError:
...@@ -31,16 +37,16 @@ def resolve(ip, port, proto): ...@@ -31,16 +37,16 @@ def resolve(ip, port, proto):
class MultiGatewayManager(dict): class MultiGatewayManager(dict):
def __init__(self, gateway): def __init__(self, gateway: Callable[[str], str]):
self._gw = gateway self._gw = gateway
def _route(self, cmd, dest, gw): def _route(self, cmd: str, dest: str, gw: str):
if gw: if gw:
cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw
logging.trace('%r', cmd) logging.trace('%r', cmd)
subprocess.check_call(cmd) subprocess.check_call(cmd)
def add(self, dest, route): def add(self, dest: str, route: bool):
try: try:
self[dest][1] += 1 self[dest][1] += 1
except KeyError: except KeyError:
...@@ -48,7 +54,7 @@ class MultiGatewayManager(dict): ...@@ -48,7 +54,7 @@ class MultiGatewayManager(dict):
self[dest] = [gw, 0] self[dest] = [gw, 0]
self._route('add', dest, gw) self._route('add', dest, gw)
def remove(self, dest): def remove(self, dest: str):
gw, count = self[dest] gw, count = self[dest]
if count: if count:
self[dest][1] = count - 1 self[dest][1] = count - 1
...@@ -59,13 +65,14 @@ class MultiGatewayManager(dict): ...@@ -59,13 +65,14 @@ class MultiGatewayManager(dict):
except: except:
pass pass
class Connection(object): class Connection:
_retry = 0 _retry = 0
serial = None serial = None
time = float('inf') time = float('inf')
def __init__(self, tunnel_manager, address_list, iface, prefix): def __init__(self, tunnel_manager: "TunnelManager",
address_list, iface, prefix):
self.tunnel_manager = tunnel_manager self.tunnel_manager = tunnel_manager
self.address_list = address_list self.address_list = address_list
self.iface = iface self.iface = iface
...@@ -94,7 +101,7 @@ class Connection(object): ...@@ -94,7 +101,7 @@ class Connection(object):
'--remap-usr1', 'SIGTERM', '--remap-usr1', 'SIGTERM',
'--ping-exit', str(tm.timeout), '--ping-exit', str(tm.timeout),
'--route-up', '%s %u' % (plib.ovpn_client, tm.write_sock.fileno()), '--route-up', '%s %u' % (plib.ovpn_client, tm.write_sock.fileno()),
*tm.ovpn_args) *tm.ovpn_args, pass_fds=[tm.write_sock.fileno()])
tm.resetTunnelRefresh() tm.resetTunnelRefresh()
self._retry += 1 self._retry += 1
...@@ -109,7 +116,7 @@ class Connection(object): ...@@ -109,7 +116,7 @@ class Connection(object):
if i: if i:
cache.addPeer(self._prefix, ','.join(self.address_list[i]), True) cache.addPeer(self._prefix, ','.join(self.address_list[i]), True)
else: else:
cache.connecting(self._prefix, 0) cache.connecting(self._prefix, False)
def close(self): def close(self):
try: try:
...@@ -132,7 +139,7 @@ class Connection(object): ...@@ -132,7 +139,7 @@ class Connection(object):
self.open() self.open()
return 0 return 0
class TunnelKiller(object): class TunnelKiller:
state = None state = None
...@@ -169,7 +176,7 @@ class TunnelKiller(object): ...@@ -169,7 +176,7 @@ class TunnelKiller(object):
if (self.address, self.ifindex) in tm.ctl.locked: if (self.address, self.ifindex) in tm.ctl.locked:
self.state = 'locked' self.state = 'locked'
self.timeout = time.time() + 2 * tm.timeout self.timeout = time.time() + 2 * tm.timeout
tm.sendto(self.peer, '\2' if self.client else '\3') tm.sendto(self.peer, b'\2' if self.client else b'\3')
else: else:
self.timeout = 0 self.timeout = 0
...@@ -186,7 +193,7 @@ class TunnelKiller(object): ...@@ -186,7 +193,7 @@ class TunnelKiller(object):
locked = unlocking = lambda _: None locked = unlocking = lambda _: None
class BaseTunnelManager(object): class BaseTunnelManager:
# TODO: To minimize downtime when network parameters change, we should do # TODO: To minimize downtime when network parameters change, we should do
# our best to not restart any process. Ideally, this list should be # our best to not restart any process. Ideally, this list should be
...@@ -198,7 +205,8 @@ class BaseTunnelManager(object): ...@@ -198,7 +205,8 @@ class BaseTunnelManager(object):
_geoiplookup = None _geoiplookup = None
_forward = None _forward = None
def __init__(self, control_socket, cache, cert, conf_country, address=()): def __init__(self, control_socket, cache: "cache.Cache", cert: x509.Cert,
conf_country, address=()):
self.cert = cert self.cert = cert
self._network = cert.network self._network = cert.network
self._prefix = cert.prefix self._prefix = cert.prefix
...@@ -242,14 +250,14 @@ class BaseTunnelManager(object): ...@@ -242,14 +250,14 @@ class BaseTunnelManager(object):
self._country = {} self._country = {}
address_dict = {family: self._updateCountry(address) address_dict = {family: self._updateCountry(address)
for family, address in address_dict.iteritems()} for family, address in address_dict.items()}
elif cache.same_country: elif cache.same_country:
sys.exit("Can not respect 'same_country' network configuration" sys.exit("Can not respect 'same_country' network configuration"
" (GEOIP2_MMDB not set)") " (GEOIP2_MMDB not set)")
self._address = {family: utils.dump_address(address) self._address = {family: utils.dump_address(address)
for family, address in address_dict.iteritems() for family, address in address_dict.items()
if address} if address}
cache.my_address = ';'.join(self._address.itervalues()) cache.my_address = ';'.join(self._address.values())
self.sock = socket.socket(socket.AF_INET6, self.sock = socket.socket(socket.AF_INET6,
socket.SOCK_DGRAM | socket.SOCK_CLOEXEC) socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
...@@ -329,7 +337,7 @@ class BaseTunnelManager(object): ...@@ -329,7 +337,7 @@ class BaseTunnelManager(object):
def _getPeer(self, prefix): def _getPeer(self, prefix):
return self._peers[bisect(self._peers, prefix) - 1] return self._peers[bisect(self._peers, prefix) - 1]
def sendto(self, prefix, msg): def sendto(self, prefix: str, msg):
to = utils.ipFromBin(self._network + prefix), PORT to = utils.ipFromBin(self._network + prefix), PORT
peer = self._getPeer(prefix) peer = self._getPeer(prefix)
if peer.prefix != prefix: if peer.prefix != prefix:
...@@ -344,9 +352,11 @@ class BaseTunnelManager(object): ...@@ -344,9 +352,11 @@ class BaseTunnelManager(object):
peer.hello0Sent() peer.hello0Sent()
def _sendto(self, to, msg, peer=None): def _sendto(self, to, msg, peer=None):
if type(msg) is str:
msg = msg.encode()
try: try:
r = self.sock.sendto(peer.encode(msg) if peer else msg, to) r = self.sock.sendto(peer.encode(msg) if peer else msg, to)
except socket.error, e: except socket.error as e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)( (logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s (%s)', to, e) 'Failed to send message to %s (%s)', to, e)
return return
...@@ -359,19 +369,20 @@ class BaseTunnelManager(object): ...@@ -359,19 +369,20 @@ class BaseTunnelManager(object):
to = address[:2] to = address[:2]
if address[0] == '::1': if address[0] == '::1':
try: try:
prefix, msg = msg.split('\0', 1) prefix, msg = msg.split(b'\0', 1)
prefix = prefix.decode()
int(prefix, 2) int(prefix, 2)
except ValueError: except ValueError:
return return
if msg: if msg:
self._forward = to self._forward = to
code = ord(msg[0]) code = msg[0]
if prefix == self._prefix: if prefix == self._prefix:
msg = self._processPacket(msg) msg = self._processPacket(msg)
if msg: if msg:
self._sendto(to, '%s\0%c%s' % (prefix, code, msg)) self._sendto(to, '%s\0%c%s' % (prefix, code, msg))
else: else:
self.sendto(prefix, chr(code | 0x80) + msg[1:]) self.sendto(prefix, bytes([code | 0x80]) + msg[1:])
return return
try: try:
sender = utils.binFromIp(address[0]) sender = utils.binFromIp(address[0])
...@@ -384,7 +395,7 @@ class BaseTunnelManager(object): ...@@ -384,7 +395,7 @@ class BaseTunnelManager(object):
msg = peer.decode(msg) msg = peer.decode(msg)
if type(msg) is tuple: if type(msg) is tuple:
seqno, msg, protocol = msg seqno, msg, protocol = msg
def handleHello(peer, seqno, msg, retry): def handleHello(peer, seqno, msg: bytes, retry):
if seqno == 2: if seqno == 2:
i = len(msg) // 2 i = len(msg) // 2
h = msg[:i] h = msg[:i]
...@@ -394,10 +405,10 @@ class BaseTunnelManager(object): ...@@ -394,10 +405,10 @@ class BaseTunnelManager(object):
except (AttributeError, crypto.Error, x509.NewSessionError, except (AttributeError, crypto.Error, x509.NewSessionError,
subprocess.CalledProcessError): subprocess.CalledProcessError):
logging.debug('ignored new session key from %r', logging.debug('ignored new session key from %r',
address, exc_info=1) address, exc_info=True)
return return
peer.version = self._version \ peer.version = self._version \
if self._sendto(to, '\0' + self._version, peer) else '' if self._sendto(to, b'\0' + self._version, peer) else b''
return return
if seqno: if seqno:
h = x509.fingerprint(self.cert.cert).digest() h = x509.fingerprint(self.cert.cert).digest()
...@@ -410,7 +421,7 @@ class BaseTunnelManager(object): ...@@ -410,7 +421,7 @@ class BaseTunnelManager(object):
serial = cert.get_serial_number() serial = cert.get_serial_number()
if serial in self.cache.crl: if serial in self.cache.crl:
raise ValueError("revoked") raise ValueError("revoked")
except (x509.VerifyError, ValueError), e: except (x509.VerifyError, ValueError) as e:
if retry: if retry:
return True return True
logging.debug('ignored invalid certificate from %r (%s)', logging.debug('ignored invalid certificate from %r (%s)',
...@@ -444,10 +455,12 @@ class BaseTunnelManager(object): ...@@ -444,10 +455,12 @@ class BaseTunnelManager(object):
# We got a valid and non-empty message. Always reply # We got a valid and non-empty message. Always reply
# something so that the sender knows we're still connected. # something so that the sender knows we're still connected.
answer = self._processPacket(msg, peer.prefix) answer = self._processPacket(msg, peer.prefix)
self._sendto(to, msg[0] + answer if answer else "", peer) self._sendto(to, msg[0:1] + answer.encode() if answer else b'',
peer)
def _processPacket(self, msg, peer=None): def _processPacket(self, msg: bytes, peer: x509.Peer|str=None):
c = ord(msg[0]) c = msg[0]
msg = msg[1:] msg = msg[1:]
code = c & 0x7f code = c & 0x7f
if c > 0x7f and msg: if c > 0x7f and msg:
...@@ -456,6 +469,7 @@ class BaseTunnelManager(object): ...@@ -456,6 +469,7 @@ class BaseTunnelManager(object):
elif code == 1: # address elif code == 1: # address
if msg: if msg:
if peer: if peer:
msg = msg.decode()
self.cache.addPeer(peer, msg) self.cache.addPeer(peer, msg)
try: try:
self._connecting.remove(peer) self._connecting.remove(peer)
...@@ -467,8 +481,8 @@ class BaseTunnelManager(object): ...@@ -467,8 +481,8 @@ class BaseTunnelManager(object):
# Don't send country to old nodes # Don't send country to old nodes
if self._getPeer(peer).protocol < 7: if self._getPeer(peer).protocol < 7:
return ';'.join(','.join(a.split(',')[:3]) for a in return ';'.join(','.join(a.split(',')[:3]) for a in
';'.join(self._address.itervalues()).split(';')) ';'.join(self._address.values()).split(';'))
return ';'.join(self._address.itervalues()) return ';'.join(self._address.values())
elif not code: # network version elif not code: # network version
if peer: if peer:
try: try:
...@@ -526,7 +540,7 @@ class BaseTunnelManager(object): ...@@ -526,7 +540,7 @@ class BaseTunnelManager(object):
if peer.prefix != prefix: if peer.prefix != prefix:
self.sendto(prefix, None) self.sendto(prefix, None)
elif (peer.version < self._version and elif (peer.version < self._version and
self.sendto(prefix, '\0' + self._version)): self.sendto(prefix, b'\0' + self._version)):
peer.version = self._version peer.version = self._version
def broadcastNewVersion(self): def broadcastNewVersion(self):
...@@ -553,18 +567,18 @@ class BaseTunnelManager(object): ...@@ -553,18 +567,18 @@ class BaseTunnelManager(object):
if (not self.NEED_RESTART.isdisjoint(changed) if (not self.NEED_RESTART.isdisjoint(changed)
or version.protocol < self.cache.min_protocol or version.protocol < self.cache.min_protocol
# TODO: With --management, we could kill clients without restarting. # TODO: With --management, we could kill clients without restarting.
or not all(crl.isdisjoint(serials.itervalues()) or not all(crl.isdisjoint(serials.values())
for serials in self._served.itervalues())): for serials in self._served.values())):
# Wait at least 1 second to broadcast new version to neighbours. # Wait at least 1 second to broadcast new version to neighbours.
self.selectTimeout(time.time() + 1 + self.cache.delay_restart, self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
self._restart) self._restart)
def handleServerEvent(self, sock): def handleServerEvent(self, sock: socket.socket):
event, args = eval(sock.recv(65536)) event, args = eval(sock.recv(65536))
logging.debug("%s%r", event, args) logging.debug("%s%r", event, args)
r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args) r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args)
if r is not None: if r is not None:
sock.send(chr(r)) sock.send(bytes([r]))
def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip): def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip):
if serial in self.cache.crl: if serial in self.cache.crl:
...@@ -576,7 +590,7 @@ class BaseTunnelManager(object): ...@@ -576,7 +590,7 @@ class BaseTunnelManager(object):
self._gateway_manager.add(trusted_ip, False) self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix: if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix) self._kill(prefix)
self.cache.connecting(prefix, 0) self.cache.connecting(prefix, False)
return True return True
def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip): def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip):
...@@ -606,7 +620,7 @@ class BaseTunnelManager(object): ...@@ -606,7 +620,7 @@ class BaseTunnelManager(object):
with open('/proc/net/ipv6_route', "r", 4096) as f: with open('/proc/net/ipv6_route', "r", 4096) as f:
try: try:
routing_table = f.read() routing_table = f.read()
except IOError, e: except IOError as e:
# ???: If someone can explain why the kernel sometimes fails # ???: If someone can explain why the kernel sometimes fails
# even when there's a lot of free memory. # even when there's a lot of free memory.
if e.errno != errno.ENOMEM: if e.errno != errno.ENOMEM:
...@@ -635,7 +649,7 @@ class BaseTunnelManager(object): ...@@ -635,7 +649,7 @@ class BaseTunnelManager(object):
logging.error("%s. Flushing...", msg) logging.error("%s. Flushing...", msg)
subprocess.call(("ip", "-6", "route", "flush", "cached")) subprocess.call(("ip", "-6", "route", "flush", "cached"))
self.sendto(self.cache.registry_prefix, self.sendto(self.cache.registry_prefix,
'\7%s (%s)' % (msg, os.uname()[2])) b'\7%s (%s)' % (msg, os.uname()[2]))
break break
def _updateCountry(self, address): def _updateCountry(self, address):
...@@ -660,7 +674,8 @@ class TunnelManager(BaseTunnelManager): ...@@ -660,7 +674,8 @@ class TunnelManager(BaseTunnelManager):
def __init__(self, control_socket, cache, cert, openvpn_args, def __init__(self, control_socket, cache, cert, openvpn_args,
timeout, client_count, iface_list, conf_country, address, timeout, client_count, iface_list, conf_country, address,
ip_changed, remote_gateway, disable_proto, neighbour_list=()): ip_changed, remote_gateway: Callable[[str], str],
disable_proto: Sequence[str], neighbour_list=()):
super(TunnelManager, self).__init__(control_socket, super(TunnelManager, self).__init__(control_socket,
cache, cert, conf_country, address) cache, cert, conf_country, address)
self.ovpn_args = openvpn_args self.ovpn_args = openvpn_args
...@@ -683,7 +698,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -683,7 +698,7 @@ class TunnelManager(BaseTunnelManager):
self._client_count = client_count self._client_count = client_count
self.new_iface_list = deque('re6stnet' + str(i) self.new_iface_list = deque('re6stnet' + str(i)
for i in xrange(1, self._client_count + 1)) for i in range(1, self._client_count + 1))
self._free_iface_list = [] self._free_iface_list = []
def close(self): def close(self):
...@@ -752,7 +767,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -752,7 +767,7 @@ class TunnelManager(BaseTunnelManager):
def babel_dump(self): def babel_dump(self):
t = time.time() t = time.time()
if self._killing: if self._killing:
for prefix, tunnel_killer in self._killing.items(): for prefix, tunnel_killer in list(self._killing.items()):
if tunnel_killer.timeout < t: if tunnel_killer.timeout < t:
if tunnel_killer.state != 'unlocking': if tunnel_killer.state != 'unlocking':
logging.info( logging.info(
...@@ -780,7 +795,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -780,7 +795,7 @@ class TunnelManager(BaseTunnelManager):
def _cleanDeads(self): def _cleanDeads(self):
disconnected = False disconnected = False
for prefix in self._connection_dict.keys(): for prefix in list(self._connection_dict):
status = self._connection_dict[prefix].refresh() status = self._connection_dict[prefix].refresh()
if status: if status:
disconnected |= status > 0 disconnected |= status > 0
...@@ -872,7 +887,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -872,7 +887,7 @@ class TunnelManager(BaseTunnelManager):
address_list.append((ip, x[1], x[2])) address_list.append((ip, x[1], x[2]))
continue continue
address_list.append(x[:3]) address_list.append(x[:3])
self.cache.connecting(prefix, 1) self.cache.connecting(prefix, True)
if not address_list: if not address_list:
return False return False
logging.info('Establishing a connection with %u/%u', logging.info('Establishing a connection with %u/%u',
...@@ -902,7 +917,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -902,7 +917,7 @@ class TunnelManager(BaseTunnelManager):
neighbours = self.ctl.neighbours neighbours = self.ctl.neighbours
# Collect all nodes known by Babel # Collect all nodes known by Babel
peers = {prefix peers = {prefix
for neigh_routes in neighbours.itervalues() for neigh_routes in neighbours.values()
for prefix in neigh_routes[1] for prefix in neigh_routes[1]
if prefix} if prefix}
# Keep only distant peers. # Keep only distant peers.
...@@ -957,7 +972,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -957,7 +972,7 @@ class TunnelManager(BaseTunnelManager):
address = self.cache.getAddress(peer) address = self.cache.getAddress(peer)
if address: if address:
count -= self._makeTunnel(peer, address) count -= self._makeTunnel(peer, address)
elif self.sendto(peer, '\1'): elif self.sendto(peer, b'\1'):
self._connecting.add(peer) self._connecting.add(peer)
count -= 1 count -= 1
elif distant_peers is None: elif distant_peers is None:
...@@ -987,7 +1002,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -987,7 +1002,7 @@ class TunnelManager(BaseTunnelManager):
break break
def killAll(self): def killAll(self):
for prefix in self._connection_dict.keys(): for prefix in list(self._connection_dict):
self._kill(prefix) self._kill(prefix)
def handleClientEvent(self): def handleClientEvent(self):
...@@ -999,7 +1014,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -999,7 +1014,7 @@ class TunnelManager(BaseTunnelManager):
if c and c.time < float(time): if c and c.time < float(time):
try: try:
c.connected(serial) c.connected(serial)
except (KeyError, TypeError), e: except (KeyError, TypeError) as e:
logging.error("%s (route_up %s)", e, common_name) logging.error("%s (route_up %s)", e, common_name)
else: else:
logging.info("ignore route_up notification for %s %r", logging.info("ignore route_up notification for %s %r",
...@@ -1010,10 +1025,10 @@ class TunnelManager(BaseTunnelManager): ...@@ -1010,10 +1025,10 @@ class TunnelManager(BaseTunnelManager):
if self.cache.same_country: if self.cache.same_country:
address = self._updateCountry(address) address = self._updateCountry(address)
self._address[family] = utils.dump_address(address) self._address[family] = utils.dump_address(address)
self.cache.my_address = ';'.join(self._address.itervalues()) self.cache.my_address = ';'.join(self._address.values())
def broadcastNewVersion(self): def broadcastNewVersion(self):
self._babel_dump_new_version() self._babel_dump_new_version()
for prefix, c in self._connection_dict.items(): for prefix, c in list(self._connection_dict.items()):
if c.serial in self.cache.crl: if c.serial in self.cache.crl:
self._kill(prefix) self._kill(prefix)
...@@ -7,7 +7,7 @@ class UPnPException(Exception): ...@@ -7,7 +7,7 @@ class UPnPException(Exception):
pass pass
class Forwarder(object): class Forwarder:
""" """
External port is chosen randomly between 32768 & 49151 included. External port is chosen randomly between 32768 & 49151 included.
""" """
...@@ -17,7 +17,7 @@ class Forwarder(object): ...@@ -17,7 +17,7 @@ class Forwarder(object):
_lcg_n = 0 _lcg_n = 0
@classmethod @classmethod
def _getExternalPort(cls): def _getExternalPort(cls) -> int:
# Since _refresh() does not test all ports in a row, we prefer to # Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port. # return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without # A linear congruential generator should be random enough, without
...@@ -35,12 +35,12 @@ class Forwarder(object): ...@@ -35,12 +35,12 @@ class Forwarder(object):
self._u.discoverdelay = 200 self._u.discoverdelay = 200
self._rules = [] self._rules = []
def __getattr__(self, name): def __getattr__(self, name: str):
wrapped = getattr(self._u, name) wrapped = getattr(self._u, name)
def wrapper(*args, **kw): def wrapper(*args, **kw):
try: try:
return wrapped(*args, **kw) return wrapped(*args, **kw)
except Exception, e: except Exception as e:
raise UPnPException(str(e)) raise UPnPException(str(e))
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
...@@ -68,14 +68,14 @@ class Forwarder(object): ...@@ -68,14 +68,14 @@ class Forwarder(object):
else: else:
try: try:
return self._refresh() return self._refresh()
except UPnPException, e: except UPnPException as e:
logging.debug("UPnP failure", exc_info=1) logging.debug("UPnP failure", exc_info=True)
self.clear() self.clear()
try: try:
self.discover() self.discover()
self.selectigd() self.selectigd()
return self._refresh() return self._refresh()
except UPnPException, e: except UPnPException as e:
self.next_refresh = self._next_retry = time.time() + 60 self.next_refresh = self._next_retry = time.time() + 60
logging.info(str(e)) logging.info(str(e))
self.clear() self.clear()
...@@ -109,7 +109,7 @@ class Forwarder(object): ...@@ -109,7 +109,7 @@ class Forwarder(object):
try: try:
self.addportmapping(port, *args) self.addportmapping(port, *args)
break break
except UPnPException, e: except UPnPException as e:
if str(e) != 'ConflictInMappingEntry': if str(e) != 'ConflictInMappingEntry':
raise raise
port = None port = None
......
import argparse, errno, fcntl, hashlib, logging, os, select as _select import argparse, errno, fcntl, hashlib, logging, os, select as _select
import shlex, signal, socket, sqlite3, struct, subprocess import shlex, signal, socket, sqlite3, struct, subprocess
import sys, textwrap, threading, time, traceback import sys, textwrap, threading, time, traceback
from collections.abc import Iterator, Mapping
# PY3: It will be even better to use Popen(pass_fds=...), HMAC_LEN = len(hashlib.sha1(b'').digest())
# and then socket.SOCK_CLOEXEC will be useless.
# (We already follow the good practice that consists in not
# relying on the GC for the closing of file descriptors.)
socket.SOCK_CLOEXEC = 0x80000
HMAC_LEN = len(hashlib.sha1('').digest())
class ReexecException(Exception): class ReexecException(Exception):
pass pass
...@@ -37,15 +32,15 @@ class FileHandler(logging.FileHandler): ...@@ -37,15 +32,15 @@ class FileHandler(logging.FileHandler):
finally: finally:
self.lock.release() self.lock.release()
# In the rare case _reopen is set just before the lock was released # In the rare case _reopen is set just before the lock was released
if self._reopen and self.lock.acquire(0): if self._reopen and self.lock.acquire(False):
self.release() self.release()
def async_reopen(self, *_): def async_reopen(self, *_):
self._reopen = True self._reopen = True
if self.lock.acquire(0): if self.lock.acquire(False):
self.release() self.release()
def setupLog(log_level, filename=None, **kw): def setupLog(log_level: int, filename: str | None=None, **kw):
if log_level and filename: if log_level and filename:
makedirs(os.path.dirname(filename)) makedirs(os.path.dirname(filename))
handler = FileHandler(filename) handler = FileHandler(filename)
...@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser): ...@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser):
ca /etc/re6stnet/ca.crt""", **kw) ca /etc/re6stnet/ca.crt""", **kw)
class exit(object): class exit:
status = None status = None
...@@ -150,7 +145,7 @@ class exit(object): ...@@ -150,7 +145,7 @@ class exit(object):
def handler(*args): def handler(*args):
if self.status is None: if self.status is None:
self.status = status self.status = status
if self.acquire(0): if self.acquire(False):
self.release() self.release()
for sig in sigs: for sig in sigs:
signal.signal(sig, handler) signal.signal(sig, handler)
...@@ -164,7 +159,7 @@ class Popen(subprocess.Popen): ...@@ -164,7 +159,7 @@ class Popen(subprocess.Popen):
self._args = tuple(args[0] if args else kw['args']) self._args = tuple(args[0] if args else kw['args'])
try: try:
super(Popen, self).__init__(*args, **kw) super(Popen, self).__init__(*args, **kw)
except OSError, e: except OSError as e:
if e.errno != errno.ENOMEM: if e.errno != errno.ENOMEM:
raise raise
self.returncode = -1 self.returncode = -1
...@@ -179,9 +174,9 @@ class Popen(subprocess.Popen): ...@@ -179,9 +174,9 @@ class Popen(subprocess.Popen):
self.terminate() self.terminate()
t = threading.Timer(5, self.kill) t = threading.Timer(5, self.kill)
t.start() t.start()
# PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel() r = os.waitid(os.P_PID, self.pid, os.WNOWAIT)
r = self.wait()
t.cancel() t.cancel()
self.poll()
return r return r
...@@ -189,7 +184,7 @@ def setCloexec(fd): ...@@ -189,7 +184,7 @@ def setCloexec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD) flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def select(R, W, T): def select(R: Mapping, W: Mapping, T):
try: try:
r, w, _ = _select.select(R, W, (), r, w, _ = _select.select(R, W, (),
max(0, min(T)[0] - time.time()) if T else None) max(0, min(T)[0] - time.time()) if T else None)
...@@ -209,19 +204,19 @@ def select(R, W, T): ...@@ -209,19 +204,19 @@ def select(R, W, T):
def makedirs(*args): def makedirs(*args):
try: try:
os.makedirs(*args) os.makedirs(*args)
except OSError, e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
def binFromIp(ip): def binFromIp(ip: str) -> str:
return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip)) return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip))
def binFromRawIp(ip): def binFromRawIp(ip: bytes) -> str:
ip1, ip2 = struct.unpack('>QQ', ip) ip1, ip2 = struct.unpack('>QQ', ip)
return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0') return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0')
def ipFromBin(ip, suffix=''): def ipFromBin(ip: str, suffix='') -> str:
suffix_len = 128 - len(ip) suffix_len = 128 - len(ip)
if suffix_len > 0: if suffix_len > 0:
ip += suffix.rjust(suffix_len, '0') ip += suffix.rjust(suffix_len, '0')
...@@ -230,30 +225,32 @@ def ipFromBin(ip, suffix=''): ...@@ -230,30 +225,32 @@ def ipFromBin(ip, suffix=''):
return socket.inet_ntop(socket.AF_INET6, return socket.inet_ntop(socket.AF_INET6,
struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2))) struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2)))
def dump_address(address): def dump_address(address: str) -> str:
return ';'.join(map(','.join, address)) return ';'.join(map(','.join, address))
# Yield ip, port, protocol, and country if it is in the address # Yield ip, port, protocol, and country if it is in the address
def parse_address(address_list): def parse_address(address_list: str) -> Iterator[tuple[str, str, str, str]]:
for address in address_list.split(';'): for address in address_list.split(';'):
try: try:
a = address.split(',') a = address.split(',')
int(a[1]) # Check if port is an int int(a[1]) # Check if port is an int
yield tuple(a[:4]) yield tuple(a[:4])
except ValueError, e: except ValueError as e:
logging.warning("Failed to parse node address %r (%s)", logging.warning("Failed to parse node address %r (%s)",
address, e) address, e)
def binFromSubnet(subnet): def binFromSubnet(subnet: str) -> str:
p, l = subnet.split('/') p, l = subnet.split('/')
return bin(int(p))[2:].rjust(int(l), '0') return bin(int(p))[2:].rjust(int(l), '0')
def newHmacSecret(): def _newHmacSecret():
from random import getrandbits as g from random import getrandbits as g
pack = struct.Struct(">QQI").pack pack = struct.Struct(">QQI").pack
assert len(pack(0,0,0)) == HMAC_LEN assert len(pack(0,0,0)) == HMAC_LEN
# A closure is built to avoid rebuilding the `pack` function at each call.
return lambda x=None: pack(g(64) if x is None else x, g(64), g(32)) return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
newHmacSecret = newHmacSecret()
newHmacSecret = _newHmacSecret() # https://github.com/python/mypy/issues/1174
### Integer serialization ### Integer serialization
# - supports values from 0 to 0x202020202020201f # - supports values from 0 to 0x202020202020201f
...@@ -261,21 +258,21 @@ newHmacSecret = newHmacSecret() ...@@ -261,21 +258,21 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value # - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes # - the 3 first bits code the number of bytes
def packInteger(i): def packInteger(i: int) -> bytes:
for n in xrange(8): for n in range(8):
x = 32 << 8 * n x = 32 << 8 * n
if i < x: if i < x:
return struct.pack("!Q", i + n * x)[7-n:] return struct.pack("!Q", i + n * x)[7-n:]
i -= x i -= x
raise OverflowError raise OverflowError
def unpackInteger(x): def unpackInteger(x: bytes) -> tuple[int, int] | None:
n = ord(x[0]) >> 5 n = x[0] >> 5
try: try:
i, = struct.unpack("!Q", '\0' * (7 - n) + x[:n+1]) i, = struct.unpack("!Q", b'\0' * (7 - n) + x[:n+1])
except struct.error: except struct.error:
return return
return sum((32 << 8 * i for i in xrange(n)), return sum((32 << 8 * i for i in range(n)),
i - (n * 32 << 8 * n)), n + 1 i - (n * 32 << 8 * n)), n + 1
### ###
......
...@@ -40,4 +40,4 @@ protocol = 8 ...@@ -40,4 +40,4 @@ protocol = 8
min_protocol = 1 min_protocol = 1
if __name__ == "__main__": if __name__ == "__main__":
print version print(version)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
from typing import Callable, Any
from OpenSSL import crypto from OpenSSL import crypto
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509 import load_pem_x509_certificate
from . import utils from . import utils
from .version import protocol from .version import protocol
def newHmacSecret(): def newHmacSecret() -> bytes:
return utils.newHmacSecret(int(time.time() * 1000000)) return utils.newHmacSecret(int(time.time() * 1000000))
def networkFromCa(ca): def networkFromCa(ca: crypto.X509) -> str:
# TODO: will be ca.serial_number after migration to cryptography
return bin(ca.get_serial_number())[3:] return bin(ca.get_serial_number())[3:]
def subnetFromCert(cert): def subnetFromCert(cert: crypto.X509) -> str:
return cert.get_subject().CN return cert.get_subject().CN
def notBefore(cert): def notBefore(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notBefore(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notBefore().decode(),
'%Y%m%d%H%M%SZ'))
def notAfter(cert): def notAfter(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notAfter().decode(),
'%Y%m%d%H%M%SZ'))
def openssl(*args): def openssl(*args: str, fds=[]) -> utils.Popen:
return utils.Popen(('openssl',) + args, return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data): def encrypt(cert: bytes, data: bytes) -> bytes:
r, w = os.pipe() r, w = os.pipe()
try: try:
threading.Thread(target=os.write, args=(w, cert)).start() threading.Thread(target=os.write, args=(w, cert)).start()
p = openssl('rsautl', '-encrypt', '-certin', p = openssl('rsautl', '-encrypt', '-certin',
'-inkey', '/proc/self/fd/%u' % r) '-inkey', '/proc/self/fd/%u' % r, fds=[r])
out, err = p.communicate(data) out, err = p.communicate(data)
finally: finally:
os.close(r) os.close(r)
...@@ -39,10 +49,12 @@ def encrypt(cert, data): ...@@ -39,10 +49,12 @@ def encrypt(cert, data):
raise subprocess.CalledProcessError(p.returncode, 'openssl', err) raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
return out return out
def fingerprint(cert, alg='sha1'): def fingerprint(cert: crypto.X509, alg='sha1'):
return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)) return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))
def maybe_renew(path, cert, info, renew, force=False): def maybe_renew(path: str, cert: crypto.X509, info: str,
renew: Callable[[], bytes],
force=False) -> tuple[crypto.X509, int]:
from .registry import RENEW_PERIOD from .registry import RENEW_PERIOD
while True: while True:
if force: if force:
...@@ -62,7 +74,7 @@ def maybe_renew(path, cert, info, renew, force=False): ...@@ -62,7 +74,7 @@ def maybe_renew(path, cert, info, renew, force=False):
exc_info = 1 exc_info = 1
break break
new_path = path + '.new' new_path = path + '.new'
with open(new_path, 'w') as f: with open(new_path, 'wb') as f:
f.write(pem) f.write(pem)
try: try:
s = os.stat(path) s = os.stat(path)
...@@ -84,39 +96,44 @@ class NewSessionError(Exception): ...@@ -84,39 +96,44 @@ class NewSessionError(Exception):
pass pass
class Cert(object): class Cert:
def __init__(self, ca, key, cert=None): def __init__(self, ca: str, key: str, cert: str | None=None):
self.ca_path = ca self.ca_path = ca
self.cert_path = cert self.cert_path = cert
self.key_path = key self.key_path = key
with open(ca) as f: # TODO: finish migration from old OpenSSL module to cryptography
self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) with open(ca, "rb") as f:
with open(key) as f: ca_pem = f.read()
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca_pem)
self.ca_crypto = load_pem_x509_certificate(ca_pem)
with open(key, "rb") as f:
key_pem = f.read()
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
self.key_crypto = load_pem_private_key(key_pem, password=None)
if cert: if cert:
with open(cert) as f: with open(cert) as f:
self.cert = self.loadVerify(f.read()) self.cert = self.loadVerify(f.read().encode())
@property @property
def prefix(self): def prefix(self) -> str:
return utils.binFromSubnet(subnetFromCert(self.cert)) return utils.binFromSubnet(subnetFromCert(self.cert))
@property @property
def network(self): def network(self) -> str:
return networkFromCa(self.ca) return networkFromCa(self.ca)
@property @property
def subject_serial(self): def subject_serial(self) -> int:
return int(self.cert.get_subject().serialNumber) return int(self.cert.get_subject().serialNumber)
@property @property
def openvpn_args(self): def openvpn_args(self) -> tuple[str, ...]:
return ('--ca', self.ca_path, return ('--ca', self.ca_path,
'--cert', self.cert_path, '--cert', self.cert_path,
'--key', self.key_path) '--key', self.key_path)
def maybeRenew(self, registry, crl): def maybeRenew(self, registry, crl) -> int:
self.cert, next_renew = maybe_renew(self.cert_path, self.cert, self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
"Certificate", lambda: registry.renewCertificate(self.prefix), "Certificate", lambda: registry.renewCertificate(self.prefix),
self.cert.get_serial_number() in crl) self.cert.get_serial_number() in crl)
...@@ -143,21 +160,31 @@ class Cert(object): ...@@ -143,21 +160,31 @@ class Cert(object):
"error running openssl, assuming cert is invalid") "error running openssl, assuming cert is invalid")
# BBB: With old versions of openssl, detailed # BBB: With old versions of openssl, detailed
# error is printed to standard output. # error is printed to standard output.
for err in err, out: for stream in err, out:
for x in err.splitlines(): for x in stream.decode(errors='replace').splitlines():
if x.startswith('error '): if x.startswith('error '):
x, msg = x.split(':', 1) x, msg = x.split(':', 1)
_, code, _, depth, _ = x.split(None, 4) _, code, _, depth, _ = x.split(None, 4)
raise VerifyError(int(code), int(depth), msg.strip()) raise VerifyError(int(code), int(depth), msg.strip())
return r return r
def verify(self, sign, data): def verify(self, sign: bytes, data: bytes):
crypto.verify(self.ca, sign, data, 'sha512') pub_key = self.ca_crypto.public_key()
pub_key.verify(
def sign(self, data): sign,
return crypto.sign(self.key, data, 'sha512') data,
padding.PKCS1v15(),
def decrypt(self, data): hashes.SHA512()
)
def sign(self, data: bytes) -> bytes:
return self.key_crypto.sign(
data,
padding.PKCS1v15(),
hashes.SHA512()
)
def decrypt(self, data: bytes) -> bytes:
p = openssl('rsautl', '-decrypt', '-inkey', self.key_path) p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
out, err = p.communicate(data) out, err = p.communicate(data)
if p.returncode: if p.returncode:
...@@ -166,7 +193,7 @@ class Cert(object): ...@@ -166,7 +193,7 @@ class Cert(object):
def verifyVersion(self, version): def verifyVersion(self, version):
try: try:
n = 1 + (ord(version[0]) >> 5) n = 1 + (version[0] >> 5)
self.verify(version[n:], version[:n]) self.verify(version[n:], version[:n])
except (IndexError, crypto.Error): except (IndexError, crypto.Error):
raise VerifyError(None, None, 'invalid network version') raise VerifyError(None, None, 'invalid network version')
...@@ -175,7 +202,7 @@ class Cert(object): ...@@ -175,7 +202,7 @@ class Cert(object):
PACKED_PROTOCOL = utils.packInteger(protocol) PACKED_PROTOCOL = utils.packInteger(protocol)
class Peer(object): class Peer:
""" """
UDP: A ─────────────────────────────────────────────> B UDP: A ─────────────────────────────────────────────> B
...@@ -206,9 +233,10 @@ class Peer(object): ...@@ -206,9 +233,10 @@ class Peer(object):
_key = newHmacSecret() _key = newHmacSecret()
serial = None serial = None
stop_date = float('inf') stop_date = float('inf')
version = '' version = b''
cert: crypto.X509
def __init__(self, prefix): def __init__(self, prefix: str):
self.prefix = prefix self.prefix = prefix
@property @property
...@@ -224,35 +252,35 @@ class Peer(object): ...@@ -224,35 +252,35 @@ class Peer(object):
def __lt__(self, other): def __lt__(self, other):
return self.prefix < (other if type(other) is str else other.prefix) return self.prefix < (other if type(other) is str else other.prefix)
def hello0(self, cert): def hello0(self, cert: crypto.X509) -> bytes:
if self._hello < time.time(): if self._hello < time.time():
try: try:
# Always assume peer is not old, in case it has just upgraded, # Always assume peer is not old, in case it has just upgraded,
# else we would be stuck with the old protocol. # else we would be stuck with the old protocol.
msg = ('\0\0\0\1' msg = (b'\0\0\0\1'
+ PACKED_PROTOCOL + PACKED_PROTOCOL
+ fingerprint(self.cert).digest()) + fingerprint(self.cert).digest())
except AttributeError: except AttributeError:
msg = '\0\0\0\0' msg = b'\0\0\0\0'
return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
def hello0Sent(self): def hello0Sent(self):
self._hello = time.time() + 60 self._hello = time.time() + 60
def hello(self, cert, protocol): def hello(self, cert: Cert, protocol: int) -> bytes:
key = self._key = newHmacSecret() key = self._key = newHmacSecret()
h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert), h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert),
key) key)
self._i = self._j = 2 self._i = self._j = 2
self._last = 0 self._last = 0
self.protocol = protocol self.protocol = protocol
return ''.join(('\0\0\0\2', PACKED_PROTOCOL if protocol else '', return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'',
h, cert.sign(h))) h, cert.sign(h)))
def _hmac(self, msg): def _hmac(self, msg: bytes) -> bytes:
return hmac.HMAC(self._key, msg, hashlib.sha1).digest() return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key, protocol): def newSession(self, key: bytes, protocol: int):
if key <= self._key: if key <= self._key:
raise NewSessionError(self._key, key) raise NewSessionError(self._key, key)
self._key = key self._key = key
...@@ -260,12 +288,13 @@ class Peer(object): ...@@ -260,12 +288,13 @@ class Peer(object):
self._last = None self._last = None
self.protocol = protocol self.protocol = protocol
def verify(self, sign, data): def verify(self, sign: bytes, data: bytes):
crypto.verify(self.cert, sign, data, 'sha512') crypto.verify(self.cert, sign, data, 'sha512')
seqno_struct = struct.Struct("!L") seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack): def decode(self, msg: bytes, _unpack=seqno_struct.unpack) \
-> tuple[int, bytes, int | None] | bytes:
seqno, = _unpack(msg[:4]) seqno, = _unpack(msg[:4])
if seqno <= 2: if seqno <= 2:
msg = msg[4:] msg = msg[4:]
...@@ -281,8 +310,10 @@ class Peer(object): ...@@ -281,8 +310,10 @@ class Peer(object):
self._i = seqno self._i = seqno
return msg[4:i] return msg[4:i]
def encode(self, msg, _pack=seqno_struct.pack): def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1 self._j += 1
if type(msg) is str:
msg = msg.encode()
msg = _pack(self._j) + msg msg = _pack(self._j) + msg
return msg + self._hmac(msg) return msg + self._hmac(msg)
......
...@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py ...@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py
from distutils import log from distutils import log
version = {"__file__": "re6st/version.py"} version = {"__file__": "re6st/version.py"}
execfile(version["__file__"], version) with open(version["__file__"]) as f:
code = compile(f.read(), version["__file__"], 'exec')
exec(code, version)
def copy_file(self, infile, outfile, *args, **kw): def copy_file(self, infile, outfile, *args, **kw):
if infile == version["__file__"]: if infile == version["__file__"]:
if not self.dry_run: if not self.dry_run:
log.info("generating %s -> %s", infile, outfile) log.info("generating %s -> %s", infile, outfile)
with open(outfile, "wb") as f: with open(outfile, "w") as f:
for x in sorted(version.iteritems()): for x in sorted(version.items()):
if not x[0].startswith("_"): if not x[0].startswith("_"):
f.write("%s = %r\n" % x) f.write("%s = %r\n" % x)
return outfile, 1 return outfile, 1
elif isinstance(self, build_py) and \ elif isinstance(self, build_py) and \
os.stat(infile).st_mode & stat.S_IEXEC: os.stat(infile).st_mode & stat.S_IEXEC:
if os.path.isdir(infile) and os.path.isdir(outfile): if os.path.isdir(infile) and os.path.isdir(outfile):
return (outfile, 0) return outfile, 0
# Adjust interpreter of OpenVPN hooks. # Adjust interpreter of OpenVPN hooks.
with open(infile) as src: with open(infile) as src:
first_line = src.readline() first_line = src.readline()
...@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw): ...@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw):
patched += src.read() patched += src.read()
dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC) dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC)
try: try:
os.write(dst, patched) os.write(dst, patched.encode())
finally: finally:
os.close(dst) os.close(dst)
return outfile, 1 return outfile, 1
...@@ -51,7 +53,8 @@ Environment :: Console ...@@ -51,7 +53,8 @@ Environment :: Console
License :: OSI Approved :: GNU General Public License (GPL) License :: OSI Approved :: GNU General Public License (GPL)
Natural Language :: English Natural Language :: English
Operating System :: POSIX :: Linux Operating System :: POSIX :: Linux
Programming Language :: Python :: 2.7 Programming Language :: Python :: 3
Programming Language :: Python :: 3.11
Topic :: Internet Topic :: Internet
Topic :: System :: Networking Topic :: System :: Networking
""" """
...@@ -73,6 +76,7 @@ setup( ...@@ -73,6 +76,7 @@ setup(
license = 'GPL 2+', license = 'GPL 2+',
platforms = ["any"], platforms = ["any"],
classifiers=classifiers.splitlines(), classifiers=classifiers.splitlines(),
python_requires = '>=3.11',
long_description = ".. contents::\n\n" + open('README.rst').read() long_description = ".. contents::\n\n" + open('README.rst').read()
+ "\n" + open('CHANGES.rst').read() + git_rev, + "\n" + open('CHANGES.rst').read() + git_rev,
packages = find_packages(), packages = find_packages(),
...@@ -95,7 +99,7 @@ setup( ...@@ -95,7 +99,7 @@ setup(
extras_require = { extras_require = {
'geoip': ['geoip2'], 'geoip': ['geoip2'],
'multicast': ['PyYAML'], 'multicast': ['PyYAML'],
'test': ['mock', 'pathlib2', 'nemu', 'python-unshare', 'python-passfd', 'multiping'] 'test': ['mock', 'nemu3', 'unshare', 'multiping']
}, },
#dependency_links = [ #dependency_links = [
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7", # "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment