Commit 7f3ca710 authored by Tom Niget's avatar Tom Niget Committed by Tom Niget

py2to3: migrate to python3

style: stop using deprecated getargspec for rpc


py2to3: migrate setup.py


style: remove new-style class code from the Python2 days


demo: properly escape screen command parameters



bug: fix wrong sending of HMAC header causing test failure


py2to3: revert changes from 90a624f0 in debian/control and re6stnet.spec

setup: indicate minimum Python versions

doc: update dependencies in readme
parent f2fd7247
......@@ -5,3 +5,4 @@
/build/
/dist/
/re6stnet.egg-info/
.idea
\ No newline at end of file
......@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes.
Requirements
============
- Python 2.7
- Python 3.11
- OpenSSL binary and development libraries
- OpenVPN 2.4.*
- Babel_ (with Nexedi patches)
......
#!/usr/bin/python2
import argparse, math, nemu, os, re, signal
#!/usr/bin/env python3
import argparse, math, nemu, os, re, shlex, signal
import socket, sqlite3, subprocess, sys, time, weakref
from collections import defaultdict
from contextlib import contextmanager
from threading import Thread
from typing import Optional
IPTABLES = 'iptables'
SCREEN = 'screen'
VERBOSE = 4
......@@ -14,9 +16,9 @@ REGISTRY2_SERIAL = '0x120010db80043'
CA_DAYS = 1000
# Quick check to avoid wasting time if there is an error.
with open(os.devnull, "wb") as f:
for x in 're6stnet', 're6st-conf', 're6st-registry':
subprocess.check_call(('./py', x, '--help'), stdout=f)
for x in 're6stnet', 're6st-conf', 're6st-registry':
subprocess.check_call(('./py', x, '--help'), stdout=subprocess.DEVNULL)
#
# Underlying network:
#
......@@ -46,23 +48,26 @@ with open(os.devnull, "wb") as f:
def disable_signal_on_children(sig):
pid = os.getpid()
sigint = signal.signal(sig, lambda *x: os.getpid() == pid and sigint(*x))
disable_signal_on_children(signal.SIGINT)
Node__add_interface = nemu.Node._add_interface
def _add_interface(node, iface):
iface.__dict__['node'] = weakref.proxy(node)
return Node__add_interface(node, iface)
nemu.Node._add_interface = _add_interface
parser = argparse.ArgumentParser()
parser.add_argument('port', type = int,
help = 'port used to display tunnels')
parser.add_argument('-d', '--duration', type = int,
help = 'time of the demo execution in seconds')
parser.add_argument('-p', '--ping', action = 'store_true',
help = 'execute ping utility')
parser.add_argument('-m', '--hmac', action = 'store_true',
help = 'execute HMAC test')
parser.add_argument('port', type=int,
help='port used to display tunnels')
parser.add_argument('-d', '--duration', type=int,
help='time of the demo execution in seconds')
parser.add_argument('-p', '--ping', action='store_true',
help='execute ping utility')
parser.add_argument('-m', '--hmac', action='store_true',
help='execute HMAC test')
args = parser.parse_args()
def handler(signum, frame):
......@@ -72,33 +77,56 @@ if args.duration:
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.duration)
execfile("fixnemu.py")
exec(compile(open("fixnemu.py", "rb").read(), "fixnemu.py", 'exec'))
class Re6stNode(nemu.Node):
name: str
short: str
re6st_cmdline: Optional[list[str]]
def __init__(self, name, short):
super().__init__()
self.name = name
self.short = short
self.Popen(('sysctl', '-q',
'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait()
self._screen = self.Popen((SCREEN, '-DmS', name))
self.re6st_cmdline = None
def screen(self, command: list[str]):
runner_cmd = 'set -- %s; "\\$@"; echo "\\$@"; exec $SHELL' % ' '.join(map(shlex.quote, command))
inner_cmd = [
'screen', 'sh', '-c', runner_cmd
]
cmd = [
SCREEN, '-r', self.name, '-X', 'eval', shlex.join(inner_cmd)
]
return subprocess.call(cmd)
# create nodes
for name in """internet=I registry=R
gateway1=g1 machine1=1 machine2=2
gateway2=g2 machine3=3 machine4=4 machine5=5
machine6=6 machine7=7 machine8=8 machine9=9
registry2=R2 machine10=10
""".split():
name, short = name.split('=')
globals()[name] = node = nemu.Node()
node.name = name
node.short = short
node.Popen(('sysctl', '-q',
'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait()
node._screen = node.Popen((SCREEN, '-DmS', name))
node.screen = (lambda name: lambda *cmd:
subprocess.call([SCREEN, '-r', name, '-X', 'eval'] + map(
"""screen sh -c 'set %s; "\$@"; echo "\$@"; exec $SHELL'"""
.__mod__, cmd)))(name)
internet = Re6stNode('internet', 'I')
registry = Re6stNode('registry', 'R')
gateway1 = Re6stNode('gateway1', 'g1')
machine1 = Re6stNode('machine1', '1')
machine2 = Re6stNode('machine2', '2')
gateway2 = Re6stNode('gateway2', 'g2')
machine3 = Re6stNode('machine3', '3')
machine4 = Re6stNode('machine4', '4')
machine5 = Re6stNode('machine5', '5')
machine6 = Re6stNode('machine6', '6')
machine7 = Re6stNode('machine7', '7')
machine8 = Re6stNode('machine8', '8')
machine9 = Re6stNode('machine9', '9')
registry2 = Re6stNode('registry2', 'R2')
machine10 = Re6stNode('machine10', '10')
# create switch
switch1 = nemu.Switch()
switch2 = nemu.Switch()
switch3 = nemu.Switch()
#create interfaces
# create interfaces
re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet)
in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1)
in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2)
......@@ -179,12 +207,14 @@ m6_if_0.add_v6_address(address='fc42:6::1', prefix_len=16)
m7_if_0.add_v6_address(address='fc42:7::1', prefix_len=16)
m8_if_0.add_v6_address(address='fc42:8::1', prefix_len=16)
def add_llrtr(iface, peer, dst='default'):
for a in peer.get_addresses():
a = a['address']
if a.startswith('fe80:'):
return iface.node.Popen(('ip', 'route', 'add', dst, 'via', a,
'proto', 'static', 'dev', iface.name)).wait()
'proto', 'static', 'dev', iface.name)).wait()
# setup routes
add_llrtr(re_if_0, in_if_0)
......@@ -205,19 +235,19 @@ for m in machine6, machine7, machine8:
# Test connectivity first. Run process, hide output and check
# return code
null = file(os.devnull, "r+")
for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3':
if machine1.Popen(('ping', '-c1', ip), stdout=null).wait():
print 'Failed to ping %s' % ip
if machine1.Popen(('ping', '-c1', ip), stdout=subprocess.DEVNULL).wait():
print('Failed to ping', ip)
break
else:
print "Connectivity IPv4 OK!"
print("Connectivity IPv4 OK!")
nodes: list[Re6stNode] = []
gateway1.screen(['miniupnpd', '-d', '-f', 'miniupnpd.conf', '-P', 'miniupnpd.pid',
'-a', g1_if_1.name, '-i', g1_if_0_name])
nodes = []
gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid'
' -a %s -i %s' % (g1_if_1.name, g1_if_0_name))
@contextmanager
def new_network(registry, reg_addr, serial, ca):
def new_network(registry: Re6stNode, reg_addr: str, serial: str, ca: str):
from OpenSSL import crypto
import hashlib, sqlite3
os.path.exists(ca) or subprocess.check_call(
......@@ -225,16 +255,17 @@ def new_network(registry, reg_addr, serial, ca):
" -subj /CN=re6st.example.com/emailAddress=re6st@example.com"
" -set_serial %s -days %u"
% (registry.name, ca, serial, CA_DAYS), shell=True)
with open(ca) as f:
with open(ca, "rb") as f:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
fingerprint = "sha256:" + hashlib.sha256(
crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)).hexdigest()
db_path = "%s/registry.db" % registry.name
registry.screen("./py re6st-registry @%s/re6st-registry.conf"
" --db %s --mailhost %s -v%u"
% (registry.name, db_path, os.path.abspath('mbox'), VERBOSE))
registry.screen([
sys.executable, './py', 're6st-registry', '@%s/re6st-registry.conf' % registry.name,
'--db', db_path, '--mailhost', os.path.abspath('mbox'), '-v%u' % VERBOSE
])
registry_url = 'http://%s/' % reg_addr
registry.Popen(('python', '-c', """if 1:
registry.Popen((sys.executable, '-c', """if 1:
import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True:
......@@ -245,16 +276,17 @@ def new_network(registry, reg_addr, serial, ca):
time.sleep(.1)
""")).wait()
db = sqlite3.connect(db_path, isolation_level=None)
def new_node(node, folder, args='', prefix_len=None, registry=registry_url):
def new_node(node: Re6stNode, folder: str, args: list[str]=[], prefix_len: Optional[int] = None, registry=registry_url):
nodes.append(node)
if not os.path.exists(folder + '/cert.crt'):
dh_path = folder + '/dh2048.pem'
if not os.path.exists(dh_path):
os.symlink('../dh2048.pem', dh_path)
email = node.name + '@example.com'
p = node.Popen(('../py', 're6st-conf', '--registry', registry,
p = node.Popen((sys.executable, '../py', 're6st-conf', '--registry', registry,
'--email', email, '--fingerprint', fingerprint),
stdin=subprocess.PIPE, cwd=folder)
stdin=subprocess.PIPE, cwd=folder)
token = None
while not token:
time.sleep(.1)
......@@ -266,27 +298,30 @@ def new_network(registry, reg_addr, serial, ca):
p.communicate(str(token[0]))
os.remove(dh_path)
os.remove(folder + '/ca.crt')
node.re6st_cmdline = (
'./py re6stnet @%s/re6stnet.conf -v%u --registry %s'
' --console %s/run/console.sock %s'
) % (folder, VERBOSE, registry, folder, args)
node.re6st_cmdline = [
sys.executable, './py', 're6stnet', '@%s/re6stnet.conf' % folder,
'-v%u' % VERBOSE, '--registry', registry, '--console', '%s/run/console.sock' % folder,
*args
]
node.screen(node.re6st_cmdline)
new_node(registry, registry.name, '--ip ' + reg_addr, registry='http://localhost/')
new_node(registry, registry.name, ['--ip', reg_addr], registry='http://localhost/')
yield new_node
db.close()
with new_network(registry, REGISTRY, REGISTRY_SERIAL, 'ca.crt') as new_node:
new_node(machine1, 'm1', '-I%s' % m1_if_0.name)
new_node(machine2, 'm2', '--remote-gateway 10.1.1.1', prefix_len=77)
new_node(machine3, 'm3', '-i%s' % m3_if_0.name)
new_node(machine4, 'm4', '-i%s' % m4_if_0.name)
new_node(machine5, 'm5', '-i%s' % m5_if_0.name)
new_node(machine6, 'm6', '-I%s' % m6_if_1.name)
new_node(machine1, 'm1', ['-I%s' % m1_if_0.name])
new_node(machine2, 'm2', ['--remote-gateway', '10.1.1.1'], prefix_len=77)
new_node(machine3, 'm3', ['-i%s' % m3_if_0.name])
new_node(machine4, 'm4', ['-i%s' % m4_if_0.name])
new_node(machine5, 'm5', ['-i%s' % m5_if_0.name])
new_node(machine6, 'm6', ['-I%s' % m6_if_1.name])
new_node(machine7, 'm7')
new_node(machine8, 'm8')
with new_network(registry2, REGISTRY2, REGISTRY2_SERIAL, 'ca2.crt') as new_node:
new_node(machine10, 'm10', '-i%s' % m10_if_0.name)
new_node(machine10, 'm10', ['-i%s' % m10_if_0.name])
if args.ping:
for j, machine in enumerate(nodes):
......@@ -297,60 +332,64 @@ if args.ping:
'2001:db8:43:1::1' if i == 10 else
# Only 1 address for machine2 because prefix_len = 80,+48 = 128
'2001:db8:42:%s::1' % i
for i in xrange(11)
for i in range(11)
if i != j]
name = machine.name if machine.short[0] == 'R' else 'm' + machine.short
machine.screen('python ping.py {} {}'.format(name, ' '.join(ips)))
machine.screen(['python', 'ping.py', name] + ips)
class testHMAC(Thread):
def run(self):
updateHMAC = ('python', '-c', "import urllib, sys; sys.exit("
"204 != urllib.urlopen('http://127.0.0.1/updateHMAC').code)")
"204 != urllib.urlopen('http://127.0.0.1/updateHMAC').code)")
reg1_db = sqlite3.connect('registry/registry.db', isolation_level=None,
check_same_thread=False)
check_same_thread=False)
reg2_db = sqlite3.connect('registry2/registry.db', isolation_level=None,
check_same_thread=False)
check_same_thread=False)
reg1_db.text_factory = reg2_db.text_factory = str
m_net1 = 'registry', 'm1', 'm2', 'm3', 'm4', 'm5', 'm6', 'm7', 'm8'
m_net2 = 'registry2', 'm10'
print 'Testing HMAC, letting the time to machines to create tunnels...'
print('Testing HMAC, letting the time to machines to create tunnels...')
time.sleep(45)
print 'Check that the initial HMAC config is deployed on network 1'
print('Check that the initial HMAC config is deployed on network 1')
test_hmac.checkHMAC(reg1_db, m_net1)
print 'Test that a HMAC update works with nodes that are up'
print('Test that a HMAC update works with nodes that are up')
registry.backticks_raise(updateHMAC)
print 'Updated HMAC (config = hmac0 & hmac1), waiting...'
print('Updated HMAC (config = hmac0 & hmac1), waiting...')
time.sleep(60)
print 'Checking HMAC on machines connected to registry 1...'
print('Checking HMAC on machines connected to registry 1...')
test_hmac.checkHMAC(reg1_db, m_net1)
print ('Test that machines can update upon reboot ' +
print('Test that machines can update upon reboot '
'when they were off during a HMAC update.')
test_hmac.killRe6st(machine1)
print 'Re6st on machine 1 is stopped'
print('Re6st on machine 1 is stopped')
time.sleep(5)
registry.backticks_raise(updateHMAC)
print 'Updated HMAC on registry (config = hmac1 & hmac2), waiting...'
print('Updated HMAC on registry (config = hmac1 & hmac2), waiting...')
time.sleep(60)
machine1.screen(machine1.re6st_cmdline)
print 'Started re6st on machine 1, waiting for it to get new conf'
print('Started re6st on machine 1, waiting for it to get new conf')
time.sleep(60)
print 'Checking HMAC on machines connected to registry 1...'
print('Checking HMAC on machines connected to registry 1...')
test_hmac.checkHMAC(reg1_db, m_net1)
print 'Testing of HMAC done!'
print('Testing of HMAC done!')
# TODO: missing last step
reg1_db.close()
reg2_db.close()
if args.hmac:
import test_hmac
t = testHMAC()
t.deamon = 1
t.start()
del t
_ll = {}
def node_by_ll(addr):
_ll: dict[str, tuple[Re6stNode, bool]] = {}
def node_by_ll(addr: str) -> tuple[Re6stNode, bool]:
try:
return _ll[addr]
except KeyError:
......@@ -368,24 +407,26 @@ def node_by_ll(addr):
if a.startswith('10.42.'):
assert not p % 8
_ll[socket.inet_ntoa(socket.inet_aton(
a)[:p/8].ljust(4, '\0'))] = n, t
a)[:p // 8].ljust(4, b'\0'))] = n, t
elif a.startswith('2001:db8:'):
assert not p % 8
a = socket.inet_ntop(socket.AF_INET6,
socket.inet_pton(socket.AF_INET6,
a)[:p/8].ljust(16, '\0'))
socket.inet_pton(socket.AF_INET6,
a)[:p // 8].ljust(16, b'\0'))
elif not a.startswith('fe80::'):
continue
_ll[a] = n, t
return _ll[addr]
def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
graph = {}
def route_svg(ipv4, z=4, default=type('', (), {'short': None})):
graph: dict[Re6stNode, dict[tuple[Re6stNode, bool], list[Re6stNode]]] = {}
for n in nodes:
g = graph[n] = defaultdict(list)
g: dict[tuple[Re6stNode, bool], list[Re6stNode]]
for r in n.get_routes():
if (r.prefix and r.prefix.startswith('10.42.') if ipv4 else
r.prefix is None or r.prefix.startswith('2001:db8:')):
r.prefix is None or r.prefix.startswith('2001:db8:')):
try:
g[node_by_ll(r.nexthop)].append(
node_by_ll(r.prefix)[0] if r.prefix else default)
......@@ -396,40 +437,42 @@ def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
a = 2 * math.pi / N
edges = set()
for i, n in enumerate(nodes):
i: int
gv.append('%s[pos="%s,%s!"];'
% (n.name, z * math.cos(a * i), z * math.sin(a * i)))
% (n.name, z * math.cos(a * i), z * math.sin(a * i)))
l = []
for p, r in graph[n].iteritems():
j = abs(nodes.index(p[0]) - i)
for p, r in graph[n].items():
p: tuple[Re6stNode, bool]
r: list[Re6stNode]
j: int = abs(nodes.index(p[0]) - i)
l.append((min(j, N - j), p, r))
for j, (l, (p, t), r) in enumerate(sorted(l)):
l = []
for j, (_, (p2, t), r) in enumerate(sorted(l, key=lambda x: x[0])):
p2: Re6stNode
l2: list[str] = []
arrowhead = 'none'
for r in sorted(r.short for r in r):
if r:
if r == p.short:
r = '<font color="grey">%s</font>' % r
l.append(r)
for r2 in sorted(r2.short for r2 in r):
if r2:
if r2 == p2.short:
r2 = '<font color="grey">%s</font>' % r2
l2.append(r2)
else:
arrowhead = 'dot'
if (n.name, p.name) in edges:
r = 'penwidth=0'
if (n.name, p2.name) in edges:
r3 = 'penwidth=0'
else:
edges.add((p.name, n.name))
r = 'style=solid' if t else 'style=dashed'
edges.add((p2.name, n.name))
r3 = 'style=solid' if t else 'style=dashed'
gv.append(
'%s -> %s [labeldistance=%u, headlabel=<%s>, arrowhead=%s, %s];'
% (p.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l),
arrowhead, r))
% (p2.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l2),
arrowhead, r3))
gv.append('}\n')
return subprocess.Popen(('neato', '-Tsvg'),
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
).communicate('\n'.join(gv))[0]
return subprocess.run(('neato', '-Tsvg'), check=True, text=True, capture_output=True, input='\n'.join(gv)).stdout
if args.port:
import SimpleHTTPServer, SocketServer
import http.server, socketserver
class Handler(SimpleHTTPServer.SimpleHTTPRequestHandler):
class Handler(http.server.SimpleHTTPRequestHandler):
_path_match = re.compile('/(.+)\.(html|svg)$').match
pages = 'ipv6', 'ipv4', 'tunnels'
......@@ -439,7 +482,7 @@ if args.port:
try:
name, ext = self._path_match(self.path).groups()
page = self.pages.index(name)
except AttributeError, ValueError:
except AttributeError as ValueError:
if self.path == '/':
self.send_response(302)
self.send_header('Location', self.pages[0] + '.html')
......@@ -450,34 +493,34 @@ if args.port:
if page < 2:
body = route_svg(page)
else:
body = registry.Popen(('python', '-c', r"""if 1:
body = registry.Popen(('python3', '-c', r"""if 1:
import math, json
from re6st.registry import RegistryClient
g = json.loads(RegistryClient(
'http://localhost/').topology())
r = set(g.pop('', ()))
a = set()
for v in g.itervalues():
for v in g.values():
a.update(v)
g.update(dict.fromkeys(a.difference(g), ()))
print 'digraph {'
print('digraph {')
a = 2 * math.pi / len(g)
z = 4
m2 = '%u/80' % (2 << 64)
title = lambda n: '2|80' if n == m2 else n
g = sorted((title(k), k in r, v) for k, v in g.iteritems())
g = sorted((title(k), k in r, v) for k, v in g.items())
for i, (n, r, v) in enumerate(g):
print '"%s"[pos="%s,%s!"%s];' % (title(n),
print('"%s"[pos="%s,%s!"%s];' % (title(n),
z * math.cos(a * i), z * math.sin(a * i),
'' if r else ', style=dashed')
'' if r else ', style=dashed'))
for v in v:
print '"%s" -> "%s";' % (n, title(v))
print '}'
"""), stdout=subprocess.PIPE, cwd="..").communicate()[0]
print('"%s" -> "%s";' % (n, title(v)))
print('}')
"""), stdout=subprocess.PIPE, cwd="..").communicate()[0].decode("utf-8")
if body:
body = subprocess.Popen(('neato', '-Tsvg'),
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
).communicate(body)[0]
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
).communicate(body.encode("utf-8"))[0].decode("utf-8")
if not body:
self.send_error(500)
return
......@@ -504,18 +547,19 @@ if args.port:
%s
</body>
</html>""" % (name, ' '.join(x if i == page else
'<a href="%s.html">%s</a>' % (x, x)
for i, x in enumerate(self.pages)),
body[body.find('<svg'):])
'<a href="%s.html">%s</a>' % (x, x)
for i, x in enumerate(self.pages)),
body[body.find('<svg'):])
self.send_response(200)
self.send_header('Content-Length', len(body))
body = body.encode("utf-8")
self.send_header('Content-Length', str(len(body)))
self.send_header('Content-type', mt + '; charset=utf-8')
self.end_headers()
self.wfile.write(body)
class TCPServer(SocketServer.TCPServer):
class TCPServer(socketserver.TCPServer):
allow_reuse_address = True
TCPServer(('', args.port), Handler).serve_forever()
import pdb; pdb.set_trace()
breakpoint()
......@@ -18,10 +18,10 @@
import re
import os
from new import function
from nemu.iproute import backticks, get_if_data, route, \
get_addr_data, get_all_route_data, interface
from nemu.interface import Switch, Interface
from types import FunctionType
def _get_all_route_data():
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
......@@ -56,7 +56,7 @@ def _get_all_route_data():
metric))
return ret
get_all_route_data.func_code = _get_all_route_data.func_code
get_all_route_data.__code__ = _get_all_route_data.__code__
interface__init__ = interface.__init__
def __init__(self, *args, **kw):
......@@ -65,12 +65,12 @@ def __init__(self, *args, **kw):
self.name = self.name.split('@',1)[0]
interface.__init__ = __init__
get_addr_data.orig = function(get_addr_data.func_code,
get_addr_data.func_globals)
get_addr_data.orig = FunctionType(get_addr_data.__code__,
get_addr_data.__globals__)
def _get_addr_data():
byidx, bynam = get_addr_data.orig()
return byidx, {name.split('@',1)[0]: a for name, a in bynam.iteritems()}
get_addr_data.func_code = _get_addr_data.func_code
return byidx, {name.split('@',1)[0]: a for name, a in bynam.items()}
get_addr_data.__code__ = _get_addr_data.__code__
@staticmethod
def _gen_if_name():
......
#!/usr/bin/env python
#!/usr/bin/env python3
def __file__():
import argparse, os, sys
sys.dont_write_bytecode = True
......@@ -30,4 +30,5 @@ def __file__():
return os.path.join(sys.path[0], sys.argv[1])
__file__ = __file__()
execfile(__file__)
with open(__file__) as f:
exec(compile(f.read(), __file__, 'exec'))
......@@ -34,7 +34,7 @@ def checkHMAC(db, machines):
else:
i = 0 if hmac[0] else 1
if hmac[i] != sign or hmac[i+1] != accept:
print 'HMAC config wrong for in %s' % args
print('HMAC config wrong for in %s' % args)
rc = False
if rc:
print('All nodes use Babel with the correct HMAC configuration')
......
......@@ -11,7 +11,7 @@ from re6st.registry import RegistryServer
@apply
class proxy(object):
class proxy:
def __init__(self):
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
......
import json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
import base64, json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
from itertools import chain
from .registry import RegistryClient
from . import utils, version, x509
class Cache(object):
class Cache:
def __init__(self, db_path, registry, cert, db_size=200):
def __init__(self, db_path, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix
self._db_size = db_size
self._decrypt = cert.decrypt
......@@ -65,7 +65,7 @@ class Cache(object):
@staticmethod
def _selectConfig(execute): # BBB: blob
return ((k, str(v) if type(v) is buffer else v)
return ((k, str(v) if type(v) is memoryview else v)
for k, v in execute("SELECT * FROM config"))
def _loadConfig(self, config):
......@@ -89,24 +89,24 @@ class Cache(object):
logging.info("Getting new network parameters from registry...")
try:
# TODO: When possible, the registry should be queried via the re6st.
network_config = self._registry.getNetworkConfig(self._prefix)
logging.debug('getNetworkConfig result: %r', network_config)
x = json.loads(zlib.decompress(
self._registry.getNetworkConfig(self._prefix)))
base64 = x.pop('', ())
network_config))
base64_list = x.pop('', ())
config = {}
for k, v in x.iteritems():
for k, v in x.items():
k = str(k)
if k.startswith('babel_hmac'):
if v:
v = self._decrypt(v.decode('base64'))
elif k in base64:
v = v.decode('base64')
elif type(v) is unicode:
v = str(v)
v = self._decrypt(base64.b64decode(v))
elif k in base64_list:
v = base64.b64decode(v)
elif isinstance(v, (list, dict)):
k += ':json'
v = json.dumps(v)
config[k] = v
except socket.error, e:
except socket.error as e:
logging.warning(e)
return
except Exception:
......@@ -133,13 +133,13 @@ class Cache(object):
# BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6
db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)",
((k, buffer(v) if k in base64 or
((k, memoryview(v) if k in base64_list or
k.startswith('babel_hmac') else v)
for k, v in config.iteritems()))
self._loadConfig(config.iteritems())
for k, v in config.items()))
self._loadConfig(config.items())
return [k[:-5] if k.endswith(':json') else k
for k in chain(remove, (k
for k, v in config.iteritems()
for k, v in config.items()
if k not in old or old[k] != v))]
def warnProtocol(self):
......@@ -232,15 +232,16 @@ class Cache(object):
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"):
return self._db.execute(__sql, (self._prefix, failed))
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)"):
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)") -> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self):
logging.info('Getting Boot peer...')
try:
bootpeer = self._registry.getBootstrapPeer(self._prefix)
prefix, address = self._decrypt(bootpeer).split()
except (socket.error, subprocess.CalledProcessError, ValueError), e:
prefix, address = self._decrypt(bootpeer).decode().split()
except (socket.error, subprocess.CalledProcessError, ValueError) as e:
logging.warning('Failed to bootstrap (%s)',
e if bootpeer else 'no peer returned')
else:
......@@ -275,6 +276,6 @@ class Cache(object):
def getCountry(self, ip):
try:
return self._registry.getCountry(self._prefix, ip)
except socket.error, e:
return self._registry.getCountry(self._prefix, ip).decode()
except socket.error as e:
logging.warning('Failed to get country (%s)', ip)
#!/usr/bin/python2
#!/usr/bin/env python3
import argparse, atexit, binascii, errno, hashlib
import os, subprocess, sqlite3, sys, time
from OpenSSL import crypto
......@@ -6,14 +6,14 @@ if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, x509
def create(path, text=None, mode=0666):
def create(path, text=None, mode=0o666):
fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, mode)
try:
os.write(fd, text)
finally:
os.close(fd)
def loadCert(pem):
def loadCert(pem: bytes):
return crypto.load_certificate(crypto.FILETYPE_PEM, pem)
def main():
......@@ -68,12 +68,12 @@ def main():
fingerprint = binascii.a2b_hex(fingerprint)
if hashlib.new(alg).digest_size != len(fingerprint):
raise ValueError("wrong size")
except StandardError, e:
except Exception as e:
parser.error("invalid fingerprint: %s" % e)
if x509.fingerprint(ca, alg).digest() != fingerprint:
sys.exit("CA fingerprint doesn't match")
else:
print "WARNING: it is strongly recommended to use --fingerprint option."
print("WARNING: it is strongly recommended to use --fingerprint option.")
network = x509.networkFromCa(ca)
if config.is_needed:
route, err = subprocess.Popen(('ip', '-6', '-o', 'route', 'get',
......@@ -91,17 +91,17 @@ def main():
try:
with open(cert_path) as f:
cert = loadCert(f.read())
components = dict(cert.get_subject().get_components())
components = {k.decode(): v for k, v in cert.get_subject().get_components()}
for k in reserved:
components.pop(k, None)
except IOError, e:
except IOError as e:
if e.errno != errno.ENOENT:
raise
components = {}
if config.req:
components.update(config.req)
subj = req.get_subject()
for k, v in components.iteritems():
for k, v in components.items():
if k in reserved:
sys.exit(k + " field is reserved.")
if v:
......@@ -116,35 +116,35 @@ def main():
token = ''
elif not token:
if not config.email:
config.email = raw_input('Please enter your email address: ')
config.email = input('Please enter your email address: ')
s.requestToken(config.email)
token_advice = "Use --token to retry without asking a new token\n"
while not token:
token = raw_input('Please enter your token: ')
token = input('Please enter your token: ')
try:
with open(key_path) as f:
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
key = None
print "Reusing existing key."
except IOError, e:
print("Reusing existing key.")
except IOError as e:
if e.errno != errno.ENOENT:
raise
bits = ca.get_pubkey().bits()
print "Generating %s-bit key ..." % bits
print("Generating %s-bit key ..." % bits)
pkey = crypto.PKey()
pkey.generate_key(crypto.TYPE_RSA, bits)
key = crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)
create(key_path, key, 0600)
create(key_path, key, 0o600)
req.set_pubkey(pkey)
req.sign(pkey, 'sha512')
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req)
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req).decode()
# First make sure we can open certificate file for writing,
# to avoid using our token for nothing.
cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0666)
print "Requesting certificate ..."
cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0o666)
print("Requesting certificate ...")
if config.location:
cert = s.requestCertificate(token, req, location=config.location)
else:
......@@ -173,7 +173,7 @@ def main():
key_path))
if not os.path.lexists(conf_path):
create(conf_path, """\
create(conf_path, ("""\
registry %s
ca %s
cert %s
......@@ -187,14 +187,14 @@ key %s
#O--verb
#O3
""" % (config.registry, ca_path, cert_path, key_path,
('country ' + config.location.split(',', 1)[0]) \
if config.location else ''))
print "Sample configuration file created."
('country ' + config.location.split(',', 1)[0])
if config.location else '')).encode())
print("Sample configuration file created.")
cn = x509.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \
% (utils.ipFromBin(subnet), len(subnet), cn)
print("Your subnet: %s/%u (CN=%s)"
% (utils.ipFromBin(subnet), len(subnet), cn))
if __name__ == "__main__":
main()
#!/usr/bin/python2
#!/usr/bin/env python3
import atexit, errno, logging, os, shutil, signal
import socket, struct, subprocess, sys
from collections import deque
......@@ -246,7 +246,7 @@ def main():
try:
from re6st.upnpigd import Forwarder
forwarder = Forwarder('re6stnet openvpn server')
except Exception, e:
except Exception as e:
if ipv4:
raise
logging.info("%s: assume we are not NATed", e)
......@@ -299,7 +299,7 @@ def main():
timeout = 4 * cache.hello
cleanup = [lambda: cache.cacheMinimize(config.client_count),
lambda: shutil.rmtree(config.run, True)]
utils.makedirs(config.run, 0700)
utils.makedirs(config.run, 0o700)
control_socket = os.path.join(config.run, 'babeld.sock')
if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(control_socket,
......@@ -362,7 +362,7 @@ def main():
if not dh:
dh = os.path.join(config.state, "dh.pem")
cache.getDh(dh)
for iface, (port, proto) in server_tunnels.iteritems():
for iface, (port, proto) in server_tunnels.items():
r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
utils.setCloexec(r)
cleanup.append(plib.server(iface, config.max_clients,
......@@ -442,7 +442,7 @@ def main():
except:
pass
exit.release()
except ReexecException, e:
except ReexecException as e:
logging.info(e)
except Exception:
utils.log_exception()
......@@ -455,7 +455,7 @@ def main():
if __name__ == "__main__":
try:
main()
except SystemExit, e:
except SystemExit as e:
if type(e.code) is str:
if hasattr(logging, 'trace'): # utils.setupLog called
logging.critical(e.code)
......
#!/usr/bin/python2
import httplib, logging, os, socket, sys
from BaseHTTPServer import BaseHTTPRequestHandler
from SocketServer import ThreadingTCPServer
from urlparse import parse_qsl
#!/usr/bin/env python3
import http.client, logging, os, socket, sys
from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer
from urllib.parse import parse_qsl
if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, version
......@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler):
path = self.path
query = {}
else:
query = dict(parse_qsl(query, keep_blank_values=1,
strict_parsing=1))
query = dict(parse_qsl(query, keep_blank_values=True,
strict_parsing=True))
_, path = path.split('/')
if not _:
return self.server.handle_request(self, path, query)
except Exception:
logging.info(self.requestline, exc_info=1)
self.send_error(httplib.BAD_REQUEST)
logging.info(self.requestline, exc_info=True)
self.send_error(http.client.BAD_REQUEST)
def log_error(*args):
pass
......
......@@ -5,7 +5,7 @@ from . import utils
uint16 = struct.Struct("!H")
header = struct.Struct("!HI")
class Struct(object):
class Struct:
def __init__(self, format, *args):
if args:
......@@ -29,39 +29,39 @@ class Struct(object):
self.encode = encode
self.decode = decode
class Array(object):
class Array:
def __init__(self, item):
self._item = item
def encode(self, buffer, value):
def encode(self, buffer: bytes, value: list):
buffer += uint16.pack(len(value))
encode = self._item.encode
for value in value:
encode(buffer, value)
def decode(self, buffer, offset=0):
def decode(self, buffer: bytes, offset=0) -> tuple[int, list]:
r = []
o = offset + 2
decode = self._item.decode
for i in xrange(*uint16.unpack_from(buffer, offset)):
for i in range(*uint16.unpack_from(buffer, offset)):
o, x = decode(buffer, o)
r.append(x)
return o, r
class String(object):
class String:
@staticmethod
def encode(buffer, value):
buffer += value + "\0"
def encode(buffer: bytes, value: str):
buffer += value.encode("utf-8") + b'\x00'
@staticmethod
def decode(buffer, offset=0):
i = buffer.index("\0", offset)
return i + 1, buffer[offset:i]
def decode(buffer: bytes, offset=0) -> tuple[int, str]:
i = buffer.index(0, offset)
return i + 1, buffer[offset:i].decode("utf-8")
class Buffer(object):
class Buffer:
def __init__(self):
self._buf = bytearray()
......@@ -104,21 +104,6 @@ class Buffer(object):
self._seek(r)
return value
try: # BBB: Python < 2.7.4 (http://bugs.python.org/issue10212)
uint16.unpack_from(bytearray(uint16.size))
except TypeError:
def unpack_from(self, struct):
r = self._r
x = r + struct.size
value = struct.unpack(buffer(self._buf)[r:x])
self._seek(x)
return value
def decode(self, decode):
r = self._r
size, value = decode(buffer(self._buf)[r:])
self._seek(r + size)
return value
# writing
def send(self, socket, *args):
......@@ -129,7 +114,7 @@ class Buffer(object):
struct.pack_into(self._buf, offset, *args)
class Packet(object):
class Packet:
response_dict = {}
......@@ -149,7 +134,7 @@ class Packet(object):
logging.trace('send %s%r', self.__class__.__name__,
(self.id,) + self.args)
offset = len(buffer)
buffer += '\0' * header.size
buffer += b'\x00' * header.size
r = self.request
if isinstance(r, Struct):
r.encode(buffer, self.args)
......@@ -182,7 +167,7 @@ class ConnectionClosed(BabelException):
return "connection to babeld closed (%s)" % self.args
class Babel(object):
class Babel:
_decode = None
......@@ -206,11 +191,11 @@ class Babel(object):
def select(*args):
try:
s.connect(self.socket_path)
except socket.error, e:
except socket.error as e:
logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e
s.send("\1")
s.setblocking(0)
s.send(b'\x01')
s.setblocking(False)
del self.select
self.socket = s
return self.select(*args)
......@@ -269,7 +254,7 @@ class Babel(object):
a = len(self.network)
for route in routes:
assert route.flags & 1, route # installed
if route.prefix.startswith('\0\0\0\0\0\0\0\0\0\0\xff\xff'):
if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'):
continue
assert route.neigh_address == route.nexthop, route
address = route.neigh_address, route.ifindex
......@@ -310,7 +295,7 @@ class Babel(object):
pass
class iterRoutes(object):
class iterRoutes:
_waiting = True
......@@ -323,7 +308,7 @@ class iterRoutes(object):
c.select(*args)
utils.select(*args)
return (prefix
for neigh_routes in c.neighbours.itervalues()
for neigh_routes in c.neighbours.values()
for prefix in neigh_routes[1]
if prefix)
......
import errno, os, socket, stat, threading
class Socket(object):
class Socket:
def __init__(self, socket):
# In case that the default timeout is not None.
......@@ -37,14 +37,14 @@ class Socket(object):
try:
self._socket.recv(0)
return True
except socket.error, (err, _):
if err != errno.EAGAIN:
except socket.error as e:
if e.errno != errno.EAGAIN:
raise
self._socket.setblocking(1)
return False
class Console(object):
class Console:
def __init__(self, path, pdb):
self.path = path
......@@ -52,7 +52,7 @@ class Console(object):
socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
try:
self._removeSocket()
except OSError, e:
except OSError as e:
if e.errno != errno.ENOENT:
raise
s.bind(path)
......
......@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs
freeifaddrs.restype = None
freeifaddrs.argtypes = [POINTER(struct_ifaddrs)]
class unpacker(object):
class unpacker:
def __init__(self, buf):
self._buf = buf
......@@ -55,7 +55,7 @@ class unpacker(object):
self._offset += s.size
return result
class PimDm(object):
class PimDm:
def __init__(self):
s_netlink = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)
......
#!/usr/bin/python2 -S
#!/usr/bin/env -S python3 -S
import os, sys
script_type = os.environ['script_type']
......@@ -14,4 +14,4 @@ if script_type == 'up':
if script_type == 'route-up':
import time
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])))
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])).encode())
#!/usr/bin/python2 -S
#!/usr/bin/env -S python3 -S
import os, sys
script_type = os.environ['script_type']
......@@ -7,10 +7,10 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events
fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip))))
int(os.environ['tls_serial_0']), external_ip))).encode("utf-8"))
if script_type == 'client-connect':
if os.read(fd, 1) == '\0':
if os.read(fd, 1) == b'\x00':
sys.exit(1)
# Send client its external ip address
with open(sys.argv[2], 'w') as f:
......
import binascii
import logging, errno, os
from typing import Optional
from . import utils
here = os.path.realpath(os.path.dirname(__file__))
ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log = None
ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw):
args = ['openvpn',
......@@ -43,7 +45,7 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
'--max-clients', str(max_clients),
'--port', str(port),
'--proto', proto,
*args, **kw)
*args, pass_fds=[fd], **kw)
def client(iface, address_list, encrypt, *args, **kw):
......@@ -80,9 +82,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign:
def key(cmd, id, value):
def key(cmd, id: str, value):
cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, value.encode('hex')))
(id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign)
default += ' key sign'
if hmac_accept is not None:
......@@ -132,7 +134,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
# WKRD: babeld fails to start if pidfile already exists
try:
os.remove(pidfile)
except OSError, e:
except OSError as e:
if e.errno != errno.ENOENT:
raise
logging.info('%r', cmd)
......
......@@ -18,16 +18,16 @@ Authenticated communication:
- the last one that was really used by the client (!hello)
- the one of the last handshake (hello)
"""
import base64, hmac, hashlib, httplib, inspect, json, logging
import base64, hmac, hashlib, http.client, inspect, json, logging
import mailbox, os, platform, random, select, smtplib, socket, sqlite3
import string, sys, threading, time, weakref, zlib
from collections import defaultdict, deque
from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from http.server import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText
from operator import itemgetter
from OpenSSL import crypto
from urllib import splittype, splithost, unquote, urlencode
from urllib.parse import urlparse, unquote, urlencode
from . import ctl, tunnel, utils, version, x509
HMAC_HEADER = "Re6stHMAC"
......@@ -35,13 +35,11 @@ RENEW_PERIOD = 30 * 86400
BABEL_HMAC = 'babel_hmac0', 'babel_hmac1', 'babel_hmac2'
def rpc(f):
args, varargs, varkw, defaults = inspect.getargspec(f)
assert not (varargs or varkw), f
if not defaults:
defaults = ()
i = len(args) - len(defaults)
f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:i]
+ map("%s=%r".__mod__, zip(args[i:], defaults))))
argspec = inspect.getfullargspec(f)
assert not (argspec.varargs or argspec.varkw), f
sig = inspect.signature(f)
sig = sig.replace(parameters=[v.replace(annotation=inspect.Parameter.empty) for v in sig.parameters.values()][1:], return_annotation=inspect.Signature.empty)
f.getcallargs = eval("lambda %s: locals()" % str(sig)[1:-1])
return f
def rpc_private(f):
......@@ -53,13 +51,13 @@ class HTTPError(Exception):
pass
class RegistryServer(object):
class RegistryServer:
peers = 0, ()
cert_duration = 365 * 86400
def _geoiplookup(self, ip):
raise HTTPError(httplib.BAD_REQUEST)
raise HTTPError(http.client.BAD_REQUEST)
def __init__(self, config):
self.config = config
......@@ -69,14 +67,14 @@ class RegistryServer(object):
# Parse community file
self.community_map = {}
if config.community:
if config.community:
with open(config.community) as x:
for x in x:
x = x.strip()
if x and not x.startswith('#'):
x = x.split()
self.community_map[x.pop(0)] = x
if sum('*' in x for x in self.community_map.itervalues()) != 1:
if sum('*' in x for x in self.community_map.values()) != 1:
sys.exit("Invalid community configuration: missing or multiple default location ('*')")
else:
self.community_map[''] = '*'
......@@ -91,7 +89,7 @@ class RegistryServer(object):
"name TEXT PRIMARY KEY NOT NULL",
"value")
self.prefix = self.getConfig("prefix", None)
self.version = str(self.getConfig("version", "\0")) # BBB: blob
self.version = self.getConfig("version", b'\x00')
utils.sqliteCreateTable(self.db, "token",
"token TEXT PRIMARY KEY NOT NULL",
"email TEXT NOT NULL",
......@@ -169,8 +167,8 @@ class RegistryServer(object):
def updateNetworkConfig(self, _it0=itemgetter(0)):
kw = {
'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125',
'crl': map(_it0, self.db.execute(
"SELECT serial FROM crl ORDER BY serial")),
'crl': list(map(_it0, self.db.execute(
"SELECT serial FROM crl ORDER BY serial"))),
'protocol': version.protocol,
'registry_prefix': self.prefix,
}
......@@ -184,31 +182,29 @@ class RegistryServer(object):
config = json.dumps(kw, sort_keys=True)
if config != self.getConfig('last_config', None):
self.increaseVersion()
# BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6
self.setConfig('version', buffer(self.version))
self.setConfig('version', self.version)
self.setConfig('last_config', config)
self.sendto(self.prefix, 0)
# The following entry lists values that are base64-encoded.
kw[''] = 'version',
kw['version'] = self.version.encode('base64')
kw['version'] = base64.b64encode(self.version).decode()
self.network_config = kw
def increaseVersion(self):
x = utils.packInteger(1 + utils.unpackInteger(self.version)[0])
self.version = x + self.cert.sign(x)
def sendto(self, prefix, code):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT))
def sendto(self, prefix: str, code: int):
self.sock.sendto(prefix.encode() + bytes((0, code)), ('::1', tunnel.PORT))
def recv(self, code):
try:
prefix, msg = self.sock.recv(1<<16).split('\0', 1)
prefix, msg = self.sock.recv(1<<16).split(b'\x00', 1)
int(prefix, 2)
except ValueError:
pass
else:
if msg and ord(msg[0]) == code:
if msg and msg[0:1] == code:
return prefix, msg[1:]
return None, None
......@@ -286,14 +282,14 @@ class RegistryServer(object):
x_forwarded_for = request.headers.get('X-Forwarded-For')
if request.client_address[0] not in authorized_origin or \
x_forwarded_for and x_forwarded_for not in authorized_origin:
return request.send_error(httplib.FORBIDDEN)
return request.send_error(http.client.FORBIDDEN)
key = m.getcallargs(**kw).get('cn')
if key:
h = base64.b64decode(request.headers[HMAC_HEADER])
with self.lock:
session = self.sessions[key]
for key, protocol in session:
if h == hmac.HMAC(key, request.path, hashlib.sha1).digest():
if h == hmac.HMAC(key, request.path.encode(), hashlib.sha1).digest():
break
else:
raise Exception("Wrong HMAC")
......@@ -313,19 +309,21 @@ class RegistryServer(object):
request.headers.get("host"))
try:
result = m(**kw)
except HTTPError, e:
except HTTPError as e:
return request.send_error(*e.args)
except:
logging.warning(request.requestline, exc_info=1)
return request.send_error(httplib.INTERNAL_SERVER_ERROR)
logging.warning(request.requestline, exc_info=True)
return request.send_error(http.client.INTERNAL_SERVER_ERROR)
if result:
request.send_response(httplib.OK)
if type(result) is str:
result = result.encode("utf-8")
request.send_response(http.client.OK)
request.send_header("Content-Length", str(len(result)))
else:
request.send_response(httplib.NO_CONTENT)
request.send_response(http.client.NO_CONTENT)
if key:
request.send_header(HMAC_HEADER, base64.b64encode(
hmac.HMAC(key, result, hashlib.sha1).digest()))
hmac.HMAC(key, result, hashlib.sha1).digest()).decode("ascii"))
request.end_headers()
if result:
request.wfile.write(result)
......@@ -349,14 +347,14 @@ class RegistryServer(object):
assert self.lock.locked()
return self.db.execute("SELECT cert FROM cert"
" WHERE prefix=? AND cert IS NOT NULL",
(client_prefix,)).next()[0]
(client_prefix,)).fetchone()[0]
@rpc_private
def isToken(self, token):
with self.lock:
if self.db.execute("SELECT 1 FROM token WHERE token = ?",
(token,)).fetchone():
return "1"
return b"1"
@rpc_private
def deleteToken(self, token):
......@@ -367,7 +365,7 @@ class RegistryServer(object):
def addToken(self, email, token):
prefix_len = self.config.prefix_length
if not prefix_len:
raise HTTPError(httplib.FORBIDDEN)
raise HTTPError(http.client.FORBIDDEN)
request = token is None
with self.lock:
while True:
......@@ -381,7 +379,7 @@ class RegistryServer(object):
break
except sqlite3.IntegrityError:
if not request:
raise HTTPError(httplib.CONFLICT)
raise HTTPError(http.client.CONFLICT)
self.timeout = 1
if request:
return token
......@@ -389,7 +387,7 @@ class RegistryServer(object):
@rpc
def requestToken(self, email):
if not self.config.mailhost:
raise HTTPError(httplib.FORBIDDEN)
raise HTTPError(http.client.FORBIDDEN)
token = self.addToken(email, None)
......@@ -418,11 +416,11 @@ class RegistryServer(object):
s.quit()
def getCommunity(self, country, continent):
for prefix, location_list in self.community_map.iteritems():
for prefix, location_list in self.community_map.items():
if country in location_list:
return prefix
default = ''
for prefix, location_list in self.community_map.iteritems():
for prefix, location_list in self.community_map.items():
if continent in location_list:
return prefix
if '*' in location_list:
......@@ -436,7 +434,7 @@ class RegistryServer(object):
while True:
max_len = q("SELECT max(length(prefix)) FROM cert"
" WHERE cert is null AND length(prefix) < ?",
max_len).next()
max_len).fetchone()
if not max_len[0]:
break
for prefix, in q("SELECT prefix FROM cert"
......@@ -460,25 +458,25 @@ class RegistryServer(object):
while True:
try:
# Find longest free prefix whithin community.
prefix, = q(
prefix, = next(q(
"SELECT prefix FROM cert"
" WHERE prefix LIKE ?"
" AND length(prefix) <= ? AND cert is null"
" ORDER BY length(prefix) DESC",
(community + '%', prefix_len)).next()
(community + '%', prefix_len)))
except StopIteration:
# Community not yet allocated?
# There should be exactly 1 row whose
# prefix is the beginning of community.
prefix, x = q("SELECT prefix, cert FROM cert"
prefix, x = next(q("SELECT prefix, cert FROM cert"
" WHERE substr(?,1,length(prefix)) = prefix",
(community,)).next()
(community,)))
if x is not None:
logging.error('No more free /%u prefix available',
prefix_len)
raise
# Split the tree until prefix has wanted length.
for x in xrange(len(prefix), prefix_len):
for x in range(len(prefix), prefix_len):
# Prefix starts with community, then we complete with 0.
x = community[x] if x < community_len else '0'
q("UPDATE cert SET prefix = ? WHERE prefix = ?",
......@@ -496,11 +494,11 @@ class RegistryServer(object):
with self.db:
if token:
if not self.config.prefix_length:
raise HTTPError(httplib.FORBIDDEN)
raise HTTPError(http.client.FORBIDDEN)
try:
token, email, prefix_len, _ = self.db.execute(
token, email, prefix_len, _ = next(self.db.execute(
"SELECT * FROM token WHERE token = ?",
(token,)).next()
(token,)))
except StopIteration:
return
self.db.execute("DELETE FROM token WHERE token = ?",
......@@ -508,7 +506,7 @@ class RegistryServer(object):
else:
prefix_len = self.config.anonymous_prefix_length
if not prefix_len:
raise HTTPError(httplib.FORBIDDEN)
raise HTTPError(http.client.FORBIDDEN)
email = None
country, continent = '*', '*'
if self.geoip_db:
......@@ -595,8 +593,8 @@ class RegistryServer(object):
hmac = [self.getConfig(k, None) for k in BABEL_HMAC]
for i, v in enumerate(v for v in hmac if v is not None):
config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \
v and x509.encrypt(cert, v).encode('base64')
return zlib.compress(json.dumps(config))
v and base64.b64encode(x509.encrypt(cert, v)).decode()
return zlib.compress(json.dumps(config).encode("utf-8"))
def _queryAddress(self, peer):
self.sendto(peer, 1)
......@@ -615,7 +613,7 @@ class RegistryServer(object):
@rpc
def getCountry(self, cn, address):
country = self._geoiplookup(address)[0]
return None if country == '*' else country
return None if country == '*' else country.encode()
@rpc
def getBootstrapPeer(self, cn):
......@@ -624,7 +622,7 @@ class RegistryServer(object):
if age < time.time() or not peers:
self.request_dump()
peers = [prefix
for neigh_routes in self.ctl.neighbours.itervalues()
for neigh_routes in self.ctl.neighbours.values()
for prefix in neigh_routes[1]
if prefix]
peers.append(self.prefix)
......@@ -673,7 +671,7 @@ class RegistryServer(object):
def newHMAC(self, i, key=None):
if key is None:
key = buffer(os.urandom(16))
key = os.urandom(16)
self.setConfig(BABEL_HMAC[i], key)
def delHMAC(self, i):
......@@ -696,18 +694,18 @@ class RegistryServer(object):
else:
# Initialization of HMAC on the network
self.newHMAC(1)
self.newHMAC(2, '')
self.newHMAC(2, b'')
self.increaseVersion()
self.setConfig('version', buffer(self.version))
self.network_config['version'] = self.version.encode('base64')
self.setConfig('version', self.version)
self.network_config['version'] = base64.b64encode(self.version)
self.sendto(self.prefix, 0)
@rpc_private
def getNodePrefix(self, email):
with self.lock, self.db:
try:
cert, = self.db.execute("SELECT cert FROM cert WHERE email = ?",
(email,)).next()
cert, = next(self.db.execute("SELECT cert FROM cert WHERE email = ?",
(email,)))
except StopIteration:
return
certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
......@@ -728,7 +726,7 @@ class RegistryServer(object):
peer = utils.binFromSubnet(peer)
with self.peers_lock:
self.request_dump()
for neigh_routes in self.ctl.neighbours.itervalues():
for neigh_routes in self.ctl.neighbours.values():
for prefix in neigh_routes[1]:
if prefix == peer:
break
......@@ -736,16 +734,16 @@ class RegistryServer(object):
return
logging.info("%s %s", email, peer)
with self.lock:
msg = self._queryAddress(peer)
msg = self._queryAddress(peer).decode()
if msg:
return msg.split(',')[0]
return msg.split(',')[0].encode()
@rpc_private
def versions(self):
with self.peers_lock:
self.request_dump()
peers = {prefix
for neigh_routes in self.ctl.neighbours.itervalues()
for neigh_routes in self.ctl.neighbours.values()
for prefix in neigh_routes[1]
if prefix}
peers.add(self.prefix)
......@@ -794,32 +792,32 @@ class RegistryServer(object):
self.sendto(utils.binFromSubnet(peers.popleft()), 5)
elif not r:
break
return json.dumps({k: list(v) for k, v in graph.iteritems()})
return json.dumps({k: list(v) for k, v in graph.items()})
class RegistryClient(object):
class RegistryClient:
_hmac = None
user_agent = "re6stnet/%s, %s" % (version.version, platform.platform())
def __init__(self, url, cert=None, auto_close=True):
def __init__(self, url, cert: x509.Cert=None, auto_close=True):
self.cert = cert
self.auto_close = auto_close
scheme, host = splittype(url)
host, path = splithost(host)
self._conn = dict(http=httplib.HTTPConnection,
https=httplib.HTTPSConnection,
url_parsed = urlparse(url)
scheme, host, path = url_parsed.scheme, url_parsed.netloc, url_parsed.path
self._conn = dict(http=http.client.HTTPConnection,
https=http.client.HTTPSConnection,
)[scheme](unquote(host), timeout=60)
self._path = path.rstrip('/')
def __getattr__(self, name):
getcallargs = getattr(RegistryServer, name).getcallargs
def rpc(*args, **kw):
def rpc(*args, **kw) -> bytes:
kw = getcallargs(*args, **kw)
query = '/' + name
if kw:
if any(type(v) is not str for v in kw.itervalues()):
raise TypeError
if any(not isinstance(v, (str, bytes)) for v in kw.values()):
raise TypeError(kw)
query += '?' + urlencode(kw)
url = self._path + query
client_prefix = kw.get('cn')
......@@ -834,7 +832,7 @@ class RegistryClient(object):
n = len(h) // 2
self.cert.verify(h[n:], h[:n])
key = self.cert.decrypt(h[:n])
h = hmac.HMAC(key, query, hashlib.sha1).digest()
h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest()
self._hmac = hashlib.sha1(key).digest()
else:
......@@ -846,14 +844,14 @@ class RegistryClient(object):
self._conn.endheaders()
response = self._conn.getresponse()
body = response.read()
if response.status in (httplib.OK, httplib.NO_CONTENT):
if response.status in (http.client.OK, http.client.NO_CONTENT):
if (not client_prefix or
hmac.HMAC(key, body, hashlib.sha1).digest() ==
base64.b64decode(response.msg[HMAC_HEADER])):
if self.auto_close and name != 'hello':
self._conn.close()
return body
elif response.status == httplib.FORBIDDEN:
elif response.status == http.client.FORBIDDEN:
# XXX: We should improve error handling, while making
# sure re6st nodes don't crash on temporary errors.
# This is currently good enough for re6st-conf, to
......@@ -864,7 +862,7 @@ class RegistryClient(object):
except HTTPError:
raise
except Exception:
logging.info(url, exc_info=1)
logging.info(url, exc_info=True)
else:
logging.info('%s\nUnexpected response %s %s',
url, response.status, response.reason)
......
from pathlib2 import Path
from pathlib import Path
DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo"
......@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH
DH_FILE = DEMO_PATH / "dh2048.pem"
class DummyNode(object):
class DummyNode:
"""fake node to reuse Re6stRegistry
error: node.Popen has destory method which not in subprocess.Popen
......@@ -60,7 +60,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# read token from db
db = sqlite3.connect(str(self.server.db), isolation_level=None)
token = None
for _ in xrange(100):
for _ in range(100):
time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?",
(email,)).fetchone()
......@@ -70,7 +70,7 @@ class TestRegistryClientInteract(unittest.TestCase):
self.fail("Request token failed, no token in database")
# token: tuple[unicode,]
token = str(token[0])
self.assertEqual(client.isToken(token), "1")
self.assertEqual(client.isToken(token).decode(), "1")
# request ca
ca = client.getCa()
......@@ -78,7 +78,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# request a cert and get cn
key, csr = tools.generate_csr()
cert = client.requestCertificate(token, csr)
self.assertEqual(client.isToken(token), '', "token should be deleted")
self.assertEqual(client.isToken(token).decode(), '', "token should be deleted")
# creat x509.cert object
def write_to_temp(text):
......@@ -97,18 +97,19 @@ class TestRegistryClientInteract(unittest.TestCase):
# verfiy cn and prefix
prefix = client.cert.prefix
cn = client.getNodePrefix(email)
cn = client.getNodePrefix(email).decode()
self.assertEqual(tools.prefix2cn(prefix), cn)
# simulate the process in cache
# just prove works
net_config = client.getNetworkConfig(prefix)
self.assertIsNotNone(net_config)
net_config = json.loads(zlib.decompress(net_config))
self.assertEqual(net_config[u'max_clients'], self.max_clients)
# no re6stnet, empty result
bootpeer = client.getBootstrapPeer(prefix)
self.assertEqual(bootpeer, "")
self.assertEqual(bootpeer.decode(), "")
# server should not die
self.assertIsNone(self.server.proc.poll())
......
......@@ -4,7 +4,7 @@ import nemu
import time
import weakref
from subprocess import PIPE
from pathlib2 import Path
from pathlib import Path
from re6st.tests import DEMO_PATH
......@@ -50,7 +50,7 @@ class Node(nemu.Node):
if_s.add_v4_address(ip, prefix_len=prefix_len)
return if_s
class NetManager(object):
class NetManager:
"""contain all the nemu object created, so they can live more time"""
def __init__(self):
self.object = []
......@@ -60,7 +60,7 @@ class NetManager(object):
Raise:
AssertionError
"""
for reg, nodes in self.registries.iteritems():
for reg, nodes in self.registries.items():
for node in nodes:
app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE)
ret = app0.wait()
......
......@@ -6,13 +6,15 @@ import ipaddress
import json
import logging
import re
import shlex
import shutil
import sqlite3
import sys
import tempfile
import time
import weakref
from subprocess import PIPE
from pathlib2 import Path
from pathlib import Path
from re6st.tests import tools
from re6st.tests import DEMO_PATH
......@@ -20,9 +22,10 @@ from re6st.tests import DEMO_PATH
WORK_DIR = Path(__file__).parent / "temp_net_test"
DH_FILE = DEMO_PATH / "dh2048.pem"
RE6STNET = "python -m re6st.cli.node"
RE6ST_REGISTRY = "python -m re6st.cli.registry"
RE6ST_CONF = "python -m re6st.cli.conf"
PYTHON = shlex.quote(sys.executable)
RE6STNET = PYTHON + " -m re6st.cli.node"
RE6ST_REGISTRY = PYTHON + " -m re6st.cli.registry"
RE6ST_CONF = PYTHON + " -m re6st.cli.conf"
def initial():
"""create the workplace"""
......@@ -36,7 +39,7 @@ def ip_to_serial(ip6):
return int(ip6, 16)
class Re6stRegistry(object):
class Re6stRegistry:
"""class run a re6st-registry service on a namespace"""
registry_seq = 0
......@@ -72,7 +75,7 @@ class Re6stRegistry(object):
self.run()
# wait the servcice started
p = self.node.Popen(['python', '-c', """if 1:
p = self.node.Popen([sys.executable, '-c', """if 1:
import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True:
......@@ -115,7 +118,7 @@ class Re6stRegistry(object):
'--client-count', (self.client_number+1)//2, '--port', self.port]
#PY3: convert PosixPath to str, can be remove in Python 3
cmd = map(str, cmd)
cmd = list(map(str, cmd))
cmd[:0] = RE6ST_REGISTRY.split()
......@@ -139,7 +142,7 @@ class Re6stRegistry(object):
pass
class Re6stNode(object):
class Re6stNode:
"""class run a re6stnet service on a namespace"""
node_seq = 0
......@@ -210,7 +213,7 @@ class Re6stNode(object):
# read token
db = sqlite3.connect(str(self.registry.db), isolation_level=None)
token = None
for _ in xrange(100):
for _ in range(100):
time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?",
(self.email,)).fetchone()
......@@ -223,7 +226,7 @@ class Re6stNode(object):
out, _ = p.communicate(str(token[0]))
# logging.debug("re6st-conf output: {}".format(out))
# find the ipv6 subnet of node
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out).group(0)
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out.decode("utf-8")).group(0)
data = {'ip6': self.ip6, 'hash': self.registry.ident}
with open(str(self.data_file), 'w') as f:
json.dump(data, f)
......@@ -236,7 +239,7 @@ class Re6stNode(object):
'--key', self.key, '-v4', '--registry', self.registry.url,
'--console', self.console]
#PY3: same as for Re6stRegistry.run
cmd = map(str, cmd)
cmd = list(map(str, cmd))
cmd[:0] = RE6STNET.split()
cmd += args
......
"""contain ping-test for re6set net"""
import os
import sys
import unittest
import time
import psutil
import logging
import random
from pathlib2 import Path
from pathlib import Path
import network_build
import re6st_wrap
from . import network_build, re6st_wrap
PING_PATH = str(Path(__file__).parent.resolve() / "ping.py")
......@@ -47,12 +47,12 @@ def wait_stable(nodes, timeout=240):
for node in nodes:
sub_ips = set(ips) - {node.ip6}
node.ping_proc = node.node.Popen(
["python", PING_PATH, '--retry', '-a'] + list(sub_ips))
[sys.executable, PING_PATH, '--retry', '-a'] + list(sub_ips), env=os.environ)
# check all the node network can ping each other, in order reverse
unfinished = list(nodes)
while unfinished:
for i in xrange(len(unfinished)-1, -1, -1):
for i in range(len(unfinished)-1, -1, -1):
node = unfinished[i]
if node.ping_proc.poll() is not None:
logging.debug("%s 's network is stable", node.name)
......
#!/usr/bin/python2
#!/usr/bin/env python3
""" unit test for re6st-conf
"""
......@@ -6,7 +6,7 @@ import os
import sys
import unittest
from shutil import rmtree
from StringIO import StringIO
from io import StringIO
from mock import patch
from OpenSSL import crypto
......@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
# mocked server cert and pkey
cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull)
cls.fingerprint = "".join( cls.cert.digest("sha1").split(":"))
cls.fingerprint = "".join( cls.cert.digest("sha1").decode().split(":"))
# client.getCa should return a string form cert
cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert)
......@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase):
# go back to original dir
os.chdir(self.origin_dir)
@patch("__builtin__.raw_input")
@patch("builtins.input")
def test_basic(self, mock_raw_input):
""" go through all the step
getCa, requestToken, requestCertificate
......
......@@ -3,7 +3,7 @@ import os
import random
import string
import json
import httplib
import http.client
import base64
import unittest
import hmac
......@@ -13,12 +13,13 @@ import tempfile
from argparse import Namespace
from OpenSSL import crypto
from mock import Mock, patch
from pathlib2 import Path
from pathlib import Path
from re6st import registry
from re6st.tests.tools import *
from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
......@@ -49,6 +50,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert.serial += 1
return key, cert
insert_cert.serial = 0
......@@ -77,17 +79,26 @@ class TestRegistryServer(unittest.TestCase):
def setUp(self):
self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \
+ "@mail.com"
+ "@mail.com"
def test_recv(self):
recv = self.server.sock.recv = Mock()
recv.side_effect = [
side_effect = iter([
"0001001001001a_msg",
"0001001001002\0001dqdq",
"0001001001001\000a_msg",
"0001001001001\000\4a_msg",
"0000000000000\0" # ERROR, IndexError: msg is null
]
])
class SocketProxy:
def __init__(self, wrappee):
self.wrappee = wrappee
self.recv = lambda _: next(side_effect)
def __getattr__(self, attr):
return getattr(self.wrappee, attr)
self.server.sock = SocketProxy(self.server.sock)
try:
res1 = self.server.recv(4)
......@@ -115,7 +126,7 @@ class TestRegistryServer(unittest.TestCase):
now = int(time.time()) - self.config.grace_period + 20
# makeup data
insert_cert(cur, self.server.cert, prefix_old, 1)
insert_cert(cur, self.server.cert, prefix, now -1)
insert_cert(cur, self.server.cert, prefix, now - 1)
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token_old, self.email, 4, 2))
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
......@@ -143,16 +154,16 @@ class TestRegistryServer(unittest.TestCase):
prefix = "0000000011111111"
method = "func"
protocol = 7
params = {"cn" : prefix, "a" : 1, "b" : 2}
params = {"cn": prefix, "a": 1, "b": 2}
func.getcallargs.return_value = params
del func._private
func.return_value = result = "this_is_a_result"
key = "this_is_a_key"
func.return_value = result = b"this_is_a_result"
key = b"this_is_a_key"
self.server.sessions[prefix] = [(key, protocol)]
request = Mock()
request.path = "/func?a=1&b=2&cn=0000000011111111"
request.headers = {registry.HMAC_HEADER: base64.b64encode(
hmac.HMAC(key, request.path, hashlib.sha1).digest())}
hmac.HMAC(key, request.path.encode(), hashlib.sha1).digest())}
self.server.handle_request(request, method, params)
......@@ -162,11 +173,11 @@ class TestRegistryServer(unittest.TestCase):
[(hashlib.sha1(key).digest(), protocol)])
func.assert_called_once_with(**params)
# http response check
request.send_response.assert_called_once_with(httplib.OK)
request.send_response.assert_called_once_with(http.client.OK)
request.send_header.assert_any_call("Content-Length", str(len(result)))
request.send_header.assert_any_call(
registry.HMAC_HEADER,
base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest()))
base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest()).decode("ascii"))
request.wfile.write.assert_called_once_with(result)
# remove the create session \n
......@@ -176,12 +187,12 @@ class TestRegistryServer(unittest.TestCase):
def test_handle_request_private(self, func):
"""case request with _private attr"""
method = "func"
params = {"a" : 1, "b" : 2}
params = {"a": 1, "b": 2}
func.getcallargs.return_value = params
func.return_value = None
request_good = Mock()
request_good.client_address = self.config.authorized_origin
request_good.headers = {'X-Forwarded-For':self.config.authorized_origin[0]}
request_good.headers = {'X-Forwarded-For': self.config.authorized_origin[0]}
request_bad = Mock()
request_bad.client_address = ["wrong_address"]
......@@ -189,8 +200,8 @@ class TestRegistryServer(unittest.TestCase):
self.server.handle_request(request_bad, method, params)
func.assert_called_once_with(**params)
request_bad.send_error.assert_called_once_with(httplib.FORBIDDEN)
request_good.send_response.assert_called_once_with(httplib.NO_CONTENT)
request_bad.send_error.assert_called_once_with(http.client.FORBIDDEN)
request_good.send_response.assert_called_once_with(http.client.NO_CONTENT)
# will cause valueError, if a node send hello twice to a registry
def test_getPeerProtocol(self):
......@@ -213,7 +224,7 @@ class TestRegistryServer(unittest.TestCase):
res = self.server.hello(prefix, protocol=protocol)
# decrypt
length = len(res)/2
length = len(res) // 2
key, sign = res[:length], res[length:]
key = decrypt(pkey, key)
self.assertEqual(self.server.sessions[prefix][-1][0], key,
......@@ -282,7 +293,7 @@ class TestRegistryServer(unittest.TestCase):
nb_less = 0
for cert in self.server.iterCert():
s = cert[0].get_subject().serialNumber
if(s and int(s) <= serial):
if s and int(s) <= serial:
nb_less += 1
self.assertEqual(nb_less, serial)
......@@ -378,7 +389,7 @@ class TestRegistryServer(unittest.TestCase):
hmacs = get_hmac()
key_1 = hmacs[1]
self.assertEqual(hmacs, [None, key_1, ''])
self.assertEqual(hmacs, [None, key_1, b''])
# step 2
self.server.updateHMAC()
......@@ -397,12 +408,11 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(get_hmac(), [None, key_2, key_1])
#setp 5
# step 5
self.server.updateHMAC()
self.assertEqual(get_hmac(), [key_2, None, None])
def test_getNodePrefix(self):
# prefix in short format
prefix = "0000000101"
......@@ -426,19 +436,33 @@ class TestRegistryServer(unittest.TestCase):
('0000000000000001', '2 0/16 6/16')
]
recv.side_effect = recv_case
def side_effct(rlist, wlist, elist, timeout):
# rlist is true until the len(recv_case)th call
side_effct.i -= side_effct.i > 0
return [side_effct.i, wlist, None]
side_effct.i = len(recv_case) + 1
select.side_effect = side_effct
res = self.server.topology()
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}'''
class CustomDecoder(json.JSONDecoder):
def __init__(self, **kwargs):
json.JSONDecoder.__init__(self, **kwargs)
self.parse_array = self.JSONArray
self.scan_once = json.scanner.py_make_scanner(self)
def JSONArray(self, s_and_end, scan_once, **kwargs):
values, end = json.decoder.JSONArray(s_and_end, scan_once, **kwargs)
return set(values), end
res = json.loads(res, cls=CustomDecoder)
expect_res = {"36893488147419103232/80": {"0/16", "7/16"},
"": {"36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"}, "4/16": {"0/16"},
"3/16": {"0/16", "7/16"}, "0/16": {"6/16", "7/16"}, "1/16": {"6/16", "0/16"},
"7/16": {"6/16", "4/16"}}
self.assertEqual(res, expect_res)
......
......@@ -2,7 +2,7 @@ import sys
import os
import unittest
import hmac
import httplib
import http.client
import base64
import hashlib
from mock import Mock, patch
......@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase):
self.assertEqual(client1._path, "/example")
self.assertEqual(client1._conn.host, "localhost")
self.assertIsInstance(client1._conn, httplib.HTTPSConnection)
self.assertIsInstance(client2._conn, httplib.HTTPConnection)
self.assertIsInstance(client1._conn, http.client.HTTPSConnection)
self.assertIsInstance(client2._conn, http.client.HTTPConnection)
def test_rpc_hello(self):
prefix = "0000000011111111"
protocol = "7"
body = "a_hmac_key"
query = "/hello?client_prefix=0000000011111111&protocol=7"
response = fakeResponse(body, httplib.OK)
response = fakeResponse(body, http.client.OK)
self.client._conn.getresponse.return_value = response
res = self.client.hello(prefix, protocol)
......@@ -52,14 +52,14 @@ class TestRegistryClient(unittest.TestCase):
self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock()
key = "this_is_a_key"
key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest()
h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest()
# response part
body = None
response = fakeResponse(body, httplib.NO_CONTENT)
response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest())
body = b'this is a body'
response = fakeResponse(body, http.client.NO_CONTENT)
response.msg = dict(Re6stHMAC=base64.b64encode(hmac.HMAC(key, body, hashlib.sha1).digest()))
self.client._conn.getresponse.return_value = response
res = self.client.getNetworkConfig(cn)
......
#!/usr/bin/python2
#!/usr/bin/env python3
import os
import sys
import unittest
......@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """
# c = chr(1)
# c = b"\x01"
# msg = "address"
# peer = x509.Peer("000001")
# self.tunnel._connecting = {peer}
......@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
def test_processPacket_address(self):
"""code is 1, for address. And peer or msg are none"""
c = chr(1)
c = b"\x01"
self.tunnel._address = {1: "1,1", 2: "2,2"}
res = self.tunnel._processPacket(c)
......@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
and each address join by ;
it will truncate address which has more than 3 element
"""
c = chr(1)
c = b"\x01"
peer = x509.Peer("000001")
peer.protocol = 1
self.tunnel._peers.append(peer)
......@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
"""code is 0, for network version, peer is not none
2 case, one modify the version, one not
"""
c = chr(0)
c = b"\x00"
peer = x509.Peer("000001")
version1 = "00003"
version2 = "00007"
self.tunnel._version = version3 = "00005"
version1 = b"00003"
version2 = b"00007"
self.tunnel._version = version3 = b"00005"
self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer)
......
#!/usr/bin/python2
#!/usr/bin/env python3
import os
import sys
import unittest
......
......@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
return
crypto.X509Cert in pem format
"""
if type(ca) is str:
if type(ca) is bytes:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca)
if type(ca_key) is str:
if type(ca_key) is bytes:
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
......@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert.gmtime_adj_notBefore(0)
if not_after:
cert.set_notAfter(
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)))
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)).encode())
else:
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject = req.get_subject()
......@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr()
cert = generate_cert(ca, ca_key, csr, prefix, serial)
with open(pkey_file, 'w') as f:
with open(pkey_file, 'wb') as f:
f.write(pkey)
with open(cert_file, 'w') as f:
with open(cert_file, 'wb') as f:
f.write(cert)
return pkey, cert
......@@ -84,9 +84,9 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
cert.set_pubkey(key)
cert.sign(key, "sha512")
with open(pkey_file, 'w') as pkey_file:
with open(pkey_file, 'wb') as pkey_file:
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file:
with open(cert_file, 'wb') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert
......@@ -101,7 +101,7 @@ def serial2prefix(serial):
# pkey: private key
def decrypt(pkey, incontent):
with open("node.key", 'w') as f:
f.write(pkey)
f.write(pkey.decode())
args = "openssl rsautl -decrypt -inkey node.key".split()
p = subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
......
......@@ -59,7 +59,7 @@ class MultiGatewayManager(dict):
except:
pass
class Connection(object):
class Connection:
_retry = 0
serial = None
......@@ -94,7 +94,7 @@ class Connection(object):
'--remap-usr1', 'SIGTERM',
'--ping-exit', str(tm.timeout),
'--route-up', '%s %u' % (plib.ovpn_client, tm.write_sock.fileno()),
*tm.ovpn_args)
*tm.ovpn_args, pass_fds=[tm.write_sock.fileno()])
tm.resetTunnelRefresh()
self._retry += 1
......@@ -132,7 +132,7 @@ class Connection(object):
self.open()
return 0
class TunnelKiller(object):
class TunnelKiller:
state = None
......@@ -169,7 +169,7 @@ class TunnelKiller(object):
if (self.address, self.ifindex) in tm.ctl.locked:
self.state = 'locked'
self.timeout = time.time() + 2 * tm.timeout
tm.sendto(self.peer, '\2' if self.client else '\3')
tm.sendto(self.peer, b'\2' if self.client else b'\3')
else:
self.timeout = 0
......@@ -186,7 +186,7 @@ class TunnelKiller(object):
locked = unlocking = lambda _: None
class BaseTunnelManager(object):
class BaseTunnelManager:
# TODO: To minimize downtime when network parameters change, we should do
# our best to not restart any process. Ideally, this list should be
......@@ -242,14 +242,14 @@ class BaseTunnelManager(object):
self._country = {}
address_dict = {family: self._updateCountry(address)
for family, address in address_dict.iteritems()}
for family, address in address_dict.items()}
elif cache.same_country:
sys.exit("Can not respect 'same_country' network configuration"
" (GEOIP2_MMDB not set)")
self._address = {family: utils.dump_address(address)
for family, address in address_dict.iteritems()
for family, address in address_dict.items()
if address}
cache.my_address = ';'.join(self._address.itervalues())
cache.my_address = ';'.join(self._address.values())
self.sock = socket.socket(socket.AF_INET6,
socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
......@@ -329,7 +329,7 @@ class BaseTunnelManager(object):
def _getPeer(self, prefix):
return self._peers[bisect(self._peers, prefix) - 1]
def sendto(self, prefix, msg):
def sendto(self, prefix: str, msg):
to = utils.ipFromBin(self._network + prefix), PORT
peer = self._getPeer(prefix)
if peer.prefix != prefix:
......@@ -344,9 +344,11 @@ class BaseTunnelManager(object):
peer.hello0Sent()
def _sendto(self, to, msg, peer=None):
if type(msg) is str:
msg = msg.encode()
try:
r = self.sock.sendto(peer.encode(msg) if peer else msg, to)
except socket.error, e:
except socket.error as e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s (%s)', to, e)
return
......@@ -359,19 +361,20 @@ class BaseTunnelManager(object):
to = address[:2]
if address[0] == '::1':
try:
prefix, msg = msg.split('\0', 1)
prefix, msg = msg.split(b'\0', 1)
prefix = prefix.decode()
int(prefix, 2)
except ValueError:
return
if msg:
self._forward = to
code = ord(msg[0])
code = msg[0]
if prefix == self._prefix:
msg = self._processPacket(msg)
if msg:
self._sendto(to, '%s\0%c%s' % (prefix, code, msg))
else:
self.sendto(prefix, chr(code | 0x80) + msg[1:])
self.sendto(prefix, bytes([code | 0x80]) + msg[1:])
return
try:
sender = utils.binFromIp(address[0])
......@@ -384,7 +387,7 @@ class BaseTunnelManager(object):
msg = peer.decode(msg)
if type(msg) is tuple:
seqno, msg, protocol = msg
def handleHello(peer, seqno, msg, retry):
def handleHello(peer, seqno, msg: bytes, retry):
if seqno == 2:
i = len(msg) // 2
h = msg[:i]
......@@ -394,10 +397,10 @@ class BaseTunnelManager(object):
except (AttributeError, crypto.Error, x509.NewSessionError,
subprocess.CalledProcessError):
logging.debug('ignored new session key from %r',
address, exc_info=1)
address, exc_info=True)
return
peer.version = self._version \
if self._sendto(to, '\0' + self._version, peer) else ''
if self._sendto(to, b'\0' + self._version, peer) else b''
return
if seqno:
h = x509.fingerprint(self.cert.cert).digest()
......@@ -410,7 +413,7 @@ class BaseTunnelManager(object):
serial = cert.get_serial_number()
if serial in self.cache.crl:
raise ValueError("revoked")
except (x509.VerifyError, ValueError), e:
except (x509.VerifyError, ValueError) as e:
if retry:
return True
logging.debug('ignored invalid certificate from %r (%s)',
......@@ -444,10 +447,11 @@ class BaseTunnelManager(object):
# We got a valid and non-empty message. Always reply
# something so that the sender knows we're still connected.
answer = self._processPacket(msg, peer.prefix)
self._sendto(to, msg[0] + answer if answer else "", peer)
self._sendto(to, msg[0:1] + answer.encode() if answer else b'', peer)
def _processPacket(self, msg, peer=None):
c = ord(msg[0])
c = msg[0]
msg = msg[1:]
code = c & 0x7f
if c > 0x7f and msg:
......@@ -456,6 +460,7 @@ class BaseTunnelManager(object):
elif code == 1: # address
if msg:
if peer:
msg = msg.decode()
self.cache.addPeer(peer, msg)
try:
self._connecting.remove(peer)
......@@ -467,8 +472,8 @@ class BaseTunnelManager(object):
# Don't send country to old nodes
if self._getPeer(peer).protocol < 7:
return ';'.join(','.join(a.split(',')[:3]) for a in
';'.join(self._address.itervalues()).split(';'))
return ';'.join(self._address.itervalues())
';'.join(self._address.values()).split(';'))
return ';'.join(self._address.values())
elif not code: # network version
if peer:
try:
......@@ -526,7 +531,7 @@ class BaseTunnelManager(object):
if peer.prefix != prefix:
self.sendto(prefix, None)
elif (peer.version < self._version and
self.sendto(prefix, '\0' + self._version)):
self.sendto(prefix, b'\0' + self._version)):
peer.version = self._version
def broadcastNewVersion(self):
......@@ -553,8 +558,8 @@ class BaseTunnelManager(object):
if (not self.NEED_RESTART.isdisjoint(changed)
or version.protocol < self.cache.min_protocol
# TODO: With --management, we could kill clients without restarting.
or not all(crl.isdisjoint(serials.itervalues())
for serials in self._served.itervalues())):
or not all(crl.isdisjoint(serials.values())
for serials in self._served.values())):
# Wait at least 1 second to broadcast new version to neighbours.
self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
self._restart)
......@@ -606,7 +611,7 @@ class BaseTunnelManager(object):
with open('/proc/net/ipv6_route', "r", 4096) as f:
try:
routing_table = f.read()
except IOError, e:
except IOError as e:
# ???: If someone can explain why the kernel sometimes fails
# even when there's a lot of free memory.
if e.errno != errno.ENOMEM:
......@@ -635,7 +640,7 @@ class BaseTunnelManager(object):
logging.error("%s. Flushing...", msg)
subprocess.call(("ip", "-6", "route", "flush", "cached"))
self.sendto(self.cache.registry_prefix,
'\7%s (%s)' % (msg, os.uname()[2]))
b'\7%s (%s)' % (msg, os.uname()[2]))
break
def _updateCountry(self, address):
......@@ -683,7 +688,7 @@ class TunnelManager(BaseTunnelManager):
self._client_count = client_count
self.new_iface_list = deque('re6stnet' + str(i)
for i in xrange(1, self._client_count + 1))
for i in range(1, self._client_count + 1))
self._free_iface_list = []
def close(self):
......@@ -752,7 +757,7 @@ class TunnelManager(BaseTunnelManager):
def babel_dump(self):
t = time.time()
if self._killing:
for prefix, tunnel_killer in self._killing.items():
for prefix, tunnel_killer in list(self._killing.items()):
if tunnel_killer.timeout < t:
if tunnel_killer.state != 'unlocking':
logging.info(
......@@ -780,7 +785,7 @@ class TunnelManager(BaseTunnelManager):
def _cleanDeads(self):
disconnected = False
for prefix in self._connection_dict.keys():
for prefix in list(self._connection_dict):
status = self._connection_dict[prefix].refresh()
if status:
disconnected |= status > 0
......@@ -902,7 +907,7 @@ class TunnelManager(BaseTunnelManager):
neighbours = self.ctl.neighbours
# Collect all nodes known by Babel
peers = {prefix
for neigh_routes in neighbours.itervalues()
for neigh_routes in neighbours.values()
for prefix in neigh_routes[1]
if prefix}
# Keep only distant peers.
......@@ -957,7 +962,7 @@ class TunnelManager(BaseTunnelManager):
address = self.cache.getAddress(peer)
if address:
count -= self._makeTunnel(peer, address)
elif self.sendto(peer, '\1'):
elif self.sendto(peer, b'\1'):
self._connecting.add(peer)
count -= 1
elif distant_peers is None:
......@@ -987,7 +992,7 @@ class TunnelManager(BaseTunnelManager):
break
def killAll(self):
for prefix in self._connection_dict.keys():
for prefix in list(self._connection_dict):
self._kill(prefix)
def handleClientEvent(self):
......@@ -999,7 +1004,7 @@ class TunnelManager(BaseTunnelManager):
if c and c.time < float(time):
try:
c.connected(serial)
except (KeyError, TypeError), e:
except (KeyError, TypeError) as e:
logging.error("%s (route_up %s)", e, common_name)
else:
logging.info("ignore route_up notification for %s %r",
......@@ -1010,10 +1015,10 @@ class TunnelManager(BaseTunnelManager):
if self.cache.same_country:
address = self._updateCountry(address)
self._address[family] = utils.dump_address(address)
self.cache.my_address = ';'.join(self._address.itervalues())
self.cache.my_address = ';'.join(self._address.values())
def broadcastNewVersion(self):
self._babel_dump_new_version()
for prefix, c in self._connection_dict.items():
for prefix, c in list(self._connection_dict.items()):
if c.serial in self.cache.crl:
self._kill(prefix)
......@@ -7,7 +7,7 @@ class UPnPException(Exception):
pass
class Forwarder(object):
class Forwarder:
"""
External port is chosen randomly between 32768 & 49151 included.
"""
......@@ -40,7 +40,7 @@ class Forwarder(object):
def wrapper(*args, **kw):
try:
return wrapped(*args, **kw)
except Exception, e:
except Exception as e:
raise UPnPException(str(e))
return wraps(wrapped)(wrapper)
......@@ -68,14 +68,14 @@ class Forwarder(object):
else:
try:
return self._refresh()
except UPnPException, e:
logging.debug("UPnP failure", exc_info=1)
except UPnPException as e:
logging.debug("UPnP failure", exc_info=True)
self.clear()
try:
self.discover()
self.selectigd()
return self._refresh()
except UPnPException, e:
except UPnPException as e:
self.next_refresh = self._next_retry = time.time() + 60
logging.info(str(e))
self.clear()
......@@ -109,7 +109,7 @@ class Forwarder(object):
try:
self.addportmapping(port, *args)
break
except UPnPException, e:
except UPnPException as e:
if str(e) != 'ConflictInMappingEntry':
raise
port = None
......
import argparse, errno, fcntl, hashlib, logging, os, select as _select
import shlex, signal, socket, sqlite3, struct, subprocess
import sys, textwrap, threading, time, traceback
from collections.abc import Iterator, Mapping
# PY3: It will be even better to use Popen(pass_fds=...),
# and then socket.SOCK_CLOEXEC will be useless.
# (We already follow the good practice that consists in not
# relying on the GC for the closing of file descriptors.)
socket.SOCK_CLOEXEC = 0x80000
HMAC_LEN = len(hashlib.sha1('').digest())
HMAC_LEN = len(hashlib.sha1(b'').digest())
class ReexecException(Exception):
pass
......@@ -37,12 +32,12 @@ class FileHandler(logging.FileHandler):
finally:
self.lock.release()
# In the rare case _reopen is set just before the lock was released
if self._reopen and self.lock.acquire(0):
if self._reopen and self.lock.acquire(False):
self.release()
def async_reopen(self, *_):
self._reopen = True
if self.lock.acquire(0):
if self.lock.acquire(False):
self.release()
def setupLog(log_level, filename=None, **kw):
......@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser):
ca /etc/re6stnet/ca.crt""", **kw)
class exit(object):
class exit:
status = None
......@@ -150,7 +145,7 @@ class exit(object):
def handler(*args):
if self.status is None:
self.status = status
if self.acquire(0):
if self.acquire(False):
self.release()
for sig in sigs:
signal.signal(sig, handler)
......@@ -164,7 +159,7 @@ class Popen(subprocess.Popen):
self._args = tuple(args[0] if args else kw['args'])
try:
super(Popen, self).__init__(*args, **kw)
except OSError, e:
except OSError as e:
if e.errno != errno.ENOMEM:
raise
self.returncode = -1
......@@ -179,9 +174,9 @@ class Popen(subprocess.Popen):
self.terminate()
t = threading.Timer(5, self.kill)
t.start()
# PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel()
r = self.wait()
r = os.waitid(os.P_PID, self.pid, os.WNOWAIT)
t.cancel()
self.poll()
return r
......@@ -209,7 +204,7 @@ def select(R, W, T):
def makedirs(*args):
try:
os.makedirs(*args)
except OSError, e:
except OSError as e:
if e.errno != errno.EEXIST:
raise
......@@ -240,7 +235,7 @@ def parse_address(address_list):
a = address.split(',')
int(a[1]) # Check if port is an int
yield tuple(a[:4])
except ValueError, e:
except ValueError as e:
logging.warning("Failed to parse node address %r (%s)",
address, e)
......@@ -261,21 +256,21 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
def packInteger(i):
for n in xrange(8):
def packInteger(i: int) -> bytes:
for n in range(8):
x = 32 << 8 * n
if i < x:
return struct.pack("!Q", i + n * x)[7-n:]
i -= x
raise OverflowError
def unpackInteger(x):
n = ord(x[0]) >> 5
def unpackInteger(x: bytes) -> tuple[int, int] | None:
n = x[0] >> 5
try:
i, = struct.unpack("!Q", '\0' * (7 - n) + x[:n+1])
i, = struct.unpack("!Q", b'\0' * (7 - n) + x[:n+1])
except struct.error:
return
return sum((32 << 8 * i for i in xrange(n)),
return sum((32 << 8 * i for i in range(n)),
i - (n * 32 << 8 * n)), n + 1
###
......
......@@ -40,4 +40,4 @@ protocol = 8
min_protocol = 1
if __name__ == "__main__":
print version
print(version)
......@@ -14,23 +14,23 @@ def subnetFromCert(cert):
return cert.get_subject().CN
def notBefore(cert):
return calendar.timegm(time.strptime(cert.get_notBefore(),'%Y%m%d%H%M%SZ'))
return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ'))
def notAfter(cert):
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))
return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ'))
def openssl(*args):
def openssl(*args, fds=[]):
return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data):
r, w = os.pipe()
try:
threading.Thread(target=os.write, args=(w, cert)).start()
p = openssl('rsautl', '-encrypt', '-certin',
'-inkey', '/proc/self/fd/%u' % r)
'-inkey', '/proc/self/fd/%u' % r, fds=[r])
out, err = p.communicate(data)
finally:
os.close(r)
......@@ -52,7 +52,7 @@ def maybe_renew(path, cert, info, renew, force=False):
if time.time() < next_renew:
return cert, next_renew
try:
pem = renew()
pem: bytes = renew()
if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert):
exc_info = 0
......@@ -62,7 +62,7 @@ def maybe_renew(path, cert, info, renew, force=False):
exc_info = 1
break
new_path = path + '.new'
with open(new_path, 'w') as f:
with open(new_path, 'wb') as f:
f.write(pem)
try:
s = os.stat(path)
......@@ -84,19 +84,19 @@ class NewSessionError(Exception):
pass
class Cert(object):
class Cert:
def __init__(self, ca, key, cert=None):
self.ca_path = ca
self.cert_path = cert
self.key_path = key
with open(ca) as f:
with open(ca, "rb") as f:
self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
with open(key) as f:
with open(key, "rb") as f:
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
if cert:
with open(cert) as f:
self.cert = self.loadVerify(f.read())
self.cert = self.loadVerify(f.read().encode())
@property
def prefix(self):
......@@ -143,21 +143,21 @@ class Cert(object):
"error running openssl, assuming cert is invalid")
# BBB: With old versions of openssl, detailed
# error is printed to standard output.
for err in err, out:
for x in err.splitlines():
for stream in err, out:
for x in stream.decode(errors='replace').splitlines():
if x.startswith('error '):
x, msg = x.split(':', 1)
_, code, _, depth, _ = x.split(None, 4)
raise VerifyError(int(code), int(depth), msg.strip())
return r
def verify(self, sign, data):
def verify(self, sign: bytes, data):
crypto.verify(self.ca, sign, data, 'sha512')
def sign(self, data):
def sign(self, data) -> bytes:
return crypto.sign(self.key, data, 'sha512')
def decrypt(self, data):
def decrypt(self, data: bytes) -> bytes:
p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
out, err = p.communicate(data)
if p.returncode:
......@@ -166,7 +166,7 @@ class Cert(object):
def verifyVersion(self, version):
try:
n = 1 + (ord(version[0]) >> 5)
n = 1 + (version[0] >> 5)
self.verify(version[n:], version[:n])
except (IndexError, crypto.Error):
raise VerifyError(None, None, 'invalid network version')
......@@ -175,7 +175,7 @@ class Cert(object):
PACKED_PROTOCOL = utils.packInteger(protocol)
class Peer(object):
class Peer:
"""
UDP: A ─────────────────────────────────────────────> B
......@@ -206,9 +206,9 @@ class Peer(object):
_key = newHmacSecret()
serial = None
stop_date = float('inf')
version = ''
version = b''
def __init__(self, prefix):
def __init__(self, prefix: str):
self.prefix = prefix
@property
......@@ -229,11 +229,11 @@ class Peer(object):
try:
# Always assume peer is not old, in case it has just upgraded,
# else we would be stuck with the old protocol.
msg = ('\0\0\0\1'
msg = (b'\0\0\0\1'
+ PACKED_PROTOCOL
+ fingerprint(self.cert).digest())
except AttributeError:
msg = '\0\0\0\0'
msg = b'\0\0\0\0'
return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
def hello0Sent(self):
......@@ -246,13 +246,13 @@ class Peer(object):
self._i = self._j = 2
self._last = 0
self.protocol = protocol
return ''.join(('\0\0\0\2', PACKED_PROTOCOL if protocol else '',
return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'',
h, cert.sign(h)))
def _hmac(self, msg):
return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key, protocol):
def newSession(self, key: bytes, protocol):
if key <= self._key:
raise NewSessionError(self._key, key)
self._key = key
......@@ -265,7 +265,7 @@ class Peer(object):
seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack):
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> str:
seqno, = _unpack(msg[:4])
if seqno <= 2:
msg = msg[4:]
......@@ -279,10 +279,12 @@ class Peer(object):
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None
self._i = seqno
return msg[4:i]
return msg[4:i].decode()
def encode(self, msg, _pack=seqno_struct.pack):
def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1
if type(msg) is str:
msg = msg.encode()
msg = _pack(self._j) + msg
return msg + self._hmac(msg)
......
......@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py
from distutils import log
version = {"__file__": "re6st/version.py"}
execfile(version["__file__"], version)
with open(version["__file__"]) as f:
code = compile(f.read(), version["__file__"], 'exec')
exec(code, version)
def copy_file(self, infile, outfile, *args, **kw):
if infile == version["__file__"]:
if not self.dry_run:
log.info("generating %s -> %s", infile, outfile)
with open(outfile, "wb") as f:
for x in sorted(version.iteritems()):
with open(outfile, "w") as f:
for x in sorted(version.items()):
if not x[0].startswith("_"):
f.write("%s = %r\n" % x)
return outfile, 1
elif isinstance(self, build_py) and \
os.stat(infile).st_mode & stat.S_IEXEC:
if os.path.isdir(infile) and os.path.isdir(outfile):
return (outfile, 0)
return outfile, 0
# Adjust interpreter of OpenVPN hooks.
with open(infile) as src:
first_line = src.readline()
......@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw):
patched += src.read()
dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC)
try:
os.write(dst, patched)
os.write(dst, patched.encode())
finally:
os.close(dst)
return outfile, 1
......@@ -51,7 +53,8 @@ Environment :: Console
License :: OSI Approved :: GNU General Public License (GPL)
Natural Language :: English
Operating System :: POSIX :: Linux
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.11
Topic :: Internet
Topic :: System :: Networking
"""
......@@ -73,6 +76,7 @@ setup(
license = 'GPL 2+',
platforms = ["any"],
classifiers=classifiers.splitlines(),
python_requires = '>=3.11',
long_description = ".. contents::\n\n" + open('README.rst').read()
+ "\n" + open('CHANGES.rst').read() + git_rev,
packages = find_packages(),
......@@ -95,7 +99,7 @@ setup(
extras_require = {
'geoip': ['geoip2'],
'multicast': ['PyYAML'],
'test': ['mock', 'pathlib2', 'nemu', 'python-unshare', 'python-passfd', 'multiping']
'test': ['mock', 'nemu3', 'unshare', 'multiping']
},
#dependency_links = [
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment